Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
55811708
Unverified
Commit
55811708
authored
Jul 07, 2022
by
Frank Lee
Committed by
GitHub
Jul 07, 2022
Browse files
[fx] fixed huggingface OPT and T5 results misalignment (#1227)
parent
2b7dca44
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
9 deletions
+19
-9
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
+12
-1
tests/test_fx/test_tracer/test_hf_model/utils.py
tests/test_fx/test_tracer/test_hf_model/utils.py
+7
-8
No files found.
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
View file @
55811708
...
...
@@ -23,9 +23,20 @@ def test_t5():
kwargs
=
dict
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
)
return
kwargs
def
data_gen_for_encoder_only
():
input_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGHT
),
dtype
=
torch
.
int64
)
kwargs
=
dict
(
input_ids
=
input_ids
)
return
kwargs
for
model_cls
in
MODEL_LIST
:
model
=
model_cls
(
config
=
config
)
trace_model_and_compare_output
(
model
,
data_gen
)
if
isinstance
(
model
,
transformers
.
T5EncoderModel
):
data_gen_func
=
data_gen_for_encoder_only
else
:
data_gen_func
=
data_gen
trace_model_and_compare_output
(
model
,
data_gen_func
)
if
__name__
==
'__main__'
:
...
...
tests/test_fx/test_tracer/test_hf_model/utils.py
View file @
55811708
...
...
@@ -6,8 +6,12 @@ from torch.utils._pytree import tree_flatten
def
trace_model_and_compare_output
(
model
,
data_gen
):
tracer
=
ColoTracer
()
# must turn on eval mode to ensure the output is consistent
model
.
eval
()
# make sure that the model is traceable
tracer
=
ColoTracer
()
try
:
kwargs
=
data_gen
()
meta_args
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
kwargs
.
items
()}
...
...
@@ -17,17 +21,12 @@ def trace_model_and_compare_output(model, data_gen):
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
# check output
inputs
=
data_gen
()
# must turn on eval mode to ensure the output is consistent
gm
.
eval
()
model
.
eval
()
# run forward
inputs
=
data_gen
()
non_fx_out
=
model
(
**
inputs
)
fx_out
=
gm
(
**
inputs
)
# check output
for
k
in
non_fx_out
.
keys
():
if
torch
.
is_tensor
(
fx_out
[
k
]):
assert
torch
.
equal
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment