Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ddcf58ca
Unverified
Commit
ddcf58ca
authored
Jun 09, 2023
by
Frank Lee
Committed by
GitHub
Jun 09, 2023
Browse files
Revert "[sync] sync feature/shardformer with develop"
parent
24651fdd
Changes
48
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
94 additions
and
46 deletions
+94
-46
tests/test_lazy/lazy_init_utils.py
tests/test_lazy/lazy_init_utils.py
+3
-7
tests/test_lazy/test_distribute.py
tests/test_lazy/test_distribute.py
+16
-12
tests/test_tensor/test_dtensor/test_comm_spec.py
tests/test_tensor/test_dtensor/test_comm_spec.py
+26
-7
tests/test_tensor/test_dtensor/test_dtensor.py
tests/test_tensor/test_dtensor/test_dtensor.py
+13
-4
tests/test_tensor/test_dtensor/test_layout_converter.py
tests/test_tensor/test_dtensor/test_layout_converter.py
+31
-10
tests/test_tensor/test_shape_consistency.py
tests/test_tensor/test_shape_consistency.py
+3
-4
tests/test_tensor/test_sharded_linear.py
tests/test_tensor/test_sharded_linear.py
+1
-1
tests/test_tensor/test_sharding_spec.py
tests/test_tensor/test_sharding_spec.py
+1
-1
No files found.
tests/test_lazy/lazy_init_utils.py
View file @
ddcf58ca
...
...
@@ -6,9 +6,7 @@ import numpy as np
import
torch
from
packaging
import
version
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.lazy.lazy_init
import
LazyInitContext
,
LazyTensor
,
_MyTensor
from
colossalai.tensor.d_tensor.layout
import
Layout
from
colossalai.tensor.d_tensor.layout_converter
import
to_global
from
tests.kit.model_zoo.registry
import
ModelAttribute
...
...
@@ -83,8 +81,7 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
print
(
f
'
{
model
.
__class__
.
__name__
}
pass'
)
def
assert_dist_model_equal
(
model
:
torch
.
nn
.
Module
,
distributed_model
:
torch
.
nn
.
Module
,
device_mesh
:
DeviceMesh
,
sharding_spec_dict
:
dict
)
->
None
:
def
assert_dist_model_equal
(
model
:
torch
.
nn
.
Module
,
distributed_model
:
torch
.
nn
.
Module
,
layout_dict
:
dict
)
->
None
:
state
=
model
.
state_dict
()
distributed_state
=
distributed_model
.
state_dict
()
...
...
@@ -94,7 +91,6 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
assert
n1
==
n2
t1
=
t1
.
cuda
()
t2
=
t2
.
cuda
()
if
n2
in
sharding_spec_dict
:
layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_dict
[
n2
],
global_shape
=
t1
.
shape
)
t2
=
to_global
(
t2
,
layout
)
if
n2
in
layout_dict
:
t2
=
to_global
(
t2
,
layout_dict
[
n2
])
assert
torch
.
equal
(
t1
,
t2
),
f
'
{
n1
}
{
t1
}
vs
{
t2
}
'
tests/test_lazy/test_distribute.py
View file @
ddcf58ca
...
...
@@ -26,19 +26,23 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]:
return
dim
def
make_
sharding_spec
(
original_tensor
:
torch
.
Tensor
)
->
Layout
:
def
make_
layout
(
device_mesh
:
DeviceMesh
,
original_tensor
:
torch
.
Tensor
)
->
Layout
:
shard_dim
=
find_shard_dim
(
original_tensor
.
shape
)
dim_partition_dict
=
{
shard_dim
:
[
0
]}
if
shard_dim
is
not
None
else
{}
target_sharding_spec
=
ShardingSpec
(
dim_size
=
original_tensor
.
dim
(),
dim_partition_dict
=
dim_partition_dict
)
return
target_sharding_spec
layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
target_sharding_spec
,
entire_shape
=
original_tensor
.
shape
)
return
layout
def
_get_current_name
(
prefix
:
str
,
name
:
str
)
->
str
:
return
f
'
{
prefix
}
.
{
name
}
'
.
lstrip
(
'.'
)
def
generate_
sharding_spec
_dict
(
model
:
nn
.
Module
)
->
dict
:
sharding_spec
_dict
=
{}
def
generate_
layout
_dict
(
model
:
nn
.
Module
,
device_mesh
:
DeviceMesh
)
->
dict
:
layout
_dict
=
{}
@
torch
.
no_grad
()
def
generate_recursively
(
module
:
nn
.
Module
,
prefix
:
str
=
''
):
...
...
@@ -49,17 +53,17 @@ def generate_sharding_spec_dict(model: nn.Module) -> dict:
# initialize tensors directly attached to the current module
for
name
,
param
in
module
.
named_parameters
(
recurse
=
False
):
if
isinstance
(
param
,
LazyTensor
):
sharding_spec
=
make_sharding_spec
(
param
)
sharding_spec
_dict
[
_get_current_name
(
prefix
,
name
)]
=
sharding_spec
layout
=
make_layout
(
device_mesh
,
param
)
layout
_dict
[
_get_current_name
(
prefix
,
name
)]
=
layout
for
name
,
buf
in
module
.
named_buffers
(
recurse
=
False
):
if
isinstance
(
buf
,
LazyTensor
):
sharding_spec
=
make_sharding_spec
(
buf
)
sharding_spec
_dict
[
_get_current_name
(
prefix
,
name
)]
=
sharding_spec
layout
=
make_layout
(
device_mesh
,
buf
)
layout
_dict
[
_get_current_name
(
prefix
,
name
)]
=
layout
generate_recursively
(
model
)
return
sharding_spec
_dict
return
layout
_dict
@
parameterize
(
'subset'
,
[
'torchvision'
,
'diffusers'
,
'timm'
,
'transformers'
,
'torchaudio'
,
'deepfm'
,
'dlrm'
])
...
...
@@ -81,9 +85,9 @@ def run_dist_lazy_init(subset, seed: int = 42):
ctx
=
LazyInitContext
()
with
ctx
:
deferred_model
=
model_fn
()
sharding_spec
_dict
=
generate_
sharding_spec
_dict
(
deferred_model
)
ctx
.
distribute
(
deferred_model
,
device_mesh
,
sharding_spec
_dict
,
verbose
=
True
)
assert_dist_model_equal
(
model
,
deferred_model
,
device_mesh
,
sharding_spec
_dict
)
layout
_dict
=
generate_
layout
_dict
(
deferred_model
,
device_mesh
)
ctx
.
distribute
(
deferred_model
,
layout
_dict
,
verbose
=
True
)
assert_dist_model_equal
(
model
,
deferred_model
,
layout
_dict
)
def
run_dist
(
rank
,
world_size
,
port
)
->
None
:
...
...
tests/test_tensor/test_dtensor/test_comm_spec.py
View file @
ddcf58ca
...
...
@@ -125,6 +125,23 @@ def check_all_reduce_bwd(process_groups_dict, rank):
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
def
check_all_reduce_in_flatten_device_mesh
(
process_groups_dict
,
rank
):
# tensor to comm
tensor_to_comm
=
torch
.
ones
(
2
,
2
).
cuda
()
*
rank
# reduce through logical process axis 0 at flatten device mesh
# tensor to check
# tensor([[6., 6.],
# [6., 6.]])
tensor_to_check
=
torch
.
tensor
([[
6
,
6
],
[
6
,
6
]],
dtype
=
tensor_to_comm
.
dtype
).
cuda
()
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec
=
CommSpec
(
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
process_groups_dict
,
logical_process_axis
=
0
)
tensor_to_comm
=
comm_spec
.
covert_spec_to_action
(
tensor_to_comm
)
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
def
check_comm
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
...
...
@@ -136,22 +153,24 @@ def check_comm(rank, world_size, port):
# [[0, 1,
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
process_group_dict
=
device_mesh
.
_process_group_dict
[
rank
]
process_groups_dict
=
device_mesh
.
process_groups_dict
# test all gather
check_all_gather
(
process_group_dict
,
rank
)
check_all_gather
(
process_group
s
_dict
,
rank
)
# test shard
check_shard
(
process_group_dict
,
rank
)
check_shard
(
process_group
s
_dict
,
rank
)
# test all to all
check_all_to_all
(
process_group_dict
,
rank
)
check_all_to_all
(
process_group
s
_dict
,
rank
)
# test all reduce
check_all_reduce_fwd
(
process_group_dict
,
rank
)
check_all_reduce_bwd
(
process_group_dict
,
rank
)
check_all_reduce_fwd
(
process_group
s
_dict
,
rank
)
check_all_reduce_bwd
(
process_group
s
_dict
,
rank
)
flatten_process_groups_dict
=
device_mesh
.
flatten_device_mesh
.
process_groups_dict
# test all reduce in 1D flatten device mesh
check_all_reduce_in_flatten_device_mesh
(
flatten_process_groups_dict
,
rank
)
gpc
.
destroy
()
...
...
tests/test_tensor/test_dtensor/test_dtensor.py
View file @
ddcf58ca
...
...
@@ -31,9 +31,13 @@ def check_dtensor(rank, world_size, port):
device_mesh
=
DeviceMesh
(
torch
.
Tensor
([
0
,
1
,
2
,
3
]),
(
2
,
2
),
init_process_group
=
True
)
target_sharding_spec
=
ShardingSpec
(
dim_size
=
original_tensor
.
dim
(),
dim_partition_dict
=
{
0
:
[
0
]})
d_tensor
=
DTensor
(
original_tensor
,
device_mesh
,
target_sharding_spec
)
layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
target_sharding_spec
,
entire_shape
=
original_tensor
.
shape
)
d_tensor
=
DTensor
(
original_tensor
,
layout
)
assert
d_tensor
.
global
_shape
==
original_tensor
.
shape
assert
d_tensor
.
entire
_shape
==
original_tensor
.
shape
assert
d_tensor
.
data_type
==
original_tensor
.
dtype
if
rank
in
(
0
,
1
):
...
...
@@ -53,7 +57,12 @@ def check_dtensor(rank, world_size, port):
raise
ValueError
(
f
'rank
{
rank
}
is not in the device mesh'
)
new_sharding_spec
=
ShardingSpec
(
dim_size
=
original_tensor
.
dim
(),
dim_partition_dict
=
{
0
:
[
0
,
1
]})
d_tensor
.
layout_convert
(
device_mesh
,
new_sharding_spec
)
new_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
new_sharding_spec
,
entire_shape
=
original_tensor
.
shape
)
d_tensor
.
layout_convert
(
new_layout
)
if
rank
==
0
:
assert
d_tensor
.
local_tensor
.
equal
(
original_tensor
.
narrow
(
0
,
0
,
1
))
...
...
@@ -66,7 +75,7 @@ def check_dtensor(rank, world_size, port):
else
:
raise
ValueError
(
f
'rank
{
rank
}
is not in the device mesh'
)
dtensor_from_local
=
distribute_tensor
(
original_tensor
,
device_mesh
,
new_sharding_spec
)
dtensor_from_local
=
distribute_tensor
(
original_tensor
,
new_layout
)
if
rank
==
0
:
assert
dtensor_from_local
.
local_tensor
.
equal
(
original_tensor
.
narrow
(
0
,
0
,
1
))
...
...
tests/test_tensor/test_dtensor/test_layout_converter.py
View file @
ddcf58ca
...
...
@@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
from
colossalai.tensor.d_tensor.sharding_spec
import
DimSpec
,
ShardingSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
global
_shape
=
torch
.
Size
((
64
,
32
,
16
))
entire
_shape
=
torch
.
Size
((
64
,
32
,
16
))
layout_converter
=
LayoutConverter
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
.
reshape
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
...
...
@@ -30,7 +30,10 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R
# device_mesh_shape: (2, 2)
sharding_spec
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_dict
)
layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec
,
global_shape
=
global_shape
)
layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec
,
entire_shape
=
entire_shape
)
rst_dict
=
layout_converter
.
all_gather_transform_layouts
(
layout
)
...
...
@@ -46,7 +49,10 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_all2all
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_dict_all2all
)
layout_all2all
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_all2all
,
global_shape
=
global_shape
)
layout_all2all
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec_all2all
,
entire_shape
=
entire_shape
)
rst_dict_all2all
=
layout_converter
.
all_to_all_transform_layout
(
layout_all2all
)
...
...
@@ -65,7 +71,10 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,R,R
# device_mesh_shape: (4, 4)
sharding_spec_shard
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_shard
)
shard_layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_shard
,
global_shape
=
global_shape
)
shard_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec_shard
,
entire_shape
=
entire_shape
)
rst_dict_shard
=
layout_converter
.
shard_transform_layout
(
shard_layout
)
...
...
@@ -91,13 +100,19 @@ def check_layout_converting(rank, world_size, port):
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_source
)
source_layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_source
,
global_shape
=
global_shape
)
source_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec_source
,
entire_shape
=
entire_shape
)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_target
)
target_layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_target
,
global_shape
=
global_shape
)
target_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec_target
,
entire_shape
=
entire_shape
)
transform_path
,
comm_action_sequence
=
layout_converter
.
layout_converting
(
source_layout
,
target_layout
)
...
...
@@ -144,15 +159,21 @@ def check_layout_converting_apply(rank, world_size, port):
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_source
)
source_layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_source
,
global_shape
=
global_shape
)
source_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec_source
,
entire_shape
=
entire_shape
)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target
=
ShardingSpec
(
dim_size
=
3
,
dim_partition_dict
=
dim_partition_target
)
target_layout
=
Layout
(
device_mesh
=
device_mesh
,
sharding_spec
=
sharding_spec_target
,
global_shape
=
global_shape
)
target_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
sharding_spec_target
,
entire_shape
=
entire_shape
)
original_tensor
=
torch
.
rand
(
global
_shape
).
cuda
()
original_tensor
=
torch
.
rand
(
entire
_shape
).
cuda
()
# tensor_to_apply: [R, S01, R]
tensor_to_apply
=
original_tensor
.
narrow
(
1
,
rank
*
8
,
8
)
...
...
tests/test_tensor/test_shape_consistency.py
View file @
ddcf58ca
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
,
CollectiveCommPattern
import
torch
from
colossalai.tensor.sharding_spec
import
_DimSpec
,
ShardingSpec
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
,
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
_DimSpec
physical_mesh_id
=
torch
.
arange
(
0
,
16
)
physical_mesh_id
=
torch
.
arange
(
0
,
16
)
.
reshape
(
2
,
8
)
mesh_shape
=
(
4
,
4
)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
...
...
tests/test_tensor/test_sharded_linear.py
View file @
ddcf58ca
...
...
@@ -26,7 +26,7 @@ def run_dist(rank, world_size, port):
# the mesh is in the following topo
# [[0, 1],
# [2, 3]]
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
.
reshape
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
row_id
=
rank
//
2
...
...
tests/test_tensor/test_sharding_spec.py
View file @
ddcf58ca
...
...
@@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
def
test_sharding_spec
():
physical_mesh_id
=
torch
.
arange
(
0
,
16
)
physical_mesh_id
=
torch
.
arange
(
0
,
16
)
.
reshape
(
2
,
8
)
mesh_shape
=
(
4
,
4
)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment