Unverified Commit f8a0e7fb authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #412 from hpcaitech/develop

merge develop to main
parents fc5101f2 21dc54e0
from typing import List
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten
from .tensor_shard_strategy import TensorShardStrategy
class BucketTensorShardStrategy(TensorShardStrategy):
def gather(self, tensor_list: List[ShardedTensor]):
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
if len(tensor_list) == 0:
return
target_device = tensor_list[0].device
dtype = tensor_list[0].dtype
buffer_list: List[torch.Tensor] = []
tensor_numels = [t.payload.numel() for t in tensor_list]
buffer_size = sum(tensor_numels)
for i in range(self.world_size):
if i == self.local_rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
# Release payload here, to decrease peak memory usage
for t in tensor_list:
t.reset_payload(None)
else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group)
# Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth
buffer_list = [buffer.to(target_device) for buffer in buffer_list]
offset = 0
for i, t in enumerate(tensor_list):
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
t.reset_payload(gathered_payload)
t.is_sharded = False
offset += tensor_numels[i]
......@@ -17,7 +17,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
get_gradient_predivide_factor)
......@@ -33,7 +34,8 @@ class ShardedModelV2(nn.Module):
fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
shard_param: bool = True):
shard_param: bool = True,
use_memory_tracer: bool = False):
r"""
A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States.
......@@ -59,8 +61,16 @@ class ShardedModelV2(nn.Module):
if self.shard_param:
self.shard_strategy.shard([param.col_attr.data])
# Init Memory Statistics Collector
self._use_memory_tracer = use_memory_tracer
if self._use_memory_tracer:
self._memstats_collector = MemStatsCollector()
else:
self._memstats_collector = None
self._iter_cnter = 0
# Register hooks
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy)])
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy, self._memstats_collector)])
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
......@@ -84,6 +94,9 @@ class ShardedModelV2(nn.Module):
return self._cpu_offload
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self._iter_cnter == 0 and self._memstats_collector:
# the opeartion will affect the flag in ZeroHook
self._memstats_collector.start_collection()
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs)
return outputs
......@@ -98,6 +111,12 @@ class ShardedModelV2(nn.Module):
@torch.no_grad()
def _final_backward_hook(self) -> None:
if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection()
if self._memstats_collector:
self._memstats_collector.reset_sampling_cnter()
self._iter_cnter += 1
if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self.comm_stream):
......@@ -185,8 +204,10 @@ class ShardedModelV2(nn.Module):
reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data)
# Maybe offload
# TODO() optimize GPU->CPU bandwidth utilization
if self._cpu_offload:
reduced_grad.data = reduced_grad.data.cpu()
col_move_to_cpu(reduced_grad)
# reduced_grad.data = reduced_grad.data.cpu()
if param.col_attr.grad is None:
param.col_attr.grad = reduced_grad.data
......
from enum import Enum
from typing import Dict, Optional
from typing import Callable, Dict, Optional, Union
import torch
import torch.distributed as dist
......@@ -15,7 +15,7 @@ from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from typing import Type, Any
from ._utils import has_inf_or_nan
......@@ -27,8 +27,8 @@ class OptimState(Enum):
class ShardedOptimizerV2(ColossalaiOptimizer):
def __init__(self,
optimizer: Optimizer,
sharded_model: ShardedModelV2,
optimizer_class: Type[Optimizer],
shard_strategy: BaseShardStrategy,
cpu_offload: bool = False,
initial_scale: float = 2**32,
......@@ -39,9 +39,34 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis: float = 2,
max_scale: int = 2**32,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
mp_process_group: Optional[ProcessGroup] = None,
**defaults: Any) -> None:
"""
:param sharded_model: A sharded model initialized by class ShardedModelV2
:type sharded_model: sharded_model
:param optimizer_class: A type of Optimizer
:type optimizer_class: Type[Optimizer]
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
:type defaults: dict()
"""
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
super().__init__(optimizer)
self._optim_defaults = defaults
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults)
super().__init__(self.optimizer)
self.shard_strategy = shard_strategy
self.model: ShardedModelV2 = sharded_model
if cpu_offload and not sharded_model.cpu_offload:
......@@ -65,7 +90,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {}
for group in optimizer.param_groups:
for group in self.optimizer.param_groups:
for p in group['params']:
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.col_attr.data.is_sharded
......@@ -118,7 +143,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.data is fp16
# TODO() optimize this line
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.col_attr.data.copy_payload(p.data)
if not is_param_sharded:
......
Subproject commit d50ef2db51e7d02ed3f7e9de13f9af86b04eaae9
Subproject commit 5345187ad55e8c80c111e0c5f7ad9b9241e8f913
......@@ -74,8 +74,5 @@ def get_training_components():
sequence_length=sequence_length,
is_distrbuted=True)
def get_optim(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = None
return bert_model_builder, trainloader, testloader, get_optim, criterion
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion
......@@ -49,8 +49,5 @@ def get_training_components():
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
......@@ -43,8 +43,5 @@ def get_training_components():
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
......@@ -29,8 +29,5 @@ def get_resnet_training_components():
trainloader = get_cifar10_dataloader(train=True)
testloader = get_cifar10_dataloader(train=False)
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
......@@ -19,11 +19,11 @@ def run_train():
# FIXME: test bert
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func()
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
model = model_builder(checkpoint=False)
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer_builder(model),
optimizer=optimizer_class(model.parameters(), lr=1e-3),
criterion=criterion,
train_dataloader=train_dataloader)
......@@ -84,7 +84,7 @@ def run_engine(rank, world_size, port):
@pytest.mark.dist
def test_engine():
world_size = 4
world_size = 2
run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -25,9 +25,9 @@ def run_trainer_no_pipeline(rank, world_size, port):
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
for name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(name)
model_builder, train_dataloader, test_dataloader, optimizer_builder, criterion = get_components_func()
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder()
optimizer = optimizer_builder(model)
optimizer = optimizer_class(model.parameters(), lr=1e-3)
engine, train_dataloader, *_ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
......
......@@ -56,6 +56,7 @@ def test_activation_checkpointing(cpu_offload):
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
torch.cuda.empty_cache()
# as seed manager is singleton
# if we don't reset seeds here,
# other tests will fail if running together with this test
......
......@@ -4,21 +4,20 @@
from functools import partial
import colossalai
from colossalai.utils.cuda import get_current_device
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
def run_dist(rank, world_size, port, init_device):
def run_dist(rank, world_size, port, init_device, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for get_components_func in non_distributed_component_funcs:
......@@ -26,7 +25,7 @@ def run_dist(rank, world_size, port, init_device):
model_numel_tensor = torch.zeros(1, dtype=torch.int)
with ZeroInitContext(convert_fp16=True,
target_device=init_device,
shard_strategy=TensorShardStrategy(),
shard_strategy=shard_strategy(),
shard_param=True,
model_numel_tensor=model_numel_tensor):
model = model_builder(checkpoint=True)
......@@ -38,23 +37,25 @@ def run_dist(rank, world_size, port, init_device):
assert param.col_attr.data.payload.device.type == init_device.type, \
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
print(f'cpu usgae {GLOBAL_MODEL_DATA_TRACER.cpu_usage}')
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
print(f'cuda usgae {ModelDataTracer().cuda_usage}')
print(f'numel {model_numel_tensor}')
if init_device.type == 'cuda':
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
elif init_device.type == 'cpu':
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
assert (ModelDataTracer().cuda_usage > 0)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4])
@pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
def test_zero_init_context(world_size, init_device):
run_func = partial(run_dist, world_size=world_size, port=free_port(), init_device=init_device)
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def test_zero_init_context(world_size, init_device, shard_strategy):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
init_device=init_device,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_init_context(2, torch.device('cpu'))
test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}'))
# test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
test_zero_init_context(4, torch.device('cpu'), BucketTensorShardStrategy)
......@@ -3,30 +3,29 @@
import copy
from functools import partial
import pytest
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding, run_fwd_bwd
from colossalai.zero.sharded_model.utils import col_model_deepcopy
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = TensorShardStrategy()
shard_strategy = shard_strategy()
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func()
......@@ -35,12 +34,12 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
if use_zero_init_ctx:
with ZeroInitContext(convert_fp16=True,
target_device=torch.device('cpu'),
target_device=torch.device(f'cpu:0'),
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy)
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
......@@ -61,19 +60,24 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
check_grads_padding(model, zero_model, loose=True)
print('overall cuda ', zero_model._memstats_collector._overall_cuda)
print('model cuda ', zero_model._memstats_collector._model_data_cuda)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("enable_autocast", [True])
@pytest.mark.parametrize("use_zero_init_ctx", [True])
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast):
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast, shard_strategy):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
use_zero_init_ctx=use_zero_init_ctx,
enable_autocast=enable_autocast)
enable_autocast=enable_autocast,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True)
test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True, shard_strategy=TensorShardStrategy)
......@@ -10,20 +10,20 @@ import torch
import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from tests.test_zero_data_parallel.common import CONFIG, allclose
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_zero_data_parallel.common import CONFIG, allclose
def _run_shard_tensor(rank, world_size, port):
def _run_shard_tensor(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
assert list(t.origin_shape) == [world_size * 2, 3]
assert list(t.shape) == [world_size * 2, 3]
shard_strategy = TensorShardStrategy(process_group=None)
shard_strategy = shard_strategy(process_group=None)
# test shard strategy
shard_strategy.shard([t])
......@@ -34,8 +34,9 @@ def _run_shard_tensor(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
def test_shard_tensor(world_size):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def test_shard_tensor(world_size, shard_strategy):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size)
......@@ -121,7 +122,7 @@ def test_init_shard_param(world_size):
if __name__ == '__main__':
test_shard_tensor(2)
test_shard_tensor(2, TensorShardStrategy)
test_shard_param(2)
test_shard_param_v2(2)
test_init_shard_param(4)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
......@@ -10,7 +7,7 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs
......@@ -38,25 +35,27 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
optimizer.step()
def run_dist(rank, world_size, port, cpu_offload):
def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy()
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model),
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None)
if dist.get_world_size() > 1:
model = DDP(model)
optim = Adam(model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3),
zero_model,
lr = 1e-3
optim = optimizer_class(model.parameters(), lr=lr)
sharded_optim = ShardedOptimizerV2(zero_model,
optimizer_class,
shard_strategy,
cpu_offload=cpu_offload,
initial_scale=2**5)
initial_scale=2**5,
lr=lr)
for i, (data, label) in enumerate(train_dataloader):
if i > 2:
break
......@@ -69,10 +68,15 @@ def run_dist(rank, world_size, port, cpu_offload):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("cpu_offload", [True, False])
def test_sharded_optim_v2(world_size, cpu_offload):
run_func = partial(run_dist, world_size=world_size, port=free_port(), cpu_offload=cpu_offload)
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
cpu_offload=cpu_offload,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim_v2(world_size=2, cpu_offload=True)
test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy)
\ No newline at end of file
......@@ -11,7 +11,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.nn.optimizer import CPUAdam
from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs
......@@ -47,23 +47,24 @@ def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
optimizer.step()
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy()
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config={'device': 'cpu'})
if dist.get_world_size() > 1:
model = DDP(model)
optim = Adam(model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(CPUAdam(zero_model.parameters(), lr=1e-3),
zero_model,
sharded_optim = ShardedOptimizerV2(zero_model,
CPUAdam,
shard_strategy,
initial_scale=2**5,
cpu_offload=True)
cpu_offload=True,
lr=1e-3)
for i, (data, label) in enumerate(train_dataloader):
if i > 2:
break
......@@ -79,10 +80,11 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
def test_sharded_optim_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def test_sharded_optim_v2(world_size, shard_strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim_v2(world_size=2)
test_sharded_optim_v2(world_size=2, shard_strategy=TensorShardStrategy)
......@@ -9,22 +9,21 @@ import pytest
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18']
shard_strategy = shard_strategy()
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model_builder()
shard_strategy = TensorShardStrategy()
model = model.half().cuda()
zero_model = ShardedModelV2(deepcopy(model), shard_strategy)
zero_state_dict = zero_model.state_dict()
......@@ -33,11 +32,12 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
def test_zero_state_dict():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
@pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def test_zero_state_dict(world_size, shard_strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_state_dict()
test_zero_state_dict(2, TensorShardStrategy)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment