Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
36086927
Unverified
Commit
36086927
authored
Jul 14, 2022
by
HELSON
Committed by
GitHub
Jul 14, 2022
Browse files
[hotfix] fix ColoTensor GPT2 unitest (#1309)
parent
3ef3791a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
15 deletions
+19
-15
tests/test_tensor/test_gpt2.py
tests/test_tensor/test_gpt2.py
+19
-15
No files found.
tests/test_tensor/test_gpt2.py
View file @
36086927
...
...
@@ -12,7 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ShardSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
ProcessGroup
from
colossalai.tensor
import
ShardSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
ProcessGroup
,
ColoTensor
,
ColoTensorSpec
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
...
...
@@ -21,18 +21,20 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
tensor_spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
and
'ln'
not
in
n
:
p
.
set_tensor_spec
(
*
tensor_spec
)
for
n
,
p
in
model
.
named_parameters
():
p
.
set_process_group
(
pg
)
if
'weight'
in
n
and
'ln'
not
in
n
:
p
.
set_tensor_spec
(
*
tensor_spec
)
def
init_1d_col_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
p
.
set_tensor_spec
(
*
spec
)
for
n
,
p
in
model
.
named_parameters
():
p
.
set_process_group
(
pg
)
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
p
.
set_tensor_spec
(
*
spec
)
def
check_param_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
...
...
@@ -48,6 +50,7 @@ def check_grad_equal(model, torch_model, pg: ProcessGroup):
def
run_gpt
(
init_spec_func
,
use_ddp
):
set_seed
(
13234
)
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
dp_degree
=
(
2
if
(
use_ddp
and
world_size
>=
2
)
else
1
))
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
...
...
@@ -67,14 +70,16 @@ def run_gpt(init_spec_func, use_ddp):
model
=
ColoDDP
(
model
,
process_group
=
pg
)
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
)
init_spec_func
(
model
,
pg
)
check_param_equal
(
model
,
torch_model
,
pg
)
model
.
train
()
torch_model
.
train
()
set_seed
(
pg
.
tp_local_rank
()
)
torch
.
distributed
.
barrier
()
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
logits
=
model
(
input_ids
,
attn_mask
)
colo_input
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
logits
=
model
(
colo_input
,
attn_mask
)
torch_logits
=
torch_model
(
input_ids
,
attn_mask
)
assert
tensor_equal
(
torch_logits
,
logits
),
f
"
{
torch_logits
-
logits
}
"
loss
=
criterion
(
logits
,
input_ids
)
...
...
@@ -95,14 +100,13 @@ def run_dist(rank, world_size, port, use_ddp):
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
#
run_gpt(init_1d_row_spec, use_ddp)
run_gpt
(
init_1d_row_spec
,
use_ddp
)
run_gpt
(
init_1d_col_spec
,
use_ddp
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
skip
(
"under development"
)
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
])
@
rerun_if_address_is_in_use
()
def
test_gpt
(
world_size
,
use_ddp
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
)
...
...
@@ -110,4 +114,4 @@ def test_gpt(world_size, use_ddp):
if
__name__
==
'__main__'
:
test_gpt
(
4
,
Tru
e
)
test_gpt
(
4
,
Fals
e
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment