Unverified Commit a241f61b authored by ver217's avatar ver217 Committed by GitHub
Browse files

[zero] Update initialize for ZeRO (#458)

* polish code

* shard strategy receive pg in shard() / gather()

* update zero engine

* polish code
parent 642846d6
from typing import Optional from typing import Optional
import torch import torch
import torch.distributed as dist
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
...@@ -17,9 +18,13 @@ class ZeroHook(BaseOpHook): ...@@ -17,9 +18,13 @@ class ZeroHook(BaseOpHook):
A hook to process sharded param for ZeRO method. A hook to process sharded param for ZeRO method.
""" """
def __init__(self, shard_strategy: BaseShardStrategy, memstarts_collector: Optional[MemStatsCollector]): def __init__(self,
shard_strategy: BaseShardStrategy,
memstarts_collector: Optional[MemStatsCollector],
process_group: Optional[dist.ProcessGroup] = None):
super().__init__() super().__init__()
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.process_group = process_group
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = torch.device(f'cuda:{get_current_device()}') self.computing_device = torch.device(f'cuda:{get_current_device()}')
...@@ -30,7 +35,7 @@ class ZeroHook(BaseOpHook): ...@@ -30,7 +35,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data) tensor_list.append(param.col_attr.data)
self.shard_strategy.gather(tensor_list) self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(): for param in module.parameters():
if param.col_attr.data.device != self.computing_device: if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device) param.col_attr.data.to(self.computing_device)
...@@ -45,7 +50,7 @@ class ZeroHook(BaseOpHook): ...@@ -45,7 +50,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data) tensor_list.append(param.col_attr.data)
self.shard_strategy.shard(tensor_list) self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters(): for param in module.parameters():
param.col_attr.remove_torch_payload() param.col_attr.remove_torch_payload()
...@@ -54,7 +59,7 @@ class ZeroHook(BaseOpHook): ...@@ -54,7 +59,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data) tensor_list.append(param.col_attr.data)
self.shard_strategy.gather(tensor_list) self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(): for param in module.parameters():
if param.col_attr.data.device != self.computing_device: if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device) param.col_attr.data.to(self.computing_device)
...@@ -80,7 +85,7 @@ class ZeroHook(BaseOpHook): ...@@ -80,7 +85,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data) tensor_list.append(param.col_attr.data)
self.shard_strategy.shard(tensor_list) self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters(): for param in module.parameters():
param.col_attr.remove_torch_payload() param.col_attr.remove_torch_payload()
......
...@@ -278,7 +278,10 @@ def initialize(model: nn.Module, ...@@ -278,7 +278,10 @@ def initialize(model: nn.Module,
cfg_ = {} cfg_ = {}
optimizer_config = zero_cfg.get('optimizer_config', None) optimizer_config = zero_cfg.get('optimizer_config', None)
model_config = zero_cfg.get('model_config', None) model_config = zero_cfg.get('model_config', None)
model, optimizer = convert_to_zero_v2(model, model_config=model_config, optimizer_config=optimizer_config) model, optimizer = convert_to_zero_v2(model,
optimizer,
model_config=model_config,
optimizer_config=optimizer_config)
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
# FIXME() throw a warning if using zero with MP # FIXME() throw a warning if using zero with MP
......
from typing import Tuple from typing import Tuple
import torch
import torch.nn as nn import torch.nn as nn
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
...@@ -11,7 +12,8 @@ from .sharded_model import ShardedModel ...@@ -11,7 +12,8 @@ from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer from .sharded_optim import ShardedOptimizer
def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
""" """
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
...@@ -34,7 +36,7 @@ def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tupl ...@@ -34,7 +36,7 @@ def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tupl
model_config = dict() model_config = dict()
zero_model = ShardedModelV2(model, **model_config) zero_model = ShardedModelV2(model, **model_config)
zero_optimizer = ShardedOptimizerV2(zero_model, **optimizer_config) zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
return zero_model, zero_optimizer return zero_model, zero_optimizer
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
import torch.distributed as dist
from typing import List, Optional from typing import List, Optional
import torch.distributed as dist
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
class BaseShardStrategy(ABC): class BaseShardStrategy(ABC):
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None: def __init__(self) -> None:
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs. """Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
Args:
process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None.
""" """
self.process_group = process_group
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
super().__init__() super().__init__()
@abstractmethod @abstractmethod
def shard(self, tensor_list: List[ShardedTensor]): def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
pass pass
@abstractmethod @abstractmethod
def gather(self, tensor_list: List[ShardedTensor]): def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
pass pass
from typing import List from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch._utils import _flatten_dense_tensors as flatten
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten
from .tensor_shard_strategy import TensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy
class BucketTensorShardStrategy(TensorShardStrategy): class BucketTensorShardStrategy(TensorShardStrategy):
def gather(self, tensor_list: List[ShardedTensor]): def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded] tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
if len(tensor_list) == 0: if len(tensor_list) == 0:
return return
...@@ -21,15 +20,17 @@ class BucketTensorShardStrategy(TensorShardStrategy): ...@@ -21,15 +20,17 @@ class BucketTensorShardStrategy(TensorShardStrategy):
buffer_list: List[torch.Tensor] = [] buffer_list: List[torch.Tensor] = []
tensor_numels = [t.payload.numel() for t in tensor_list] tensor_numels = [t.payload.numel() for t in tensor_list]
buffer_size = sum(tensor_numels) buffer_size = sum(tensor_numels)
for i in range(self.world_size): world_size = dist.get_world_size(process_group)
if i == self.local_rank: rank = dist.get_rank(process_group)
for i in range(world_size):
if i == rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
# Release payload here, to decrease peak memory usage # Release payload here, to decrease peak memory usage
for t in tensor_list: for t in tensor_list:
t.reset_payload(None) t.reset_payload(None)
else: else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group) dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
# Move to target device before splitting buffer # Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth # Ensure we utilize maximum PCIE bandwidth
buffer_list = [buffer.to(target_device) for buffer in buffer_list] buffer_list = [buffer.to(target_device) for buffer in buffer_list]
......
...@@ -2,49 +2,44 @@ from typing import List, Optional ...@@ -2,49 +2,44 @@ from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
class TensorShardStrategy(BaseShardStrategy): class TensorShardStrategy(BaseShardStrategy):
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None: def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
super().__init__(process_group)
def shard(self, tensor_list: List[ShardedTensor]):
for t in tensor_list: for t in tensor_list:
self._shard_tensor(t) self._shard_tensor(t, process_group)
def gather(self, tensor_list: List[ShardedTensor]): def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
for t in tensor_list: for t in tensor_list:
self._gather_tensor(t) self._gather_tensor(t, process_group)
def _shard_tensor(self, t: ShardedTensor): def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
if t.is_sharded: if t.is_sharded:
return return
sharded_payload, _ = get_shard(t.payload, self.local_rank, self.world_size) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.reset_payload(sharded_payload) t.reset_payload(sharded_payload)
t.is_sharded = True t.is_sharded = True
def _gather_tensor(self, t: ShardedTensor): def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
if not t.is_sharded: if not t.is_sharded:
return return
target_device = t.device target_device = t.device
buffer_list = [] buffer_list = []
payload_numel = t.payload.numel() payload_numel = t.payload.numel()
for i in range(self.world_size): world_size = dist.get_world_size(process_group)
if i == self.local_rank: rank = dist.get_rank(process_group)
for i in range(world_size):
if i == rank:
buffer_list.append(t.payload.cuda(get_current_device())) buffer_list.append(t.payload.cuda(get_current_device()))
else: else:
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device())) buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
torch.distributed.all_gather(buffer_list, dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
buffer_list[self.local_rank],
group=self.process_group,
async_op=False)
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape) gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
t.reset_payload(gathered_payload) t.reset_payload(gathered_payload)
t.to(target_device) t.to(target_device)
......
...@@ -70,7 +70,8 @@ class ShardedModelV2(nn.Module): ...@@ -70,7 +70,8 @@ class ShardedModelV2(nn.Module):
self._iter_cnter = 0 self._iter_cnter = 0
# Register hooks # Register hooks
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy, self._memstats_collector)]) register_ophooks_recursively(self.module,
[ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)])
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
...@@ -145,7 +146,7 @@ class ShardedModelV2(nn.Module): ...@@ -145,7 +146,7 @@ class ShardedModelV2(nn.Module):
if self.shard_param: if self.shard_param:
for p in self.module.parameters(): for p in self.module.parameters():
if not p.col_attr.param_is_sharded: if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.data]) self.shard_strategy.shard([p.col_attr.data], self.process_group)
for p in self.module.parameters(): for p in self.module.parameters():
p.col_attr.bwd_count = 0 p.col_attr.bwd_count = 0
if not p.requires_grad: if not p.requires_grad:
...@@ -229,13 +230,13 @@ class ShardedModelV2(nn.Module): ...@@ -229,13 +230,13 @@ class ShardedModelV2(nn.Module):
param.col_attr.fp16_grad = reduced_grad.data param.col_attr.fp16_grad = reduced_grad.data
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()]) self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()], self.process_group)
prev_params = {} prev_params = {}
for p in self.module.parameters(): for p in self.module.parameters():
prev_params[p] = p.data prev_params[p] = p.data
p.data = p.col_attr.data.payload p.data = p.col_attr.data.payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars) gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()]) self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()], self.process_group)
for p in self.module.parameters(): for p in self.module.parameters():
p.data = prev_params[p] p.data = prev_params[p]
return gathered_state_dict return gathered_state_dict
......
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
...@@ -101,6 +102,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): ...@@ -101,6 +102,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis=hysteresis, hysteresis=hysteresis,
max_scale=max_scale) max_scale=max_scale)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device()) self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger()
# Store fp32 param shards # Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {} self.master_params: Dict[Parameter, Tensor] = {}
...@@ -113,12 +115,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer): ...@@ -113,12 +115,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# TODO (ver217): we may not use shard / gather here # TODO (ver217): we may not use shard / gather here
# Param is no sharded, which means we use ZeRO-2 here # Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here # As we only store param shard, we shard it here
self.shard_strategy.shard([p.col_attr.data]) self.shard_strategy.shard([p.col_attr.data], self.dp_process_group)
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device) self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device)
if not is_param_sharded: if not is_param_sharded:
# In this branch, there's no need to shard param # In this branch, there's no need to shard param
# So we gather here # So we gather here
self.shard_strategy.gather([p.col_attr.data]) self.shard_strategy.gather([p.col_attr.data], self.dp_process_group)
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
# unscale grads if scaled # unscale grads if scaled
...@@ -155,7 +157,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): ...@@ -155,7 +157,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# But we only have updated fp32 param shard here # But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it # So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them # Then we will gather them
self.shard_strategy.shard([p.col_attr.data]) self.shard_strategy.shard([p.col_attr.data], self.dp_process_group)
# We have to use `copy_payload` instead of `reset_payload` # We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.data is fp16 # Since p.data is fp32 and p.col_attr.data is fp16
...@@ -164,7 +166,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): ...@@ -164,7 +166,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if not is_param_sharded: if not is_param_sharded:
# We gather full fp16 param here # We gather full fp16 param here
self.shard_strategy.gather([p.col_attr.data]) self.shard_strategy.gather([p.col_attr.data], self.dp_process_group)
p.data = p.col_attr.data.payload p.data = p.col_attr.data.payload
return ret return ret
......
...@@ -16,7 +16,7 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, ...@@ -16,7 +16,7 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
offload_config=None, offload_config=None,
gradient_predivide_factor=1.0, gradient_predivide_factor=1.0,
use_memory_tracer=False, use_memory_tracer=False,
shard_strategy=TensorShardStrategy) shard_strategy=TensorShardStrategy())
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False, _ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
initial_scale=2**5, initial_scale=2**5,
...@@ -25,8 +25,7 @@ _ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False, ...@@ -25,8 +25,7 @@ _ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
backoff_factor=0.5, backoff_factor=0.5,
growth_interval=1000, growth_interval=1000,
hysteresis=2, hysteresis=2,
max_scale=2**32, max_scale=2**32)
lr=1e-3)
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict( zero=dict(
......
...@@ -7,26 +7,27 @@ import colossalai ...@@ -7,26 +7,27 @@ import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.testing import parameterize
@parameterize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')]) @parameterize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(init_device, shard_strategy): def run_model_test(init_device, shard_strategy_class):
for get_components_func in non_distributed_component_funcs: for get_components_func in non_distributed_component_funcs:
model_builder, _, _, _, _ = get_components_func() model_builder, _, _, _, _ = get_components_func()
model_numel_tensor = torch.zeros(1, dtype=torch.int) model_numel_tensor = torch.zeros(1, dtype=torch.int)
with ZeroInitContext(convert_fp16=True, with ZeroInitContext(convert_fp16=True,
target_device=init_device, target_device=init_device,
shard_strategy=shard_strategy(), shard_strategy=shard_strategy_class(),
shard_param=True, shard_param=True,
model_numel_tensor=model_numel_tensor): model_numel_tensor=model_numel_tensor):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
......
...@@ -9,22 +9,22 @@ import pytest ...@@ -9,22 +9,22 @@ import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.testing import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_zero_data_parallel.common import CONFIG, allclose from tests.test_zero_data_parallel.common import CONFIG, allclose
from colossalai.testing import parameterize
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_shard_tensor_with_strategy(shard_strategy, world_size): def run_shard_tensor_with_strategy(shard_strategy_class, world_size):
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
assert list(t.origin_shape) == [world_size * 2, 3] assert list(t.origin_shape) == [world_size * 2, 3]
assert list(t.shape) == [world_size * 2, 3] assert list(t.shape) == [world_size * 2, 3]
shard_strategy = shard_strategy(process_group=None) shard_strategy = shard_strategy_class()
# test shard strategy # test shard strategy
shard_strategy.shard([t]) shard_strategy.shard([t])
......
...@@ -11,6 +11,8 @@ import torch.multiprocessing as mp ...@@ -11,6 +11,8 @@ import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
from torchvision.models import resnet50 from torchvision.models import resnet50
...@@ -19,7 +21,7 @@ def run_dist(rank, world_size, port): ...@@ -19,7 +21,7 @@ def run_dist(rank, world_size, port):
# as this model has sync batch normalization # as this model has sync batch normalization
# need to configure cudnn deterministic so that # need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled # randomness of convolution layers will be disabled
zero_config = dict(optimizer_config=dict(optimizer_class=torch.optim.Adam, lr=1e-3)) zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy()))
colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False), colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False),
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
...@@ -27,7 +29,11 @@ def run_dist(rank, world_size, port): ...@@ -27,7 +29,11 @@ def run_dist(rank, world_size, port):
port=port, port=port,
backend='nccl') backend='nccl')
model = resnet50() with ZeroInitContext(convert_fp16=True,
target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True):
model = resnet50()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
...@@ -64,10 +70,6 @@ def run_dist(rank, world_size, port): ...@@ -64,10 +70,6 @@ def run_dist(rank, world_size, port):
'expected the output from different ranks to be the same, but got different values' 'expected the output from different ranks to be the same, but got different values'
# FIXME: enable this test in next PR
@pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
def test_sharded_optim_with_sync_bn(): def test_sharded_optim_with_sync_bn():
""" """
......
...@@ -8,7 +8,6 @@ import pytest ...@@ -8,7 +8,6 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
...@@ -17,8 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan ...@@ -17,8 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_params_padding)
check_sharded_params_padding)
def run_dist(rank, world_size, port, parallel_config): def run_dist(rank, world_size, port, parallel_config):
...@@ -35,18 +33,19 @@ def run_dist(rank, world_size, port, parallel_config): ...@@ -35,18 +33,19 @@ def run_dist(rank, world_size, port, parallel_config):
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(convert_fp16=hasattr(gpc.config, 'fp16'), with ZeroInitContext(convert_fp16=hasattr(gpc.config, 'fp16'),
target_device=torch.cuda.current_device(), target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shared_strategy( shard_strategy=gpc.config.zero.model_config.shard_strategy,
gpc.get_group(ParallelMode.DATA)),
shard_param=True): shard_param=True):
colo_model = model_builder(checkpoint=True) colo_model = model_builder(checkpoint=True)
torch_model = model_builder(checkpoint=True).half() colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
col_model_deepcopy(colo_model, torch_model)
torch_model = torch_model.cuda().float()
engine, train_dataloader, _, _ = colossalai.initialize(colo_model, engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
optimizer=optimizer_class, optimizer=colo_optimizer,
criterion=criterion, criterion=criterion,
train_dataloader=train_dataloader) train_dataloader=train_dataloader)
torch_model = model_builder(checkpoint=True).half()
col_model_deepcopy(engine.model, torch_model)
torch_model = torch_model.cuda().float()
engine.train() engine.train()
torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3) torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
...@@ -102,7 +101,6 @@ def test_mp_engine(world_size): ...@@ -102,7 +101,6 @@ def test_mp_engine(world_size):
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
def test_zero_engine(world_size): def test_zero_engine(world_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment