Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a39a5c66
Unverified
Commit
a39a5c66
authored
Sep 04, 2023
by
Hongxin Liu
Committed by
GitHub
Sep 04, 2023
Browse files
Merge branch 'main' into feature/shardformer
parents
e79b1e80
aaeb520c
Changes
138
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3 additions
and
2102 deletions
+3
-2102
tests/test_ddp/test_ddp_state_dict.py
tests/test_ddp/test_ddp_state_dict.py
+0
-67
tests/test_ddp/test_reducer.py
tests/test_ddp/test_reducer.py
+0
-47
tests/test_ops/test_addmm_tp.py
tests/test_ops/test_addmm_tp.py
+0
-73
tests/test_ops/test_embedding_bag_tp.py
tests/test_ops/test_embedding_bag_tp.py
+0
-43
tests/test_ops/test_embedding_tp.py
tests/test_ops/test_embedding_tp.py
+0
-44
tests/test_ops/test_linear_tp.py
tests/test_ops/test_linear_tp.py
+0
-48
tests/test_ops/test_loss_func.py
tests/test_ops/test_loss_func.py
+0
-48
tests/test_ops/test_op.py
tests/test_ops/test_op.py
+0
-87
tests/test_ops/test_view.py
tests/test_ops/test_view.py
+0
-97
tests/test_pipeline/test_pipelinable.py
tests/test_pipeline/test_pipelinable.py
+2
-0
tests/test_shardformer/test_model/test_shard_gpt2.py
tests/test_shardformer/test_model/test_shard_gpt2.py
+1
-0
tests/test_tensor/core/test_tensor.py
tests/test_tensor/core/test_tensor.py
+0
-153
tests/test_tensor/model/test_gpt2.py
tests/test_tensor/model/test_gpt2.py
+0
-148
tests/test_tensor/model/test_model.py
tests/test_tensor/model/test_model.py
+0
-334
tests/test_tensor/model/test_module_spec.py
tests/test_tensor/model/test_module_spec.py
+0
-227
tests/test_tensor/test_colo_checkpoint_tools.py
tests/test_tensor/test_colo_checkpoint_tools.py
+0
-41
tests/test_tensor/test_context.py
tests/test_tensor/test_context.py
+0
-64
tests/test_tensor/test_sharded_linear.py
tests/test_tensor/test_sharded_linear.py
+0
-232
tests/test_tensor/test_tp_with_zero.py
tests/test_tensor/test_tp_with_zero.py
+0
-143
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+0
-206
No files found.
tests/test_ddp/test_ddp_state_dict.py
deleted
100644 → 0
View file @
e79b1e80
from
collections
import
OrderedDict
import
pytest
import
torch
import
colossalai
from
colossalai.nn.parallel
import
ColoDDP
from
colossalai.tensor
import
ColoParameter
,
ProcessGroup
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ColoInitContext
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
check_state_dict_equal
(
state_dict
:
OrderedDict
,
other_state_dict
:
OrderedDict
):
for
(
k1
,
t1
),
(
k2
,
t2
)
in
zip
(
state_dict
.
items
(),
other_state_dict
.
items
()):
assert
k1
==
k2
if
t1
.
device
!=
t2
.
device
:
temp_t2
=
t2
.
to
(
t1
.
device
)
else
:
temp_t2
=
t2
assert
torch
.
equal
(
t1
,
temp_t2
),
"
\t
{}
\n\t
{}"
.
format
(
t1
,
temp_t2
)
def
init_ddp
(
module
:
torch
.
nn
.
Module
)
->
ColoDDP
:
pg
=
ProcessGroup
()
return
ColoDDP
(
module
,
process_group
=
pg
)
def
run_ddp_state_dict
():
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
torch_model
=
model_builder
().
cuda
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
init_ddp
(
model
)
torch_state_dict
=
torch_model
.
state_dict
()
for
param
in
model
.
parameters
():
if
isinstance
(
param
,
ColoParameter
):
assert
param
.
get_process_group
()
is
not
None
model
.
load_state_dict
(
torch_state_dict
)
for
param
in
model
.
parameters
():
if
isinstance
(
param
,
ColoParameter
):
assert
param
.
get_process_group
()
is
not
None
state_dict
=
model
.
state_dict
()
check_state_dict_equal
(
torch_state_dict
,
state_dict
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_ddp_state_dict
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
def
test_state_dict
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
'__main__'
:
test_state_dict
(
2
)
tests/test_ddp/test_reducer.py
deleted
100644 → 0
View file @
e79b1e80
from
functools
import
partial
import
pytest
import
torch
import
torch.distributed
as
dist
from
torch.distributed.distributed_c10d
import
_get_default_group
import
colossalai
from
colossalai.nn.parallel.reducer
import
Reducer
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
REDUCE_CNT
=
0
def
check_eq
(
grad
,
grad_clone
):
global
REDUCE_CNT
print
(
f
'Rank
{
dist
.
get_rank
()
}
check
{
REDUCE_CNT
}
'
)
REDUCE_CNT
+=
1
assert
torch
.
allclose
(
grad
,
grad_clone
)
def
run_reducer
():
grads
=
[
torch
.
rand
(
64
,
i
+
1
,
device
=
get_current_device
())
for
i
in
range
(
10
)]
grads_clone
=
[
g
.
clone
().
detach
()
for
g
in
grads
]
for
g
in
grads
:
dist
.
all_reduce
(
g
)
reducer
=
Reducer
(
bucket_size_mb
=
1
)
for
g
,
g_clone
in
zip
(
grads
,
grads_clone
):
reducer
.
all_reduce_async
(
g_clone
,
_get_default_group
(),
partial
(
check_eq
,
g
))
reducer
.
flush
()
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_reducer
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
def
test_reducer
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
'__main__'
:
test_reducer
(
2
)
tests/test_ops/test_addmm_tp.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
import
torch.nn
as
nn
import
colossalai
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ProcessGroup
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.test_tensor.common_utils
import
split_param_col_tp1d
,
split_param_row_tp1d
,
tensor_equal
,
tensor_shard_equal
class
Conv1D
(
nn
.
Module
):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""
def
__init__
(
self
,
nf
,
nx
):
super
().
__init__
()
self
.
nf
=
nf
w
=
torch
.
empty
(
nx
,
nf
)
nn
.
init
.
normal_
(
w
,
std
=
0.02
)
self
.
weight
=
nn
.
Parameter
(
w
)
self
.
bias
=
nn
.
Parameter
(
torch
.
ones
(
nf
))
def
forward
(
self
,
x
):
size_out
=
x
.
size
()[:
-
1
]
+
(
self
.
nf
,)
x
=
torch
.
addmm
(
self
.
bias
,
x
.
view
(
-
1
,
x
.
size
(
-
1
)),
self
.
weight
)
x
=
x
.
view
(
size_out
)
return
x
def
run_with_spec
(
spec_init_func
,
split_bias
):
model
=
Conv1D
(
4
,
16
).
cuda
()
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
weight
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
weight
.
detach
()),
ColoTensorSpec
(
pg
))
bias
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
bias
.
detach
()),
ColoTensorSpec
(
pg
))
spec_init_func
(
weight
,
pg
)
if
split_bias
:
spec_init_func
(
bias
,
pg
)
x
=
torch
.
rand
(
2
,
16
).
cuda
()
out
=
model
(
x
)
colo_out
=
torch
.
addmm
(
bias
,
x
,
weight
)
colo_out
=
colo_out
.
to_replicate
()
assert
tensor_equal
(
out
,
colo_out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
tensor_shard_equal
(
model
.
bias
.
grad
,
bias
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_with_spec
(
spec_init_func
=
split_param_row_tp1d
,
split_bias
=
False
)
run_with_spec
(
spec_init_func
=
split_param_col_tp1d
,
split_bias
=
True
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_addmm_1d
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
'__main__'
:
test_addmm_1d
(
4
)
tests/test_ops/test_embedding_bag_tp.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
from
torch.nn
import
functional
as
F
import
colossalai
from
colossalai.tensor
import
ColoParameter
,
ColoTensorSpec
,
ProcessGroup
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.test_tensor.common_utils
import
split_param_col_tp1d
,
tensor_equal
,
tensor_shard_equal
def
run_with_spec
(
spec_init_func
):
pg
=
ProcessGroup
(
tp_degree
=
torch
.
distributed
.
get_world_size
())
model
=
torch
.
nn
.
EmbeddingBag
(
10
,
4
).
cuda
()
weight
=
ColoParameter
(
model
.
weight
.
clone
(),
True
,
ColoTensorSpec
(
pg
))
spec_init_func
(
weight
,
pg
)
inputs
=
torch
.
tensor
([
1
,
2
,
4
,
5
,
4
,
3
,
2
,
9
]).
cuda
()
offsets
=
torch
.
tensor
([
0
,
4
]).
cuda
()
out
=
model
(
inputs
,
offsets
=
offsets
)
colo_out
=
F
.
embedding_bag
(
inputs
,
weight
,
offsets
=
offsets
)
assert
tensor_equal
(
out
,
colo_out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_with_spec
(
split_param_col_tp1d
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_embedding_bag_1d
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
'__main__'
:
test_embedding_bag_1d
(
4
)
tests/test_ops/test_embedding_tp.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
from
torch.nn
import
functional
as
F
import
colossalai
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ProcessGroup
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.test_tensor.common_utils
import
split_param_col_tp1d
,
split_param_row_tp1d
,
tensor_equal
,
tensor_shard_equal
def
run_with_spec
(
spec_init_func
,
pg
:
ProcessGroup
):
model
=
torch
.
nn
.
Embedding
(
12
,
32
).
cuda
()
weight
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
weight
.
detach
()),
ColoTensorSpec
(
pg
))
spec_init_func
(
weight
,
pg
)
x
=
torch
.
tensor
((
0
,
3
,
6
,
9
)).
cuda
()
out
=
model
(
x
)
colo_out
=
F
.
embedding
(
x
,
weight
)
assert
tensor_equal
(
out
,
colo_out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
# compare grad inside a TP group
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
def
run_dist
(
rank
,
world_size
,
port
):
# config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
run_with_spec
(
split_param_row_tp1d
,
pg
)
run_with_spec
(
split_param_col_tp1d
,
pg
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_embedding_1d
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
'__main__'
:
test_embedding_1d
(
4
)
tests/test_ops/test_linear_tp.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
import
torch.nn.functional
as
F
import
colossalai
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ProcessGroup
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.test_tensor.common_utils
import
split_param_col_tp1d
,
split_param_row_tp1d
,
tensor_equal
,
tensor_shard_equal
def
run_with_spec
(
spec_init_func
,
split_bias
):
pg
=
ProcessGroup
(
tp_degree
=
torch
.
distributed
.
get_world_size
())
model
=
torch
.
nn
.
Linear
(
4
,
8
).
cuda
()
weight
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
weight
.
detach
()),
ColoTensorSpec
(
pg
))
bias
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
bias
.
detach
()),
ColoTensorSpec
(
pg
))
spec_init_func
(
weight
,
pg
)
if
split_bias
:
spec_init_func
(
bias
,
pg
)
x
=
torch
.
rand
(
2
,
4
).
cuda
()
out
=
model
(
x
)
colo_out
=
F
.
linear
(
x
,
weight
,
bias
)
colo_out
=
colo_out
.
to_replicate
()
assert
tensor_equal
(
out
,
colo_out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
assert
tensor_shard_equal
(
model
.
bias
.
grad
,
bias
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_with_spec
(
spec_init_func
=
split_param_col_tp1d
,
split_bias
=
False
)
run_with_spec
(
spec_init_func
=
split_param_row_tp1d
,
split_bias
=
True
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_linear_1d
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
'__main__'
:
test_linear_1d
(
4
)
tests/test_ops/test_loss_func.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
import
torch.nn.functional
as
F
import
colossalai
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils
import
get_current_device
def
check_cross_entropy
():
input_t
=
torch
.
randn
(
4
,
4
,
device
=
get_current_device
(),
requires_grad
=
True
)
input_ct
=
torch
.
randn
(
4
,
4
,
device
=
get_current_device
(),
requires_grad
=
True
)
with
torch
.
no_grad
():
input_ct
.
copy_
(
input_t
)
target
=
torch
.
randint
(
4
,
(
4
,),
dtype
=
torch
.
int64
,
device
=
get_current_device
())
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
input_t_colo
=
ColoTensor
.
from_torch_tensor
(
tensor
=
input_ct
,
spec
=
ColoTensorSpec
(
pg
))
input_shard
=
input_t_colo
.
redistribute
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]))
input_shard
.
set_tensor_spec
(
dist_spec
=
None
,
compute_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
F
.
cross_entropy
(
input_t
,
target
)
output_colo
=
F
.
cross_entropy
(
input_shard
,
target
)
assert
torch
.
allclose
(
output_colo
,
output
)
output
.
backward
()
output_colo
.
backward
()
assert
torch
.
allclose
(
input_t
.
grad
,
input_ct
.
grad
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_cross_entropy
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
def
test_loss_func
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
'__main__'
:
test_loss_func
(
1
)
tests/test_ops/test_op.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
import
torch.nn.functional
as
F
from
torch.nn
import
Parameter
import
colossalai
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils
import
get_current_device
def
_run_layer_norm
():
ln_op
=
torch
.
nn
.
LayerNorm
(
2
,
3
,
device
=
get_current_device
())
input_t
=
torch
.
randn
(
3
,
2
,
device
=
get_current_device
())
pg
=
ProcessGroup
(
tp_degree
=
torch
.
distributed
.
get_world_size
())
input_t_colo
=
ColoTensor
.
from_torch_tensor
(
input_t
.
clone
().
detach
(),
ColoTensorSpec
(
pg
))
# prepare colossalai LN
weight
=
ColoTensor
(
Parameter
(
ln_op
.
weight
.
detach
()),
ColoTensorSpec
(
pg
))
bias
=
ColoTensor
(
Parameter
(
ln_op
.
bias
.
detach
()),
ColoTensorSpec
(
pg
))
output
=
ln_op
(
input_t
)
output_colo
=
F
.
layer_norm
(
input_t_colo
,
ln_op
.
normalized_shape
,
weight
,
bias
,
ln_op
.
eps
)
assert
torch
.
allclose
(
output_colo
,
output
)
torch
.
mean
(
output
).
backward
()
torch
.
mean
(
output_colo
).
backward
()
assert
torch
.
allclose
(
ln_op
.
weight
.
grad
,
weight
.
grad
)
def
check_spec_eq
(
tensor
,
other
):
assert
isinstance
(
tensor
,
ColoTensor
)
and
isinstance
(
other
,
ColoTensor
)
for
k
in
dir
(
tensor
.
dist_spec
):
if
not
k
.
startswith
(
'__'
):
assert
hasattr
(
other
.
dist_spec
,
k
),
f
"
{
k
}
"
assert
getattr
(
tensor
.
dist_spec
,
k
)
==
getattr
(
other
.
dist_spec
,
k
)
def
check_element_wise_ops
():
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
t
=
torch
.
rand
(
2
,
2
)
x
=
ColoTensor
(
t
,
spec
=
ColoTensorSpec
(
pg
,
ShardSpec
([
0
],
[
pg
.
tp_world_size
()])))
check_spec_eq
(
x
,
x
.
cuda
())
assert
torch
.
equal
(
x
.
cuda
(),
t
.
cuda
())
check_spec_eq
(
x
,
torch
.
abs
(
x
))
assert
torch
.
equal
(
torch
.
abs
(
x
),
torch
.
abs
(
t
))
check_spec_eq
(
x
,
F
.
sigmoid
(
x
))
assert
torch
.
equal
(
F
.
sigmoid
(
x
),
F
.
sigmoid
(
t
))
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_element_wise_ops
()
_run_layer_norm
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
@
rerun_if_address_is_in_use
()
def
test_element_wise_ops
(
world_size
):
spawn
(
run_dist
,
world_size
)
def
run_dist2
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
_run_layer_norm
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
])
@
rerun_if_address_is_in_use
()
def
test_ln
(
world_size
):
spawn
(
run_dist2
,
world_size
)
def
check_all
():
test_element_wise_ops
(
2
)
if
__name__
==
'__main__'
:
check_all
()
tests/test_ops/test_view.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
import
torch.distributed
as
dist
import
colossalai
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.tensor.distspec
import
DistPlacementPattern
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils
import
get_current_device
from
tests.test_tensor.common_utils
import
debug_print
,
split_param_col_tp1d
,
split_param_row_tp1d
def
exam_view_core
(
pg
):
# the case of replicated ColoTensors
x
=
torch
.
randn
(
4
,
4
).
cuda
()
x_colo
=
ColoTensor
(
x
,
ColoTensorSpec
(
pg
))
y
=
x
.
view
(
2
,
-
1
,
2
)
y_colo
=
x_colo
.
view
(
2
,
-
1
,
2
)
assert
torch
.
all
(
y
==
y_colo
)
assert
y_colo
.
dist_spec
.
placement
==
DistPlacementPattern
.
REPLICATE
# the perfect case of col-sliced ColoTensors
split_param_col_tp1d
(
x_colo
,
pg
)
z
=
x
.
view
(
torch
.
Size
((
2
,
1
,
2
,
-
1
)))
z_colo
=
x_colo
.
view
(
torch
.
Size
((
2
,
1
,
2
,
-
1
)))
if
dist
.
get_rank
()
==
0
:
z
=
z
[:,
:,
:,
0
:
2
]
else
:
z
=
z
[:,
:,
:,
2
:]
assert
torch
.
all
(
z
==
z_colo
)
assert
z_colo
.
dist_spec
==
x_colo
.
dist_spec
# the perfect case of row-sliced ColoTensors
split_param_row_tp1d
(
x_colo
,
pg
)
z
=
x
.
view
(
torch
.
Size
((
-
1
,
2
,
2
)))
z_colo
=
x_colo
.
view
(
torch
.
Size
((
-
1
,
2
,
2
)))
if
dist
.
get_rank
()
==
0
:
z
=
z
[
0
:
2
,
:,
:]
else
:
z
=
z
[
2
:,
:,
:]
assert
torch
.
all
(
z
==
z_colo
)
assert
z_colo
.
dist_spec
==
x_colo
.
dist_spec
# the normal case of row-sliced ColoTensors
z
=
x
.
view
(
-
1
,
2
,
2
,
2
)
z_colo
=
x_colo
.
view
(
-
1
,
2
,
2
,
2
)
assert
torch
.
all
(
z
==
z_colo
)
assert
y_colo
.
dist_spec
.
placement
==
DistPlacementPattern
.
REPLICATE
def
exam_view_autograd
(
pg
):
x
=
torch
.
randn
(
8
,
2
,
device
=
get_current_device
(),
requires_grad
=
True
)
y
=
torch
.
randn
(
8
,
2
,
device
=
get_current_device
(),
requires_grad
=
True
)
with
torch
.
no_grad
():
y
.
copy_
(
x
)
y
=
ColoTensor
(
y
,
ColoTensorSpec
(
pg
))
y_slice
=
y
.
redistribute
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]))
xx
=
x
.
view
(
2
,
2
,
-
1
)
yy_slice
=
y_slice
.
view
(
2
,
2
,
-
1
)
yy
=
yy_slice
.
to_replicate
()
grad
=
torch
.
randn
(
2
,
2
,
4
,
device
=
get_current_device
())
xx
.
backward
(
grad
)
yy
.
backward
(
grad
)
assert
torch
.
all
(
x
.
grad
==
y
.
grad
)
def
exam_view_errors
(
pg
):
x
=
torch
.
randn
(
8
,
2
,
device
=
get_current_device
())
x
=
ColoTensor
(
x
,
ColoTensorSpec
(
pg
))
split_param_row_tp1d
(
x
,
pg
)
x
.
view
(
'a'
,
'b'
,
'c'
)
x
.
view
(
8
,
-
1
)
x
.
view
([
-
2
,
-
2
,
-
2
])
x
.
view
((
-
1
,
-
1
,
-
1
))
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
torch
.
distributed
.
get_world_size
())
exam_view_core
(
pg
)
exam_view_autograd
(
pg
)
# exam_view_errors(pg)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
@
rerun_if_address_is_in_use
()
def
test_view
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
'__main__'
:
test_view
(
2
)
tests/test_pipeline/test_pipelinable.py
View file @
a39a5c66
import
pytest
import
torch
from
colossalai.pipeline.pipelinable
import
PipelinableContext
...
...
@@ -48,6 +49,7 @@ def run_pipelinable(rank, world_size, port):
assert
layers_count_in_part_0
+
layers_count_in_part_1
==
pipelinable
.
layers_count
@
pytest
.
mark
.
skip
(
reason
=
"this is useless"
)
@
rerun_if_address_is_in_use
()
def
test_pipelinable
():
spawn
(
run_pipelinable
,
1
)
...
...
tests/test_shardformer/test_model/test_shard_gpt2.py
View file @
a39a5c66
...
...
@@ -219,6 +219,7 @@ def check_gpt2_3d(rank, world_size, port):
run_gpt2_3d_test
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
...
...
tests/test_tensor/core/test_tensor.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
from
numpy
import
allclose
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
,
distspec
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
def
_run_tensor_indexing
():
pg
=
ProcessGroup
()
torch_t
=
torch
.
randn
(
2
,
3
)
colo_t
=
ColoTensor
(
torch_t
,
ColoTensorSpec
(
pg
))
assert
allclose
(
torch_t
[:,
1
],
colo_t
[:,
1
])
def
_run_wrapped_tensor_func
():
pg
=
ProcessGroup
()
t_ref
=
torch
.
randn
(
4
,
5
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
ColoTensorSpec
(
pg
))
# non-func attr
assert
t
.
is_cuda
==
t_ref
.
is_cuda
# return 1 torch.Tensor
t_abs
=
t
.
abs
()
assert
isinstance
(
t_abs
,
ColoTensor
)
and
torch
.
equal
(
t_abs
,
t_ref
.
abs
())
# return 1 non-torch.Tensor
assert
t
.
dim
()
==
t_ref
.
dim
()
# return >1 torch.Tensor
assert
isinstance
(
t
,
ColoTensor
)
t_split1
,
t_split2
=
t
.
split
(
2
)
assert
isinstance
(
t_split1
,
ColoTensor
)
and
isinstance
(
t_split2
,
ColoTensor
),
f
"
{
type
(
t_split1
)
}
{
type
(
t_split2
)
}
"
def
_run_operand
(
world_size
):
pg
=
ProcessGroup
()
t_ref
=
torch
.
randn
(
4
,
5
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
ColoTensorSpec
(
pg
))
t_ref_res
=
t_ref
+
t_ref
t_res
=
t
+
t
assert
isinstance
(
t_res
,
ColoTensor
)
assert
torch
.
allclose
(
t_ref_res
,
t_res
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
ColoTensorSpec
(
pg
))
t
.
set_dist_spec
(
ShardSpec
([
0
],
[
world_size
]))
t_new
=
torch
.
zeros_like
(
t
)
assert
isinstance
(
t_new
,
ColoTensor
)
assert
t_new
.
is_sharded
()
#### Test Distributed init a Colotensor
def
_run_view
(
world_size
):
t_ref
=
torch
.
randn
(
4
,
5
)
rank
=
gpc
.
get_global_rank
()
pg
=
ProcessGroup
(
rank
,
list
(
range
(
world_size
)),
tp_degree
=
world_size
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
,
ColoTensorSpec
(
pg
,
dist_attr
=
ShardSpec
(
dims
=
[
0
],
num_partitions
=
[
pg
.
tp_world_size
()])))
assert
t
.
size_global
()[
0
]
==
4
*
world_size
assert
t
.
size_global
(
1
)
==
5
assert
t
.
size_global
()
==
torch
.
Size
([
4
*
world_size
,
5
])
t
=
t
.
view
(
4
*
5
*
world_size
)
assert
t
.
shape
==
torch
.
Size
([
4
*
5
*
world_size
])
def
_run_tensor_shard_init
(
world_size
):
t_ref
=
torch
.
randn
(
4
,
5
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
shard_attr
=
ShardSpec
(
dims
=
[
0
],
num_partitions
=
[
pg
.
tp_world_size
()])
tensor_spec
=
ColoTensorSpec
(
pg
,
dist_attr
=
shard_attr
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
tensor_spec
)
t
.
set_dist_spec
(
ReplicaSpec
())
assert
t
.
shape
==
torch
.
Size
((
4
*
world_size
,
5
)),
f
"
{
t
.
shape
}
vs (
{
4
*
world_size
,
5
}
)"
def
_run_tensor_replicated_init
(
world_size
):
t_ref
=
torch
.
randn
(
4
*
world_size
,
5
)
pg
=
ProcessGroup
()
spec
=
ColoTensorSpec
(
pg
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
spec
)
assert
t
.
shape
==
torch
.
Size
((
4
*
world_size
,
5
)),
f
"
{
t
.
shape
}
"
def
_run_process_group
(
world_size
):
pg1
=
ProcessGroup
()
pg2
=
ProcessGroup
()
assert
pg1
==
pg2
def
_run_redistributed
(
world_size
):
if
world_size
!=
4
:
return
pg1
=
ProcessGroup
(
tp_degree
=
2
,
dp_degree
=
2
)
pg2
=
ProcessGroup
(
tp_degree
=
4
,
dp_degree
=
1
)
spec1
=
ColoTensorSpec
(
pg1
)
t1
=
ColoTensor
.
from_torch_tensor
(
torch
.
randn
(
2
,
3
,
4
),
spec1
)
t1
=
t1
.
redistribute
(
ShardSpec
([
0
],
[
pg1
.
tp_world_size
()]))
assert
t1
.
is_sharded
()
t1
=
t1
.
redistribute
(
ShardSpec
([
-
1
],
[
pg2
.
tp_world_size
()]),
pg2
)
assert
t1
.
is_sharded
()
pg3
=
ProcessGroup
(
tp_degree
=
1
,
dp_degree
=
4
)
t1
=
t1
.
redistribute
(
ReplicaSpec
(),
pg3
)
assert
t1
.
is_replicate
()
def
_run_set_tensor_spec
(
world_size
):
if
world_size
!=
4
:
return
pg
=
ProcessGroup
(
tp_degree
=
2
,
dp_degree
=
2
)
spec1
=
ColoTensorSpec
(
pg
)
t1
=
ColoTensor
.
from_torch_tensor
(
torch
.
randn
(
2
,
3
,
4
),
spec1
)
dist_spec2
=
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()])
assert
t1
.
is_replicate
()
t1
.
set_dist_spec
(
dist_spec2
)
assert
t1
.
is_shard_1dcol
()
def
run_dist_tests
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
_run_tensor_shard_init
(
world_size
)
_run_tensor_replicated_init
(
world_size
)
_run_view
(
world_size
)
_run_process_group
(
world_size
)
_run_tensor_indexing
()
_run_operand
(
world_size
)
_run_wrapped_tensor_func
()
_run_redistributed
(
world_size
)
_run_set_tensor_spec
(
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
def
test_dist_cases
(
world_size
):
spawn
(
run_dist_tests
,
world_size
)
if
__name__
==
'__main__'
:
test_dist_cases
(
4
)
tests/test_tensor/model/test_gpt2.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ColoInitContext
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
(
debug_print
,
set_seed
,
split_param_col_tp1d
,
split_param_row_tp1d
,
tensor_equal
,
tensor_shard_equal
,
)
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
tensor_spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
for
n
,
p
in
model
.
named_parameters
():
p
.
set_process_group
(
pg
)
if
'weight'
in
n
and
'ln'
not
in
n
:
p
.
set_tensor_spec
(
*
tensor_spec
)
def
init_1d_col_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
for
n
,
p
in
model
.
named_parameters
():
p
.
set_process_group
(
pg
)
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
p
.
set_tensor_spec
(
*
spec
)
def
init_megatron_spec
(
model
,
pg
:
ProcessGroup
):
for
mn
,
module
in
model
.
named_modules
():
# debug_print([0], mn)
for
pn
,
param
in
module
.
named_parameters
(
recurse
=
False
):
# debug_print([0], '\t', pn, param.compute_spec, param.shape)
param
.
set_process_group
(
pg
)
if
'mlp.c_fc'
in
mn
:
if
'weight'
in
pn
or
'bias'
in
pn
:
split_param_col_tp1d
(
param
,
pg
)
param
.
compute_spec
.
set_output_replicate
(
False
)
else
:
raise
RuntimeError
elif
'mlp.c_proj'
in
mn
:
if
'weight'
in
pn
:
split_param_row_tp1d
(
param
,
pg
)
else
:
assert
'bias'
in
pn
elif
'wte'
in
mn
or
'wpe'
in
mn
:
assert
'weight'
in
pn
split_param_col_tp1d
(
param
,
pg
)
elif
'c_attn'
in
mn
or
'c_proj'
in
mn
:
split_param_col_tp1d
(
param
,
pg
)
# debug_print([0], '\t', param.compute_spec, param.shape)
def
check_param_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
assert
pg
.
tp_local_rank
()
is
not
None
,
f
"
{
pg
.
rank
()
}
{
pg
.
tp_world_size
()
}
{
pg
.
_tp_degree
}
{
pg
.
tp_local_rank
()
}
1"
assert
pg
.
tp_world_size
()
is
not
None
assert
tensor_shard_equal
(
torch_p
,
p
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
def
check_grad_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
assert
tensor_shard_equal
(
torch_p
.
grad
,
p
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
def
run_gpt
(
init_spec_func
,
use_ddp
):
world_size
=
torch
.
distributed
.
get_world_size
()
# build a PG with TP and DP hybrid
pg
=
ProcessGroup
(
dp_degree
=
(
2
if
(
use_ddp
and
world_size
>=
2
)
else
1
))
# set seed make processes of the same tp group use the same seed
# set_seed(pg.tp_local_rank())
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
# make sure torch_model and model has the same parameter values
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model
.
cuda
()
torch_model
=
model_builder
().
cuda
()
if
use_ddp
:
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
model
=
ColoDDP
(
model
,
process_group
=
pg
)
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
)
init_spec_func
(
model
,
pg
)
check_param_equal
(
model
,
torch_model
,
pg
)
# close the dropout in eval mode
model
.
eval
()
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
torch
.
distributed
.
barrier
()
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
colo_input
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
logits
=
model
(
colo_input
)
torch_logits
=
torch_model
(
input_ids
)
assert
tensor_equal
(
torch_logits
,
logits
),
f
"
{
torch_logits
-
logits
}
"
loss
=
criterion
(
logits
,
input_ids
)
torch_loss
=
criterion
(
torch_logits
,
input_ids
)
if
use_ddp
:
model
.
backward
(
loss
)
else
:
loss
.
backward
()
torch_loss
.
backward
()
check_grad_equal
(
model
,
torch_model
,
pg
)
if
i
>
0
:
break
set_seed
(
313
)
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
):
if
use_ddp
and
world_size
==
1
:
return
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# Comments below tests for speed concern
# run_gpt(init_1d_row_spec, use_ddp)
# run_gpt(init_1d_col_spec, use_ddp)
run_gpt
(
init_megatron_spec
,
use_ddp
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
,
True
])
@
rerun_if_address_is_in_use
()
def
test_gpt
(
world_size
,
use_ddp
):
spawn
(
run_dist
,
world_size
,
use_ddp
=
use_ddp
)
if
__name__
==
'__main__'
:
test_gpt
(
4
,
use_ddp
=
False
)
tests/test_tensor/model/test_model.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
import
colossalai
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.tensor
import
ColoTensor
,
ProcessGroup
from
colossalai.tensor.colo_parameter
import
ColoParameter
from
colossalai.testing
import
free_port
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ColoInitContext
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
(
check_equal
,
set_seed
,
split_param_col_tp1d
,
split_param_row_tp1d
,
tensor_shard_equal
,
)
def
run_1d_hybrid_tp
(
model_name
):
# A simple net with two stacked nn.Linear
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
set_seed
(
1
)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
True
)
if
rank
==
0
:
model_torch
=
model_builder
(
checkpoint
=
True
)
model_torch
=
model_torch
.
cuda
()
optimizer_torch
=
ColossalaiOptimizer
(
torch
.
optim
.
SGD
(
model_torch
.
parameters
(),
lr
=
0.1
))
# Make two models have the same init params
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_torch
.
parameters
()):
p2
.
data
.
copy_
(
p1
.
data
)
else
:
model_torch
=
None
optimizer_torch
=
None
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
if
'bert'
==
model_name
:
for
name
,
p
in
model
.
named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
continue
# num_class = type_vocab_size = 2 | (8, 2)
if
'classifier'
in
name
and
'weight'
in
name
:
split_param_col_tp1d
(
p
,
pg
)
# num_class = vocab_size = 30524 | (30524, 8)
elif
'word_embeddings'
in
name
and
'weight'
in
name
:
split_param_row_tp1d
(
p
,
pg
)
# num_class = seq_len = 512 | (512, 8)
elif
'position_embeddings'
in
name
and
'weight'
in
name
:
split_param_row_tp1d
(
p
,
pg
)
# num_class = type_vocab_size = 2 | (2, 8)
elif
'token_type_embeddings'
in
name
and
'weight'
in
name
:
split_param_col_tp1d
(
p
,
pg
)
elif
"simple_net"
==
model_name
:
# A naive way to set spec for all weights in Linear
for
name
,
p
in
model
.
named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
continue
if
'embed'
in
name
and
'weight'
in
name
:
split_param_col_tp1d
(
p
,
pg
)
if
'proj1'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
split_param_row_tp1d
(
p
,
pg
)
if
'proj2'
in
name
and
'weight'
in
name
:
split_param_col_tp1d
(
p
,
pg
)
if
'classifier'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
split_param_row_tp1d
(
p
,
pg
)
model
=
model
.
cuda
()
model
.
eval
()
if
rank
==
0
:
model_torch
.
eval
()
colo_optimizer
=
ColossalaiOptimizer
(
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
))
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
# Zero grad
colo_optimizer
.
zero_grad
()
if
rank
==
0
:
optimizer_torch
.
zero_grad
()
torch
.
distributed
.
barrier
()
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
torch
.
distributed
.
broadcast
(
data
,
0
,
group
=
pg
.
tp_process_group
())
torch
.
distributed
.
broadcast
(
label
,
0
,
group
=
pg
.
tp_process_group
())
# Bcast rank0 data to all processes
if
criterion
:
output
=
model
(
data
)
loss
=
criterion
(
output
,
label
)
else
:
output
=
model
(
data
,
label
)
loss
=
output
# Test output
if
rank
==
0
:
if
criterion
:
output_torch
=
model_torch
(
data
)
loss_torch
=
criterion
(
output_torch
,
label
)
else
:
output_torch
=
model_torch
(
data
,
label
)
loss_torch
=
output_torch
assert
torch
.
allclose
(
loss
,
loss_torch
,
rtol
=
1e-2
),
f
"model_name
{
model_name
}
failed"
torch
.
distributed
.
barrier
()
loss
.
backward
()
colo_optimizer
.
step
()
if
rank
==
0
:
loss_torch
.
backward
()
optimizer_torch
.
step
()
with
torch
.
no_grad
():
# check param
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
model_torch
.
parameters
()):
assert
tensor_shard_equal
(
torch_p
,
p
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
torch
.
distributed
.
barrier
()
if
i
>
5
:
break
# Test the overrided parameters() and named_parameters() member functions
def
test_model_parameters
():
colossalai
.
launch
(
config
=
{},
rank
=
0
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
# build a module with 2 Linear, 4 parameters in total.
class
Net
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
fcs
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
2
,
3
),
torch
.
nn
.
Linear
(
3
,
2
))
self
.
extra_param
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
2
))
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
Net
()
param_cnt
=
0
for
name
,
p
in
model
.
named_parameters
():
param_cnt
+=
1
assert
param_cnt
==
5
for
name
,
colo_p
in
model
.
named_parameters
():
assert
colo_p
.
is_model_data
()
param_cnt
=
0
for
name
,
p
in
model
.
named_parameters
(
recurse
=
False
):
param_cnt
+=
1
assert
param_cnt
==
1
param_cnt
=
0
for
p
in
model
.
fcs
[
0
].
parameters
(
recurse
=
False
):
param_cnt
+=
1
assert
param_cnt
==
2
def
test_colo_optimizer
():
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'simple_net'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
set_seed
(
1
)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
True
)
colo_optimizer
=
ColossalaiOptimizer
(
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
))
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
colo_optimizer
.
zero_grad
()
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
# Bcast rank0 data to all processes
if
criterion
:
output
=
model
(
data
)
loss
=
criterion
(
output
,
label
)
else
:
output
=
model
(
data
,
label
)
loss
=
output
loss
.
backward
()
colo_optimizer
.
step
()
if
i
>
5
:
break
def
run_1d_row_tp
(
model_name
:
str
):
# A simple net with two stacked nn.Linear
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
rank
=
torch
.
distributed
.
get_rank
()
set_seed
(
1
)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
True
)
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
set_seed
(
1
)
if
rank
==
0
:
model_torch
=
model_builder
(
checkpoint
=
True
)
model_torch
=
model_torch
.
cuda
()
# A naive way to set spec for all weights in Linear
for
mo_name
,
module
in
model
.
named_modules
():
# print(mo_name)
for
pa_name
,
param
in
module
.
named_parameters
(
recurse
=
False
):
# print('\t', pa_name, param.shape)
if
not
isinstance
(
param
,
ColoTensor
):
continue
if
'weight'
in
pa_name
:
if
'embed'
in
mo_name
and
'token'
not
in
mo_name
and
'LayerNorm'
not
in
mo_name
:
split_param_row_tp1d
(
param
,
pg
)
elif
'LayerNorm'
not
in
mo_name
and
'ln'
not
in
mo_name
:
split_param_col_tp1d
(
param
,
pg
)
model
=
model
.
cuda
()
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
torch
.
distributed
.
broadcast
(
data
,
0
,
group
=
pg
.
tp_process_group
())
torch
.
distributed
.
broadcast
(
label
,
0
,
group
=
pg
.
tp_process_group
())
# Bcast rank0 data to all processes
if
criterion
:
output
=
model
(
data
)
loss
=
criterion
(
output
,
label
)
else
:
output
=
model
(
data
,
label
)
loss
=
output
# For reference
if
rank
==
0
:
if
criterion
:
output_torch
=
model_torch
(
data
)
loss_torch
=
criterion
(
output_torch
,
label
)
else
:
output_torch
=
model_torch
(
data
,
label
)
loss_torch
=
output_torch
assert
torch
.
allclose
(
loss
,
loss_torch
,
rtol
=
1e-2
)
torch
.
distributed
.
barrier
()
loss
.
backward
()
if
rank
==
0
:
loss_torch
.
backward
()
torch
.
distributed
.
barrier
()
if
i
>
5
:
break
def
_run_pretrain_load
():
from
transformers
import
BertForMaskedLM
set_seed
(
1
)
model_pretrained
=
BertForMaskedLM
.
from_pretrained
(
'bert-base-uncased'
)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
BertForMaskedLM
.
from_pretrained
(
'bert-base-uncased'
)
model_pretrained
=
model_pretrained
.
cuda
()
model
=
model
.
cuda
()
dict_pretrained
=
{}
dict_col
=
{}
c_ref
=
0
for
name
,
param
in
model_pretrained
.
named_parameters
():
dict_pretrained
[
name
]
=
param
c_ref
+=
1
c1
=
0
c2
=
0
for
name
,
param
in
model
.
named_parameters
():
if
isinstance
(
param
,
ColoParameter
):
c1
+=
1
else
:
c2
+=
1
dict_col
[
name
]
=
param
assert
c_ref
==
c1
assert
c2
==
0
if
model_pretrained
.
cls
.
predictions
.
decoder
.
bias
is
model_pretrained
.
cls
.
predictions
.
bias
:
assert
model
.
cls
.
predictions
.
decoder
.
bias
is
model
.
cls
.
predictions
.
bias
for
name
,
param
in
dict_pretrained
.
items
():
check_equal
(
param
,
dict_col
[
name
])
def
run_model_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# Comment below test for speed consideration
# for name in ['bert', 'simple_net']:
# run_1d_row_tp(name)
for
name
in
[
'bert'
,
'simple_net'
]:
run_1d_hybrid_tp
(
name
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_model
(
world_size
):
spawn
(
run_model_dist
,
world_size
)
def
run_pretrain_load_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
_run_pretrain_load
()
# The test case has to download huggingface pretrained models from the internet
# So we manually trigger the test.
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_pretrain_load
(
world_size
):
spawn
(
run_pretrain_load_dist
,
world_size
)
if
__name__
==
'__main__'
:
# test_model_parameters()
# test_colo_optimizer()
test_model
(
4
)
# test_pretrain_load(4)
tests/test_tensor/model/test_module_spec.py
deleted
100644 → 0
View file @
e79b1e80
from
copy
import
deepcopy
import
pytest
import
torch
import
colossalai
from
colossalai.nn.parallel.layers
import
check_colo_module
,
init_colo_module
from
colossalai.tensor
import
(
ColoTensor
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
,
distspec
,
)
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ColoInitContext
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
set_seed
,
tensor_equal
,
tensor_shard_equal
def
run_model_with_spec
(
mode
,
model_name
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
rank
=
pg
.
rank
()
set_seed
(
1
)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
False
)
if
rank
==
0
:
model_seq
=
model_builder
(
checkpoint
=
False
)
model_seq
=
model_seq
.
cuda
()
# Make two models have the same init params
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_seq
.
parameters
()):
p2
.
data
.
copy_
(
p1
.
data
)
compute_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
)
# Not all layers in Bert can be mod by 4.
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
if
'bert'
==
model_name
:
if
'col'
==
mode
:
init_colo_module
(
model
.
bert
.
embeddings
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
.
bert
.
encoder
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
.
classifier
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
'row'
)
elif
'row'
==
mode
:
init_colo_module
(
model
.
bert
.
embeddings
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
'col'
)
init_colo_module
(
model
.
bert
.
encoder
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
.
classifier
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
elif
'simple_net'
==
model_name
:
init_colo_module
(
model
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
model
=
model
.
cuda
()
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
torch
.
distributed
.
broadcast
(
data
,
0
,
group
=
pg
.
tp_process_group
())
torch
.
distributed
.
broadcast
(
label
,
0
,
group
=
pg
.
tp_process_group
())
if
criterion
:
output
=
model
(
data
)
loss
=
criterion
(
output
,
label
)
else
:
output
=
model
(
data
,
label
)
loss
=
output
# For reference
if
rank
==
0
:
if
criterion
:
output_seq
=
model_seq
(
data
)
loss_seq
=
criterion
(
output_seq
,
label
)
else
:
output_seq
=
model_seq
(
data
,
label
)
loss_seq
=
output_seq
if
rank
==
0
:
with
torch
.
no_grad
():
assert
torch
.
allclose
(
loss
,
loss_seq
,
rtol
=
1e-2
)
loss
.
backward
()
if
rank
==
0
:
loss_seq
.
backward
()
with
torch
.
no_grad
():
# check param
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_seq
.
parameters
()):
if
p1
.
size
()
==
p2
.
size
():
assert
torch
.
allclose
(
p1
,
p2
)
else
:
if
p1
.
size
(
-
1
)
<
p2
.
size
(
-
1
):
# col
world_size
=
p2
.
size
(
-
1
)
//
p1
.
size
(
-
1
)
split_p2
=
torch
.
chunk
(
p2
,
world_size
,
dim
=-
1
)[
0
]
elif
p1
.
size
(
0
)
<
p2
.
size
(
0
):
# row
world_size
=
p2
.
size
(
0
)
//
p1
.
size
(
0
)
split_p2
=
torch
.
chunk
(
p2
,
world_size
,
dim
=
0
)[
0
]
assert
torch
.
allclose
(
p1
,
split_p2
)
if
i
>
3
:
break
def
run_linear_with_spec
(
mode
):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
torch
.
nn
.
Linear
(
4
,
8
)
model_handy
=
deepcopy
(
model
)
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
compute_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
)
init_colo_module
(
model
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
x
=
torch
.
rand
(
2
,
4
).
cuda
()
colo_x
=
ColoTensor
.
from_torch_tensor
(
x
,
ColoTensorSpec
(
pg
))
out
=
model
(
x
)
colo_out
=
model_handy
(
colo_x
)
assert
tensor_equal
(
out
,
colo_out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
assert
tensor_shard_equal
(
model_handy
.
weight
.
grad
,
model
.
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
assert
tensor_shard_equal
(
model_handy
.
bias
.
grad
,
model
.
bias
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
def
run_check_shared_param
():
from
transformers
import
BertConfig
,
BertForMaskedLM
hidden_dim
=
8
num_head
=
4
sequence_length
=
12
num_layer
=
2
vocab_size
=
24
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
rank
=
pg
.
rank
()
config
=
BertConfig
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_dim
,
intermediate_size
=
hidden_dim
*
4
,
num_attention_heads
=
num_head
,
max_position_embeddings
=
sequence_length
,
num_hidden_layers
=
num_layer
,
hidden_dropout_prob
=
0.
,
attention_probs_dropout_prob
=
0.
)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
BertForMaskedLM
(
config
)
model
=
model
.
cuda
()
compute_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
)
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
assert
len
(
model
.
cls
.
predictions
.
decoder
.
bias
.
shared_param_modules
)
==
2
# They are all Linear, so both row is allowed. This should pass check.
init_colo_module
(
model
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
'row'
)
# This should be detected by check because you can not set weight as row while set bias as col.
col_spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
# TODO(jiaruifang) optimize this line
if
not
model
.
cls
.
predictions
.
bias
.
has_initialized
:
model
.
cls
.
predictions
.
bias
.
pg
=
pg
model
.
cls
.
predictions
.
bias
.
dist_spec
=
ReplicaSpec
()
model
.
cls
.
predictions
.
bias
.
has_initialized
=
True
model
.
cls
.
predictions
.
bias
.
set_tensor_spec
(
*
col_spec
)
try
:
check_colo_module
(
model
.
cls
.
predictions
.
decoder
,
pg
=
pg
,
recursive
=
False
)
except
Exception
as
e
:
assert
'incorrectly sharded'
in
str
(
e
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_linear_with_spec
(
'col'
)
run_linear_with_spec
(
'row'
)
def
run_dist_model
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
for
model_name
in
[
'simple_net'
,
'bert'
]:
run_model_with_spec
(
'col'
,
model_name
)
run_model_with_spec
(
'row'
,
model_name
)
def
run_dist_check
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_check_shared_param
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
skip
(
"for higher testing speed"
)
@
rerun_if_address_is_in_use
()
def
test_module_linear_1d
(
world_size
):
spawn
(
run_dist
,
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
skip
(
"for higher testing speed"
)
@
rerun_if_address_is_in_use
()
def
test_module_model
(
world_size
):
spawn
(
run_dist_model
,
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
skip
(
"for higher testing speed"
)
@
rerun_if_address_is_in_use
()
def
test_module_check
(
world_size
):
spawn
(
run_dist_check
,
world_size
)
if
__name__
==
'__main__'
:
test_module_linear_1d
(
4
)
tests/test_tensor/test_colo_checkpoint_tools.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
import
torch.distributed
as
dist
import
colossalai
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.checkpoint.utils
import
gather_tensor
,
scatter_tensor
from
tests.test_tensor.common_utils
import
tensor_shard_equal
def
run_dist
(
rank
,
world_size
,
port
,
dp_degree
,
tp_degree
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
dp_degree
=
dp_degree
,
tp_degree
=
tp_degree
)
x
=
torch
.
randn
(
4
,
4
)
param
=
ColoTensor
(
torch
.
nn
.
Parameter
(
x
),
spec
=
ColoTensorSpec
(
pg
))
spec
=
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
)
param
.
set_tensor_spec
(
*
spec
)
gather_tensor
(
param
)
if
dist
.
get_rank
()
==
0
:
assert
torch
.
all
(
x
==
param
)
else
:
assert
tensor_shard_equal
(
x
,
param
.
data
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
dist
.
barrier
()
scatter_tensor
(
param
,
spec
[
0
])
assert
tensor_shard_equal
(
x
,
param
.
data
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
assert
param
.
requires_grad
is
True
dist
.
barrier
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
4
])
@
rerun_if_address_is_in_use
()
def
test_checkpoint
(
world_size
):
spawn
(
run_dist
,
world_size
,
dp_degree
=
2
,
tp_degree
=
world_size
//
2
)
if
__name__
==
'__main__'
:
test_checkpoint
(
world_size
=
4
)
tests/test_tensor/test_context.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
import
colossalai
from
colossalai.tensor
import
(
ColoParameter
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
,
)
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ColoInitContext
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
set_seed
def
run_colo_init_context
(
rank
:
int
,
world_size
:
int
,
port
:
int
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# make sure seed of each process is the same, so the params are consistent among processes and the params are exactly replicated.
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
# keep parameters replicated during init
with
ColoInitContext
(
device
=
get_current_device
()):
model1
=
model_builder
()
# shard the parameters during init
set_seed
(
42
)
shard_spec
=
ReplicaSpec
()
# If using ShardSpec, the assertations will failed.
# But it is not a bug, the initialized values are not consist with the original one.
# shard_spec = ShardSpec(dims=[0], num_partitions=[world_size])
default_pg
=
ProcessGroup
(
tp_degree
=
world_size
)
with
ColoInitContext
(
device
=
get_current_device
(),
default_pg
=
default_pg
,
default_dist_spec
=
shard_spec
):
model2
=
model_builder
()
# reshard both models
new_shard
=
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
world_size
])
for
p1
,
p2
in
zip
(
model1
.
parameters
(),
model2
.
parameters
()):
p1
:
ColoParameter
=
p1
p1
.
set_process_group
(
ProcessGroup
(
tp_degree
=
world_size
))
p1
.
set_dist_spec
(
new_shard
)
p2
.
set_dist_spec
(
new_shard
)
for
p1
,
p2
in
zip
(
model1
.
parameters
(),
model2
.
parameters
()):
assert
(
torch
.
allclose
(
p1
,
p2
))
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_colo_init_context
(
world_size
):
spawn
(
run_colo_init_context
,
world_size
)
if
__name__
==
'__main__'
:
test_colo_init_context
(
2
)
tests/test_tensor/test_sharded_linear.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
import
torch.nn.functional
as
F
import
colossalai
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.nn._ops._utils
import
gather_forward_split_backward
from
colossalai.tensor
import
ColoParameter
,
ColoTensor
,
ProcessGroup
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# create mlp vars
x
=
ColoTensor
.
from_torch_tensor
(
torch
.
rand
(
4
,
4
,
8
,
requires_grad
=
True
)).
cuda
()
w
=
ColoParameter
.
from_torch_tensor
(
torch
.
rand
(
16
,
8
,
requires_grad
=
True
)).
cuda
()
b
=
ColoParameter
.
from_torch_tensor
(
torch
.
rand
(
16
,
requires_grad
=
True
)).
cuda
()
# run normal forward
out
=
F
.
linear
(
x
,
w
,
b
)
# create mesh meta
# the mesh is in the following topo
# [[0, 1],
# [2, 3]]
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
row_id
=
rank
//
2
column_id
=
rank
%
2
# create pg
row_process_group
=
None
col_process_group
=
None
row_to_ranks
=
{
0
:
[
0
,
1
],
1
:
[
2
,
3
]}
col_to_ranks
=
{
0
:
[
0
,
2
],
1
:
[
1
,
3
]}
for
idx
in
range
(
2
):
# row ranks
row_ranks
=
row_to_ranks
[
idx
]
row_pg
=
ProcessGroup
(
ranks
=
row_ranks
,
tp_degree
=
2
)
# col ranks
col_ranks
=
col_to_ranks
[
idx
]
col_pg
=
ProcessGroup
(
ranks
=
col_ranks
,
tp_degree
=
2
)
if
rank
in
row_ranks
:
row_process_group
=
row_pg
if
rank
in
col_ranks
:
col_process_group
=
col_pg
########################
# RRR x RS0 -> RRS0 #
########################
# w will be transposed in F.linear
x_replica
=
x
.
detach
().
clone
()
w_shard
=
torch
.
chunk
(
w
.
detach
().
clone
(),
chunks
=
2
,
dim
=
0
)[
row_id
]
b_shard
=
torch
.
chunk
(
b
.
detach
().
clone
(),
chunks
=
2
,
dim
=
0
)[
row_id
]
# adding sharding spec
x_replica
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
x
.
shape
,
dim_partition_dict
=
{})
w_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
w
.
shape
,
dim_partition_dict
=
{
0
:
[
0
]})
b_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
b
.
shape
,
dim_partition_dict
=
{
0
:
[
0
]})
# check sharding spec
assert
str
(
x_replica
.
sharding_spec
.
sharding_sequence
)
==
"[R, R, R]"
assert
str
(
w_shard
.
sharding_spec
.
sharding_sequence
)
==
"[S0, R]"
assert
str
(
b_shard
.
sharding_spec
.
sharding_sequence
)
==
"[S0]"
w_shard
.
pg_axis0
=
col_process_group
w_shard
.
pg_axis1
=
row_process_group
out_shard
=
F
.
linear
(
x_replica
,
w_shard
,
b_shard
)
assert
str
(
out_shard
.
sharding_spec
.
sharding_sequence
)
==
"[R, R, S0]"
# each row only has a mini-batch
expected_out_shard
=
torch
.
chunk
(
out
,
chunks
=
2
,
dim
=
2
)[
row_id
]
assert
torch
.
allclose
(
out_shard
,
expected_out_shard
)
########################
# S0RR x RS1 -> S0RS1 #
########################
# w will be transposed in F.linear
x_shard
=
torch
.
chunk
(
x
.
detach
().
clone
(),
chunks
=
2
,
dim
=
0
)[
row_id
]
w_shard
=
torch
.
chunk
(
w
.
detach
().
clone
(),
chunks
=
2
,
dim
=
0
)[
column_id
]
b_shard
=
torch
.
chunk
(
b
.
detach
().
clone
(),
chunks
=
2
,
dim
=
0
)[
column_id
]
# adding sharding spec
x_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
x
.
shape
,
dim_partition_dict
=
{
0
:
[
0
]})
w_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
w
.
shape
,
dim_partition_dict
=
{
0
:
[
1
]})
b_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
b
.
shape
,
dim_partition_dict
=
{
0
:
[
1
]})
# check sharding spec
assert
str
(
x_shard
.
sharding_spec
.
sharding_sequence
)
==
"[S0, R, R]"
assert
str
(
w_shard
.
sharding_spec
.
sharding_sequence
)
==
"[S1, R]"
assert
str
(
b_shard
.
sharding_spec
.
sharding_sequence
)
==
"[S1]"
w_shard
.
pg_axis0
=
col_process_group
w_shard
.
pg_axis1
=
row_process_group
out_shard
=
F
.
linear
(
x_shard
,
w_shard
,
b_shard
)
# each row only has a mini-batch
expected_out_shard
=
torch
.
chunk
(
out
,
chunks
=
2
,
dim
=
0
)[
row_id
]
expected_out_shard
=
torch
.
chunk
(
expected_out_shard
,
chunks
=
2
,
dim
=
2
)[
column_id
]
assert
torch
.
allclose
(
out_shard
,
expected_out_shard
)
########################
# S0RS1 x S1R -> S0RR #
########################
# w will be transposed in F.linear
x_shard
=
torch
.
chunk
(
x
.
clone
(),
chunks
=
2
,
dim
=
0
)[
row_id
]
x_shard
=
torch
.
chunk
(
x_shard
,
chunks
=
2
,
dim
=
2
)[
column_id
]
w_shard
=
torch
.
chunk
(
w
.
clone
(),
chunks
=
2
,
dim
=
1
)[
column_id
]
b_replica
=
b
.
clone
()
# adding sharding spec
x_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
x
.
shape
,
dim_partition_dict
=
{
0
:
[
0
],
2
:
[
1
]})
w_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
w
.
shape
,
dim_partition_dict
=
{
1
:
[
1
]})
b_replica
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
b
.
shape
,
dim_partition_dict
=
{})
# check sharding spec
assert
str
(
x_shard
.
sharding_spec
.
sharding_sequence
)
==
"[S0, R, S1]"
assert
str
(
w_shard
.
sharding_spec
.
sharding_sequence
)
==
"[R, S1]"
assert
str
(
b_replica
.
sharding_spec
.
sharding_sequence
)
==
"[R]"
w_shard
.
pg_axis0
=
col_process_group
w_shard
.
pg_axis1
=
row_process_group
out_shard
=
F
.
linear
(
x_shard
,
w_shard
,
b_replica
)
# each row only has a mini-batch
expected_out_shard
=
torch
.
chunk
(
out
,
chunks
=
2
,
dim
=
0
)[
row_id
]
assert
torch
.
allclose
(
out_shard
,
expected_out_shard
)
########################
# RRS0 x S0R -> RRR #
########################
# w will be transposed in F.linear
x_shard
=
torch
.
chunk
(
x
.
clone
(),
chunks
=
2
,
dim
=
2
)[
row_id
]
w_shard
=
torch
.
chunk
(
w
.
clone
(),
chunks
=
2
,
dim
=
1
)[
row_id
]
b_replica
=
b
.
clone
()
# adding sharding spec
x_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
x
.
shape
,
dim_partition_dict
=
{
2
:
[
0
]})
w_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
w
.
shape
,
dim_partition_dict
=
{
1
:
[
0
]})
b_replica
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
b
.
shape
,
dim_partition_dict
=
{})
# check sharding spec
assert
str
(
x_shard
.
sharding_spec
.
sharding_sequence
)
==
"[R, R, S0]"
assert
str
(
w_shard
.
sharding_spec
.
sharding_sequence
)
==
"[R, S0]"
assert
str
(
b_replica
.
sharding_spec
.
sharding_sequence
)
==
"[R]"
w_shard
.
pg_axis0
=
col_process_group
w_shard
.
pg_axis1
=
row_process_group
out_shard
=
F
.
linear
(
x_shard
,
w_shard
,
b_replica
)
# each row only has a mini-batch
expected_out_shard
=
out
assert
torch
.
allclose
(
out_shard
,
expected_out_shard
)
########################
# RS0S1 x S1R -> RS0R #
########################
# w will be transposed in F.linear
x_shard
=
torch
.
chunk
(
x
.
clone
(),
chunks
=
2
,
dim
=
1
)[
row_id
]
x_shard
=
torch
.
chunk
(
x_shard
,
chunks
=
2
,
dim
=
2
)[
column_id
]
w_shard
=
torch
.
chunk
(
w
.
clone
(),
chunks
=
2
,
dim
=
1
)[
column_id
]
b_replica
=
b
.
clone
()
# adding sharding spec
x_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
x
.
shape
,
dim_partition_dict
=
{
1
:
[
0
],
2
:
[
1
]})
w_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
w
.
shape
,
dim_partition_dict
=
{
1
:
[
1
]})
b_replica
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
b
.
shape
,
dim_partition_dict
=
{})
# check sharding spec
assert
str
(
x_shard
.
sharding_spec
.
sharding_sequence
)
==
"[R, S0, S1]"
assert
str
(
w_shard
.
sharding_spec
.
sharding_sequence
)
==
"[R, S1]"
assert
str
(
b_replica
.
sharding_spec
.
sharding_sequence
)
==
"[R]"
w_shard
.
pg_axis0
=
col_process_group
w_shard
.
pg_axis1
=
row_process_group
out_shard
=
F
.
linear
(
x_shard
,
w_shard
,
b_replica
)
# each row only has a mini-batch
expected_out_shard
=
torch
.
chunk
(
out
,
chunks
=
2
,
dim
=
1
)[
row_id
]
assert
torch
.
allclose
(
out_shard
,
expected_out_shard
)
########################
# RRS0 x S0S1 -> RRS1 #
########################
# w will be transposed in F.linear
x_shard
=
torch
.
chunk
(
x
.
clone
(),
chunks
=
2
,
dim
=
2
)[
row_id
]
w_shard
=
torch
.
chunk
(
w
.
clone
(),
chunks
=
2
,
dim
=
1
)[
row_id
]
w_shard
=
torch
.
chunk
(
w_shard
,
chunks
=
2
,
dim
=
0
)[
column_id
]
b_shard
=
torch
.
chunk
(
b
.
clone
(),
chunks
=
2
,
dim
=
0
)[
column_id
]
# adding sharding spec
x_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
x
.
shape
,
dim_partition_dict
=
{
2
:
[
0
]})
w_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
w
.
shape
,
dim_partition_dict
=
{
0
:
[
1
],
1
:
[
0
]})
b_shard
.
sharding_spec
=
ShardingSpec
(
device_mesh
,
b
.
shape
,
dim_partition_dict
=
{
0
:
[
1
]})
# check sharding spec
assert
str
(
x_shard
.
sharding_spec
.
sharding_sequence
)
==
"[R, R, S0]"
assert
str
(
w_shard
.
sharding_spec
.
sharding_sequence
)
==
"[S1, S0]"
assert
str
(
b_shard
.
sharding_spec
.
sharding_sequence
)
==
"[S1]"
w_shard
.
pg_axis0
=
col_process_group
w_shard
.
pg_axis1
=
row_process_group
out_shard
=
F
.
linear
(
x_shard
,
w_shard
,
b_shard
)
# each row only has a mini-batch
expected_out_shard
=
torch
.
chunk
(
out
,
chunks
=
2
,
dim
=
2
)[
column_id
]
assert
torch
.
allclose
(
out_shard
,
expected_out_shard
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
4
])
@
rerun_if_address_is_in_use
()
def
test_sharded_mlp
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
'__main__'
:
test_sharded_mlp
(
4
)
tests/test_tensor/test_tp_with_zero.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ColoInitContext
,
GeminiAdamOptimizer
,
GeminiDDP
,
ZeroDDP
from
colossalai.zero.gemini
import
search_chunk_configuration
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
set_seed
,
tensor_shard_equal
from
tests.test_tensor.model.test_gpt2
import
init_megatron_spec
def
check_param
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
):
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
torch_dict
=
torch_model
.
state_dict
()
for
key
,
value
in
torch_dict
.
items
():
# key is 'module.model.PARAMETER', so we truncate it
key
=
key
[
7
:]
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert
tensor_shard_equal
(
value
,
temp_zero_value
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()),
\
"parameter '{}' has problem."
.
format
(
key
)
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
):
optimizer
.
zero_grad
()
logits
=
model
(
input_ids
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
optimizer
.
backward
(
loss
)
return
logits
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
for
n
,
p
in
model
.
named_parameters
():
p
.
set_process_group
(
pg
)
if
'weight'
in
n
and
'ln'
not
in
n
:
p
.
set_tensor_spec
(
*
spec
)
def
init_1d_col_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
for
n
,
p
in
model
.
named_parameters
():
p
.
set_process_group
(
pg
)
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
p
.
set_tensor_spec
(
*
spec
)
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
def
run_gpt
(
placement_policy
,
tp_init_spec_func
=
None
):
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model
.
cuda
()
torch_model
=
model_builder
().
cuda
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
# world size, dp = 2, tp =2, construct a hybrid parallelism.
if
world_size
==
4
:
pg
=
ProcessGroup
(
tp_degree
=
2
)
else
:
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
if
tp_init_spec_func
:
tp_init_spec_func
(
model
,
pg
)
dp_world_size
=
pg
.
dp_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
dp_world_size
][
'chunk_size'
]
=
5000
config_dict
[
dp_world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
model
=
GeminiDDP
(
model
,
init_device
,
placement_policy
,
True
,
False
)
# The same as the following 3 lines
# chunk_manager = ChunkManager(config_dict, init_device=init_device)
# gemini_manager = GeminiManager(placement_policy, chunk_manager)
# model = ZeroDDP(model, gemini_manager, pin_memory=True)
zero_optim
=
GeminiAdamOptimizer
(
model
,
lr
=
1e-3
,
initial_scale
=
1
)
# The same as the following 2 lines
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
# zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
check_param
(
model
,
torch_model
,
pg
)
model
.
eval
()
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
input_ids_colo
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids_colo
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
zero_optim
.
step
()
torch_optim
.
step
()
check_param
(
model
,
torch_model
,
pg
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
if
world_size
==
4
:
run_gpt
(
tp_init_spec_func
=
init_megatron_spec
)
else
:
run_gpt
(
tp_init_spec_func
=
init_1d_col_spec
)
run_gpt
(
tp_init_spec_func
=
init_1d_row_spec
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_gpt
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
'__main__'
:
test_gpt
(
4
)
tests/test_utils/test_colo_checkpoint.py
deleted
100644 → 0
View file @
e79b1e80
import
os
import
shutil
from
copy
import
deepcopy
import
pytest
import
torch
import
torch.distributed
as
dist
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
,
MultiplicativeLR
import
colossalai
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.tensor
import
ColoTensor
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.checkpoint
import
load_checkpoint
,
save_checkpoint
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ColoInitContext
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
init_1d_row_linear
(
weight
:
ColoTensor
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_col_linear
(
weight
,
pg
):
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_row_embedding
(
weight
,
pg
):
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_col_embedding
(
weight
,
pg
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_row_for_linear_weight_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
for
name
,
p
in
model
.
named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
continue
if
'embed'
in
name
and
'weight'
in
name
:
init_1d_col_embedding
(
p
,
pg
)
if
'proj1'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
init_1d_col_linear
(
p
,
pg
)
if
'proj2'
in
name
and
'weight'
in
name
:
init_1d_row_linear
(
p
,
pg
)
if
'classifier'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
init_1d_col_linear
(
p
,
pg
)
def
check_param_equal
(
model
,
torch_model
):
for
(
n
,
p
),
(
tn
,
tp
)
in
zip
(
model
.
named_parameters
(),
torch_model
.
named_parameters
()):
assert
torch
.
all
(
p
.
data
==
tp
.
data
),
"{} went wrong.
\n
{} vs {}
\n
{}"
.
format
(
n
,
p
,
tp
,
p
.
shape
)
def
remove
(
path
):
""" param <path> could either be relative or absolute. """
if
os
.
path
.
isfile
(
path
)
or
os
.
path
.
islink
(
path
):
os
.
remove
(
path
)
elif
os
.
path
.
isdir
(
path
):
shutil
.
rmtree
(
path
)
else
:
raise
ValueError
(
"file {} is not a file or dir."
.
format
(
path
))
def
compare_optims
(
optim1
,
optim2
):
state1
=
optim1
.
state_dict
()[
'state'
]
state2
=
optim2
.
state_dict
()[
'state'
]
for
k
,
p1
in
state1
.
items
():
if
k
not
in
state2
:
continue
p2
=
state2
[
k
]
for
n
,
t1
in
p1
.
items
():
if
n
not
in
p2
:
continue
t2
=
p2
[
n
]
if
isinstance
(
t1
,
ColoTensor
):
assert
isinstance
(
t2
,
ColoTensor
)
assert
torch
.
allclose
(
t1
,
t2
,
rtol
=
0
,
atol
=
0
)
def
_run_checkpoint
(
model_name
,
init_spec_func
,
use_ddp
,
use_mp_reload
,
test_scheduler
,
pg
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
# set_seed(1)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
True
)
if
use_mp_reload
:
if
'bert'
==
model_name
:
for
name
,
p
in
model
.
named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
continue
# num_class = type_vocab_size = 2 | (8, 2)
if
'classifier'
in
name
and
'weight'
in
name
:
init_1d_row_linear
(
p
,
pg
)
# num_class = vocab_size = 30524 | (30524, 8)
elif
'word_embeddings'
in
name
and
'weight'
in
name
:
init_1d_row_embedding
(
p
,
pg
)
# num_class = seq_len = 512 | (512, 8)
elif
'position_embeddings'
in
name
and
'weight'
in
name
:
init_1d_row_embedding
(
p
,
pg
)
# num_class = type_vocab_size = 2 | (2, 8)
elif
'token_type_embeddings'
in
name
and
'weight'
in
name
:
init_1d_col_embedding
(
p
,
pg
)
elif
p
.
process_group
.
tp_world_size
()
==
1
:
p
.
set_process_group
(
pg
)
elif
"simple_net"
==
model_name
:
init_spec_func
(
model
,
pg
)
model_reload
=
deepcopy
(
model
)
model
=
model
.
cuda
()
model
.
eval
()
model_reload
=
model_reload
.
cuda
()
model_reload
.
eval
()
opt_class
=
torch
.
optim
.
Adam
colo_optimizer
=
ColossalaiOptimizer
(
opt_class
(
model
.
parameters
(),
lr
=
0.1
))
colo_optimizer_reload
=
ColossalaiOptimizer
(
opt_class
(
model_reload
.
parameters
(),
lr
=
0.1
))
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
# Zero grad
colo_optimizer
.
zero_grad
()
colo_optimizer_reload
.
zero_grad
()
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
dist
.
broadcast
(
data
,
pg
.
tp_rank_list
()[
0
],
pg
.
tp_process_group
())
dist
.
broadcast
(
label
,
pg
.
tp_rank_list
()[
0
],
pg
.
tp_process_group
())
# Bcast rank0 data to all processes
if
criterion
:
output
=
model
(
data
)
output_reload
=
model_reload
(
data
)
loss
=
criterion
(
output
,
label
)
loss_reload
=
criterion
(
output_reload
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss_reload
=
model_reload
(
data
,
label
)
loss
.
backward
()
loss_reload
.
backward
()
colo_optimizer
.
step
()
colo_optimizer_reload
.
step
()
if
i
>
2
:
break
if
not
os
.
path
.
isdir
(
'./checkpoint'
)
and
rank
==
0
:
os
.
mkdir
(
'./checkpoint'
)
dist
.
barrier
()
save_checkpoint
(
'./checkpoint'
,
0
,
model
,
colo_optimizer
,
None
)
load_checkpoint
(
'./checkpoint'
,
0
,
model_reload
,
colo_optimizer_reload
,
None
)
check_param_equal
(
model
,
model_reload
)
compare_optims
(
colo_optimizer
,
colo_optimizer_reload
)
if
rank
==
0
:
remove
(
'./checkpoint'
)
dist
.
barrier
()
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
for
model_name
in
[
'bert'
]:
_run_checkpoint
(
model_name
,
init_1d_row_for_linear_weight_spec
,
use_ddp
,
use_mp_reload
,
test_scheduler
=
test_scheduler
,
pg
=
pg
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
])
@
pytest
.
mark
.
parametrize
(
'use_mp_reload'
,
[
True
,
False
])
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
@
rerun_if_address_is_in_use
()
def
test_checkpoint
(
world_size
,
use_ddp
,
use_mp_reload
,
test_scheduler
=
None
):
spawn
(
run_dist
,
world_size
,
use_ddp
=
use_ddp
,
use_mp_reload
=
use_mp_reload
,
test_scheduler
=
test_scheduler
)
if
__name__
==
'__main__'
:
test_checkpoint
(
2
,
use_ddp
=
False
,
use_mp_reload
=
True
,
test_scheduler
=
"torch_cosine"
)
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment