Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d49708ae
Unverified
Commit
d49708ae
authored
Jul 15, 2022
by
HELSON
Committed by
GitHub
Jul 15, 2022
Browse files
[hotfix] fix ddp for unit test test_gpt2 (#1326)
parent
250be4d3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
85 additions
and
68 deletions
+85
-68
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+17
-12
tests/test_tensor/test_gpt2.py
tests/test_tensor/test_gpt2.py
+21
-19
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+11
-7
tests/test_tensor/test_zero_optim.py
tests/test_tensor/test_zero_optim.py
+36
-30
No files found.
colossalai/tensor/process_group.py
View file @
d49708ae
...
@@ -21,7 +21,7 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
...
@@ -21,7 +21,7 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
if
pg_key
not
in
self
.
dict
:
if
pg_key
not
in
self
.
dict
:
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
self
.
logger
.
info
(
f
'NCCL initialize
TP g
roup on
{
rank_list
}
'
,
ranks
=
[
0
])
self
.
logger
.
info
(
f
'NCCL initialize
ProcessG
roup on
{
rank_list
}
'
,
ranks
=
[
0
])
self
.
dict
[
pg_key
]
=
torch
.
distributed
.
new_group
(
ranks
=
rank_list
,
backend
=
backend
)
self
.
dict
[
pg_key
]
=
torch
.
distributed
.
new_group
(
ranks
=
rank_list
,
backend
=
backend
)
return
self
.
dict
[
pg_key
]
return
self
.
dict
[
pg_key
]
...
@@ -63,7 +63,6 @@ class ProcessGroup:
...
@@ -63,7 +63,6 @@ class ProcessGroup:
self
.
_rank_list
=
ranks
self
.
_rank_list
=
ranks
self
.
_rank_list
.
sort
()
# ensure that the list is in order
self
.
_rank_list
.
sort
()
# ensure that the list is in order
self
.
_rank_idx
=
self
.
_rank_list
.
index
(
self
.
_rank
)
self
.
_world_size
=
len
(
self
.
_rank_list
)
self
.
_world_size
=
len
(
self
.
_rank_list
)
if
dp_degree
is
None
and
tp_degree
is
None
:
if
dp_degree
is
None
and
tp_degree
is
None
:
...
@@ -84,19 +83,22 @@ class ProcessGroup:
...
@@ -84,19 +83,22 @@ class ProcessGroup:
f
"the world size
{
self
.
_world_size
}
should equals to the product of DP degree
{
self
.
_dp_degree
}
"
\
f
"the world size
{
self
.
_world_size
}
should equals to the product of DP degree
{
self
.
_dp_degree
}
"
\
f
"and TP degree
{
self
.
_tp_degree
}
"
f
"and TP degree
{
self
.
_tp_degree
}
"
self
.
_tp_rank_list
=
[]
self
.
_tp_rank_list
=
None
self
.
_dp_rank_list
=
[]
self
.
_dp_rank_list
=
None
for
idx
,
rank_id
in
enumerate
(
self
.
_rank_list
):
for
i
in
range
(
self
.
_dp_degree
):
# idx and self._rank_idx in the same tp group
i_tp_list
=
[
self
.
_rank_list
[
i
*
self
.
_tp_degree
+
j
]
for
j
in
range
(
self
.
_tp_degree
)]
if
idx
%
self
.
_tp_degree
==
self
.
_rank_idx
%
self
.
_tp_degree
:
PYTORCHPGDICT_
.
get
(
i_tp_list
,
'nccl'
)
self
.
_dp_rank_list
.
append
(
rank_id
)
if
self
.
_rank
in
i_tp_list
:
if
idx
//
self
.
_tp_degree
==
self
.
_rank_idx
//
self
.
_tp_degree
:
self
.
_tp_rank_list
=
i_tp_list
self
.
_tp_rank_list
.
append
(
rank_id
)
for
j
in
range
(
self
.
_tp_degree
):
j_dp_list
=
[
self
.
_rank_list
[
i
*
self
.
_tp_degree
+
j
]
for
i
in
range
(
self
.
_dp_degree
)]
PYTORCHPGDICT_
.
get
(
j_dp_list
,
'nccl'
)
if
self
.
_rank
in
j_dp_list
:
self
.
_dp_rank_list
=
j_dp_list
self
.
_has_cpu_groups
=
False
self
.
_has_cpu_groups
=
False
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'nccl'
)
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'nccl'
)
self
.
is_init
=
True
self
.
is_init
=
True
def
set_cpu_groups
(
self
):
def
set_cpu_groups
(
self
):
...
@@ -106,6 +108,7 @@ class ProcessGroup:
...
@@ -106,6 +108,7 @@ class ProcessGroup:
f
'
{
self
.
_rank
}
Gloo initialize TP group on
{
self
.
_tp_rank_list
}
, DP group on
{
self
.
_dp_rank_list
}
'
)
f
'
{
self
.
_rank
}
Gloo initialize TP group on
{
self
.
_tp_rank_list
}
, DP group on
{
self
.
_dp_rank_list
}
'
)
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'gloo'
)
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'gloo'
)
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'gloo'
)
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'gloo'
)
self
.
_has_cpu_groups
=
True
@
property
@
property
def
has_cpu_groups
(
self
):
def
has_cpu_groups
(
self
):
...
@@ -162,7 +165,9 @@ class ProcessGroup:
...
@@ -162,7 +165,9 @@ class ProcessGroup:
return
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'nccl'
)
return
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'nccl'
)
def
cpu_dp_process_group
(
self
):
def
cpu_dp_process_group
(
self
):
assert
self
.
_has_cpu_groups
return
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'gloo'
)
return
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'gloo'
)
def
cpu_tp_process_group
(
self
):
def
cpu_tp_process_group
(
self
):
assert
self
.
_has_cpu_groups
return
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'gloo'
)
return
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'gloo'
)
tests/test_tensor/test_gpt2.py
View file @
d49708ae
...
@@ -12,16 +12,13 @@ from colossalai.testing import rerun_if_address_is_in_use
...
@@ -12,16 +12,13 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ShardSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
ProcessGroup
,
ColoTensor
,
ColoTensorSpec
from
colossalai.tensor
import
ShardSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ColoTensor
,
ColoTensorSpec
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
tensor_spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
tensor_spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
for
n
,
p
in
model
.
named_parameters
():
for
n
,
p
in
model
.
named_parameters
():
p
.
set_process_group
(
pg
)
p
.
set_process_group
(
pg
)
if
'weight'
in
n
and
'ln'
not
in
n
:
if
'weight'
in
n
and
'ln'
not
in
n
:
...
@@ -50,33 +47,39 @@ def check_grad_equal(model, torch_model, pg: ProcessGroup):
...
@@ -50,33 +47,39 @@ def check_grad_equal(model, torch_model, pg: ProcessGroup):
def
run_gpt
(
init_spec_func
,
use_ddp
):
def
run_gpt
(
init_spec_func
,
use_ddp
):
set_seed
(
13234
)
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
# build a PG with TP and DP hybrid
pg
=
ProcessGroup
(
dp_degree
=
(
2
if
(
use_ddp
and
world_size
>=
2
)
else
1
))
pg
=
ProcessGroup
(
dp_degree
=
(
2
if
(
use_ddp
and
world_size
>=
2
)
else
1
))
# set seed make processes of the same tp group use the same seed
# set_seed(pg.tp_local_rank())
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
# make sure torch_model and model has the same parameter values
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model_builder
()
model
=
model
.
cuda
()
model
=
model
.
cuda
()
torch_model
=
model_builder
().
cuda
()
torch_model
=
model_builder
().
cuda
()
if
use_ddp
:
# torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg)
# torch.distributed.barrier()
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
gpc
.
get_global_rank
()],
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
if
use_ddp
:
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
model
=
ColoDDP
(
model
,
process_group
=
pg
)
model
=
ColoDDP
(
model
,
process_group
=
pg
)
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
)
torch_p
.
data
.
copy_
(
p
)
init_spec_func
(
model
,
pg
)
init_spec_func
(
model
,
pg
)
check_param_equal
(
model
,
torch_model
,
pg
)
check_param_equal
(
model
,
torch_model
,
pg
)
model
.
train
()
torch_model
.
train
()
torch
.
distributed
.
barrier
()
# close the dropout in eval mode
model
.
eval
()
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
torch
.
distributed
.
barrier
()
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
colo_input
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
colo_input
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
logits
=
model
(
colo_input
,
attn_mask
)
logits
=
model
(
colo_input
,
attn_mask
)
...
@@ -92,21 +95,20 @@ def run_gpt(init_spec_func, use_ddp):
...
@@ -92,21 +95,20 @@ def run_gpt(init_spec_func, use_ddp):
check_grad_equal
(
model
,
torch_model
,
pg
)
check_grad_equal
(
model
,
torch_model
,
pg
)
if
i
>
0
:
if
i
>
0
:
break
break
set_seed
(
313
)
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
):
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
):
if
use_ddp
and
world_size
==
1
:
if
use_ddp
and
world_size
==
1
:
return
return
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_gpt
(
init_1d_row_spec
,
use_ddp
)
run_gpt
(
init_1d_row_spec
,
use_ddp
)
run_gpt
(
init_1d_col_spec
,
use_ddp
)
run_gpt
(
init_1d_col_spec
,
use_ddp
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
,
True
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_gpt
(
world_size
,
use_ddp
):
def
test_gpt
(
world_size
,
use_ddp
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
)
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
)
...
@@ -114,4 +116,4 @@ def test_gpt(world_size, use_ddp):
...
@@ -114,4 +116,4 @@ def test_gpt(world_size, use_ddp):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_gpt
(
4
,
Fals
e
)
test_gpt
(
4
,
use_ddp
=
Tru
e
)
tests/test_tensor/test_model.py
View file @
d49708ae
...
@@ -77,9 +77,9 @@ def run_1d_hybrid_tp(model_name):
...
@@ -77,9 +77,9 @@ def run_1d_hybrid_tp(model_name):
split_param_row_tp1d
(
p
,
pg
)
split_param_row_tp1d
(
p
,
pg
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
model
.
train
()
model
.
eval
()
if
rank
==
0
:
if
rank
==
0
:
model_torch
.
train
()
model_torch
.
eval
()
colo_optimizer
=
ColossalaiOptimizer
(
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
))
colo_optimizer
=
ColossalaiOptimizer
(
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
))
...
@@ -89,6 +89,7 @@ def run_1d_hybrid_tp(model_name):
...
@@ -89,6 +89,7 @@ def run_1d_hybrid_tp(model_name):
colo_optimizer
.
zero_grad
()
colo_optimizer
.
zero_grad
()
if
rank
==
0
:
if
rank
==
0
:
optimizer_torch
.
zero_grad
()
optimizer_torch
.
zero_grad
()
torch
.
distributed
.
barrier
()
data
=
data
.
to
(
get_current_device
())
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
...
@@ -113,6 +114,7 @@ def run_1d_hybrid_tp(model_name):
...
@@ -113,6 +114,7 @@ def run_1d_hybrid_tp(model_name):
output_torch
=
model_torch
(
data
,
label
)
output_torch
=
model_torch
(
data
,
label
)
loss_torch
=
output_torch
loss_torch
=
output_torch
assert
torch
.
allclose
(
loss
,
loss_torch
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
loss
,
loss_torch
,
rtol
=
1e-2
)
torch
.
distributed
.
barrier
()
loss
.
backward
()
loss
.
backward
()
colo_optimizer
.
step
()
colo_optimizer
.
step
()
...
@@ -125,7 +127,7 @@ def run_1d_hybrid_tp(model_name):
...
@@ -125,7 +127,7 @@ def run_1d_hybrid_tp(model_name):
# check param
# check param
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
model_torch
.
parameters
()):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
model_torch
.
parameters
()):
assert
tensor_shard_equal
(
torch_p
,
p
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
assert
tensor_shard_equal
(
torch_p
,
p
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
torch
.
distributed
.
barrier
()
if
i
>
5
:
if
i
>
5
:
break
break
...
@@ -248,14 +250,15 @@ def run_1d_row_tp(model_name: str):
...
@@ -248,14 +250,15 @@ def run_1d_row_tp(model_name: str):
else
:
else
:
output_torch
=
model_torch
(
data
,
label
)
output_torch
=
model_torch
(
data
,
label
)
loss_torch
=
output_torch
loss_torch
=
output_torch
if
rank
==
0
:
assert
torch
.
allclose
(
loss
,
loss_torch
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
loss
,
loss_torch
,
rtol
=
1e-2
)
torch
.
distributed
.
barrier
()
loss
.
backward
()
loss
.
backward
()
if
rank
==
0
:
if
rank
==
0
:
loss_torch
.
backward
()
loss_torch
.
backward
()
torch
.
distributed
.
barrier
()
if
i
>
5
:
if
i
>
5
:
break
break
...
@@ -296,8 +299,9 @@ def _run_pretrain_load():
...
@@ -296,8 +299,9 @@ def _run_pretrain_load():
def
run_model_dist
(
rank
,
world_size
,
port
):
def
run_model_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
for
name
in
[
'bert'
,
'simple_net'
]:
# Comment below test for speed consideration
run_1d_row_tp
(
name
)
# for name in ['bert', 'simple_net']:
# run_1d_row_tp(name)
for
name
in
[
'bert'
,
'simple_net'
]:
for
name
in
[
'bert'
,
'simple_net'
]:
run_1d_hybrid_tp
(
name
)
run_1d_hybrid_tp
(
name
)
...
...
tests/test_tensor/test_zero_optim.py
View file @
d49708ae
...
@@ -17,22 +17,25 @@ from colossalai.zero import ZeroOptimizer
...
@@ -17,22 +17,25 @@ from colossalai.zero import ZeroOptimizer
from
colossalai.testing
import
parameterize
from
colossalai.testing
import
parameterize
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.tensor
import
ColoTensorSpec
,
ShardSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
ProcessGroup
from
colossalai.tensor
import
ColoTensorSpec
,
ShardSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ColoTensor
def
check_param_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
def
check_param_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
for
(
n
,
p
),
(
tn
,
tp
)
in
zip
(
model
.
named_
parameters
(),
torch_model
.
named_
parameters
()):
if
p
.
storage
().
size
()
>
0
:
if
p
.
storage
().
size
()
>
0
:
assert
p
.
dtype
==
torch
.
half
assert
p
.
dtype
==
torch
.
float16
assert
tensor_shard_equal
(
t
orch_
p
.
to
(
dtype
=
p
.
dtype
,
device
=
p
.
device
),
p
,
pg
.
tp_local_rank
(),
assert
tensor_shard_equal
(
tp
.
to
(
dtype
=
p
.
dtype
,
device
=
p
.
device
),
p
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()),
f
'
{
t
orch_p
}
vs
{
p
}
'
pg
.
tp_world_size
()),
f
'
{
t
p
}
vs
{
p
}
\n
{
n
}
:
\n\t
{
tp
.
shape
}
vs
{
p
.
shape
}
'
def
check_grad_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
def
check_grad_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
for
(
n
,
p
),
(
tn
,
tp
)
in
zip
(
model
.
named_
parameters
(),
torch_model
.
named_
parameters
()):
if
p
.
grad
is
not
None
:
if
p
.
grad
is
not
None
:
assert
tensor_shard_equal
(
torch_p
.
grad
.
to
(
dtype
=
p
.
grad
.
dtype
,
device
=
p
.
grad
.
device
),
p
.
grad
,
torch
.
distributed
.
barrier
()
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
print
(
torch
.
distributed
.
get_rank
(),
p
.
grad
)
assert
tensor_shard_equal
(
tp
.
grad
.
to
(
dtype
=
p
.
grad
.
dtype
,
device
=
p
.
grad
.
device
),
p
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()),
\
f
'
{
tp
.
grad
}
vs
{
p
.
grad
}
\n
{
n
}
:
\n\t
{
tp
.
grad
.
shape
}
vs
{
p
.
grad
.
shape
}
in
{
pg
.
rank
()
}
'
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
...
@@ -46,23 +49,23 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
...
@@ -46,23 +49,23 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
for
n
,
p
in
model
.
named_parameters
():
p
.
set_process_group
(
pg
)
if
'weight'
in
n
and
'ln'
not
in
n
:
if
'weight'
in
n
and
'ln'
not
in
n
:
p
.
set_tensor_spec
(
*
spec
)
p
.
set_tensor_spec
(
*
spec
)
def
init_1d_col_spec
(
model
,
pg
:
ProcessGroup
):
def
init_1d_col_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
for
n
,
p
in
model
.
named_parameters
():
p
.
set_process_group
(
pg
)
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
p
.
set_tensor_spec
(
*
spec
)
p
.
set_tensor_spec
(
*
spec
)
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_chunk'
,
[
False
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
@
parameterize
(
'use_zero'
,
[
False
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
])
def
run_gpt
(
use_chunk
,
use_zero
,
placement_policy
,
tp_init_spec_func
=
None
):
def
run_gpt
(
use_chunk
,
use_zero
,
placement_policy
,
tp_init_spec_func
=
None
):
set_seed
(
42
)
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
...
@@ -70,10 +73,11 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
...
@@ -70,10 +73,11 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model_builder
()
model
=
model
.
cuda
()
.
half
()
model
=
model
.
cuda
()
torch_model
=
model_builder
().
cuda
()
torch_model
=
model_builder
().
cuda
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
)
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
...
@@ -93,23 +97,25 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
...
@@ -93,23 +97,25 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pg
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pg
)
optim
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optim
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optim
=
ZeroOptimizer
(
optim
,
model
,
initial_scale
=
32
)
optim
=
ZeroOptimizer
(
optim
,
model
,
initial_scale
=
1
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
32
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
# print(chunk_manager)
# print(chunk_manager)
check_param_equal
(
model
,
torch_model
,
pg
)
check_param_equal
(
model
,
torch_model
,
pg
)
model
.
train
()
torch_model
.
train
()
model
.
eval
()
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
set_seed
(
pg
.
dp_local_rank
())
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
if
i
>
2
:
break
break
input_ids_colo
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
logits
=
run_fwd_bwd
(
model
,
criterion
,
optim
,
input_ids
,
attn_mask
)
logits
=
run_fwd_bwd
(
model
,
criterion
,
optim
,
input_ids
_colo
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
assert
tensor_equal
(
logits
,
torch_logits
)
assert
tensor_equal
(
logits
,
torch_logits
)
check_grad_equal
(
model
,
torch_model
,
pg
)
check_grad_equal
(
model
,
torch_model
,
pg
)
...
@@ -123,13 +129,13 @@ def run_dist(rank, world_size, port):
...
@@ -123,13 +129,13 @@ def run_dist(rank, world_size, port):
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
if
world_size
==
4
:
if
world_size
==
4
:
run_gpt
(
tp_init_spec_func
=
init_1d_col_spec
)
run_gpt
(
tp_init_spec_func
=
init_1d_col_spec
)
run_gpt
(
tp_init_spec_func
=
init_1d_row_spec
)
#
run_gpt(tp_init_spec_func=init_1d_row_spec)
else
:
else
:
run_gpt
()
run_gpt
(
tp_init_spec_func
=
init_1d_col_spec
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
skip
(
"
under developmen
t"
)
@
pytest
.
mark
.
skip
(
"
buggy tes
t"
)
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_gpt
(
world_size
):
def
test_gpt
(
world_size
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment