Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cefc29ff
Unverified
Commit
cefc29ff
authored
May 21, 2022
by
ver217
Committed by
GitHub
May 21, 2022
Browse files
[tensor] impl ColoDDP for ColoTensor (#1009)
* impl ColoDDP for ColoTensor * polish code
parent
ae7c3381
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
102 additions
and
9 deletions
+102
-9
colossalai/nn/parallel.py
colossalai/nn/parallel.py
+78
-0
tests/test_tensor/test_gpt.py
tests/test_tensor/test_gpt.py
+24
-9
No files found.
colossalai/nn/parallel.py
0 → 100644
View file @
cefc29ff
import
torch
import
torch.distributed
as
dist
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
functools
import
partial
__all__
=
[
'ColoDDP'
]
def
free_storage
(
data
:
torch
.
Tensor
)
->
None
:
"""Free underlying storage of a Tensor."""
if
data
.
storage
().
size
()
>
0
:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage.
assert
data
.
storage_offset
()
==
0
data
.
storage
().
resize_
(
0
)
class
ColoDDP
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
)
->
None
:
super
().
__init__
()
self
.
module
=
module
self
.
comm_stream
:
torch
.
cuda
.
Stream
=
torch
.
cuda
.
Stream
()
self
.
dp_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
for
p
in
module
.
parameters
():
if
p
.
requires_grad
:
p
.
register_hook
(
partial
(
self
.
grad_handle
,
p
))
def
parameters
(
self
,
recurse
:
bool
=
True
):
return
self
.
module
.
parameters
(
recurse
)
def
named_parameters
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
):
return
self
.
module
.
named_parameters
(
prefix
,
recurse
)
def
forward
(
self
,
*
args
,
**
kwargs
):
self
.
module
.
zero_grad
(
set_to_none
=
True
)
return
self
.
module
(
*
args
,
**
kwargs
)
def
backward
(
self
,
loss
:
torch
.
Tensor
):
loss
.
backward
()
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
comm_stream
)
for
p
in
self
.
module
.
parameters
():
p
.
grad
=
p
.
_saved_grad
def
grad_handle
(
self
,
p
,
grad
):
empty_grad
=
torch
.
empty_like
(
grad
)
free_storage
(
empty_grad
)
if
self
.
dp_world_size
>
1
:
grad
=
grad
/
self
.
dp_world_size
self
.
comm_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
dist
.
all_reduce
(
grad
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
ColoDDP
.
_save_grad
(
p
,
grad
)
grad
.
record_stream
(
self
.
comm_stream
)
else
:
ColoDDP
.
_save_grad
(
p
,
grad
)
return
empty_grad
@
staticmethod
def
_save_grad
(
p
,
grad
):
if
hasattr
(
p
,
'_saved_grad'
):
p
.
_saved_grad
.
add_
(
grad
)
else
:
p
.
_saved_grad
=
grad
def
zero_grad
(
self
,
set_to_none
:
bool
=
False
)
->
None
:
self
.
module
.
zero_grad
(
set_to_none
=
True
)
for
p
in
self
.
module
.
parameters
():
if
getattr
(
p
,
'_saved_grad'
,
None
)
is
not
None
:
if
set_to_none
:
p
.
_saved_grad
=
None
else
:
if
p
.
_saved_grad
.
grad_fn
is
not
None
:
p
.
_saved_grad
.
detach_
()
else
:
p
.
_saved_grad
.
requires_grad_
(
False
)
p
.
_saved_grad
.
zero_
()
tests/test_tensor/test_gpt.py
View file @
cefc29ff
...
...
@@ -9,8 +9,10 @@ from colossalai.utils import ColoInitContext
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
,
DistSpecManager
,
distspec
from
colossalai.core
import
global_context
as
gpc
from
functools
import
partial
from
_utils
import
tensor_equal
,
tensor_shard_equal
from
_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
colossalai.nn.parallel
import
ColoDDP
def
init_1d_row_spec
(
model
):
...
...
@@ -43,7 +45,7 @@ def check_grad_equal(model, torch_model):
assert
tensor_shard_equal
(
torch_p
.
grad
,
p
.
grad
)
def
run_gpt
(
init_spec_func
):
def
run_gpt
(
init_spec_func
,
use_ddp
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
...
@@ -51,37 +53,50 @@ def run_gpt(init_spec_func):
model
=
model_builder
()
model
=
model
.
cuda
()
torch_model
=
model_builder
().
cuda
()
if
use_ddp
:
model
=
ColoDDP
(
model
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
gpc
.
get_global_rank
()],
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
)
init_spec_func
(
model
)
check_param_equal
(
model
,
torch_model
)
model
.
train
()
torch_model
.
train
()
set_seed
(
gpc
.
get_local_rank
(
ParallelMode
.
DATA
))
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
logits
=
model
(
input_ids
,
attn_mask
)
torch_logits
=
torch_model
(
input_ids
,
attn_mask
)
assert
tensor_equal
(
torch_logits
,
logits
)
loss
=
criterion
(
logits
,
input_ids
)
torch_loss
=
criterion
(
torch_logits
,
input_ids
)
loss
.
backward
()
if
use_ddp
:
model
.
backward
(
loss
)
else
:
loss
.
backward
()
torch_loss
.
backward
()
check_grad_equal
(
model
,
torch_model
)
if
i
>
0
:
break
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
):
if
use_ddp
and
world_size
==
1
:
return
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_gpt
(
init_1d_row_spec
)
run_gpt
(
init_1d_col_spec
)
run_gpt
(
init_1d_row_spec
,
use_ddp
)
run_gpt
(
init_1d_col_spec
,
use_ddp
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
,
True
])
@
rerun_if_address_is_in_use
()
def
test_gpt
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
def
test_gpt
(
world_size
,
use_ddp
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
()
,
use_ddp
=
use_ddp
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment