Commit c6ab9698 authored by LuGY's avatar LuGY Committed by Hongxin Liu
Browse files

[zero] refactor low level zero for shard evenly (#4030)

* refactor low level zero

* fix zero2 and support cpu offload

* avg gradient and modify unit test

* refactor grad store, support layer drop

* refactor bucket store, support grad accumulation

* fix and update unit test of zero and ddp

* compatible with tp, ga and unit test

* fix memory leak and polish

* add zero layer drop unittest

* polish code

* fix import err in unit test

* support diffenert comm dtype, modify docstring style

* polish code

* test padding and fix

* fix unit test of low level zero

* fix pad recording in bucket store

* support some models

* polish
parent 5187c96b
...@@ -253,7 +253,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): ...@@ -253,7 +253,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
return total_norm return total_norm
def sync_param(flat_tensor, tensor_list): def sync_tensor(flat_tensor, tensor_list):
""" """
Synchronize the flattened tensor and unflattened tensor list. When Synchronize the flattened tensor and unflattened tensor list. When
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`, a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
......
from typing import Dict
import torch
from torch import Tensor
from torch._utils import _flatten_dense_tensors
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from .base_store import BaseStore from .base_store import BaseStore
...@@ -7,35 +12,102 @@ class BucketStore(BaseStore): ...@@ -7,35 +12,102 @@ class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup): def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg) super().__init__(torch_pg)
self._params = dict()
self._num_elements_in_bucket = dict() # init and reset
self.current_group_id = 0
# mapping gardient slices and parameter
self.grad_to_param_mapping = dict()
self._param_list = []
self._padding_size = []
self.reset() self.reset()
def num_elements_in_bucket(self, reduce_rank: int = None): def num_elements_in_bucket(self) -> int:
return self._num_elements_in_bucket[reduce_rank] """Return the total number of elements in bucket
Returns:
int: the total number of elements in bucket
"""
return self._num_elements_in_bucket
def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
"""Add a param to bucket and record the padding size of a param for gradient padding
Args:
group_id (int): The index of a parameter group
param (Tensor): The parameter
padding_size (int): The padding size of the parameter
"""
self._param_list.append(param)
self._padding_size.append(padding_size)
self._num_elements_in_bucket += (param.numel() + padding_size)
self.current_group_id = group_id
def build_grad_in_bucket(self):
"""Orgnize parameters' gradient(padding and split), follows the paramters' splitting method
Data structure of self._grad_in_bucket:
{
rank0: [grad0_rank0, grad1_rank0, ...]
rank1: [grad1_rank1, grad1_rank1, ...]
}
"""
for param, padding_size in zip(self._param_list, self._padding_size):
with torch.no_grad():
grad = param.grad.detach().flatten()
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
grad_list = grad.split(grad.numel() // self._world_size)
for rank in range(self._world_size):
grad_current_rank = grad_list[rank].detach()
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
self._grad_in_bucket[rank].append(grad_current_rank)
param.grad = None
def get_grad(self) -> Dict:
"""Return the dictionary of gradients slices, of which the keys are ranks
Returns:
Dict: The dictionary of gradients slices
"""
return self._grad_in_bucket
def get_flatten_grad(self) -> Tensor:
"""Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor:
[grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....]
Returns:
Tensor: the flattened gradients slices in the bucket
"""
flat_grad = []
for grad_list in self._grad_in_bucket.values():
flat_grad.append(_flatten_dense_tensors(grad_list))
flat_grad = _flatten_dense_tensors(flat_grad)
return flat_grad
def get_param_id_of_grad(self, grad: Tensor) -> int:
"""Return the id of a parameter which the gradient slice belongs to
Args:
grad (Tensor): the gradient slice
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): Returns:
self._num_elements_in_bucket[reduce_rank] += num_elements int: the id of a parameter which the gradient slice belongs to
"""
def add_param(self, tensor, reduce_rank: int = None): return self.grad_to_param_mapping[id(grad)]
self._params[reduce_rank].append(tensor)
def reset(self): def reset(self):
keys = [None] + list(range(self._world_size)) self.grad_to_param_mapping = dict()
self._params = {rank: [] for rank in keys} self._num_elements_in_bucket = 0
self._num_elements_in_bucket = {rank: 0 for rank in keys} self._param_list = []
self._padding_size = []
def reset_by_rank(self, reduce_rank=None): self._grad_in_bucket = dict()
self._params[reduce_rank] = [] for rank in range(self._world_size):
self._num_elements_in_bucket[reduce_rank] = 0 self._grad_in_bucket[rank] = []
def get_grad(self, reduce_rank: int = None):
param_list = self.get_param(reduce_rank)
for param in param_list:
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
return [param.grad for param in param_list]
def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank]
from typing import List from typing import List
from torch import Tensor from torch import Tensor
from torch._utils import _flatten_dense_tensors
from .base_store import BaseStore from .base_store import BaseStore
class GradientStore(BaseStore): class GradientStore(BaseStore):
def __init__(self, *args): def __init__(self, *args, partition_grad: bool = False):
super().__init__(*args) super().__init__(*args)
# bookkeeping data structures
self._averaged_gradients = dict()
# for backward reduction hooks
self._grad_acc_objs = []
def append_accumulate_grad_object(self, obj):
""" """
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not self._grads_of_params mapping the paramater and its gradient slices
be attached successfully. data structure:
{
:param obj: An object of :class:`AccumulateGrad` class group_id:{
:type obj: :class:`AccumulateGrad` param_id: [grad_rank0, grad_rank1, ...]
}
}
""" """
self._grads_of_params = dict()
# for zero2, it's `param_id: [grad_local_rank]`
self._working_index = 0 if partition_grad else self._local_rank
self._grad_acc_objs.append(obj) def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
"""Return list of gradient slices of a specific parameter
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]: Args:
""" group_id (int): The index of a parameter group
Return average gradients of a parameter group param_id (int): The id of a parameter
:param group_id: The index of parameter group Returns:
:type group_id: int List: the list of gradient slices of a parameter.
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
:rtype: List[torch.Tensor]
""" """
if group_id not in self._averaged_gradients:
self._averaged_gradients[group_id] = []
return self._averaged_gradients[group_id]
def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: if group_id in self._grads_of_params:
""" if param_id in self._grads_of_params[group_id]:
Append an average gradient to the list of averaged gradients of a parameter group return self._grads_of_params[group_id][param_id]
# the param has no grad, for instance, in layer drop
return []
:param group_id: The index of a parameter group def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int):
:param tensor: A :class:`torch.Tensor` object """Append a gradient slice to the parameter's gradient slice list
:type group_id: int
:type tensor: torch.Tensor
Args:
grad (Tensor): The gradient slice to append to list
group_id (int): The index of a parameter group
param_id (int): The id of a parameter
""" """
if group_id in self._averaged_gradients: if group_id not in self._grads_of_params:
self._averaged_gradients[group_id].append(tensor) self._grads_of_params[group_id] = dict()
if param_id not in self._grads_of_params[group_id]:
self._grads_of_params[group_id][param_id] = [grad]
else: else:
self._averaged_gradients[group_id] = [tensor] self._grads_of_params[group_id][param_id].append(grad)
def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None: def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
"""For old gradient accumulation, not in use now.
Add a gradient slice on an existing slice of the parameter's gradient
Args:
grad (Tensor): The split gradient to append to list
grad_idx (int): The index of the existing slice
group_id (int): The index of a parameter group
param_id (int): The id of a parameter
""" """
Add an average gradient to the list of averaged gradients of a parameter group
:param group_id: The index of a parameter group self._grads_of_params[group_id][param_id][grad_idx].add_(grad)
:param tensor_idx: The index of a tensor in the list of averaged gradients
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor_idx: int
:type tensor: torch.Tensor
""" def get_working_grads_by_group_id(self, group_id: int) -> List:
self._averaged_gradients[group_id][tensor_idx].add_(tensor) """Return list of working gradient slices in the group
def reset_average_gradients_by_group(self, group_id: int) -> None: Args:
""" group_id (int): The index of a parameter group
Reset the bookkeeping data structure for averaged gradients to an empty list
:param group_id: The index of a parameter group Returns:
:type group_id: int List: the list working gradient slices in the group
""" """
self._averaged_gradients[group_id] = [] grad_list = []
for param_grads in self._grads_of_params[group_id].values():
grad_list.append(param_grads[self._working_index])
def reset_all_average_gradients(self) -> None: return grad_list
"""
Reset the bookkeeping data structure for averaged gradients to an empty list def reset_grads_by_group_id(self, group_id: int):
""" self._grads_of_params[group_id] = dict()
self._averaged_gradients = dict()
def reset_all_gradients(self):
self._grads_of_params = dict()
from typing import List
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -10,88 +8,43 @@ class ParameterStore(BaseStore): ...@@ -10,88 +8,43 @@ class ParameterStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup): def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg) super().__init__(torch_pg)
# param partitioning data structures
self._param_to_rank = dict()
self._rank_group_id_to_param_list = dict()
self._rank_group_id_to_flat_param = dict()
# param reduction data structures # record the padding size of each param
self._is_param_reduced = dict() self._padding_map = dict()
self._reduced_param = []
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None: # mapping working param and master param
""" self.master_to_working_param = dict()
Set the mapping between parameter to rank, each parameter should be owned by a rank. self.working_to_master_param = dict()
:param tensor: A :class:`torch.Tensor` object def record_param_padding_size(self, param: Tensor, padding_size: int):
:type tensor: torch.Tensor """Record the padding size of a param
:param rank: The rank of which the process is responsible for updating the parameter
:type rank: int
"""
self._param_to_rank[tensor] = rank Args:
param (Tensor): The parameter
def get_param_rank(self, tensor: Tensor) -> int: padding_size (int): The padding size of the parameter
""" """
Gives the rank which the parameter belongs to
:param tensor: A :class:`torch.Tensor` object self._padding_map[id(param)] = padding_size
:type tensor: torch.Tensor
"""
return self._param_to_rank[tensor]
def belongs_to_current_rank(self, tensor) -> bool: def get_param_padding_size(self, param: Tensor) -> int:
""" """Return the padding size of the parameter
Check whether a parameter is supposed to be updated by the process of the current rank
:param tensor: A :class:`torch.Tensor` object Args:
:type tensor: torch.Tensor param (Tensor): The parameter
:return: True if the parameter should be updated by the current rank. Otherwise false. Returns:
:rtype: bool int: the padding size of the parameter
""" """
tensor_rank = self._param_to_rank[tensor] return self._padding_map[id(param)]
return tensor_rank == self._local_rank
def add_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
if rank not in self._rank_group_id_to_param_list:
self._rank_group_id_to_param_list[rank] = dict()
if group_id not in self._rank_group_id_to_param_list[rank]:
self._rank_group_id_to_param_list[rank][group_id] = []
self._rank_group_id_to_param_list[rank][group_id].extend(tensor_list)
def get_params_by_rank_group(self, rank, group_id) -> List[Tensor]: def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor):
return self._rank_group_id_to_param_list[rank][group_id] """Mapping master parameter and working parameter
def add_flat_param_by_rank_group(self, rank, group_id, tensor) -> None: Args:
if rank not in self._rank_group_id_to_flat_param: master_param (Tensor): The parameter copy in optimizer
self._rank_group_id_to_flat_param[rank] = dict() working_param (Tensor): The parameter of the model
"""
self._rank_group_id_to_flat_param[rank][group_id] = tensor
def get_flat_param_by_rank_group(self, rank, group_id) -> Tensor:
return self._rank_group_id_to_flat_param[rank][group_id]
def is_param_reduced(self, tensor):
return self._is_param_reduced[tensor]
def set_param_reduction_state(self, tensor, state):
self._is_param_reduced[tensor] = state
def get_param_reduction_states(self):
return self._is_param_reduced
def reset_previous_reduced_params(self):
self._reduced_param = []
def add_previous_reduced_param(self, tensor):
self._reduced_param.append(tensor)
def clear_grads_of_previous_reduced_params(self): self.master_to_working_param[id(master_param)] = working_param
if len(self._reduced_param) > 0: self.working_to_master_param[id(working_param)] = master_param
for param in self._reduced_param:
param.grad = None
self.reset_previous_reduced_params()
...@@ -11,14 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn ...@@ -11,14 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
# These models are not compatible with AMP # These models are not compatible with AMP
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn'] _AMP_ERR_MODELS = ['timm_convit', 'deepfm_interactionarch']
# These models have no parameters # These models have no parameters
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch'] _LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch']
# These models will get stuck
_STUCK_MODELS = [
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
'transformers_bert_for_pretraining', 'transformers_gpt_double_heads'
]
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
...@@ -58,7 +53,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): ...@@ -58,7 +53,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
""" """
passed_models = [] passed_models = []
failed_info = {} # (model_name, error) pair failed_info = {} # (model_name, error) pair
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS
skipped_models = [] skipped_models = []
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
......
...@@ -39,37 +39,37 @@ def exam_zero_1_2_grad_acc(): ...@@ -39,37 +39,37 @@ def exam_zero_1_2_grad_acc():
overlap_communication=True, overlap_communication=True,
initial_scale=32, initial_scale=32,
clip_grad_norm=1.0, clip_grad_norm=1.0,
grad_accumulate_interval=2,
verbose=True) verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True, overlap_communication=True,
partition_grad=True, partition_grad=True,
initial_scale=32, initial_scale=32,
clip_grad_norm=1.0) clip_grad_norm=1.0,
grad_accumulate_interval=2)
# create data # create data
seed_all(2021 + local_rank) seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda() input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda() input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data): def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward # zero-dp forward
zero1_output = zero1_model(cur_data) zero1_output = zero1_model(cur_data)
zero2_output = zero2_model(cur_data) zero2_output = zero2_model(cur_data)
assert torch.equal(zero1_output, zero2_output) assert torch.equal(zero1_output, zero2_output)
# zero-dp backward # zero-dp backward
zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False) zero1_optimizer.backward(zero1_output.sum().float())
zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False) zero2_optimizer.backward(zero2_output.sum().float())
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): if check_flag:
if z2p.grad is not None: for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) if z2p.grad is not None:
assert torch.equal(z1p.grad, z2p.grad) # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
zero1_optimizer._sync_grad()
zero2_optimizer._sync_grad()
fwd_bwd_func(0, input_data1) fwd_bwd_func(0, input_data1, True)
fwd_bwd_func(1, input_data2) fwd_bwd_func(1, input_data2, False)
# step # step
zero1_optimizer.step() zero1_optimizer.step()
...@@ -101,7 +101,8 @@ def exam_zero_1_grad_acc(): ...@@ -101,7 +101,8 @@ def exam_zero_1_grad_acc():
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False, overlap_communication=False,
reduce_bucket_size=262144, reduce_bucket_size=262144,
clip_grad_norm=1.0) clip_grad_norm=1.0,
grad_accumulate_interval=2)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
...@@ -115,13 +116,19 @@ def exam_zero_1_grad_acc(): ...@@ -115,13 +116,19 @@ def exam_zero_1_grad_acc():
zero_output = zero_model(cur_data) zero_output = zero_model(cur_data)
# torch-ddp forward # torch-ddp forward
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
# zero-dp backward # zero-dp backward
zero_optimizer.backward(zero_output.sum().float(), sync_grad=False) zero_optimizer.backward(zero_output.sum().float())
# torch-ddp backward # torch-ddp backward
torch_output.sum().backward() if number < 1:
with torch_model.no_sync():
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()
else:
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()
if check_flag: if check_flag:
# check grad # check grad
...@@ -129,8 +136,6 @@ def exam_zero_1_grad_acc(): ...@@ -129,8 +136,6 @@ def exam_zero_1_grad_acc():
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, z1p.grad) assert torch.equal(p.grad, z1p.grad)
zero_optimizer._sync_grad()
fwd_bwd_func(0, input_data1, True) fwd_bwd_func(0, input_data1, True)
fwd_bwd_func(1, input_data2, False) fwd_bwd_func(1, input_data2, False)
...@@ -148,7 +153,8 @@ def run_dist(rank, world_size, port): ...@@ -148,7 +153,8 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_grad_acc() exam_zero_1_grad_acc()
exam_zero_1_2_grad_acc() # gradient accumulation is not compatible with ZeRO-2
# exam_zero_1_2_grad_acc()
@pytest.mark.dist @pytest.mark.dist
......
...@@ -2,6 +2,7 @@ import copy ...@@ -2,6 +2,7 @@ import copy
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
...@@ -16,8 +17,9 @@ class MlpModel(nn.Module): ...@@ -16,8 +17,9 @@ class MlpModel(nn.Module):
def __init__(self): def __init__(self):
super(MlpModel, self).__init__() super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256) self.linear1 = nn.Linear(123, 253)
self.linear2 = nn.Linear(256, 512) self.linear_drop = nn.Linear(253, 253)
self.linear2 = nn.Linear(253, 512)
def forward(self, x): def forward(self, x):
x = self.linear1(x) x = self.linear1(x)
...@@ -41,6 +43,16 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32): ...@@ -41,6 +43,16 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
assert_close(a, b, rtol=rtol, atol=atol) assert_close(a, b, rtol=rtol, atol=atol)
def split_ddp_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def exam_zero_1_2(): def exam_zero_1_2():
""" """
In this test, we want to test whether zero stage 1 and 2 In this test, we want to test whether zero stage 1 and 2
...@@ -72,23 +84,21 @@ def exam_zero_1_2(): ...@@ -72,23 +84,21 @@ def exam_zero_1_2():
initial_scale=128) initial_scale=128)
# create data # create data
seed_all(2001 + local_rank) seed_all(2001 + local_rank)
input_data = torch.randn(32, 128).cuda() input_data = torch.randn(32, 123).cuda()
zero1_output = zero1_model(input_data) zero1_output = zero1_model(input_data)
zero2_output = zero2_model(input_data) zero2_output = zero2_model(input_data)
assert torch.equal(zero1_output, zero2_output) assert torch.equal(zero1_output, zero2_output)
# zero-dp backward # zero-dp backward
zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False) zero1_optimizer.backward(zero1_output.mean().float())
zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False) zero2_optimizer.backward(zero2_output.mean().float())
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): # check grad
if z2p.grad is not None: z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0)
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0)
assert torch.equal(z1p.grad, z2p.grad) for z1g, z2g in zip(z1g_list, z2g_list):
assert torch.equal(z1g, z2g)
zero1_optimizer._sync_grad()
zero2_optimizer._sync_grad()
# step # step
zero1_optimizer.step() zero1_optimizer.step()
...@@ -100,7 +110,7 @@ def exam_zero_1_2(): ...@@ -100,7 +110,7 @@ def exam_zero_1_2():
@parameterize('dtype', [torch.float16, torch.bfloat16]) @parameterize('dtype', [torch.float16, torch.bfloat16])
def exam_zero_1_torch_ddp(dtype: torch.dtype): def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype):
""" """
In this test, two pairs of model and optimizers are created. In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters 1. zero: use sharded optimizer and fp16 parameters
...@@ -116,7 +126,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype): ...@@ -116,7 +126,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
torch_model = MlpModel().cuda() torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model).to(dtype) zero_model = copy.deepcopy(torch_model).to(dtype)
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda() torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
# create optimizer # create optimizer
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
...@@ -133,7 +143,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype): ...@@ -133,7 +143,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
seed_all(1453 + local_rank) seed_all(1453 + local_rank)
# create # create
input_data = torch.rand(32, 128).cuda() input_data = torch.rand(32, 123).cuda()
# zero-dp forward # zero-dp forward
zero_output = zero_model(input_data.to(dtype)) zero_output = zero_model(input_data.to(dtype))
...@@ -143,17 +153,20 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype): ...@@ -143,17 +153,20 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
loose_close(zero_output, torch_output, dtype=dtype) loose_close(zero_output, torch_output, dtype=dtype)
# zero-dp backward # zero-dp backward
zero_optimizer.backward(zero_output.mean().float(), sync_grad=False) zero_optimizer.backward(zero_output.mean().float())
# torch-ddp backward # torch-ddp backward
torch_output.mean().backward() torch_output.mean().backward()
# check grad # check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
loose_close(p.grad, z1p.grad, dtype=dtype) if p.grad is not None:
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p))
torch_grad_list = split_ddp_grad(p.grad, world_size)
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
loose_close(zero_grad, torch_grad, dtype=dtype)
# zero-dp step # zero-dp step
zero_optimizer._sync_grad()
zero_optimizer.step() zero_optimizer.step()
# torch ddp step # torch ddp step
...@@ -161,14 +174,13 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype): ...@@ -161,14 +174,13 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
# check updated param # check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, torch.max(torch.abs(p.data - z1p.data)))
loose_close(p.data, z1p.data, dtype=dtype) loose_close(p.data, z1p.data, dtype=dtype)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_torch_ddp() exam_zero_1_torch_ddp(world_size=world_size)
exam_zero_1_2() exam_zero_1_2()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment