Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
fce9432f
Commit
fce9432f
authored
Mar 16, 2022
by
ver217
Browse files
sync before creating empty grad
parent
ea6905a8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
4 deletions
+7
-4
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+1
-0
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+6
-4
No files found.
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
fce9432f
...
@@ -218,6 +218,7 @@ class ShardedModelV2(nn.Module):
...
@@ -218,6 +218,7 @@ class ShardedModelV2(nn.Module):
else
:
else
:
self
.
_reduce_scatter_callback
(
param
,
new_grad
)
self
.
_reduce_scatter_callback
(
param
,
new_grad
)
orig_grad_data
.
record_stream
(
self
.
comm_stream
)
orig_grad_data
.
record_stream
(
self
.
comm_stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
comm_stream
)
empty_grad
=
torch
.
empty_like
(
grad
)
empty_grad
=
torch
.
empty_like
(
grad
)
free_storage
(
empty_grad
)
free_storage
(
empty_grad
)
return
empty_grad
return
empty_grad
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
fce9432f
...
@@ -2,12 +2,14 @@
...
@@ -2,12 +2,14 @@
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
import
copy
import
copy
from
asyncio.log
import
logger
from
functools
import
partial
from
functools
import
partial
import
colossalai
import
colossalai
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
...
@@ -18,12 +20,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs
...
@@ -18,12 +20,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
,
shard_strategy
):
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
logger
=
get_dist_logger
()
logger
.
set_level
(
'DEBUG'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
shard_strategy
()
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
for
model_name
in
test_models
:
...
@@ -60,8 +62,8 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
...
@@ -60,8 +62,8 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
print
(
'overall cuda '
,
zero_model
.
_memstats_collector
.
_overall_cuda
)
# logger.debug
('overall cuda ', zero_model._memstats_collector._overall_cuda)
print
(
'model cuda '
,
zero_model
.
_memstats_collector
.
_model_data_cuda
)
# logger.debug
('model cuda ', zero_model._memstats_collector._model_data_cuda)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment