Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
fce9432f
Commit
fce9432f
authored
Mar 16, 2022
by
ver217
Browse files
sync before creating empty grad
parent
ea6905a8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
4 deletions
+7
-4
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+1
-0
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+6
-4
No files found.
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
fce9432f
...
...
@@ -218,6 +218,7 @@ class ShardedModelV2(nn.Module):
else
:
self
.
_reduce_scatter_callback
(
param
,
new_grad
)
orig_grad_data
.
record_stream
(
self
.
comm_stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
comm_stream
)
empty_grad
=
torch
.
empty_like
(
grad
)
free_storage
(
empty_grad
)
return
empty_grad
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
fce9432f
...
...
@@ -2,12 +2,14 @@
# -*- encoding: utf-8 -*-
import
copy
from
asyncio.log
import
logger
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
...
...
@@ -18,12 +20,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
logger
=
get_dist_logger
()
logger
.
set_level
(
'DEBUG'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
...
...
@@ -60,8 +62,8 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
print
(
'overall cuda '
,
zero_model
.
_memstats_collector
.
_overall_cuda
)
print
(
'model cuda '
,
zero_model
.
_memstats_collector
.
_model_data_cuda
)
# logger.debug
('overall cuda ', zero_model._memstats_collector._overall_cuda)
# logger.debug
('model cuda ', zero_model._memstats_collector._model_data_cuda)
@
pytest
.
mark
.
dist
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment