Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0fecbb9e
Unverified
Commit
0fecbb9e
authored
Dec 08, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 08, 2022
Browse files
[autoparallel] support addbmm computation (#2102)
parent
d3d46304
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
179 additions
and
65 deletions
+179
-65
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py
...addition_patch/patched_bias_addition_function/__init__.py
+2
-1
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
...s_addition_patch/patched_bias_addition_function/addbmm.py
+75
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
...as_addition_patch/patched_bias_addition_function/addmm.py
+4
-20
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
.../patched_bias_addition_function/bias_addition_function.py
+23
-0
colossalai/fx/tracer/registry.py
colossalai/fx/tracer/registry.py
+1
-0
colossalai/fx/tracer/tracer.py
colossalai/fx/tracer/tracer.py
+13
-5
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py
...est_tensor_shard/test_node_handler/test_addbmm_handler.py
+61
-39
No files found.
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py
View file @
0fecbb9e
from
.addbmm
import
Addbmm
from
.addmm
import
Addmm
from
.bias_addition_function
import
BiasAdditionFunc
,
LinearBasedBiasFunc
,
func_to_func_dict
from
.bias_addition_function
import
BiasAdditionFunc
,
LinearBasedBiasFunc
,
func_to_func_dict
,
method_to_func_dict
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
0 → 100644
View file @
0fecbb9e
import
operator
import
torch
import
torch.nn.functional
as
F
from
...registry
import
bias_addition_function
,
bias_addition_method
from
.bias_addition_function
import
LinearBasedBiasFunc
@
bias_addition_method
.
register
(
torch
.
Tensor
.
addbmm
)
@
bias_addition_function
.
register
(
torch
.
addbmm
)
class
Addbmm
(
LinearBasedBiasFunc
):
def
extract_kwargs_from_origin_func
(
self
):
kwargs
=
{}
if
'beta'
in
self
.
kwargs
:
kwargs
[
'beta'
]
=
self
.
kwargs
[
'beta'
]
if
'alpha'
in
self
.
kwargs
:
kwargs
[
'alpha'
]
=
self
.
kwargs
[
'alpha'
]
return
kwargs
def
create_non_bias_func_proxy
(
self
,
input_proxy
,
other_proxy
):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
assert
self
.
substitute_func
==
torch
.
bmm
node_kind
=
'call_function'
node_target
=
self
.
substitute_func
node_args
=
(
input_proxy
,
other_proxy
)
# torch.bmm does not have any kwargs
node_kwargs
=
{}
non_bias_func_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
non_bias_func_proxy
def
insert_sum_node
(
self
,
input_proxy
,
sum_dims
=
0
):
'''
This method is used to sum the input_proxy through the sum_dims.
'''
node_kind
=
'call_function'
node_target
=
torch
.
sum
node_args
=
(
input_proxy
,
sum_dims
)
node_kwargs
=
{}
sum_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
sum_proxy
def
generate
(
self
):
# The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2))
# doing the non-bias computation(temp_0 = torch.bmm(b1, b2))
non_bias_linear_func_proxy
=
self
.
create_non_bias_func_proxy
(
self
.
args
[
1
],
self
.
args
[
2
])
# doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0))
sum_proxy
=
self
.
insert_sum_node
(
non_bias_linear_func_proxy
)
kwargs
=
self
.
extract_kwargs_from_origin_func
()
if
'beta'
in
kwargs
:
beta
=
kwargs
[
'beta'
]
# doing the multiplication with beta if it exists(temp_2 = beta * input)
beta_proxy
=
self
.
create_mul_node
(
self
.
args
[
0
],
beta
)
else
:
beta_proxy
=
self
.
args
[
0
]
if
'alpha'
in
kwargs
:
alpha
=
kwargs
[
'alpha'
]
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
alpha_proxy
=
self
.
create_mul_node
(
alpha
,
sum_proxy
)
else
:
alpha_proxy
=
sum_proxy
# doing the addition(temp_4 = temp_2 + temp_3)
bias_addition_proxy
=
self
.
create_bias_addition_proxy
(
alpha_proxy
,
beta_proxy
)
return
bias_addition_proxy
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
View file @
0fecbb9e
...
...
@@ -3,10 +3,11 @@ import operator
import
torch
import
torch.nn.functional
as
F
from
...registry
import
bias_addition_function
from
...registry
import
bias_addition_function
,
bias_addition_method
from
.bias_addition_function
import
LinearBasedBiasFunc
@
bias_addition_method
.
register
(
torch
.
Tensor
.
addmm
)
@
bias_addition_function
.
register
(
torch
.
addmm
)
class
Addmm
(
LinearBasedBiasFunc
):
...
...
@@ -18,23 +19,6 @@ class Addmm(LinearBasedBiasFunc):
kwargs
[
'alpha'
]
=
self
.
kwargs
[
'alpha'
]
return
kwargs
def
coefficent_for_addmm
(
self
,
input_proxy
,
coefficent
):
"""
This method is used to create a coefficent node for the numerical correctness.
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
node_kind
=
'call_function'
node_target
=
operator
.
mul
node_args
=
(
input_proxy
,
coefficent
,
)
node_kwargs
=
{}
mul_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
mul_proxy
def
transpose_other_operand_for_linear
(
self
,
other_proxy
):
'''
This method is used to transpose the other operand for linear function.
...
...
@@ -61,13 +45,13 @@ class Addmm(LinearBasedBiasFunc):
if
'beta'
in
kwargs
:
beta
=
kwargs
[
'beta'
]
beta_proxy
=
self
.
c
oefficent_for_addmm
(
self
.
args
[
0
],
beta
)
beta_proxy
=
self
.
c
reate_mul_node
(
self
.
args
[
0
],
beta
)
else
:
beta_proxy
=
self
.
args
[
0
]
if
'alpha'
in
kwargs
:
alpha
=
kwargs
[
'alpha'
]
alpha_proxy
=
self
.
c
oefficent_for_addmm
(
alpha
,
non_bias_linear_func_proxy
)
alpha_proxy
=
self
.
c
reate_mul_node
(
alpha
,
non_bias_linear_func_proxy
)
else
:
alpha_proxy
=
non_bias_linear_func_proxy
...
...
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
View file @
0fecbb9e
...
...
@@ -52,6 +52,23 @@ class BiasAdditionFunc(ABC):
"""
pass
def
create_mul_node
(
self
,
input_proxy
,
coefficent
):
"""
This method is used to create a coefficent node for the numerical correctness.
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
node_kind
=
'call_function'
node_target
=
operator
.
mul
node_args
=
(
input_proxy
,
coefficent
,
)
node_kwargs
=
{}
mul_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
mul_proxy
class
LinearBasedBiasFunc
(
BiasAdditionFunc
):
"""
...
...
@@ -88,4 +105,10 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
func_to_func_dict
=
{
torch
.
addmm
:
F
.
linear
,
torch
.
addbmm
:
torch
.
bmm
,
}
method_to_func_dict
=
{
torch
.
Tensor
.
addmm
:
F
.
linear
,
torch
.
Tensor
.
addbmm
:
torch
.
bmm
,
}
colossalai/fx/tracer/registry.py
View file @
0fecbb9e
...
...
@@ -25,3 +25,4 @@ meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution
meta_patched_module
=
PatchRegistry
(
name
=
'patched_modules_for_meta_execution'
)
bias_addition_function
=
PatchRegistry
(
name
=
'patched_function_for_bias_addition'
)
bias_addition_module
=
PatchRegistry
(
name
=
'patched_module_for_bias_addition'
)
bias_addition_method
=
PatchRegistry
(
name
=
'patched_method_for_bias_addition'
)
colossalai/fx/tracer/tracer.py
View file @
0fecbb9e
...
...
@@ -20,8 +20,14 @@ from torch.fx.proxy import ParameterProxy, Proxy
from
..proxy
import
ColoProxy
from
._tracer_utils
import
compute_meta_data_for_functions_proxy
,
extract_meta
,
is_element_in_list
from
.bias_addition_patch
import
func_to_func_dict
,
module_to_func_dict
from
.registry
import
bias_addition_function
,
bias_addition_module
,
meta_patched_function
,
meta_patched_module
from
.bias_addition_patch
import
func_to_func_dict
,
method_to_func_dict
,
module_to_func_dict
from
.registry
import
(
bias_addition_function
,
bias_addition_method
,
bias_addition_module
,
meta_patched_function
,
meta_patched_module
,
)
__all__
=
[
'ColoTracer'
]
...
...
@@ -100,12 +106,14 @@ class ColoTracer(Tracer):
handle
=
bias_addition_function
.
get
(
target
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
elif
bias_addition_function
.
has
(
target
.
__name__
):
# use name for some builtin op like @ (matmul)
handle
=
bias_addition_function
.
get
(
target
.
__name__
)(
self
,
target
,
args
,
kwargs
)
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
.
__name__
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
elif
kind
==
"call_method"
:
method
=
getattr
(
args_metas
[
0
].
__class__
,
target
)
if
bias_addition_function
.
has
(
method
):
handle
=
bias_addition_function
.
get
(
method
)(
self
,
target
,
args
,
kwargs
)
if
bias_addition_method
.
has
(
method
):
function_to_substitute
=
method_to_func_dict
[
method
]
handle
=
bias_addition_method
.
get
(
method
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
elif
kind
==
"call_module"
:
if
not
hasattr
(
self
,
"orig_forward"
):
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py
View file @
0fecbb9e
...
...
@@ -5,7 +5,7 @@ import torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler
import
Add
BMMFunctionHandler
from
colossalai.auto_parallel.tensor_shard.node_handler
import
BMMFunctionHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
...
...
@@ -19,20 +19,36 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n
class
AddBMMTensorMethodModule
(
nn
.
Module
):
def
__init__
(
self
,
using_kwargs
):
super
().
__init__
()
self
.
using_kwargs
=
using_kwargs
def
forward
(
self
,
bias
,
x1
,
x2
):
return
bias
.
addbmm
(
x1
,
x2
)
if
self
.
using_kwargs
:
output
=
bias
.
addbmm
(
x1
,
x2
,
alpha
=
2
,
beta
=
3
)
else
:
output
=
bias
.
addbmm
(
x1
,
x2
)
return
output
class
AddBMMTorchFunctionModule
(
nn
.
Module
):
def
__init__
(
self
,
using_kwargs
):
super
().
__init__
()
self
.
using_kwargs
=
using_kwargs
def
forward
(
self
,
bias
,
x1
,
x2
):
return
torch
.
addbmm
(
bias
,
x1
,
x2
)
if
self
.
using_kwargs
:
output
=
torch
.
addbmm
(
bias
,
x1
,
x2
,
alpha
=
2
,
beta
=
3
)
else
:
output
=
torch
.
addbmm
(
bias
,
x1
,
x2
)
return
output
def
check_2d_device_mesh
(
rank
,
module
,
bias_shape
,
world_size
,
port
):
def
check_2d_device_mesh
(
rank
,
module
,
bias_shape
,
using_kwargs
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
module
().
cuda
()
model
=
module
(
using_kwargs
).
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
...
...
@@ -54,6 +70,14 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
input_args
=
input_args
,
meta_arg_names
=
meta_arg_names
)
tracer
=
ColoTracer
()
# graph():
# %bias : torch.Tensor [#users=1] = placeholder[target=bias]
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
# %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {})
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {})
# return add
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
'bias'
:
torch
.
rand
(
*
bias_shape
).
to
(
'meta'
),
...
...
@@ -62,11 +86,11 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
})
gm
=
ColoGraphModule
(
model
,
graph
)
linear
_mod_node
=
list
(
graph
.
nodes
)[
3
]
strategies_vector
=
StrategiesVector
(
linear
_mod_node
)
bmm
_mod_node
=
list
(
graph
.
nodes
)[
3
]
strategies_vector
=
StrategiesVector
(
bmm
_mod_node
)
# build handler
handler
=
Add
BMMFunctionHandler
(
node
=
linear
_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
handler
=
BMMFunctionHandler
(
node
=
bmm
_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
...
...
@@ -89,19 +113,15 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
assert
mapping
[
'other'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
4
,
16
,
8
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
(
bias_shape
)
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'bias'
].
logical_shape
==
torch
.
Size
([
8
,
8
])
assert
mapping
[
'output'
].
name
==
"addbmm"
assert
mapping
[
'output'
].
name
==
"bmm"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
8
,
8
])
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
8
,
8
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
for
name
in
strategy_name_list
:
print
(
name
)
# one batch dim
assert
'Sb0 = Sb0 x Sb0'
not
in
strategy_name_list
...
...
@@ -123,23 +143,21 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
for
strategy
in
strategies_vector
:
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x1'
)
other_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x2'
)
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bias'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'addbmm'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bmm'
)
# make sure the sharding matches across different operation data
assert
input_sharding_spec
.
sharding_sequence
[
1
]
==
output_sharding_spec
.
sharding_sequence
[
0
]
assert
input_sharding_spec
.
sharding_sequence
[
0
]
==
output_sharding_spec
.
sharding_sequence
[
0
]
assert
other_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
assert
other_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
def
check_1d_device_mesh
(
rank
,
module
,
bias_shape
,
world_size
,
port
):
def
check_1d_device_mesh
(
rank
,
module
,
bias_shape
,
using_kwargs
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
1
,
4
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
model
=
module
().
cuda
()
model
=
module
(
using_kwargs
).
cuda
()
x1
=
torch
.
rand
(
4
,
8
,
16
).
cuda
()
x2
=
torch
.
rand
(
4
,
16
,
8
).
cuda
()
bias
=
torch
.
rand
(
bias_shape
).
cuda
()
...
...
@@ -159,6 +177,14 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
meta_arg_names
=
meta_arg_names
)
tracer
=
ColoTracer
()
# graph():
# %bias : torch.Tensor [#users=1] = placeholder[target=bias]
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
# %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {})
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {})
# return add
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
'bias'
:
torch
.
rand
(
*
bias_shape
).
to
(
'meta'
),
...
...
@@ -166,11 +192,11 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
'x2'
:
torch
.
rand
(
4
,
16
,
8
).
to
(
'meta'
)
})
gm
=
ColoGraphModule
(
model
,
graph
)
linear
_mod_node
=
list
(
graph
.
nodes
)[
3
]
strategies_vector
=
StrategiesVector
(
linear
_mod_node
)
bmm
_mod_node
=
list
(
graph
.
nodes
)[
3
]
strategies_vector
=
StrategiesVector
(
bmm
_mod_node
)
# build handler
handler
=
Add
BMMFunctionHandler
(
node
=
linear
_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
handler
=
BMMFunctionHandler
(
node
=
bmm
_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
...
...
@@ -193,15 +219,9 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
assert
mapping
[
'other'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
4
,
16
,
8
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
(
bias_shape
)
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'bias'
].
logical_shape
==
torch
.
Size
([
8
,
8
])
assert
mapping
[
'output'
].
name
==
"addbmm"
assert
mapping
[
'output'
].
name
==
"bmm"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
8
,
8
])
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
8
,
8
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
...
...
@@ -213,14 +233,12 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
for
strategy
in
strategies_vector
:
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x1'
)
other_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x2'
)
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bias'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'addbmm'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bmm'
)
# make sure the sharding matches across different operation data
assert
input_sharding_spec
.
sharding_sequence
[
1
]
==
output_sharding_spec
.
sharding_sequence
[
0
]
assert
input_sharding_spec
.
sharding_sequence
[
0
]
==
output_sharding_spec
.
sharding_sequence
[
0
]
assert
other_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
assert
other_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
@
pytest
.
mark
.
skip
(
"skip due to bias cases not ready"
)
...
...
@@ -228,13 +246,15 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
@
pytest
.
mark
.
dist
@
parameterize
(
'module'
,
[
AddBMMTorchFunctionModule
,
AddBMMTensorMethodModule
])
@
parameterize
(
'bias_shape'
,
[[
8
],
[
1
,
8
],
[
8
,
8
]])
@
parameterize
(
'using_kwargs'
,
[
True
,
False
])
@
rerun_if_address_is_in_use
()
def
test_2d_device_mesh
(
module
,
bias_shape
):
def
test_2d_device_mesh
(
module
,
bias_shape
,
using_kwargs
):
world_size
=
4
run_func
=
partial
(
check_2d_device_mesh
,
module
=
module
,
bias_shape
=
bias_shape
,
world_size
=
world_size
,
using_kwargs
=
using_kwargs
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
@@ -244,12 +264,14 @@ def test_2d_device_mesh(module, bias_shape):
@
pytest
.
mark
.
dist
@
parameterize
(
'module'
,
[
AddBMMTorchFunctionModule
,
AddBMMTensorMethodModule
])
@
parameterize
(
'bias_shape'
,
[[
8
],
[
1
,
8
],
[
8
,
8
]])
@
parameterize
(
'using_kwargs'
,
[
True
,
False
])
@
rerun_if_address_is_in_use
()
def
test_1d_device_mesh
(
module
,
bias_shape
):
def
test_1d_device_mesh
(
module
,
bias_shape
,
using_kwargs
):
world_size
=
4
run_func
=
partial
(
check_1d_device_mesh
,
module
=
module
,
bias_shape
=
bias_shape
,
using_kwargs
=
using_kwargs
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment