Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
019a8474
Unverified
Commit
019a8474
authored
Mar 22, 2023
by
YuliangLiu0306
Committed by
GitHub
Mar 22, 2023
Browse files
[Analyzer] fix analyzer tests (#3197)
parent
f57d3495
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
60 additions
and
100 deletions
+60
-100
tests/test_analyzer/test_fx/test_bias_addition.py
tests/test_analyzer/test_fx/test_bias_addition.py
+20
-13
tests/test_analyzer/test_fx/test_shape_prop.py
tests/test_analyzer/test_fx/test_shape_prop.py
+16
-14
tests/test_analyzer/test_fx/test_symbolic_profile.py
tests/test_analyzer/test_fx/test_symbolic_profile.py
+10
-8
tests/test_analyzer/test_fx/zoo.py
tests/test_analyzer/test_fx/zoo.py
+4
-4
tests/test_analyzer/test_subclasses/test_flop_tensor.py
tests/test_analyzer/test_subclasses/test_flop_tensor.py
+6
-5
tests/test_analyzer/test_subclasses/test_meta_mode.py
tests/test_analyzer/test_subclasses/test_meta_mode.py
+4
-3
tests/test_analyzer/test_subclasses/zoo.py
tests/test_analyzer/test_subclasses/zoo.py
+0
-53
No files found.
tests/test_analyzer/test_fx/test_bias_addition.py
View file @
019a8474
...
...
@@ -3,6 +3,8 @@ import torch
from
packaging
import
version
from
torch.utils.checkpoint
import
checkpoint
from
colossalai.testing.utils
import
parameterize
try
:
from
colossalai._analyzer.fx
import
symbolic_trace
except
:
...
...
@@ -56,9 +58,13 @@ class SiuModel(torch.nn.Module):
self
.
linear
=
LinearModel
(
3
,
3
,
bias
)
self
.
conv
=
ConvModel
(
3
,
6
,
3
,
bias
)
def
forward
(
self
,
x
,
select
=
0
):
def
forward
(
self
,
x
,
select
=
torch
.
Tensor
([
0
])
):
x
=
self
.
linear
(
x
)
x
=
checkpoint
(
self
.
conv
,
x
,
select
)
if
select
:
x
=
checkpoint
(
self
.
conv
,
x
,
0
)
else
:
x
=
checkpoint
(
self
.
conv
,
x
,
1
)
return
x
...
...
@@ -75,10 +81,10 @@ class AddmmModel(torch.nn.Module):
@
pytest
.
mark
.
skipif
(
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.12.0'
),
reason
=
'torch version < 12'
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"bias_addition_split"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
3
,
3
,
3
),
(
3
,
3
,
3
,
3
)])
@
pytest
.
mark
.
parametrize
(
"select"
,
[
0
,
1
])
@
paramet
e
rize
(
"bias"
,
[
True
,
False
])
@
paramet
e
rize
(
"bias_addition_split"
,
[
True
,
False
])
@
paramet
e
rize
(
"shape"
,
[(
3
,
3
,
3
),
(
3
,
3
,
3
,
3
)])
@
paramet
e
rize
(
"select"
,
[
torch
.
Tensor
([
0
]),
torch
.
Tensor
([
1
])
])
def
test_siu_model
(
bias
,
bias_addition_split
,
shape
,
select
):
model
=
SiuModel
(
bias
=
bias
)
x
=
torch
.
rand
(
shape
)
...
...
@@ -87,18 +93,18 @@ def test_siu_model(bias, bias_addition_split, shape, select):
concrete_args
=
{
'select'
:
select
},
trace_act_ckpt
=
True
,
bias_addition_split
=
bias_addition_split
)
assert
torch
.
allclose
(
model
(
x
,
select
),
gm
(
x
,
select
)),
'original model and traced model should be the same!'
assert
torch
.
allclose
(
model
(
x
,
select
),
gm
(
x
)),
'original model and traced model should be the same!'
if
bias
and
bias_addition_split
:
assert
'+'
in
gm
.
code
,
'bias addition should be split!'
else
:
assert
'+'
not
in
gm
.
code
,
'bias addition should not be split!'
@
pytest
.
mark
.
skipif
(
torch
.
__version__
<
'1.12.0'
,
reason
=
'torch version < 12'
)
@
pytest
.
mark
.
parametrize
(
"alpha"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"beta"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"bias_addition_split"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
3
,
3
),
(
5
,
5
)])
@
pytest
.
mark
.
skipif
(
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.12.0'
)
,
reason
=
'torch version < 12'
)
@
paramet
e
rize
(
"alpha"
,
[
1
,
2
])
@
paramet
e
rize
(
"beta"
,
[
1
,
2
])
@
paramet
e
rize
(
"bias_addition_split"
,
[
True
,
False
])
@
paramet
e
rize
(
"shape"
,
[(
3
,
3
),
(
5
,
5
)])
def
test_addmm_model
(
alpha
,
beta
,
bias_addition_split
,
shape
):
model
=
AddmmModel
(
alpha
=
alpha
,
beta
=
beta
)
x
=
torch
.
rand
(
shape
)
...
...
@@ -111,4 +117,5 @@ def test_addmm_model(alpha, beta, bias_addition_split, shape):
if
__name__
==
'__main__'
:
test_siu_model
(
True
,
True
,
(
3
,
3
,
3
))
test_siu_model
()
test_addmm_model
()
tests/test_analyzer/test_fx/test_shape_prop.py
View file @
019a8474
import
pytest
import
timm.models
as
tmm
import
torch
import
torchvision.models
as
tm
from
.zoo
import
tm_models
,
tmm_models
from
packaging
import
version
from
colossalai.testing.utils
import
parameterize
from
tests.test_analyzer.test_fx.zoo
import
tm_models
,
tmm_models
try
:
from
colossalai._analyzer._subclasses
import
MetaTensorMode
from
colossalai._analyzer.fx
import
symbolic_trace
from
colossalai._analyzer.fx.passes.shape_prop
import
shape_prop_pass
from
colossalai._analyzer.fx.symbolic_profile
import
register_shape_impl
@
register_shape_impl
(
torch
.
nn
.
functional
.
linear
)
def
linear_impl
(
*
args
,
**
kwargs
):
assert
True
...
...
@@ -23,15 +24,15 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
for
node
in
gm
.
graph
.
nodes
:
assert
node
.
meta
[
'info'
].
outputs
,
f
'In
{
gm
.
__class__
.
__name__
}
,
{
node
}
has no output shape.'
if
node
.
op
in
[
#
'call_module', # can apply to params
#
'call_function', # can apply to params
#
'call_method', # can apply to params
'call_module'
,
# can apply to params
'call_function'
,
# can apply to params
'call_method'
,
# can apply to params
]:
assert
node
.
meta
[
'info'
]
.
inputs
,
f
'In
{
gm
.
__class__
.
__name__
}
,
{
node
}
has no input shape.'
assert
hasattr
(
node
.
meta
[
'info'
]
,
'
inputs
'
)
,
f
'In
{
gm
.
__class__
.
__name__
}
,
{
node
}
has no input shape.'
@
pytest
.
mark
.
skipif
(
torch
.
__version__
<
'1.12.0'
,
reason
=
'torch version < 12'
)
@
pytest
.
mark
.
parametrize
(
'm'
,
tm_models
)
@
pytest
.
mark
.
skipif
(
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.12.0'
)
,
reason
=
'torch version < 12'
)
@
paramet
e
rize
(
'm'
,
tm_models
)
def
test_torchvision_shape_prop
(
m
):
with
MetaTensorMode
():
model
=
m
()
...
...
@@ -44,8 +45,8 @@ def test_torchvision_shape_prop(m):
_check_gm_validity
(
gm
)
@
pytest
.
mark
.
skipif
(
torch
.
__version__
<
'1.12.0'
,
reason
=
'torch version < 12'
)
@
pytest
.
mark
.
parametrize
(
'm'
,
tmm_models
)
@
pytest
.
mark
.
skipif
(
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.12.0'
)
,
reason
=
'torch version < 12'
)
@
paramet
e
rize
(
'm'
,
tmm_models
)
def
test_timm_shape_prop
(
m
):
with
MetaTensorMode
():
model
=
m
()
...
...
@@ -53,11 +54,12 @@ def test_timm_shape_prop(m):
meta_args
=
{
"x"
:
data
,
}
gm
=
symbolic_trace
(
model
,
meta_args
=
meta_args
)
shape_prop_pass
(
gm
,
data
)
_check_gm_validity
(
gm
)
if
__name__
==
"__main__"
:
test_torchvision_shape_prop
(
tm
.
resnet18
)
test_timm_shape_prop
(
tmm
.
vgg11
)
test_torchvision_shape_prop
()
test_timm_shape_prop
()
tests/test_analyzer/test_fx/test_symbolic_profile.py
View file @
019a8474
import
pytest
import
timm.models
as
tmm
import
torch
import
torchvision.models
as
tm
from
.zoo
import
tm_models
,
tmm_models
from
packaging
import
version
from
colossalai.testing.utils
import
parameterize
from
tests.test_analyzer.test_fx.zoo
import
tm_models
,
tmm_models
try
:
from
colossalai._analyzer._subclasses
import
MetaTensorMode
...
...
@@ -16,8 +18,8 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
assert
len
(
node
.
meta
[
'info'
].
global_ctx
),
f
'In
{
gm
.
__class__
.
__name__
}
,
{
node
}
has empty global context.'
@
pytest
.
mark
.
skipif
(
torch
.
__version__
<
'1.12.0'
,
reason
=
'torch version < 12'
)
@
pytest
.
mark
.
parametrize
(
'm'
,
tm_models
)
@
pytest
.
mark
.
skipif
(
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.12.0'
)
,
reason
=
'torch version < 12'
)
@
paramet
e
rize
(
'm'
,
tm_models
)
def
test_torchvision_profile
(
m
,
verbose
=
False
,
bias_addition_split
=
False
):
with
MetaTensorMode
():
model
=
m
()
...
...
@@ -30,8 +32,8 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
_check_gm_validity
(
gm
)
@
pytest
.
mark
.
skipif
(
torch
.
__version__
<
'1.12.0'
,
reason
=
'torch version < 12'
)
@
pytest
.
mark
.
parametrize
(
'm'
,
tmm_models
)
@
pytest
.
mark
.
skipif
(
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.12.0'
)
,
reason
=
'torch version < 12'
)
@
paramet
e
rize
(
'm'
,
tmm_models
)
def
test_timm_profile
(
m
,
verbose
=
False
,
bias_addition_split
=
False
):
with
MetaTensorMode
():
model
=
m
()
...
...
@@ -45,5 +47,5 @@ def test_timm_profile(m, verbose=False, bias_addition_split=False):
if
__name__
==
"__main__"
:
test_torchvision_profile
(
tm
.
vit_b_16
,
verbose
=
True
,
bias_addition_split
=
False
)
test_timm_profile
(
tmm
.
gmlp_b16_224
,
verbose
=
True
,
bias_addition_split
=
False
)
test_torchvision_profile
()
test_timm_profile
()
tests/test_analyzer/test_fx/zoo.py
View file @
019a8474
...
...
@@ -33,18 +33,18 @@ tmm_models = [
tmm
.
dm_nfnet_f0
,
tmm
.
eca_nfnet_l0
,
tmm
.
efficientformer_l1
,
tmm
.
ese_vovnet19b_dw
,
#
tmm.ese_vovnet19b_dw,
tmm
.
gmixer_12_224
,
tmm
.
gmlp_b16_224
,
tmm
.
hardcorenas_a
,
#
tmm.hardcorenas_a,
tmm
.
hrnet_w18_small
,
tmm
.
inception_v3
,
tmm
.
mixer_b16_224
,
tmm
.
nf_ecaresnet101
,
tmm
.
nf_regnet_b0
,
# tmm.pit_b_224, # pretrained only
tmm
.
regnetv_040
,
tmm
.
skresnet18
,
#
tmm.regnetv_040,
#
tmm.skresnet18,
# tmm.swin_base_patch4_window7_224, # fx bad case
# tmm.tnt_b_patch16_224, # bad case
tmm
.
vgg11
,
...
...
tests/test_analyzer/test_subclasses/test_flop_tensor.py
View file @
019a8474
import
pytest
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchvision.models
as
tm
from
.zoo
import
tm_models
,
tmm_models
from
packaging
import
version
from
tests.test_analyzer.test_fx.zoo
import
tm_models
,
tmm_models
try
:
from
colossalai._analyzer._subclasses
import
MetaTensorMode
,
flop_count
...
...
@@ -11,7 +12,7 @@ except:
pass
@
pytest
.
mark
.
skipif
(
torch
.
__version__
<
'1.12.0'
,
reason
=
'torch version < 12'
)
@
pytest
.
mark
.
skipif
(
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.12.0'
)
,
reason
=
'torch version < 12'
)
@
pytest
.
mark
.
parametrize
(
'm'
,
tm_models
+
tmm_models
)
def
test_flop_count_module
(
m
):
x
=
torch
.
rand
(
2
,
3
,
224
,
224
)
...
...
@@ -37,7 +38,7 @@ odd_cases = [
]
@
pytest
.
mark
.
skipif
(
torch
.
__version__
<
'1.12.0'
,
reason
=
'torch version < 12'
)
@
pytest
.
mark
.
skipif
(
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.12.0'
)
,
reason
=
'torch version < 12'
)
@
pytest
.
mark
.
parametrize
(
'func, args, kwargs'
,
odd_cases
)
def
test_flop_count_function
(
func
,
args
,
kwargs
):
rs_fwd
,
rs_bwd
=
flop_count
(
func
,
*
args
,
**
kwargs
,
verbose
=
True
)
...
...
@@ -46,5 +47,5 @@ def test_flop_count_function(func, args, kwargs):
if
__name__
==
'__main__'
:
test_flop_count_module
(
tm
.
resnet18
,
torch
.
rand
(
2
,
3
,
224
,
224
)
)
test_flop_count_module
(
tm
.
resnet18
)
test_flop_count_function
(
F
.
relu
,
(
torch
.
rand
(
2
,
3
,
224
,
224
,
requires_grad
=
True
),),
{
'inplace'
:
True
})
tests/test_analyzer/test_subclasses/test_meta_mode.py
View file @
019a8474
import
pytest
import
torch
import
torch.distributed
as
dist
import
torchvision.models
as
tm
from
packaging
import
version
try
:
from
colossalai._analyzer._subclasses
import
MetaTensor
,
MetaTensorMode
except
:
pass
from
.zoo
import
tm_models
,
tmm_models
from
tests.test_analyzer.test_fx
.zoo
import
tm_models
,
tmm_models
def
compare_all
(
tensor
:
torch
.
Tensor
,
meta_tensor
:
torch
.
Tensor
):
...
...
@@ -28,7 +29,7 @@ def run_and_compare(model):
compare_all
(
x
.
grad
,
meta_x
.
grad
)
@
pytest
.
mark
.
skipif
(
torch
.
__version__
<
'1.12.0'
,
reason
=
'torch version < 12'
)
@
pytest
.
mark
.
skipif
(
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.12.0'
)
,
reason
=
'torch version < 12'
)
@
pytest
.
mark
.
parametrize
(
'm'
,
tm_models
+
tmm_models
)
def
test_meta_mode_shape
(
m
):
run_and_compare
(
m
())
...
...
tests/test_analyzer/test_subclasses/zoo.py
deleted
100644 → 0
View file @
f57d3495
import
timm.models
as
tmm
import
torchvision.models
as
tm
# input shape: (batch_size, 3, 224, 224)
tm_models
=
[
tm
.
alexnet
,
tm
.
convnext_base
,
tm
.
densenet121
,
# tm.efficientnet_v2_s,
# tm.googlenet, # output bad case
# tm.inception_v3, # bad case
tm
.
mobilenet_v2
,
tm
.
mobilenet_v3_small
,
tm
.
mnasnet0_5
,
tm
.
resnet18
,
tm
.
regnet_x_16gf
,
tm
.
resnext50_32x4d
,
tm
.
shufflenet_v2_x0_5
,
tm
.
squeezenet1_0
,
# tm.swin_s, # fx bad case
tm
.
vgg11
,
tm
.
vit_b_16
,
tm
.
wide_resnet50_2
,
]
tmm_models
=
[
tmm
.
beit_base_patch16_224
,
tmm
.
beitv2_base_patch16_224
,
tmm
.
cait_s24_224
,
tmm
.
coat_lite_mini
,
tmm
.
convit_base
,
tmm
.
deit3_base_patch16_224
,
tmm
.
dm_nfnet_f0
,
tmm
.
eca_nfnet_l0
,
tmm
.
efficientformer_l1
,
tmm
.
ese_vovnet19b_dw
,
tmm
.
gmixer_12_224
,
tmm
.
gmlp_b16_224
,
tmm
.
hardcorenas_a
,
tmm
.
hrnet_w18_small
,
tmm
.
inception_v3
,
tmm
.
mixer_b16_224
,
tmm
.
nf_ecaresnet101
,
tmm
.
nf_regnet_b0
,
# tmm.pit_b_224, # pretrained only
tmm
.
regnetv_040
,
tmm
.
skresnet18
,
# tmm.swin_base_patch4_window7_224, # fx bad case
# tmm.tnt_b_patch16_224, # bad case
tmm
.
vgg11
,
tmm
.
vit_base_patch16_18x2_224
,
tmm
.
wide_resnet50_2
,
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment