Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
37fcf96b
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "cd4f02bed8f3dccd22ab49d67ba96a5147a48bc0"
Unverified
Commit
37fcf96b
authored
Jul 07, 2022
by
Frank Lee
Committed by
GitHub
Jul 07, 2022
Browse files
[fx] fixed timm tracing result misalignment (#1225)
parent
b6cb5a47
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
13 deletions
+18
-13
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
+18
-13
No files found.
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
View file @
37fcf96b
...
@@ -11,13 +11,15 @@ from torch.fx import GraphModule
...
@@ -11,13 +11,15 @@ from torch.fx import GraphModule
def
trace_and_compare
(
model_cls
,
tracer
,
data
,
meta_args
=
None
):
def
trace_and_compare
(
model_cls
,
tracer
,
data
,
meta_args
=
None
):
# trace
# trace
model
=
model_cls
()
model
=
model_cls
()
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
# convert to eval for inference
# convert to eval for inference
# it is important to set it to eval mode before tracing
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
model
.
eval
()
model
.
eval
()
gm
.
eval
()
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
# run forward
# run forward
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -39,11 +41,14 @@ def test_timm_models_without_control_flow():
...
@@ -39,11 +41,14 @@ def test_timm_models_without_control_flow():
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
deterministic
=
True
MODEL_LIST
=
[
MODEL_LIST
=
[
tm
.
resnest
.
resnest50d
,
tm
.
beit
.
beit_base_patch16_224
,
tm
.
cait
.
cait_s24_224
,
tm
.
convmixer
.
convmixer_768_32
,
tm
.
resnest
.
resnest50d
,
tm
.
efficientnet
.
efficientnetv2_m
,
tm
.
resmlp_12_224
,
tm
.
vision_transformer
.
vit_base_patch16_224
tm
.
beit
.
beit_base_patch16_224
,
tm
.
cait
.
cait_s24_224
,
# results not aligned
tm
.
convmixer
.
convmixer_768_32
,
# tm.deit_base_distilled_patch16_224,
tm
.
efficientnet
.
efficientnetv2_m
,
tm
.
resmlp_12_224
,
tm
.
vision_transformer
.
vit_base_patch16_224
,
tm
.
deit_base_distilled_patch16_224
,
]
]
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
...
@@ -60,11 +65,11 @@ def test_timm_models_with_control_flow():
...
@@ -60,11 +65,11 @@ def test_timm_models_with_control_flow():
MODEL_LIST_WITH_CONTROL_FLOW
=
[
MODEL_LIST_WITH_CONTROL_FLOW
=
[
tm
.
convnext
.
convnext_base
,
tm
.
convnext
.
convnext_base
,
tm
.
vgg
.
vgg11
,
tm
.
vgg
.
vgg11
,
tm
.
dpn
.
dpn68
,
tm
.
densenet
.
densenet121
,
tm
.
rexnet
.
rexnet_100
,
# results not aligned
# not traceable
# tm.dpn.dpn68,
# tm.densenet.densenet121,
# tm.rexnet.rexnet_100,
# tm.swin_transformer.swin_base_patch4_window7_224
# tm.swin_transformer.swin_base_patch4_window7_224
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment