Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b6cb5a47
Unverified
Commit
b6cb5a47
authored
Jul 07, 2022
by
Frank Lee
Committed by
GitHub
Jul 07, 2022
Browse files
[fx] added timm model tracing testing (#1221)
parent
280a8124
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
125 additions
and
4 deletions
+125
-4
colossalai/fx/tracer/meta_patch/patched_function.py
colossalai/fx/tracer/meta_patch/patched_function.py
+41
-1
colossalai/fx/tracer/meta_patch/patched_module.py
colossalai/fx/tracer/meta_patch/patched_module.py
+2
-2
tests/test_fx/test_coloproxy.py
tests/test_fx/test_coloproxy.py
+0
-1
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
+82
-0
No files found.
colossalai/fx/tracer/meta_patch/patched_function.py
View file @
b6cb5a47
from
curses
import
meta
import
operator
import
torch
from
.registry
import
meta_patched_function
...
...
@@ -99,7 +100,6 @@ def torch_abs(input, *, out=None):
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
relu
)
def
torch_nn_func_relu
(
input
,
inplace
=
False
):
assert
not
inplace
,
'inplace is not supported yet'
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
...
...
@@ -178,3 +178,43 @@ def torch_unsqueeze(input, dim):
@
meta_patched_function
.
register
(
torch
.
Tensor
.
unsqueeze
)
def
torch_tensor_unsqueeze
(
self
,
dim
):
return
torch_unsqueeze
(
self
,
dim
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
layer_norm
)
def
torch_nn_func_layernorm
(
input
,
normalized_shape
,
weight
=
None
,
bias
=
None
,
eps
=
1e-05
):
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
batch_norm
)
def
torch_nn_func_batchnorm
(
input
,
running_mean
,
running_var
,
weight
=
None
,
bias
=
None
,
training
=
False
,
momentum
=
0.1
,
eps
=
1e-05
):
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
var_mean
)
def
torch_var_mean
(
input
,
dim
,
unbiased
=
True
,
keepdim
=
False
,
*
,
out
=
None
):
assert
out
is
None
,
'saving to out is not supported yet'
var
=
torch
.
empty
(
1
).
squeeze
(
0
).
to
(
'meta'
)
mean
=
torch
.
empty
(
1
).
squeeze
(
0
).
to
(
'meta'
)
return
var
,
mean
@
meta_patched_function
.
register
(
torch
.
cat
)
def
torch_cat
(
tensors
,
dim
=
None
,
axis
=
None
,
*
,
out
=
None
):
if
dim
is
None
and
axis
is
None
:
dim
=
0
if
dim
is
None
and
axis
is
not
None
:
dim
=
axis
if
dim
<
0
:
dim
=
tensors
[
0
].
dim
()
+
dim
shapes
=
[
t
.
shape
for
t
in
tensors
]
shape
=
list
(
shapes
[
0
])
concatenated_dim
=
sum
(
shape
[
dim
]
for
shape
in
shapes
)
final_shape
=
shape
[:
dim
]
+
[
concatenated_dim
]
+
shape
[
dim
+
1
:]
return
torch
.
empty
(
final_shape
,
device
=
"meta"
)
colossalai/fx/tracer/meta_patch/patched_module.py
View file @
b6cb5a47
...
...
@@ -250,6 +250,6 @@ def torch_nn_maxpool3d(self, input):
@
meta_patched_module
.
register
(
torch
.
nn
.
ReLU
)
@
meta_patched_module
.
register
(
torch
.
nn
.
ReLU6
)
def
torch_nn_func_relu
(
self
,
input
):
assert
not
self
.
inplace
,
'inplace is not supported yet'
return
input
.
clone
()
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
tests/test_fx/test_coloproxy.py
View file @
b6cb5a47
...
...
@@ -3,7 +3,6 @@ from colossalai.fx.proxy import ColoProxy
import
pytest
@
pytest
.
mark
.
skip
def
test_coloproxy
():
# create a dummy node only for testing purpose
model
=
torch
.
nn
.
Linear
(
10
,
10
)
...
...
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
0 → 100644
View file @
b6cb5a47
import
torch
import
pytest
try
:
import
timm.models
as
tm
except
:
pass
from
colossalai.fx
import
ColoTracer
from
torch.fx
import
GraphModule
def
trace_and_compare
(
model_cls
,
tracer
,
data
,
meta_args
=
None
):
# trace
model
=
model_cls
()
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
# convert to eval for inference
model
.
eval
()
gm
.
eval
()
# run forward
with
torch
.
no_grad
():
fx_out
=
gm
(
data
)
non_fx_out
=
model
(
data
)
# compare output
if
isinstance
(
fx_out
,
tuple
):
# some models produce tuple as output
for
v1
,
v2
in
zip
(
fx_out
,
non_fx_out
):
assert
torch
.
allclose
(
v1
,
v2
),
f
'
{
model
.
__class__
.
__name__
}
has inconsistent outputs,
{
v1
}
vs
{
v2
}
'
else
:
assert
torch
.
allclose
(
fx_out
,
non_fx_out
),
f
'
{
model
.
__class__
.
__name__
}
has inconsistent outputs,
{
fx_out
}
vs
{
non_fx_out
}
'
@
pytest
.
mark
.
skip
(
'skip as timm is required'
)
def
test_timm_models_without_control_flow
():
torch
.
backends
.
cudnn
.
deterministic
=
True
MODEL_LIST
=
[
tm
.
resnest
.
resnest50d
,
tm
.
beit
.
beit_base_patch16_224
,
tm
.
cait
.
cait_s24_224
,
tm
.
convmixer
.
convmixer_768_32
,
tm
.
efficientnet
.
efficientnetv2_m
,
tm
.
resmlp_12_224
,
tm
.
vision_transformer
.
vit_base_patch16_224
# results not aligned
# tm.deit_base_distilled_patch16_224,
]
tracer
=
ColoTracer
()
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
for
model_cls
in
MODEL_LIST
:
trace_and_compare
(
model_cls
,
tracer
,
data
)
@
pytest
.
mark
.
skip
(
'skip as timm is required'
)
def
test_timm_models_with_control_flow
():
torch
.
backends
.
cudnn
.
deterministic
=
True
MODEL_LIST_WITH_CONTROL_FLOW
=
[
tm
.
convnext
.
convnext_base
,
tm
.
vgg
.
vgg11
,
# results not aligned
# tm.dpn.dpn68,
# tm.densenet.densenet121,
# tm.rexnet.rexnet_100,
# tm.swin_transformer.swin_base_patch4_window7_224
]
tracer
=
ColoTracer
()
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
meta_args
=
{
'x'
:
data
.
to
(
'meta'
)}
for
model_cls
in
MODEL_LIST_WITH_CONTROL_FLOW
:
trace_and_compare
(
model_cls
,
tracer
,
data
,
meta_args
)
if
__name__
==
'__main__'
:
test_timm_models_with_control_flow
()
test_timm_models_without_control_flow
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment