Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b2475d8c
Unverified
Commit
b2475d8c
authored
Jul 15, 2022
by
Frank Lee
Committed by
GitHub
Jul 15, 2022
Browse files
[fx] fixed unit tests for torch 1.12 (#1327)
parent
d49708ae
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
2 additions
and
18 deletions
+2
-18
colossalai/fx/tracer/_tracer_utils.py
colossalai/fx/tracer/_tracer_utils.py
+0
-1
requirements/requirements-test.txt
requirements/requirements-test.txt
+1
-0
tests/test_fx/test_pipeline/test_hf_model/test_albert.py
tests/test_fx/test_pipeline/test_hf_model/test_albert.py
+0
-1
tests/test_fx/test_pipeline/test_hf_model/test_bert.py
tests/test_fx/test_pipeline/test_hf_model/test_bert.py
+0
-1
tests/test_fx/test_pipeline/test_hf_model/test_gpt.py
tests/test_fx/test_pipeline/test_hf_model/test_gpt.py
+0
-1
tests/test_fx/test_pipeline/test_hf_model/test_opt.py
tests/test_fx/test_pipeline/test_hf_model/test_opt.py
+0
-1
tests/test_fx/test_pipeline/test_hf_model/test_t5.py
tests/test_fx/test_pipeline/test_hf_model/test_t5.py
+0
-1
tests/test_fx/test_pipeline/test_timm_model/test_timm.py
tests/test_fx/test_pipeline/test_timm_model/test_timm.py
+0
-2
tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py
...est_fx/test_pipeline/test_torchvision/test_torchvision.py
+0
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
+0
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
+0
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
+0
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
+0
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
+0
-1
tests/test_fx/test_tracer/test_patched_module.py
tests/test_fx/test_tracer/test_patched_module.py
+1
-1
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
+0
-2
tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
...t_tracer/test_torchvision_model/test_torchvision_model.py
+0
-1
No files found.
colossalai/fx/tracer/_tracer_utils.py
View file @
b2475d8c
...
...
@@ -22,7 +22,6 @@ def extract_meta(*args, **kwargs):
if
isinstance
(
val
,
MetaDeviceAttribute
):
return
'meta'
elif
isinstance
(
val
,
ColoProxy
):
assert
val
.
meta_data
is
not
None
return
val
.
meta_data
return
val
...
...
requirements/requirements-test.txt
View file @
b2475d8c
pytest
torchvision
transformers
timm
titans
tests/test_fx/test_pipeline/test_hf_model/test_albert.py
View file @
b2475d8c
...
...
@@ -7,7 +7,6 @@ BATCH_SIZE = 2
SEQ_LENGHT
=
16
@
pytest
.
mark
.
skip
(
"error with pytorch 1.10"
)
def
test_single_sentence_albert
():
MODEL_LIST
=
[
transformers
.
AlbertModel
,
...
...
tests/test_fx/test_pipeline/test_hf_model/test_bert.py
View file @
b2475d8c
...
...
@@ -7,7 +7,6 @@ BATCH_SIZE = 2
SEQ_LENGHT
=
16
@
pytest
.
mark
.
skip
(
"error with pytorch 1.10"
)
def
test_single_sentence_bert
():
MODEL_LIST
=
[
transformers
.
BertModel
,
...
...
tests/test_fx/test_pipeline/test_hf_model/test_gpt.py
View file @
b2475d8c
...
...
@@ -9,7 +9,6 @@ NUM_EPOCHS = 2
NUM_CHUNKS
=
1
@
pytest
.
mark
.
skip
(
"error with pytorch 1.10"
)
def
test_gpt
():
MODEL_LIST
=
[
transformers
.
GPT2Model
,
...
...
tests/test_fx/test_pipeline/test_hf_model/test_opt.py
View file @
b2475d8c
...
...
@@ -7,7 +7,6 @@ BATCH_SIZE = 1
SEQ_LENGHT
=
16
@
pytest
.
mark
.
skip
(
"error with pytorch 1.10"
)
def
test_opt
():
MODEL_LIST
=
[
transformers
.
OPTModel
,
...
...
tests/test_fx/test_pipeline/test_hf_model/test_t5.py
View file @
b2475d8c
...
...
@@ -7,7 +7,6 @@ BATCH_SIZE = 1
SEQ_LENGHT
=
16
@
pytest
.
mark
.
skip
(
"error with pytorch 1.10"
)
def
test_t5
():
MODEL_LIST
=
[
transformers
.
T5Model
,
...
...
tests/test_fx/test_pipeline/test_timm_model/test_timm.py
View file @
b2475d8c
...
...
@@ -7,7 +7,6 @@ except:
from
timm_utils
import
split_model_and_compare_output
@
pytest
.
mark
.
skip
(
'skip as timm is required'
)
def
test_timm_models_without_control_flow
():
MODEL_LIST
=
[
...
...
@@ -28,7 +27,6 @@ def test_timm_models_without_control_flow():
split_model_and_compare_output
(
model
,
data
)
@
pytest
.
mark
.
skip
(
'skip as timm is required'
)
def
test_timm_models_with_control_flow
():
torch
.
backends
.
cudnn
.
deterministic
=
True
...
...
tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py
View file @
b2475d8c
...
...
@@ -19,7 +19,6 @@ torch.manual_seed(MANUAL_SEED)
torch
.
backends
.
cudnn
.
deterministic
=
True
@
pytest
.
mark
.
skip
(
'skip as torchvision is required'
)
def
test_torchvision_models
():
MODEL_LIST
=
[
tm
.
vgg11
,
tm
.
resnet18
,
tm
.
densenet121
,
tm
.
mobilenet_v3_small
,
tm
.
resnext50_32x4d
,
tm
.
wide_resnet50_2
,
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
View file @
b2475d8c
...
...
@@ -34,7 +34,6 @@ def test_single_sentence_albert():
trace_model_and_compare_output
(
model
,
data_gen
)
@
pytest
.
mark
.
skip
(
"error with pytorch 1.10"
)
def
test_multi_sentence_albert
():
config
=
transformers
.
AlbertConfig
(
hidden_size
=
128
,
num_hidden_layers
=
2
,
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
View file @
b2475d8c
...
...
@@ -31,7 +31,6 @@ def test_single_sentence_bert():
trace_model_and_compare_output
(
model
,
data_gen
)
@
pytest
.
mark
.
skip
(
"error with pytorch 1.10"
)
def
test_multi_sentence_bert
():
config
=
transformers
.
BertConfig
(
hidden_size
=
128
,
num_hidden_layers
=
2
,
num_attention_heads
=
4
,
intermediate_size
=
256
)
tokenizer
=
transformers
.
BertTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
View file @
b2475d8c
...
...
@@ -7,7 +7,6 @@ BATCH_SIZE = 1
SEQ_LENGHT
=
16
@
pytest
.
mark
.
skip
(
"error with pytorch 1.10"
)
def
test_gpt
():
MODEL_LIST
=
[
transformers
.
GPT2Model
,
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
View file @
b2475d8c
...
...
@@ -7,7 +7,6 @@ BATCH_SIZE = 1
SEQ_LENGHT
=
16
@
pytest
.
mark
.
skip
(
"error with pytorch 1.10"
)
def
test_opt
():
MODEL_LIST
=
[
transformers
.
OPTModel
,
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
View file @
b2475d8c
...
...
@@ -7,7 +7,6 @@ BATCH_SIZE = 1
SEQ_LENGHT
=
16
@
pytest
.
mark
.
skip
(
"error with pytorch 1.10"
)
def
test_t5
():
MODEL_LIST
=
[
transformers
.
T5Model
,
...
...
tests/test_fx/test_tracer/test_patched_module.py
View file @
b2475d8c
...
...
@@ -40,7 +40,7 @@ def test_embedding():
_assert_output_shape
(
data
,
ln
,
patched_module
.
torch_nn_normalize
,
False
,
data
.
shape
)
# test group norm
gn
=
torch
.
nn
.
GroupNorm
(
4
,
num_channels
=
2
)
gn
=
torch
.
nn
.
GroupNorm
(
4
,
num_channels
=
8
)
_assert_output_shape
(
data
,
gn
,
patched_module
.
torch_nn_normalize
,
False
,
data
.
shape
)
# test batch norm 1d
...
...
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
View file @
b2475d8c
...
...
@@ -36,7 +36,6 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
fx_out
,
non_fx_out
),
f
'
{
model
.
__class__
.
__name__
}
has inconsistent outputs,
{
fx_out
}
vs
{
non_fx_out
}
'
@
pytest
.
mark
.
skip
(
'skip as timm is required'
)
def
test_timm_models_without_control_flow
():
torch
.
backends
.
cudnn
.
deterministic
=
True
...
...
@@ -58,7 +57,6 @@ def test_timm_models_without_control_flow():
trace_and_compare
(
model_cls
,
tracer
,
data
)
@
pytest
.
mark
.
skip
(
'skip as timm is required'
)
def
test_timm_models_with_control_flow
():
torch
.
backends
.
cudnn
.
deterministic
=
True
...
...
tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
View file @
b2475d8c
...
...
@@ -8,7 +8,6 @@ from colossalai.fx import ColoTracer
from
torch.fx
import
GraphModule
@
pytest
.
mark
.
skip
(
'skip as torchvision is required'
)
def
test_torchvision_models
():
MODEL_LIST
=
[
tm
.
vgg11
,
tm
.
resnet18
,
tm
.
densenet121
,
tm
.
mobilenet_v3_small
,
tm
.
resnext50_32x4d
,
tm
.
wide_resnet50_2
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment