Commit fce9432f authored by ver217's avatar ver217
Browse files

sync before creating empty grad

parent ea6905a8
...@@ -218,6 +218,7 @@ class ShardedModelV2(nn.Module): ...@@ -218,6 +218,7 @@ class ShardedModelV2(nn.Module):
else: else:
self._reduce_scatter_callback(param, new_grad) self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream) orig_grad_data.record_stream(self.comm_stream)
torch.cuda.current_stream().wait_stream(self.comm_stream)
empty_grad = torch.empty_like(grad) empty_grad = torch.empty_like(grad)
free_storage(empty_grad) free_storage(empty_grad)
return empty_grad return empty_grad
......
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import copy import copy
from asyncio.log import logger
from functools import partial from functools import partial
import colossalai import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.logging import get_dist_logger
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
...@@ -18,12 +20,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs ...@@ -18,12 +20,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding, run_fwd_bwd from common import CONFIG, check_grads_padding, run_fwd_bwd
from colossalai.zero.sharded_model.utils import col_model_deepcopy
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy): def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
logger = get_dist_logger()
logger.set_level('DEBUG')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy() shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
...@@ -60,8 +62,8 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s ...@@ -60,8 +62,8 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
check_grads_padding(model, zero_model, loose=True) check_grads_padding(model, zero_model, loose=True)
print('overall cuda ', zero_model._memstats_collector._overall_cuda) # logger.debug('overall cuda ', zero_model._memstats_collector._overall_cuda)
print('model cuda ', zero_model._memstats_collector._model_data_cuda) # logger.debug('model cuda ', zero_model._memstats_collector._model_data_cuda)
@pytest.mark.dist @pytest.mark.dist
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment