Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
16335cb5
"...Chat/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "f447ca18111c2e37a2f14e7aecc98876dc7e3216"
Unverified
Commit
16335cb5
authored
Dec 20, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 20, 2022
Browse files
[hotfix] fix aten default bug (#2158)
parent
a4b4bb01
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
137 additions
and
122 deletions
+137
-122
colossalai/fx/profiler/opcount.py
colossalai/fx/profiler/opcount.py
+122
-116
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
...hard/test_node_handler/test_binary_elementwise_handler.py
+1
-1
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
...l/test_tensor_shard/test_node_handler/test_bmm_handler.py
+1
-1
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
...st_tensor_shard/test_node_handler/test_getitem_handler.py
+2
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
...est_tensor_shard/test_node_handler/test_linear_handler.py
+1
-1
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py
...nsor_shard/test_node_handler/test_norm_pooling_handler.py
+3
-3
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py
...st_tensor_shard/test_node_handler/test_reshape_handler.py
+1
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py
...tensor_shard/test_node_handler/test_tensor_constructor.py
+2
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py
...hard/test_node_handler/test_unary_element_wise_handler.py
+1
-0
tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py
..._parallel/test_tensor_shard/test_param_resharding_cost.py
+3
-0
No files found.
colossalai/fx/profiler/opcount.py
View file @
16335cb5
...
@@ -7,6 +7,7 @@ from numbers import Number
...
@@ -7,6 +7,7 @@ from numbers import Number
from
typing
import
Any
,
Callable
,
List
from
typing
import
Any
,
Callable
,
List
import
torch
import
torch
from
packaging
import
version
aten
=
torch
.
ops
.
aten
aten
=
torch
.
ops
.
aten
...
@@ -188,7 +189,8 @@ def zero_flop_jit(*args):
...
@@ -188,7 +189,8 @@ def zero_flop_jit(*args):
return
0
return
0
flop_mapping
=
{
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
):
flop_mapping
=
{
# gemm
# gemm
aten
.
mm
.
default
:
matmul_flop_jit
,
aten
.
mm
.
default
:
matmul_flop_jit
,
aten
.
matmul
.
default
:
matmul_flop_jit
,
aten
.
matmul
.
default
:
matmul_flop_jit
,
...
@@ -228,9 +230,9 @@ flop_mapping = {
...
@@ -228,9 +230,9 @@ flop_mapping = {
aten
.
_adaptive_avg_pool3d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
_adaptive_avg_pool3d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
embedding_dense_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
embedding_dense_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
embedding
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
embedding
.
default
:
elementwise_flop_counter
(
1
,
0
),
}
}
elementwise_flop_aten
=
[
elementwise_flop_aten
=
[
# basic op
# basic op
aten
.
add
.
Tensor
,
aten
.
add
.
Tensor
,
aten
.
add_
.
Tensor
,
aten
.
add_
.
Tensor
,
...
@@ -275,13 +277,12 @@ elementwise_flop_aten = [
...
@@ -275,13 +277,12 @@ elementwise_flop_aten = [
# dropout
# dropout
aten
.
native_dropout
.
default
,
aten
.
native_dropout
.
default
,
aten
.
native_dropout_backward
.
default
,
aten
.
native_dropout_backward
.
default
,
]
]
for
op
in
elementwise_flop_aten
:
for
op
in
elementwise_flop_aten
:
flop_mapping
[
op
]
=
elementwise_flop_counter
(
1
,
0
)
flop_mapping
[
op
]
=
elementwise_flop_counter
(
1
,
0
)
# TODO: this will be removed in future
# TODO: this will be removed in future
zero_flop_aten
=
[
zero_flop_aten
=
[
aten
.
as_strided
.
default
,
aten
.
as_strided
.
default
,
aten
.
as_strided_
.
default
,
aten
.
as_strided_
.
default
,
aten
.
bernoulli_
.
float
,
aten
.
bernoulli_
.
float
,
...
@@ -312,7 +313,12 @@ zero_flop_aten = [
...
@@ -312,7 +313,12 @@ zero_flop_aten = [
aten
.
where
.
self
,
aten
.
where
.
self
,
aten
.
zero_
.
default
,
aten
.
zero_
.
default
,
aten
.
zeros_like
.
default
,
aten
.
zeros_like
.
default
,
]
]
for
op
in
zero_flop_aten
:
for
op
in
zero_flop_aten
:
flop_mapping
[
op
]
=
zero_flop_jit
flop_mapping
[
op
]
=
zero_flop_jit
else
:
flop_mapping
=
{}
elementwise_flop_aten
=
{}
zero_flop_aten
=
{}
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
View file @
16335cb5
...
@@ -207,9 +207,9 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
...
@@ -207,9 +207,9 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
assert
input_sharding_spec
.
sharding_sequence
==
output_sharding_spec
.
sharding_sequence
assert
input_sharding_spec
.
sharding_sequence
==
output_sharding_spec
.
sharding_sequence
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
parameterize
(
'op'
,
[
torch
.
add
])
@
parameterize
(
'op'
,
[
torch
.
add
])
@
parameterize
(
'other_dim'
,
[
1
,
2
])
@
parameterize
(
'other_dim'
,
[
1
,
2
])
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_binary_elementwise_handler
(
op
,
other_dim
):
def
test_binary_elementwise_handler
(
op
,
other_dim
):
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
View file @
16335cb5
...
@@ -203,8 +203,8 @@ def check_1d_device_mesh(rank, module, world_size, port):
...
@@ -203,8 +203,8 @@ def check_1d_device_mesh(rank, module, world_size, port):
assert
other_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
assert
other_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
@
parameterize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
parameterize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_bmm_handler
(
module
):
def
test_bmm_handler
(
module
):
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
View file @
16335cb5
...
@@ -23,6 +23,7 @@ class GetItemFromTensorModel(nn.Module):
...
@@ -23,6 +23,7 @@ class GetItemFromTensorModel(nn.Module):
return
x
return
x
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_getitem_from_tensor_handler
():
def
test_getitem_from_tensor_handler
():
model
=
GetItemFromTensorModel
()
model
=
GetItemFromTensorModel
()
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
...
@@ -96,6 +97,7 @@ class GetItemFromTupleModel(nn.Module):
...
@@ -96,6 +97,7 @@ class GetItemFromTupleModel(nn.Module):
return
x
return
x
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_getitem_from_tuple_handler
():
def
test_getitem_from_tuple_handler
():
model
=
GetItemFromTupleModel
()
model
=
GetItemFromTupleModel
()
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
View file @
16335cb5
...
@@ -308,8 +308,8 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
...
@@ -308,8 +308,8 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
@
parameterize
(
'input_shape'
,
[(
1
,
4
,
4
,
16
),
(
4
,
4
,
4
,
16
)])
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
parameterize
(
'input_shape'
,
[(
1
,
4
,
4
,
16
),
(
4
,
4
,
4
,
16
)])
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_linear_handler
(
input_shape
,
bias
=
False
):
def
test_linear_handler
(
input_shape
,
bias
=
False
):
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py
View file @
16335cb5
...
@@ -2,15 +2,15 @@ import pytest
...
@@ -2,15 +2,15 @@ import pytest
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler
import
\
from
colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler
import
NormPoolingHandler
NormPoolingHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
StrategiesVector
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_module
import
linear
from
colossalai.fx.tracer.meta_patch.patched_module
import
linear
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_norm_pool_handler
():
def
test_norm_pool_handler
():
model
=
nn
.
Sequential
(
nn
.
MaxPool2d
(
4
,
padding
=
1
).
to
(
'meta'
))
model
=
nn
.
Sequential
(
nn
.
MaxPool2d
(
4
,
padding
=
1
).
to
(
'meta'
))
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py
View file @
16335cb5
...
@@ -20,6 +20,7 @@ class ReshapeModel(nn.Module):
...
@@ -20,6 +20,7 @@ class ReshapeModel(nn.Module):
return
reshape_node
return
reshape_node
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_reshape_handler
():
def
test_reshape_handler
():
model
=
ReshapeModel
()
model
=
ReshapeModel
()
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py
View file @
16335cb5
...
@@ -5,6 +5,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handl
...
@@ -5,6 +5,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handl
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
class
TensorConstructorModel
(
nn
.
Module
):
class
TensorConstructorModel
(
nn
.
Module
):
...
@@ -18,6 +19,7 @@ class TensorConstructorModel(nn.Module):
...
@@ -18,6 +19,7 @@ class TensorConstructorModel(nn.Module):
return
x
return
x
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_where_handler
():
def
test_where_handler
():
model
=
TensorConstructorModel
()
model
=
TensorConstructorModel
()
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py
View file @
16335cb5
...
@@ -22,6 +22,7 @@ class ReLuModel(nn.Module):
...
@@ -22,6 +22,7 @@ class ReLuModel(nn.Module):
return
relu_node
return
relu_node
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_elementwise_handler
():
def
test_elementwise_handler
():
model
=
ReLuModel
()
model
=
ReLuModel
()
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
...
...
tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py
View file @
16335cb5
...
@@ -10,6 +10,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
...
@@ -10,6 +10,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
)
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
def
_param_resharding_cost_assertion
(
node
):
def
_param_resharding_cost_assertion
(
node
):
...
@@ -51,6 +52,7 @@ class ConvModel(torch.nn.Module):
...
@@ -51,6 +52,7 @@ class ConvModel(torch.nn.Module):
return
x
return
x
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_linear_module
():
def
test_linear_module
():
model
=
LinearModel
(
4
,
8
)
model
=
LinearModel
(
4
,
8
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
...
@@ -86,6 +88,7 @@ def test_linear_module():
...
@@ -86,6 +88,7 @@ def test_linear_module():
_param_resharding_cost_assertion
(
linear_node
)
_param_resharding_cost_assertion
(
linear_node
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_conv_module
():
def
test_conv_module
():
model
=
ConvModel
(
3
,
6
,
2
)
model
=
ConvModel
(
3
,
6
,
2
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment