Unverified Commit f8a0e7fb authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #412 from hpcaitech/develop

merge develop to main
parents fc5101f2 21dc54e0
from typing import List
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten
from .tensor_shard_strategy import TensorShardStrategy
class BucketTensorShardStrategy(TensorShardStrategy):
def gather(self, tensor_list: List[ShardedTensor]):
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
if len(tensor_list) == 0:
return
target_device = tensor_list[0].device
dtype = tensor_list[0].dtype
buffer_list: List[torch.Tensor] = []
tensor_numels = [t.payload.numel() for t in tensor_list]
buffer_size = sum(tensor_numels)
for i in range(self.world_size):
if i == self.local_rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
# Release payload here, to decrease peak memory usage
for t in tensor_list:
t.reset_payload(None)
else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group)
# Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth
buffer_list = [buffer.to(target_device) for buffer in buffer_list]
offset = 0
for i, t in enumerate(tensor_list):
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
t.reset_payload(gathered_payload)
t.is_sharded = False
offset += tensor_numels[i]
...@@ -17,7 +17,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer ...@@ -17,7 +17,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
get_gradient_predivide_factor) get_gradient_predivide_factor)
...@@ -33,7 +34,8 @@ class ShardedModelV2(nn.Module): ...@@ -33,7 +34,8 @@ class ShardedModelV2(nn.Module):
fp32_reduce_scatter: bool = False, fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None, offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0, gradient_predivide_factor: Optional[float] = 1.0,
shard_param: bool = True): shard_param: bool = True,
use_memory_tracer: bool = False):
r""" r"""
A demo to reconfigure zero1 shared_model. A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States. Currently do not consider the Optimizer States.
...@@ -59,8 +61,16 @@ class ShardedModelV2(nn.Module): ...@@ -59,8 +61,16 @@ class ShardedModelV2(nn.Module):
if self.shard_param: if self.shard_param:
self.shard_strategy.shard([param.col_attr.data]) self.shard_strategy.shard([param.col_attr.data])
# Init Memory Statistics Collector
self._use_memory_tracer = use_memory_tracer
if self._use_memory_tracer:
self._memstats_collector = MemStatsCollector()
else:
self._memstats_collector = None
self._iter_cnter = 0
# Register hooks # Register hooks
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy)]) register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy, self._memstats_collector)])
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
...@@ -84,6 +94,9 @@ class ShardedModelV2(nn.Module): ...@@ -84,6 +94,9 @@ class ShardedModelV2(nn.Module):
return self._cpu_offload return self._cpu_offload
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self._iter_cnter == 0 and self._memstats_collector:
# the opeartion will affect the flag in ZeroHook
self._memstats_collector.start_collection()
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
return outputs return outputs
...@@ -98,6 +111,12 @@ class ShardedModelV2(nn.Module): ...@@ -98,6 +111,12 @@ class ShardedModelV2(nn.Module):
@torch.no_grad() @torch.no_grad()
def _final_backward_hook(self) -> None: def _final_backward_hook(self) -> None:
if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection()
if self._memstats_collector:
self._memstats_collector.reset_sampling_cnter()
self._iter_cnter += 1
if self._require_backward_grad_sync: if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream. # Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self.comm_stream): with torch.cuda.stream(self.comm_stream):
...@@ -185,8 +204,10 @@ class ShardedModelV2(nn.Module): ...@@ -185,8 +204,10 @@ class ShardedModelV2(nn.Module):
reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data) reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data)
# Maybe offload # Maybe offload
# TODO() optimize GPU->CPU bandwidth utilization
if self._cpu_offload: if self._cpu_offload:
reduced_grad.data = reduced_grad.data.cpu() col_move_to_cpu(reduced_grad)
# reduced_grad.data = reduced_grad.data.cpu()
if param.col_attr.grad is None: if param.col_attr.grad is None:
param.col_attr.grad = reduced_grad.data param.col_attr.grad = reduced_grad.data
......
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Callable, Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -15,7 +15,7 @@ from torch import Tensor ...@@ -15,7 +15,7 @@ from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from typing import Type, Any
from ._utils import has_inf_or_nan from ._utils import has_inf_or_nan
...@@ -27,8 +27,8 @@ class OptimState(Enum): ...@@ -27,8 +27,8 @@ class OptimState(Enum):
class ShardedOptimizerV2(ColossalaiOptimizer): class ShardedOptimizerV2(ColossalaiOptimizer):
def __init__(self, def __init__(self,
optimizer: Optimizer,
sharded_model: ShardedModelV2, sharded_model: ShardedModelV2,
optimizer_class: Type[Optimizer],
shard_strategy: BaseShardStrategy, shard_strategy: BaseShardStrategy,
cpu_offload: bool = False, cpu_offload: bool = False,
initial_scale: float = 2**32, initial_scale: float = 2**32,
...@@ -39,9 +39,34 @@ class ShardedOptimizerV2(ColossalaiOptimizer): ...@@ -39,9 +39,34 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis: float = 2, hysteresis: float = 2,
max_scale: int = 2**32, max_scale: int = 2**32,
dp_process_group: Optional[ProcessGroup] = None, dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None: mp_process_group: Optional[ProcessGroup] = None,
**defaults: Any) -> None:
"""
:param sharded_model: A sharded model initialized by class ShardedModelV2
:type sharded_model: sharded_model
:param optimizer_class: A type of Optimizer
:type optimizer_class: Type[Optimizer]
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
:type defaults: dict()
"""
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
super().__init__(optimizer)
self._optim_defaults = defaults
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults)
super().__init__(self.optimizer)
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.model: ShardedModelV2 = sharded_model self.model: ShardedModelV2 = sharded_model
if cpu_offload and not sharded_model.cpu_offload: if cpu_offload and not sharded_model.cpu_offload:
...@@ -65,7 +90,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): ...@@ -65,7 +90,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Store fp32 param shards # Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {} self.master_params: Dict[Parameter, Tensor] = {}
for group in optimizer.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam' assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.col_attr.data.is_sharded is_param_sharded = p.col_attr.data.is_sharded
...@@ -118,7 +143,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): ...@@ -118,7 +143,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We have to use `copy_payload` instead of `reset_payload` # We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.data is fp16 # Since p.data is fp32 and p.col_attr.data is fp16
# TODO() optimize this line # TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.col_attr.data.copy_payload(p.data) p.col_attr.data.copy_payload(p.data)
if not is_param_sharded: if not is_param_sharded:
......
Subproject commit d50ef2db51e7d02ed3f7e9de13f9af86b04eaae9 Subproject commit 5345187ad55e8c80c111e0c5f7ad9b9241e8f913
...@@ -74,8 +74,5 @@ def get_training_components(): ...@@ -74,8 +74,5 @@ def get_training_components():
sequence_length=sequence_length, sequence_length=sequence_length,
is_distrbuted=True) is_distrbuted=True)
def get_optim(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = None criterion = None
return bert_model_builder, trainloader, testloader, get_optim, criterion return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion
...@@ -49,8 +49,5 @@ def get_training_components(): ...@@ -49,8 +49,5 @@ def get_training_components():
trainloader = DummyDataLoader() trainloader = DummyDataLoader()
testloader = DummyDataLoader() testloader = DummyDataLoader()
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion return model_builder, trainloader, testloader, torch.optim.Adam, criterion
...@@ -43,8 +43,5 @@ def get_training_components(): ...@@ -43,8 +43,5 @@ def get_training_components():
trainloader = DummyDataLoader() trainloader = DummyDataLoader()
testloader = DummyDataLoader() testloader = DummyDataLoader()
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion return model_builder, trainloader, testloader, torch.optim.Adam, criterion
...@@ -29,8 +29,5 @@ def get_resnet_training_components(): ...@@ -29,8 +29,5 @@ def get_resnet_training_components():
trainloader = get_cifar10_dataloader(train=True) trainloader = get_cifar10_dataloader(train=True)
testloader = get_cifar10_dataloader(train=False) testloader = get_cifar10_dataloader(train=False)
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion return model_builder, trainloader, testloader, torch.optim.Adam, criterion
...@@ -19,11 +19,11 @@ def run_train(): ...@@ -19,11 +19,11 @@ def run_train():
# FIXME: test bert # FIXME: test bert
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func() model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
model = model_builder(checkpoint=False) model = model_builder(checkpoint=False)
engine, train_dataloader, *args = colossalai.initialize(model=model, engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer_builder(model), optimizer=optimizer_class(model.parameters(), lr=1e-3),
criterion=criterion, criterion=criterion,
train_dataloader=train_dataloader) train_dataloader=train_dataloader)
...@@ -84,7 +84,7 @@ def run_engine(rank, world_size, port): ...@@ -84,7 +84,7 @@ def run_engine(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
def test_engine(): def test_engine():
world_size = 4 world_size = 2
run_func = partial(run_engine, world_size=world_size, port=free_port()) run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -25,9 +25,9 @@ def run_trainer_no_pipeline(rank, world_size, port): ...@@ -25,9 +25,9 @@ def run_trainer_no_pipeline(rank, world_size, port):
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model'] test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
for name in test_models: for name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(name) get_components_func = non_distributed_component_funcs.get_callable(name)
model_builder, train_dataloader, test_dataloader, optimizer_builder, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder() model = model_builder()
optimizer = optimizer_builder(model) optimizer = optimizer_class(model.parameters(), lr=1e-3)
engine, train_dataloader, *_ = colossalai.initialize(model=model, engine, train_dataloader, *_ = colossalai.initialize(model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
......
...@@ -56,6 +56,7 @@ def test_activation_checkpointing(cpu_offload): ...@@ -56,6 +56,7 @@ def test_activation_checkpointing(cpu_offload):
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match' assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
torch.cuda.empty_cache() torch.cuda.empty_cache()
# as seed manager is singleton # as seed manager is singleton
# if we don't reset seeds here, # if we don't reset seeds here,
# other tests will fail if running together with this test # other tests will fail if running together with this test
......
...@@ -4,21 +4,20 @@ ...@@ -4,21 +4,20 @@
from functools import partial from functools import partial
import colossalai import colossalai
from colossalai.utils.cuda import get_current_device
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
TensorShardStrategy
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
def run_dist(rank, world_size, port, init_device): def run_dist(rank, world_size, port, init_device, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for get_components_func in non_distributed_component_funcs: for get_components_func in non_distributed_component_funcs:
...@@ -26,7 +25,7 @@ def run_dist(rank, world_size, port, init_device): ...@@ -26,7 +25,7 @@ def run_dist(rank, world_size, port, init_device):
model_numel_tensor = torch.zeros(1, dtype=torch.int) model_numel_tensor = torch.zeros(1, dtype=torch.int)
with ZeroInitContext(convert_fp16=True, with ZeroInitContext(convert_fp16=True,
target_device=init_device, target_device=init_device,
shard_strategy=TensorShardStrategy(), shard_strategy=shard_strategy(),
shard_param=True, shard_param=True,
model_numel_tensor=model_numel_tensor): model_numel_tensor=model_numel_tensor):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
...@@ -38,23 +37,25 @@ def run_dist(rank, world_size, port, init_device): ...@@ -38,23 +37,25 @@ def run_dist(rank, world_size, port, init_device):
assert param.col_attr.data.payload.device.type == init_device.type, \ assert param.col_attr.data.payload.device.type == init_device.type, \
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}' f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
print(f'cpu usgae {GLOBAL_MODEL_DATA_TRACER.cpu_usage}') print(f'cuda usgae {ModelDataTracer().cuda_usage}')
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
print(f'numel {model_numel_tensor}') print(f'numel {model_numel_tensor}')
if init_device.type == 'cuda': if init_device.type == 'cuda':
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0) assert (ModelDataTracer().cuda_usage > 0)
elif init_device.type == 'cpu':
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4]) @pytest.mark.parametrize("world_size", [1, 4])
@pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')]) @pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
def test_zero_init_context(world_size, init_device): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port(), init_device=init_device) def test_zero_init_context(world_size, init_device, shard_strategy):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
init_device=init_device,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_init_context(2, torch.device('cpu')) # test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}')) test_zero_init_context(4, torch.device('cpu'), BucketTensorShardStrategy)
...@@ -3,30 +3,29 @@ ...@@ -3,30 +3,29 @@
import copy import copy
from functools import partial from functools import partial
import pytest
import colossalai
import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.init_ctx import ZeroInitContext
TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding, run_fwd_bwd from common import CONFIG, check_grads_padding, run_fwd_bwd
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast): def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = TensorShardStrategy() shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func() model_builder, train_dataloader, _, _, criterion = get_components_func()
...@@ -35,12 +34,12 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast): ...@@ -35,12 +34,12 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
if use_zero_init_ctx: if use_zero_init_ctx:
with ZeroInitContext(convert_fp16=True, with ZeroInitContext(convert_fp16=True,
target_device=torch.device('cpu'), target_device=torch.device(f'cpu:0'),
shard_strategy=shard_strategy, shard_strategy=shard_strategy,
shard_param=True, shard_param=True,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly): rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
zero_model = model_builder(checkpoint=True) zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy) zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
model = model_builder(checkpoint=True).half() model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model) col_model_deepcopy(zero_model, model)
...@@ -61,19 +60,24 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast): ...@@ -61,19 +60,24 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
check_grads_padding(model, zero_model, loose=True) check_grads_padding(model, zero_model, loose=True)
print('overall cuda ', zero_model._memstats_collector._overall_cuda)
print('model cuda ', zero_model._memstats_collector._model_data_cuda)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("enable_autocast", [True]) @pytest.mark.parametrize("enable_autocast", [True])
@pytest.mark.parametrize("use_zero_init_ctx", [True]) @pytest.mark.parametrize("use_zero_init_ctx", [True])
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast, shard_strategy):
run_func = partial(run_dist, run_func = partial(run_dist,
world_size=world_size, world_size=world_size,
port=free_port(), port=free_port(),
use_zero_init_ctx=use_zero_init_ctx, use_zero_init_ctx=use_zero_init_ctx,
enable_autocast=enable_autocast) enable_autocast=enable_autocast,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True) test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True, shard_strategy=TensorShardStrategy)
...@@ -10,20 +10,20 @@ import torch ...@@ -10,20 +10,20 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from tests.test_zero_data_parallel.common import CONFIG, allclose
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_zero_data_parallel.common import CONFIG, allclose
def _run_shard_tensor(rank, world_size, port): def _run_shard_tensor(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
assert list(t.origin_shape) == [world_size * 2, 3] assert list(t.origin_shape) == [world_size * 2, 3]
assert list(t.shape) == [world_size * 2, 3] assert list(t.shape) == [world_size * 2, 3]
shard_strategy = TensorShardStrategy(process_group=None) shard_strategy = shard_strategy(process_group=None)
# test shard strategy # test shard strategy
shard_strategy.shard([t]) shard_strategy.shard([t])
...@@ -34,8 +34,9 @@ def _run_shard_tensor(rank, world_size, port): ...@@ -34,8 +34,9 @@ def _run_shard_tensor(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
def test_shard_tensor(world_size): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port()) def test_shard_tensor(world_size, shard_strategy):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
...@@ -121,7 +122,7 @@ def test_init_shard_param(world_size): ...@@ -121,7 +122,7 @@ def test_init_shard_param(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_shard_tensor(2) test_shard_tensor(2, TensorShardStrategy)
test_shard_param(2) test_shard_param(2)
test_shard_param_v2(2) test_shard_param_v2(2)
test_init_shard_param(4) test_init_shard_param(4)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy import copy
from functools import partial from functools import partial
...@@ -10,7 +7,7 @@ import torch ...@@ -10,7 +7,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
...@@ -38,25 +35,27 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False): ...@@ -38,25 +35,27 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
optimizer.step() optimizer.step()
def run_dist(rank, world_size, port, cpu_offload): def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy() model, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model(checkpoint=True).cuda() model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), zero_model = ShardedModelV2(copy.deepcopy(model),
shard_strategy, shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None) offload_config=dict(device='cpu') if cpu_offload else None)
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
model = DDP(model) model = DDP(model)
optim = Adam(model.parameters(), lr=1e-3) lr = 1e-3
sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3), optim = optimizer_class(model.parameters(), lr=lr)
zero_model, sharded_optim = ShardedOptimizerV2(zero_model,
optimizer_class,
shard_strategy, shard_strategy,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
initial_scale=2**5) initial_scale=2**5,
lr=lr)
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
...@@ -69,10 +68,15 @@ def run_dist(rank, world_size, port, cpu_offload): ...@@ -69,10 +68,15 @@ def run_dist(rank, world_size, port, cpu_offload):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("cpu_offload", [True, False]) @pytest.mark.parametrize("cpu_offload", [True, False])
def test_sharded_optim_v2(world_size, cpu_offload): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port(), cpu_offload=cpu_offload) def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
cpu_offload=cpu_offload,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_sharded_optim_v2(world_size=2, cpu_offload=True) test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy)
\ No newline at end of file
...@@ -11,7 +11,7 @@ import torch.distributed as dist ...@@ -11,7 +11,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.nn.optimizer import CPUAdam from colossalai.nn.optimizer import CPUAdam
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
...@@ -47,23 +47,24 @@ def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False): ...@@ -47,23 +47,24 @@ def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
optimizer.step() optimizer.step()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model(checkpoint=True).cuda() model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config={'device': 'cpu'}) zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config={'device': 'cpu'})
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
model = DDP(model) model = DDP(model)
optim = Adam(model.parameters(), lr=1e-3) optim = Adam(model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(CPUAdam(zero_model.parameters(), lr=1e-3), sharded_optim = ShardedOptimizerV2(zero_model,
zero_model, CPUAdam,
shard_strategy, shard_strategy,
initial_scale=2**5, initial_scale=2**5,
cpu_offload=True) cpu_offload=True,
lr=1e-3)
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
...@@ -79,10 +80,11 @@ def run_dist(rank, world_size, port): ...@@ -79,10 +80,11 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
def test_sharded_optim_v2(world_size): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port()) def test_sharded_optim_v2(world_size, shard_strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_sharded_optim_v2(world_size=2) test_sharded_optim_v2(world_size=2, shard_strategy=TensorShardStrategy)
...@@ -9,22 +9,21 @@ import pytest ...@@ -9,22 +9,21 @@ import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18'] test_models = ['repeated_computed_layers', 'resnet18']
shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model_builder() model = model_builder()
shard_strategy = TensorShardStrategy()
model = model.half().cuda() model = model.half().cuda()
zero_model = ShardedModelV2(deepcopy(model), shard_strategy) zero_model = ShardedModelV2(deepcopy(model), shard_strategy)
zero_state_dict = zero_model.state_dict() zero_state_dict = zero_model.state_dict()
...@@ -33,11 +32,12 @@ def run_dist(rank, world_size, port): ...@@ -33,11 +32,12 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
def test_zero_state_dict(): @pytest.mark.parametrize("world_size", [1, 2])
world_size = 2 @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port()) def test_zero_state_dict(world_size, shard_strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_state_dict() test_zero_state_dict(2, TensorShardStrategy)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment