Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c4b1b659
Commit
c4b1b659
authored
Jun 26, 2023
by
Frank Lee
Browse files
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
parent
92f67910
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
57 additions
and
93 deletions
+57
-93
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
+2
-2
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
+1
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
+7
-2
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
+1
-1
tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py
...est_tracer/test_torchaudio_model/test_torchaudio_model.py
+1
-1
tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
...t_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
+1
-1
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
...est_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
+1
-1
tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
...t_tracer/test_torchvision_model/test_torchvision_model.py
+1
-1
tests/test_lazy/lazy_init_utils.py
tests/test_lazy/lazy_init_utils.py
+3
-1
tests/test_lazy/test_distribute.py
tests/test_lazy/test_distribute.py
+13
-17
tests/test_lazy/test_models.py
tests/test_lazy/test_models.py
+1
-1
tests/test_tensor/test_dtensor/test_comm_spec.py
tests/test_tensor/test_dtensor/test_comm_spec.py
+7
-26
tests/test_tensor/test_dtensor/test_dtensor.py
tests/test_tensor/test_dtensor/test_dtensor.py
+1
-1
tests/test_tensor/test_dtensor/test_layout_converter.py
tests/test_tensor/test_dtensor/test_layout_converter.py
+11
-32
tests/test_tensor/test_shape_consistency.py
tests/test_tensor/test_shape_consistency.py
+4
-3
tests/test_tensor/test_sharded_linear.py
tests/test_tensor/test_sharded_linear.py
+1
-1
tests/test_tensor/test_sharding_spec.py
tests/test_tensor/test_sharding_spec.py
+1
-1
No files found.
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
View file @
c4b1b659
...
...
@@ -12,7 +12,7 @@ from tests.kit.model_zoo import model_zoo
def
test_gpt
():
sub_registry
=
model_zoo
.
get_sub_registry
(
'transformers_gpt'
)
for
name
,
(
model_fn
,
data_gen_fn
,
_
,
_
)
in
sub_registry
.
items
():
for
name
,
(
model_fn
,
data_gen_fn
,
_
,
_
,
_
)
in
sub_registry
.
items
():
model
=
model_fn
()
# TODO: support the following models
...
...
@@ -21,7 +21,7 @@ def test_gpt():
if
model
.
__class__
.
__name__
in
[
'GPT2DoubleHeadsModel'
]:
continue
trace_model_and_compare_output
(
model
,
data_gen_fn
)
trace_model_and_compare_output
(
model
,
data_gen_fn
,
ignore_data
=
[
'labels'
]
)
if
__name__
==
'__main__'
:
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
View file @
c4b1b659
...
...
@@ -12,7 +12,7 @@ from tests.kit.model_zoo import model_zoo
def
test_opt
():
sub_registry
=
model_zoo
.
get_sub_registry
(
'transformers_opt'
)
for
name
,
(
model_fn
,
data_gen_fn
,
_
,
_
)
in
sub_registry
.
items
():
for
name
,
(
model_fn
,
data_gen_fn
,
_
,
_
,
_
)
in
sub_registry
.
items
():
model
=
model_fn
()
trace_model_and_compare_output
(
model
,
data_gen_fn
)
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
View file @
c4b1b659
...
...
@@ -12,9 +12,14 @@ from tests.kit.model_zoo import model_zoo
def
test_t5
():
sub_registry
=
model_zoo
.
get_sub_registry
(
'transformers_t5'
)
for
name
,
(
model_fn
,
data_gen_fn
,
_
,
_
)
in
sub_registry
.
items
():
for
name
,
(
model_fn
,
data_gen_fn
,
_
,
_
,
_
)
in
sub_registry
.
items
():
if
name
==
"transformers_t5_for_conditional_generation"
:
# cannot trace for loss function yet
# so we use a data gen which does not produce labels
data_gen_fn
=
sub_registry
.
get
(
'transformers_t5'
)[
1
]
model
=
model_fn
()
trace_model_and_compare_output
(
model
,
data_gen_fn
)
trace_model_and_compare_output
(
model
,
data_gen_fn
,
ignore_data
=
[
'labels'
]
)
if
__name__
==
'__main__'
:
...
...
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
View file @
c4b1b659
...
...
@@ -56,7 +56,7 @@ def test_timm_models():
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'timm'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
attribute
)
in
sub_model_zoo
.
items
():
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
,
attribute
)
in
sub_model_zoo
.
items
():
data
=
data_gen_fn
()
if
attribute
is
not
None
and
attribute
.
has_control_flow
:
meta_args
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
data
.
items
()}
...
...
tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py
View file @
c4b1b659
...
...
@@ -16,7 +16,7 @@ def test_torchaudio_models():
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'torchaudio'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
attribute
)
in
sub_model_zoo
.
items
():
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
,
attribute
)
in
sub_model_zoo
.
items
():
model
=
model_fn
()
trace_and_compare
(
model
,
data_gen_fn
,
...
...
tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
View file @
c4b1b659
...
...
@@ -53,7 +53,7 @@ def test_torchrec_deepfm_models():
deepfm_models
=
model_zoo
.
get_sub_registry
(
'deepfm'
)
torch
.
backends
.
cudnn
.
deterministic
=
True
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
attribute
)
in
deepfm_models
.
items
():
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
attribute
)
in
deepfm_models
.
items
():
data
=
data_gen_fn
()
if
attribute
is
not
None
and
attribute
.
has_control_flow
:
meta_args
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
data
.
items
()}
...
...
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
View file @
c4b1b659
...
...
@@ -53,7 +53,7 @@ def test_torchrec_dlrm_models():
torch
.
backends
.
cudnn
.
deterministic
=
True
dlrm_models
=
model_zoo
.
get_sub_registry
(
'dlrm'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
attribute
)
in
dlrm_models
.
items
():
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
attribute
)
in
dlrm_models
.
items
():
data
=
data_gen_fn
()
# dlrm_interactionarch is not supported
...
...
tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
View file @
c4b1b659
...
...
@@ -10,7 +10,7 @@ def test_torchvision_models():
torch
.
backends
.
cudnn
.
deterministic
=
True
tv_sub_registry
=
model_zoo
.
get_sub_registry
(
'torchvision'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
model_attribute
)
in
tv_sub_registry
.
items
():
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
model_attribute
)
in
tv_sub_registry
.
items
():
data
=
data_gen_fn
()
if
model_attribute
is
not
None
and
model_attribute
.
has_stochastic_depth_prob
:
...
...
tests/test_lazy/lazy_init_utils.py
View file @
c4b1b659
...
...
@@ -6,6 +6,7 @@ import numpy as np
import
torch
from
packaging
import
version
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.lazy.lazy_init
import
LazyInitContext
,
LazyTensor
,
_MyTensor
from
colossalai.tensor.d_tensor
import
to_global
from
colossalai.tensor.d_tensor.layout
import
Layout
...
...
@@ -82,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
print
(
f
'
{
model
.
__class__
.
__name__
}
pass'
)
def
assert_dist_model_equal
(
model
:
torch
.
nn
.
Module
,
distributed_model
:
torch
.
nn
.
Module
,
layout_dict
:
dict
)
->
None
:
def
assert_dist_model_equal
(
model
:
torch
.
nn
.
Module
,
distributed_model
:
torch
.
nn
.
Module
,
device_mesh
:
DeviceMesh
,
sharding_spec_dict
:
dict
)
->
None
:
state
=
model
.
state_dict
()
distributed_state
=
distributed_model
.
state_dict
()
...
...
tests/test_lazy/test_distribute.py
View file @
c4b1b659
...
...
@@ -26,23 +26,19 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]:
return
dim
def
make_
layout
(
device_mesh
:
DeviceMesh
,
original_tensor
:
torch
.
Tensor
)
->
Layout
:
def
make_
sharding_spec
(
original_tensor
:
torch
.
Tensor
)
->
Layout
:
shard_dim
=
find_shard_dim
(
original_tensor
.
shape
)
dim_partition_dict
=
{
shard_dim
:
[
0
]}
if
shard_dim
is
not
None
else
{}
target_sharding_spec
=
ShardingSpec
(
dim_size
=
original_tensor
.
dim
(),
dim_partition_dict
=
dim_partition_dict
)
layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
target_sharding_spec
,
entire_shape
=
original_tensor
.
shape
)
return
layout
return
target_sharding_spec
def
_get_current_name
(
prefix
:
str
,
name
:
str
)
->
str
:
return
f
'
{
prefix
}
.
{
name
}
'
.
lstrip
(
'.'
)
def
generate_
layout
_dict
(
model
:
nn
.
Module
,
device_mesh
:
DeviceMesh
)
->
dict
:
layout
_dict
=
{}
def
generate_
sharding_spec
_dict
(
model
:
nn
.
Module
)
->
dict
:
sharding_spec
_dict
=
{}
@
torch
.
no_grad
()
def
generate_recursively
(
module
:
nn
.
Module
,
prefix
:
str
=
''
):
...
...
@@ -53,17 +49,17 @@ def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
# initialize tensors directly attached to the current module
for
name
,
param
in
module
.
named_parameters
(
recurse
=
False
):
if
isinstance
(
param
,
LazyTensor
):
layout
=
make_layout
(
device_mesh
,
param
)
layout
_dict
[
_get_current_name
(
prefix
,
name
)]
=
layout
sharding_spec
=
make_sharding_spec
(
param
)
sharding_spec
_dict
[
_get_current_name
(
prefix
,
name
)]
=
sharding_spec
for
name
,
buf
in
module
.
named_buffers
(
recurse
=
False
):
if
isinstance
(
buf
,
LazyTensor
):
layout
=
make_layout
(
device_mesh
,
buf
)
layout
_dict
[
_get_current_name
(
prefix
,
name
)]
=
layout
sharding_spec
=
make_sharding_spec
(
buf
)
sharding_spec
_dict
[
_get_current_name
(
prefix
,
name
)]
=
sharding_spec
generate_recursively
(
model
)
return
layout
_dict
return
sharding_spec
_dict
@
parameterize
(
'subset'
,
[
'torchvision'
,
'diffusers'
,
'timm'
,
'transformers'
,
'torchaudio'
,
'deepfm'
,
'dlrm'
])
...
...
@@ -75,7 +71,7 @@ def run_dist_lazy_init(subset, seed: int = 42):
for
name
,
entry
in
sub_model_zoo
.
items
():
# TODO(ver217): lazy init does not support weight norm, skip these models
if
name
in
(
'torchaudio_wav2vec2_base'
,
'torchaudio_hubert_base'
):
if
name
in
(
'torchaudio_wav2vec2_base'
,
'torchaudio_hubert_base'
)
or
name
.
startswith
(
'transformers_llama'
)
:
continue
print_rank_0
(
name
)
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
model_attr
=
entry
...
...
@@ -85,9 +81,9 @@ def run_dist_lazy_init(subset, seed: int = 42):
ctx
=
LazyInitContext
()
with
ctx
:
deferred_model
=
model_fn
()
layout
_dict
=
generate_
layout
_dict
(
deferred_model
,
device_mesh
)
ctx
.
distribute
(
deferred_model
,
layout
_dict
,
verbose
=
True
)
assert_dist_model_equal
(
model
,
deferred_model
,
layout
_dict
)
sharding_spec
_dict
=
generate_
sharding_spec
_dict
(
deferred_model
)
ctx
.
distribute
(
deferred_model
,
device_mesh
,
sharding_spec
_dict
,
verbose
=
True
)
assert_dist_model_equal
(
model
,
deferred_model
,
device_mesh
,
sharding_spec
_dict
)
def
run_dist
(
rank
,
world_size
,
port
)
->
None
:
...
...
tests/test_lazy/test_models.py
View file @
c4b1b659
...
...
@@ -10,7 +10,7 @@ def test_torchvision_models_lazy_init(subset):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
subset
)
for
name
,
entry
in
sub_model_zoo
.
items
():
# TODO(ver217): lazy init does not support weight norm, skip these models
if
name
in
(
'torchaudio_wav2vec2_base'
,
'torchaudio_hubert_base'
):
if
name
in
(
'torchaudio_wav2vec2_base'
,
'torchaudio_hubert_base'
)
or
name
.
startswith
(
'transformers_llama'
)
:
continue
check_lazy_init
(
entry
,
verbose
=
True
)
...
...
tests/test_tensor/test_dtensor/test_comm_spec.py
View file @
c4b1b659
...
...
@@ -122,23 +122,6 @@ def check_all_reduce_bwd(process_groups_dict, rank):
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
def
check_all_reduce_in_flatten_device_mesh
(
process_groups_dict
,
rank
):
# tensor to comm
tensor_to_comm
=
torch
.
ones
(
2
,
2
).
cuda
()
*
rank
# reduce through logical process axis 0 at flatten device mesh
# tensor to check
# tensor([[6., 6.],
# [6., 6.]])
tensor_to_check
=
torch
.
tensor
([[
6
,
6
],
[
6
,
6
]],
dtype
=
tensor_to_comm
.
dtype
).
cuda
()
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec
=
CommSpec
(
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
process_groups_dict
,
logical_process_axis
=
0
)
tensor_to_comm
=
comm_spec
.
covert_spec_to_action
(
tensor_to_comm
)
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
def
check_comm
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
...
...
@@ -150,24 +133,22 @@ def check_comm(rank, world_size, port):
# [[0, 1,
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
process_groups_dict
=
device_mesh
.
process_groups_dict
process_group_dict
=
device_mesh
.
_process_group_dict
[
rank
]
# test all gather
check_all_gather
(
process_group
s
_dict
,
rank
)
check_all_gather
(
process_group_dict
,
rank
)
# test shard
check_shard
(
process_group
s
_dict
,
rank
)
check_shard
(
process_group_dict
,
rank
)
# test all to all
check_all_to_all
(
process_group
s
_dict
,
rank
)
check_all_to_all
(
process_group_dict
,
rank
)
# test all reduce
check_all_reduce_fwd
(
process_group
s
_dict
,
rank
)
check_all_reduce_bwd
(
process_group
s
_dict
,
rank
)
check_all_reduce_fwd
(
process_group_dict
,
rank
)
check_all_reduce_bwd
(
process_group_dict
,
rank
)
flatten_process_groups_dict
=
device_mesh
.
flatten_device_mesh
.
process_groups_dict
# test all reduce in 1D flatten device mesh
check_all_reduce_in_flatten_device_mesh
(
flatten_process_groups_dict
,
rank
)
gpc
.
destroy
()
...
...
tests/test_tensor/test_dtensor/test_dtensor.py
View file @
c4b1b659
...
...
@@ -64,7 +64,7 @@ def check_dtensor(rank, world_size, port):
else
:
raise
ValueError
(
f
'rank
{
rank
}
is not in the device mesh'
)
dtensor_from_local
=
distribute_tensor
(
original_tensor
,
new_layout
)
dtensor_from_local
=
distribute_tensor
(
original_tensor
,
device_mesh
,
new_sharding_spec
)
if
rank
==
0
:
assert
dtensor_from_local
.
equal
(
original_tensor
.
narrow
(
0
,
0
,
1
))
...
...
tests/test_tensor/test_dtensor/test_layout_converter.py
View file @
c4b1b659
...
...
@@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
from
colossalai.tensor.d_tensor.sharding_spec
import
ShardingSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
entire
_shape
=
torch
.
Size
((
64
,
32
,
16
))
global
_shape
=
torch
.
Size
((
64
,
32
,
16
))
layout_converter
=
LayoutConverter
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
.
reshape
(
2
,
2
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
...
...
@@ -30,10 +30,7 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R
# device_mesh_shape: (2, 2)
sharding_spec
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_dict
)
layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec
,
entire_shape
=
entire_shape
)
layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec
,
global_shape
=
global_shape
)
rst_dict
=
layout_converter
.
all_gather_transform_layouts
(
layout
)
...
...
@@ -49,10 +46,7 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_all2all
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_dict_all2all
)
layout_all2all
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec_all2all
,
entire_shape
=
entire_shape
)
layout_all2all
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_all2all
,
global_shape
=
global_shape
)
rst_dict_all2all
=
layout_converter
.
all_to_all_transform_layout
(
layout_all2all
)
...
...
@@ -71,10 +65,7 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,R,R
# device_mesh_shape: (4, 4)
sharding_spec_shard
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_shard
)
shard_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec_shard
,
entire_shape
=
entire_shape
)
shard_layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_shard
,
global_shape
=
global_shape
)
rst_dict_shard
=
layout_converter
.
shard_transform_layout
(
shard_layout
)
...
...
@@ -100,19 +91,13 @@ def check_layout_converting(rank, world_size, port):
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_source
)
source_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec_source
,
entire_shape
=
entire_shape
)
source_layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_source
,
global_shape
=
global_shape
)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_target
)
target_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec_target
,
entire_shape
=
entire_shape
)
target_layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_target
,
global_shape
=
global_shape
)
transform_path
,
comm_action_sequence
=
layout_converter
.
layout_converting
(
source_layout
,
target_layout
)
...
...
@@ -137,7 +122,7 @@ def check_layout_converting(rank, world_size, port):
assert
comm_action_sequence
[
2
].
shard_dim
==
0
assert
comm_action_sequence
[
2
].
logical_process_axis
==
1
# checkout cached_spec_pairs_transform_path
# checkout c
h
ached_spec_pairs_transform_path
assert
layout_converter
.
cached_solution
[(
'[R, S01, R]'
,
'[S01, R, R]'
)][
0
]
==
transform_path
assert
layout_converter
.
cached_solution
[(
'[R, S01, R]'
,
'[S01, R, R]'
)][
1
]
==
comm_action_sequence
...
...
@@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port):
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_source
)
source_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec_source
,
entire_shape
=
entire_shape
)
source_layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_source
,
global_shape
=
global_shape
)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_target
)
target_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec_target
,
entire_shape
=
entire_shape
)
target_layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_target
,
global_shape
=
global_shape
)
original_tensor
=
torch
.
rand
(
entire
_shape
).
cuda
()
original_tensor
=
torch
.
rand
(
global
_shape
).
cuda
()
# tensor_to_apply: [R, S01, R]
tensor_to_apply
=
original_tensor
.
narrow
(
1
,
rank
*
8
,
8
)
...
...
tests/test_tensor/test_shape_consistency.py
View file @
c4b1b659
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
,
CollectiveCommPattern
import
torch
from
colossalai.tensor.sharding_spec
import
_DimSpec
,
ShardingSpec
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
,
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
_DimSpec
physical_mesh_id
=
torch
.
arange
(
0
,
16
)
.
reshape
(
2
,
8
)
physical_mesh_id
=
torch
.
arange
(
0
,
16
)
mesh_shape
=
(
4
,
4
)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
...
...
tests/test_tensor/test_sharded_linear.py
View file @
c4b1b659
...
...
@@ -26,7 +26,7 @@ def run_dist(rank, world_size, port):
# the mesh is in the following topo
# [[0, 1],
# [2, 3]]
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
.
reshape
(
2
,
2
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
row_id
=
rank
//
2
...
...
tests/test_tensor/test_sharding_spec.py
View file @
c4b1b659
...
...
@@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
def
test_sharding_spec
():
physical_mesh_id
=
torch
.
arange
(
0
,
16
)
.
reshape
(
2
,
8
)
physical_mesh_id
=
torch
.
arange
(
0
,
16
)
mesh_shape
=
(
4
,
4
)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment