".github/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "03e52ecba3b60b04b552d82809043e5642509005"
Unverified Commit d565a248 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[zero] add unit testings for hybrid parallelism (#2486)

parent fcc6d61d
...@@ -7,7 +7,6 @@ class BucketStore(BaseStore): ...@@ -7,7 +7,6 @@ class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup): def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg) super().__init__(torch_pg)
self._grads = dict()
self._params = dict() self._params = dict()
self._num_elements_in_bucket = dict() self._num_elements_in_bucket = dict()
...@@ -19,25 +18,24 @@ class BucketStore(BaseStore): ...@@ -19,25 +18,24 @@ class BucketStore(BaseStore):
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
self._num_elements_in_bucket[reduce_rank] += num_elements self._num_elements_in_bucket[reduce_rank] += num_elements
def add_grad(self, tensor, reduce_rank: int = None):
self._grads[reduce_rank].append(tensor)
def add_param(self, tensor, reduce_rank: int = None): def add_param(self, tensor, reduce_rank: int = None):
self._params[reduce_rank].append(tensor) self._params[reduce_rank].append(tensor)
def reset(self): def reset(self):
keys = [None] + list(range(self._world_size)) keys = [None] + list(range(self._world_size))
self._grads = {rank: [] for rank in keys}
self._params = {rank: [] for rank in keys} self._params = {rank: [] for rank in keys}
self._num_elements_in_bucket = {rank: 0 for rank in keys} self._num_elements_in_bucket = {rank: 0 for rank in keys}
def reset_by_rank(self, reduce_rank=None): def reset_by_rank(self, reduce_rank=None):
self._grads[reduce_rank] = []
self._params[reduce_rank] = [] self._params[reduce_rank] = []
self._num_elements_in_bucket[reduce_rank] = 0 self._num_elements_in_bucket[reduce_rank] = 0
def get_grad(self, reduce_rank: int = None): def get_grad(self, reduce_rank: int = None):
return self._grads[reduce_rank] param_list = self.get_param(reduce_rank)
for param in param_list:
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
return [param.grad for param in param_list]
def get_param(self, reduce_rank: int = None): def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank] return self._params[reduce_rank]
...@@ -46,7 +46,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -46,7 +46,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
reduce_bucket_size: int = 1024 * 1024, # communication reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None, communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False, overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None): forced_dtype: Optional[torch.dtype] = None):
...@@ -248,9 +248,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -248,9 +248,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0]) self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
return params_per_rank return params_per_rank
########################################################### ###########################
# Backward Reduction Hook # Backward Reduction Hook #
########################################################### ###########################
def _grad_handler(self, param, grad, reduce_rank):
self._add_to_reduction_bucket(param, reduce_rank)
return grad
def _attach_reduction_hook(self): def _attach_reduction_hook(self):
# we iterate over the fp16 params # we iterate over the fp16 params
...@@ -268,53 +272,61 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -268,53 +272,61 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
else: else:
reduce_rank = None reduce_rank = None
def _define_and_attach(param, reduce_rank): param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
# get the AccumulateGrad object of the param itself
accum_grad_obj = get_grad_accumulate_object(param)
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
reduction_func = partial(self._reduce_and_remove_grads_by_bucket, def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
param=param, if self._overlap_communication:
reduce_rank=reduce_rank) torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
# define hook with torch.cuda.stream(stream):
# NOT IMPORTANT BUT GOOD TO KNOW: flat = bucket.flatten()
# args here is not grad, but allow_unreacable and accumulate_grad reduce_global_rank = None
def reduce_grad_hook(*args): if reduce_rank is not None:
reduction_func() reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype,
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)
accum_grad_obj.register_hook(reduce_grad_hook) # update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)
_define_and_attach(param, reduce_rank) def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None): for tensor in tensor_list:
param_size = param.numel() param_bucket.add_to_bucket(tensor, allow_oversize=True)
# check if the bucket is full if param_bucket.is_full_or_oversized():
# if full, will reduce the grads already in the bucket self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
# after reduction, the bucket will be empty param_bucket.empty()
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_in_bucket(reduce_rank)
# the param must not be reduced to ensure correctness if not param_bucket.is_empty():
is_param_reduced = self._param_store.is_param_reduced(param) self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
# the param must have grad for reduction def _reduce_grads(self, reduce_rank, grads, bucket_size):
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' grad_buckets_by_dtype = split_half_float_double(grads)
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) for tensor_list in grad_buckets_by_dtype:
self._bucket_store.add_grad(param.grad, reduce_rank) self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
self._bucket_store.add_param(param, reduce_rank) bucket_size=bucket_size,
reduce_rank=reduce_rank)
#######################
# Reduction Functions #
#######################
def _reduce_grads_in_bucket(self, reduce_rank=None): def _run_reduction(self, reduce_rank=None):
# reduce grads # reduce grads
self._reduce_grads_by_rank(reduce_rank=reduce_rank, self._reduce_grads(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
# use communication stream if overlapping # use communication stream if overlapping
# communication with computation # communication with computation
...@@ -351,50 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -351,50 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._bucket_store.reset_by_rank(reduce_rank) self._bucket_store.reset_by_rank(reduce_rank)
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): def _add_to_reduction_bucket(self, param, reduce_rank=None):
grad_buckets_by_dtype = split_half_float_double(grads) param_size = param.numel()
for tensor_list in grad_buckets_by_dtype:
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
##############################
# Reduction Utility Function #
##############################
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if param_bucket.is_full_or_oversized():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
if not param_bucket.is_empty():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): # check if the bucket is full
if self._overlap_communication: # if full, will reduce the grads already in the bucket
torch.cuda.synchronize() # after reduction, the bucket will be empty
self._param_store.clear_grads_of_previous_reduced_params() if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
stream = self._comm_stream self._run_reduction(reduce_rank)
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream): # the param must not be reduced to ensure correctness
flat = bucket.flatten() is_param_reduced = self._param_store.is_param_reduced(param)
reduce_global_rank = None if is_param_reduced:
if reduce_rank is not None: msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
reduce_global_rank = self._dp_global_ranks[reduce_rank] + 'duplicate reduction will lead to arithmetic incorrectness'
reduced_flat = reduce_tensor_dp_group(tensor=flat, raise RuntimeError(msg)
dtype=self._communication_dtype,
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)
# update the reduced tensor self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
if reduce_rank is None or reduce_rank == self._local_rank: self._bucket_store.add_param(param, reduce_rank)
bucket.unflatten_and_copy(reduced_flat)
################################ ################################
# torch.optim.Optimizer methods # torch.optim.Optimizer methods
...@@ -498,8 +484,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -498,8 +484,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# broadcast the updated model weights # broadcast the updated model weights
handles = [] handles = []
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
for rank in range(self._world_size): for index in range(self._world_size):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) rank = self._dp_global_ranks[index]
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True) handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle) handles.append(handle)
...@@ -585,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -585,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
param_group = self._fp16_param_groups[group_id] param_group = self._fp16_param_groups[group_id]
for param in param_group: for param in param_group:
if param.grad is not None: if param.grad is not None:
self._reduce_and_remove_grads_by_bucket(param) self._add_to_reduction_bucket(param)
# we need to reduce the gradients # we need to reduce the gradients
# left in the communication bucket # left in the communication bucket
self._reduce_grads_in_bucket() self._run_reduction()
def _reduce_grad_stage2(self): def _reduce_grad_stage2(self):
# when partition_grads is True, reduction hooks # when partition_grads is True, reduction hooks
...@@ -597,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -597,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# only need to reduce the gradients # only need to reduce the gradients
# left in the communication bucket # left in the communication bucket
for reduce_rank in range(self._world_size): for reduce_rank in range(self._world_size):
self._reduce_grads_in_bucket(reduce_rank) self._run_reduction(reduce_rank)
...@@ -4,6 +4,7 @@ import random ...@@ -4,6 +4,7 @@ import random
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.testing import assert_close
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
...@@ -41,14 +42,20 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0): ...@@ -41,14 +42,20 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
return tensor_chunk.clone() return tensor_chunk.clone()
def tensor_equal(A, B): def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: float = 1e-1):
return torch.allclose(A, B, rtol=1e-3, atol=1e-1) assert_close(t_a, t_b, rtol=rtol, atol=atol)
return True
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_size): def tensor_shard_equal(tensor: torch.Tensor,
shard: torch.Tensor,
rank: int,
world_size: int,
rtol: float = 1e-3,
atol: float = 1e-1):
assert tensor.ndim == shard.ndim assert tensor.ndim == shard.ndim
if tensor.shape == shard.shape: if tensor.shape == shard.shape:
return tensor_equal(tensor, shard) return tensor_equal(tensor, shard, rtol, atol)
else: else:
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
if dims_not_eq.numel() == 1: if dims_not_eq.numel() == 1:
...@@ -58,7 +65,7 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si ...@@ -58,7 +65,7 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
if rank is None: if rank is None:
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) return tensor_equal(tensor.chunk(world_size, dim)[rank], shard, rtol, atol)
else: else:
raise NotImplementedError raise NotImplementedError
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import LowLevelZeroOptimizer
from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal
def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4):
return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol)
class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.linear1 = nn.Linear(32, 128)
self.act = nn.GELU()
self.linear2 = nn.Linear(128, 32)
def forward(self, x):
y = self.linear1(x)
y = self.act(y)
y = self.linear2(y)
return x + y
@parameterize("overlap_flag", [False, True])
@parameterize("partition_flag", [False, True])
def exam_zero_with_tp(overlap_flag, partition_flag):
set_seed(233010)
tp_pg = ProcessGroup(tp_degree=2)
with ColoInitContext(device=get_current_device(), default_pg=tp_pg):
hybrid_model = TestModel()
torch_model = TestModel().cuda()
for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()):
pt.data.copy_(ph.data)
for name, param in hybrid_model.named_parameters():
if 'linear1' in name:
split_param_row_tp1d(param, tp_pg)
param.compute_spec.set_output_replicate(False)
if 'linear2.weight' in name:
split_param_col_tp1d(param, tp_pg)
torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group())
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1)
hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1)
hybrid_optim = LowLevelZeroOptimizer(hybrid_optim,
initial_scale=1,
overlap_communication=overlap_flag,
partition_grad=partition_flag)
dp_local_rank = tp_pg.dp_local_rank()
set_seed(255 + dp_local_rank)
data = torch.randn(8, 32, device=get_current_device())
torch_loss = torch_model(data).sum()
hybrid_loss = hybrid_model(data).sum()
assert_close(torch_loss, hybrid_loss)
torch_loss.backward()
hybrid_optim.backward(hybrid_loss)
hybrid_optim.sync_grad()
torch_optim.step()
hybrid_optim.step()
for (name, pt), ph in zip(torch_model.named_parameters(), hybrid_model.parameters()):
assert strict_shard_equal(pt.data, ph.data, tp_pg)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_with_tp()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_with_tp():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_with_tp()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment