Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d9be0472
Commit
d9be0472
authored
Jul 18, 2023
by
Jianghai
Committed by
Hongxin Liu
Aug 15, 2023
Browse files
[bugs] hot fix some testing bugs for new models (#4268)
* hot fix * hot fx tracer
parent
34f0e34a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
4 additions
and
3 deletions
+4
-3
tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py
tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py
+1
-0
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
+2
-0
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
+1
-1
tests/test_shardformer/test_model/test_pure_pipeline.py
tests/test_shardformer/test_model/test_pure_pipeline.py
+0
-2
No files found.
tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py
View file @
d9be0472
...
...
@@ -22,6 +22,7 @@ def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = Non
try
:
meta_args
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
inputs
.
items
()}
gm
=
symbolic_trace
(
model
,
meta_args
=
meta_args
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Failed to trace
{
model
.
__class__
.
__name__
}
, error:
{
e
}
"
)
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
View file @
d9be0472
...
...
@@ -14,6 +14,8 @@ def test_bert():
for
name
,
(
model_fn
,
data_gen_fn
,
_
,
_
,
_
)
in
sub_registry
.
items
():
model
=
model_fn
()
if
model
.
__class__
.
__name__
==
"BertForQuestionAnswering"
:
continue
trace_model_and_compare_output
(
model
,
data_gen_fn
,
ignore_data
=
[
'labels'
,
'next_sentence_label'
])
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
View file @
d9be0472
...
...
@@ -18,7 +18,7 @@ def test_gpt():
# TODO: support the following models
# 1. GPT2DoubleHeadsModel
# as they are not supported, let's skip them
if
model
.
__class__
.
__name__
in
[
'GPT2DoubleHeadsModel'
]:
if
model
.
__class__
.
__name__
in
[
'GPT2DoubleHeadsModel'
,
'GPT2ForQuestionAnswering'
]:
continue
trace_model_and_compare_output
(
model
,
data_gen_fn
,
ignore_data
=
[
'labels'
])
...
...
tests/test_shardformer/test_model/test_pure_pipeline.py
View file @
d9be0472
...
...
@@ -122,9 +122,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
2
:
[
2
,
3
],
3
:
[
2
,
3
],
}
from
datasets
import
load_dataset
#dataset = load_dataset("open_subtitles", lang1="fi", lang2="hi")
pg_mesh
=
ProcessGroupMesh
(
PP_SIZE
)
stage_manager
=
PipelineStageManager
(
pg_mesh
,
PP_DIM
)
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_llama'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment