Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
7a0d4192
Unverified
Commit
7a0d4192
authored
May 06, 2020
by
kunalb6
Committed by
GitHub
May 06, 2020
Browse files
make seed parameterized (#614)
Co-authored-by:
Kunal Bhandari
<
bkunal@fb.com
>
Closes #610
parent
c80d9a71
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
test/test_batch_consistency.py
test/test_batch_consistency.py
+6
-6
No files found.
test/test_batch_consistency.py
View file @
7a0d4192
...
@@ -9,23 +9,23 @@ import torchaudio.functional as F
...
@@ -9,23 +9,23 @@ import torchaudio.functional as F
import
common_utils
import
common_utils
def
_test_batch_consistency
(
functional
,
tensor
,
*
args
,
batch_size
=
1
,
atol
=
1e-8
,
rtol
=
1e-5
,
**
kwargs
):
def
_test_batch_consistency
(
functional
,
tensor
,
*
args
,
batch_size
=
1
,
atol
=
1e-8
,
rtol
=
1e-5
,
seed
=
42
,
**
kwargs
):
# run then batch the result
# run then batch the result
torch
.
random
.
manual_seed
(
42
)
torch
.
random
.
manual_seed
(
seed
)
expected
=
functional
(
tensor
.
clone
(),
*
args
,
**
kwargs
)
expected
=
functional
(
tensor
.
clone
(),
*
args
,
**
kwargs
)
expected
=
expected
.
repeat
([
batch_size
]
+
[
1
]
*
expected
.
dim
())
expected
=
expected
.
repeat
([
batch_size
]
+
[
1
]
*
expected
.
dim
())
# batch the input and run
# batch the input and run
torch
.
random
.
manual_seed
(
42
)
torch
.
random
.
manual_seed
(
seed
)
pattern
=
[
batch_size
]
+
[
1
]
*
tensor
.
dim
()
pattern
=
[
batch_size
]
+
[
1
]
*
tensor
.
dim
()
computed
=
functional
(
tensor
.
repeat
(
pattern
),
*
args
,
**
kwargs
)
computed
=
functional
(
tensor
.
repeat
(
pattern
),
*
args
,
**
kwargs
)
torch
.
testing
.
assert_allclose
(
computed
,
expected
,
rtol
=
rtol
,
atol
=
atol
)
torch
.
testing
.
assert_allclose
(
computed
,
expected
,
rtol
=
rtol
,
atol
=
atol
)
def
_test_batch
(
functional
,
tensor
,
*
args
,
atol
=
1e-8
,
rtol
=
1e-5
,
**
kwargs
):
def
_test_batch
(
functional
,
tensor
,
*
args
,
atol
=
1e-8
,
rtol
=
1e-5
,
seed
=
42
,
**
kwargs
):
_test_batch_consistency
(
functional
,
tensor
,
*
args
,
batch_size
=
1
,
atol
=
atol
,
rtol
=
rtol
,
**
kwargs
)
_test_batch_consistency
(
functional
,
tensor
,
*
args
,
batch_size
=
1
,
atol
=
atol
,
rtol
=
rtol
,
seed
=
seed
,
**
kwargs
)
_test_batch_consistency
(
functional
,
tensor
,
*
args
,
batch_size
=
3
,
atol
=
atol
,
rtol
=
rtol
,
**
kwargs
)
_test_batch_consistency
(
functional
,
tensor
,
*
args
,
batch_size
=
3
,
atol
=
atol
,
rtol
=
rtol
,
seed
=
seed
,
**
kwargs
)
class
TestFunctional
(
unittest
.
TestCase
):
class
TestFunctional
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment