Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
adf5054f
Unverified
Commit
adf5054f
authored
Aug 01, 2022
by
Frank Lee
Committed by
GitHub
Aug 01, 2022
Browse files
[fx] fixed torchaudio conformer tracing (#1392)
parent
7d629392
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
9 deletions
+18
-9
tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py
..._tracer/test_torchaudio_model/test_torchaudio_tacotron.py
+1
-1
tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py
...acer/test_torchaudio_model/test_torchaudio_transformer.py
+9
-3
tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py
..._tracer/test_torchaudio_model/test_torchaudio_wave2vec.py
+1
-1
tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py
..._fx/test_tracer/test_torchaudio_model/torchaudio_utils.py
+7
-4
No files found.
tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py
View file @
adf5054f
...
...
@@ -31,7 +31,7 @@ def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
)
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
skip
(
"Tracing failed"
)
def
test_tacotron_model
():
n_mels
=
80
n_batch
=
3
...
...
tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py
View file @
adf5054f
...
...
@@ -4,7 +4,6 @@ from torchaudio.models import Emformer, Conformer
import
pytest
@
pytest
.
mark
.
skip
def
test_conformer
():
input_dim
=
80
batch_size
=
10
...
...
@@ -27,10 +26,17 @@ def test_conformer():
input
=
torch
.
rand
(
batch_size
,
int
(
lengths
.
max
()),
input_dim
)
return
dict
(
input
=
input
,
lengths
=
lengths
)
trace_and_compare
(
model
,
data_gen
,
need_meta
=
False
,
need_concrete
=
True
)
def
kwargs_transform
(
data
):
new_data
=
{}
for
k
,
v
in
data
.
items
():
new_data
[
f
'
{
k
}
_1'
]
=
v
return
new_data
@
pytest
.
mark
.
skip
trace_and_compare
(
model
,
data_gen
,
need_meta
=
False
,
need_concrete
=
True
,
kwargs_transform
=
kwargs_transform
)
@
pytest
.
mark
.
skip
(
"Tracing failed"
)
def
test_emformer
():
input_dim
=
128
batch_size
=
10
...
...
tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py
View file @
adf5054f
...
...
@@ -40,7 +40,7 @@ def _smoke_test(model, device):
trace_and_compare
(
model
,
data_gen
,
need_meta
=
True
,
need_concrete
=
False
)
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
skip
(
"Tracing failed"
)
def
test_wav2vec
():
for
model_fn
in
MODEL_LIST
:
_smoke_test
(
model_fn
(),
'cpu'
)
...
...
tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py
View file @
adf5054f
...
...
@@ -3,21 +3,24 @@ import torch
from
torch.fx
import
GraphModule
,
Tracer
def
trace_and_compare
(
model
,
data_gen
,
need_meta
=
False
,
need_concrete
=
False
):
def
trace_and_compare
(
model
,
data_gen
,
need_meta
=
False
,
need_concrete
=
False
,
kwargs_transform
=
False
):
data
=
data_gen
()
concrete_args
=
data
if
need_concrete
else
{}
meta_args
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
data
.
items
()}
if
need_meta
else
{}
tracer
=
ColoTracer
()
model
.
eval
()
graph
=
tracer
.
trace
(
root
=
model
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
model
.
eval
()
gm
.
eval
()
with
torch
.
no_grad
():
non_fx_out
=
model
(
**
data
)
if
kwargs_transform
:
data
=
kwargs_transform
(
data
)
fx_out
=
gm
(
**
data
)
if
isinstance
(
fx_out
,
tuple
):
for
non_fx
,
fx
in
zip
(
non_fx_out
,
fx_out
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment