Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
5cf3e56f
"vscode:/vscode.git/clone" did not exist on "57f9685dc3257a5cbfa593ffe5ca531bb3e53149"
Unverified
Commit
5cf3e56f
authored
Apr 03, 2020
by
moto
Committed by
GitHub
Apr 03, 2020
Browse files
Extract batch test from test_transforms and move to the dedicated module (#501)
parent
0f8fa5f8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
210 additions
and
192 deletions
+210
-192
test/test_batch_consistency.py
test/test_batch_consistency.py
+210
-0
test/test_transforms.py
test/test_transforms.py
+0
-192
No files found.
test/test_batch_consistency.py
View file @
5cf3e56f
...
@@ -7,6 +7,7 @@ import torchaudio
...
@@ -7,6 +7,7 @@ import torchaudio
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
import
common_utils
import
common_utils
from
common_utils
import
AudioBackendScope
,
BACKENDS
def
_test_batch_shape
(
functional
,
tensor
,
*
args
,
**
kwargs
):
def
_test_batch_shape
(
functional
,
tensor
,
*
args
,
**
kwargs
):
...
@@ -102,3 +103,212 @@ class TestFunctional(unittest.TestCase):
...
@@ -102,3 +103,212 @@ class TestFunctional(unittest.TestCase):
[[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]]
[[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]]
])
])
_test_batch
(
F
.
istft
,
stft
,
n_fft
=
4
,
length
=
4
)
_test_batch
(
F
.
istft
,
stft
,
n_fft
=
4
,
length
=
4
)
class
TestTransforms
(
unittest
.
TestCase
):
"""Test suite for classes defined in `transforms` module"""
def
test_batch_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
AmplitudeToDB
()(
spec
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
AmplitudeToDB
()(
spec
.
repeat
(
3
,
1
,
1
))
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
torch
.
allclose
(
computed
,
expected
)
def
test_batch_Resample
(
self
):
waveform
=
torch
.
randn
(
2
,
2786
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
Resample
()(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
Resample
()(
waveform
.
repeat
(
3
,
1
,
1
))
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
torch
.
allclose
(
computed
,
expected
)
def
test_batch_MelScale
(
self
):
specgram
=
torch
.
randn
(
2
,
31
,
2786
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
MelScale
()(
specgram
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
MelScale
()(
specgram
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, 201, 1394)
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
torch
.
allclose
(
computed
,
expected
)
def
test_batch_InverseMelScale
(
self
):
n_mels
=
32
n_stft
=
5
mel_spec
=
torch
.
randn
(
2
,
n_mels
,
32
)
**
2
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
InverseMelScale
(
n_stft
,
n_mels
)(
mel_spec
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
InverseMelScale
(
n_stft
,
n_mels
)(
mel_spec
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, n_mels, 32)
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
assert
torch
.
allclose
(
computed
,
expected
,
atol
=
1.0
)
def
test_batch_compute_deltas
(
self
):
specgram
=
torch
.
randn
(
2
,
31
,
2786
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
ComputeDeltas
()(
specgram
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
ComputeDeltas
()(
specgram
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, 201, 1394)
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
torch
.
allclose
(
computed
,
expected
)
def
test_batch_mulaw
(
self
):
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
# (2, 278756), 44100
# Single then transform then batch
waveform_encoded
=
torchaudio
.
transforms
.
MuLawEncoding
()(
waveform
)
expected
=
waveform_encoded
.
unsqueeze
(
0
).
repeat
(
3
,
1
,
1
)
# Batch then transform
waveform_batched
=
waveform
.
unsqueeze
(
0
).
repeat
(
3
,
1
,
1
)
computed
=
torchaudio
.
transforms
.
MuLawEncoding
()(
waveform_batched
)
# shape = (3, 2, 201, 1394)
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
torch
.
allclose
(
computed
,
expected
)
# Single then transform then batch
waveform_decoded
=
torchaudio
.
transforms
.
MuLawDecoding
()(
waveform_encoded
)
expected
=
waveform_decoded
.
unsqueeze
(
0
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
MuLawDecoding
()(
computed
)
# shape = (3, 2, 201, 1394)
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
torch
.
allclose
(
computed
,
expected
)
def
test_batch_spectrogram
(
self
):
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
# (2, 278756), 44100
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
Spectrogram
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
Spectrogram
()(
waveform
.
repeat
(
3
,
1
,
1
))
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
torch
.
allclose
(
computed
,
expected
)
def
test_batch_melspectrogram
(
self
):
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
# (2, 278756), 44100
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
MelSpectrogram
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
MelSpectrogram
()(
waveform
.
repeat
(
3
,
1
,
1
))
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
torch
.
allclose
(
computed
,
expected
)
@
unittest
.
skipIf
(
"sox"
not
in
BACKENDS
,
"sox not available"
)
@
AudioBackendScope
(
"sox"
)
def
test_batch_mfcc
(
self
):
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.mp3'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
MFCC
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
MFCC
()(
waveform
.
repeat
(
3
,
1
,
1
))
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
torch
.
allclose
(
computed
,
expected
,
atol
=
1e-5
)
def
test_batch_TimeStretch
(
self
):
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
# (2, 278756), 44100
kwargs
=
{
'n_fft'
:
2048
,
'hop_length'
:
512
,
'win_length'
:
2048
,
'window'
:
torch
.
hann_window
(
2048
),
'center'
:
True
,
'pad_mode'
:
'reflect'
,
'normalized'
:
True
,
'onesided'
:
True
,
}
rate
=
2
complex_specgrams
=
torch
.
stft
(
waveform
,
**
kwargs
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
TimeStretch
(
fixed_rate
=
rate
,
n_freq
=
1025
,
hop_length
=
512
,
)(
complex_specgrams
).
repeat
(
3
,
1
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
TimeStretch
(
fixed_rate
=
rate
,
n_freq
=
1025
,
hop_length
=
512
,
)(
complex_specgrams
.
repeat
(
3
,
1
,
1
,
1
,
1
))
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
torch
.
allclose
(
computed
,
expected
,
atol
=
1e-5
)
def
test_batch_Fade
(
self
):
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
# (2, 278756), 44100
fade_in_len
=
3000
fade_out_len
=
3000
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
Fade
(
fade_in_len
,
fade_out_len
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
Fade
(
fade_in_len
,
fade_out_len
)(
waveform
.
repeat
(
3
,
1
,
1
))
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
torch
.
allclose
(
computed
,
expected
)
def
test_batch_Vol
(
self
):
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
# (2, 278756), 44100
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
Vol
(
gain
=
1.1
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
Vol
(
gain
=
1.1
)(
waveform
.
repeat
(
3
,
1
,
1
))
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
torch
.
allclose
(
computed
,
expected
)
test/test_transforms.py
View file @
5cf3e56f
...
@@ -44,18 +44,6 @@ class Tester(unittest.TestCase):
...
@@ -44,18 +44,6 @@ class Tester(unittest.TestCase):
waveform_exp
=
transforms
.
MuLawDecoding
(
quantization_channels
)(
waveform_mu
)
waveform_exp
=
transforms
.
MuLawDecoding
(
quantization_channels
)(
waveform_mu
)
self
.
assertTrue
(
waveform_exp
.
min
()
>=
-
1.
and
waveform_exp
.
max
()
<=
1.
)
self
.
assertTrue
(
waveform_exp
.
min
()
>=
-
1.
and
waveform_exp
.
max
()
<=
1.
)
def
test_batch_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
# Single then transform then batch
expected
=
transforms
.
AmplitudeToDB
()(
spec
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
transforms
.
AmplitudeToDB
()(
spec
.
repeat
(
3
,
1
,
1
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_AmplitudeToDB
(
self
):
def
test_AmplitudeToDB
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
...
@@ -175,18 +163,6 @@ class Tester(unittest.TestCase):
...
@@ -175,18 +163,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
torch_mfcc_norm_none
.
allclose
(
norm_check
))
self
.
assertTrue
(
torch_mfcc_norm_none
.
allclose
(
norm_check
))
def
test_batch_Resample
(
self
):
waveform
=
torch
.
randn
(
2
,
2786
)
# Single then transform then batch
expected
=
transforms
.
Resample
()(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
transforms
.
Resample
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_resample_size
(
self
):
def
test_resample_size
(
self
):
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
waveform
,
sample_rate
=
torchaudio
.
load
(
input_path
)
waveform
,
sample_rate
=
torchaudio
.
load
(
input_path
)
...
@@ -242,174 +218,6 @@ class Tester(unittest.TestCase):
...
@@ -242,174 +218,6 @@ class Tester(unittest.TestCase):
computed
=
transform
(
specgram
)
computed
=
transform
(
specgram
)
self
.
assertTrue
(
computed
.
shape
==
specgram
.
shape
,
(
computed
.
shape
,
specgram
.
shape
))
self
.
assertTrue
(
computed
.
shape
==
specgram
.
shape
,
(
computed
.
shape
,
specgram
.
shape
))
def
test_batch_MelScale
(
self
):
specgram
=
torch
.
randn
(
2
,
31
,
2786
)
# Single then transform then batch
expected
=
transforms
.
MelScale
()(
specgram
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
transforms
.
MelScale
()(
specgram
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, 201, 1394)
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_batch_InverseMelScale
(
self
):
n_fft
=
8
n_mels
=
32
n_stft
=
5
mel_spec
=
torch
.
randn
(
2
,
n_mels
,
32
)
**
2
# Single then transform then batch
expected
=
transforms
.
InverseMelScale
(
n_stft
,
n_mels
)(
mel_spec
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
transforms
.
InverseMelScale
(
n_stft
,
n_mels
)(
mel_spec
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, n_mels, 32)
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
,
atol
=
1.0
))
def
test_batch_compute_deltas
(
self
):
specgram
=
torch
.
randn
(
2
,
31
,
2786
)
# Single then transform then batch
expected
=
transforms
.
ComputeDeltas
()(
specgram
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
transforms
.
ComputeDeltas
()(
specgram
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, 201, 1394)
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_batch_mulaw
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
# (2, 278756), 44100
# Single then transform then batch
waveform_encoded
=
transforms
.
MuLawEncoding
()(
waveform
)
expected
=
waveform_encoded
.
unsqueeze
(
0
).
repeat
(
3
,
1
,
1
)
# Batch then transform
waveform_batched
=
waveform
.
unsqueeze
(
0
).
repeat
(
3
,
1
,
1
)
computed
=
transforms
.
MuLawEncoding
()(
waveform_batched
)
# shape = (3, 2, 201, 1394)
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
# Single then transform then batch
waveform_decoded
=
transforms
.
MuLawDecoding
()(
waveform_encoded
)
expected
=
waveform_decoded
.
unsqueeze
(
0
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
transforms
.
MuLawDecoding
()(
computed
)
# shape = (3, 2, 201, 1394)
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_batch_spectrogram
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
# Single then transform then batch
expected
=
transforms
.
Spectrogram
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
transforms
.
Spectrogram
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_batch_melspectrogram
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
# Single then transform then batch
expected
=
transforms
.
MelSpectrogram
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
transforms
.
MelSpectrogram
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
@
unittest
.
skipIf
(
"sox"
not
in
BACKENDS
,
"sox not available"
)
@
AudioBackendScope
(
"sox"
)
def
test_batch_mfcc
(
self
):
test_filepath
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'steam-train-whistle-daniel_simon.mp3'
)
waveform
,
sample_rate
=
torchaudio
.
load
(
test_filepath
)
# Single then transform then batch
expected
=
transforms
.
MFCC
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
transforms
.
MFCC
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
,
atol
=
1e-5
))
def
test_batch_TimeStretch
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
kwargs
=
{
'n_fft'
:
2048
,
'hop_length'
:
512
,
'win_length'
:
2048
,
'window'
:
torch
.
hann_window
(
2048
),
'center'
:
True
,
'pad_mode'
:
'reflect'
,
'normalized'
:
True
,
'onesided'
:
True
,
}
rate
=
2
complex_specgrams
=
torch
.
stft
(
waveform
,
**
kwargs
)
# Single then transform then batch
expected
=
transforms
.
TimeStretch
(
fixed_rate
=
rate
,
n_freq
=
1025
,
hop_length
=
512
)(
complex_specgrams
).
repeat
(
3
,
1
,
1
,
1
,
1
)
# Batch then transform
computed
=
transforms
.
TimeStretch
(
fixed_rate
=
rate
,
n_freq
=
1025
,
hop_length
=
512
)(
complex_specgrams
.
repeat
(
3
,
1
,
1
,
1
,
1
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
,
atol
=
1e-5
))
def
test_batch_Fade
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
fade_in_len
=
3000
fade_out_len
=
3000
# Single then transform then batch
expected
=
transforms
.
Fade
(
fade_in_len
,
fade_out_len
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
transforms
.
Fade
(
fade_in_len
,
fade_out_len
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_batch_Vol
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
# Single then transform then batch
expected
=
transforms
.
Vol
(
gain
=
1.1
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
transforms
.
Vol
(
gain
=
1.1
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment