Unverified Commit 6e51d296 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[zero] migrate zero1&2 (#1878)

* add zero1&2 optimizer

* rename test ditectory

* rename test files

* change tolerance in test
parent cc55ff0a
......@@ -2,9 +2,11 @@ from typing import Tuple
import torch
import torch.nn as nn
from colossalai.logging import get_dist_logger
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2
from .zero_optimizer import ZeroOptimizer
......@@ -36,4 +38,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
return zero_model, zero_optimizer
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']
__all__ = ['convert_to_zero_v2', 'LowLevelZeroOptimizer', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']
from .low_level_optim import LowLevelZeroOptimizer
from .sharded_optim_v2 import ShardedOptimizerV2
__all__ = ['ShardedOptimizerV2']
__all__ = ['ShardedOptimizerV2', 'LowLevelZeroOptimizer']
from .bucket_store import BucketStore
from .gradient_store import GradientStore
from .parameter_store import ParameterStore
from .tensor_bucket import TensorBucket
__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket']
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
class BaseStore:
def __init__(self, dp_parallel_mode=ParallelMode.DATA):
self._world_size = gpc.get_world_size(dp_parallel_mode)
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
@property
def world_size(self):
return self._world_size
@property
def local_rank(self):
return self._local_rank
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from .base_store import BaseStore
class BucketStore(BaseStore):
def __init__(self, dp_parallel_mode):
super().__init__(dp_parallel_mode)
self._grads = dict()
self._params = dict()
self._num_elements_in_bucket = dict()
self.reset()
def num_elements_in_bucket(self, reduce_rank: int = None):
return self._num_elements_in_bucket[reduce_rank]
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
self._num_elements_in_bucket[reduce_rank] += num_elements
def add_grad(self, tensor, reduce_rank: int = None):
self._grads[reduce_rank].append(tensor)
def add_param(self, tensor, reduce_rank: int = None):
self._params[reduce_rank].append(tensor)
def reset(self):
keys = [None] + list(range(self._world_size))
self._grads = {rank: [] for rank in keys}
self._params = {rank: [] for rank in keys}
self._num_elements_in_bucket = {rank: 0 for rank in keys}
def reset_by_rank(self, reduce_rank=None):
self._grads[reduce_rank] = []
self._params[reduce_rank] = []
self._num_elements_in_bucket[reduce_rank] = 0
def get_grad(self, reduce_rank: int = None):
return self._grads[reduce_rank]
def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank]
from typing import List
from torch import Tensor
from .base_store import BaseStore
class GradientStore(BaseStore):
def __init__(self, *args):
super().__init__(*args)
# bookkeeping data structures
self._averaged_gradients = dict()
# for backward reduction hooks
self._grad_acc_objs = []
def add_accumulate_grad_object(self, obj):
"""
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
be attached successfully.
:param obj: An object of :class:`AccumulateGrad` class
:type obj: :class:`AccumulateGrad`
"""
self._grad_acc_objs.append(obj)
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
"""
Return average gradients of a parameter group
:param group_id: The index of parameter group
:type group_id: int
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
:rtype: List[torch.Tensor]
"""
return self._averaged_gradients[group_id]
def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
"""
Append an average gradient to the list of averaged gradients of a parameter group
:param group_id: The index of a parameter group
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor: torch.Tensor
"""
if group_id in self._averaged_gradients:
self._averaged_gradients[group_id].append(tensor)
else:
self._averaged_gradients[group_id] = [tensor]
def reset_average_gradients_by_group(self, group_id: int) -> None:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
:param group_id: The index of a parameter group
:type group_id: int
"""
self._averaged_gradients[group_id] = []
from typing import List
from torch import Tensor
from .base_store import BaseStore
class ParameterStore(BaseStore):
def __init__(self, dp_paralle_mode):
super().__init__(dp_paralle_mode)
# param partitioning data structures
self._fp16_param_to_rank = dict()
self._rank_groupid_to_fp16_param_list = dict()
self._rank_group_id_to_flat_fp16_param = dict()
# param reduction data structures
self._is_param_reduced = dict()
self._reduced_param = []
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
"""
Set the mapping between parameter to rank, each parameter should be owned by a rank.
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:param rank: The rank of which the process is responsible for updating the parameter
:type rank: int
"""
self._fp16_param_to_rank[tensor] = rank
def get_param_rank(self, tensor: Tensor) -> int:
"""
Gives the rank which the parameter belongs to
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
"""
return self._fp16_param_to_rank[tensor]
def belongs_to_current_rank(self, tensor) -> bool:
"""
Check whether a parameter is supposed to be updated by the process of the current rank
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:return: True if the parameter should be updated by the current rank. Otherwise false.
:rtype: bool
"""
tensor_rank = self._fp16_param_to_rank[tensor]
return tensor_rank == self._local_rank
def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
if rank not in self._rank_groupid_to_fp16_param_list:
self._rank_groupid_to_fp16_param_list[rank] = dict()
if group_id not in self._rank_groupid_to_fp16_param_list[rank]:
self._rank_groupid_to_fp16_param_list[rank][group_id] = []
self._rank_groupid_to_fp16_param_list[rank][group_id].extend(tensor_list)
def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
return self._rank_groupid_to_fp16_param_list[rank][group_id]
def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None:
if rank not in self._rank_group_id_to_flat_fp16_param:
self._rank_group_id_to_flat_fp16_param[rank] = dict()
self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor
def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor:
return self._rank_group_id_to_flat_fp16_param[rank][group_id]
def is_param_reduced(self, tensor):
return self._is_param_reduced[tensor]
def set_param_reduction_state(self, tensor, state):
self._is_param_reduced[tensor] = state
def get_param_reduction_states(self):
return self._is_param_reduced
def reset_previous_reduced_params(self):
self._reduced_param = []
def add_previous_reduced_param(self, tensor):
self._reduced_param.append(tensor)
def clear_grads_of_previous_reduced_params(self):
if len(self._reduced_param) > 0:
for param in self._reduced_param:
param.grad = None
self.reset_previous_reduced_params()
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class TensorBucket:
def __init__(self, size):
self._max_size = size
self._current_size = 0
self._bucket = []
@property
def max_size(self):
return self._max_size
@property
def current_size(self):
return self._current_size
def is_full_or_oversized(self):
return self._current_size >= self._max_size
def is_empty(self):
return len(self._bucket) == 0
def add_to_bucket(self, tensor, allow_oversize=False):
tensor_size = tensor.numel()
if not allow_oversize and self.will_exceed_max_size(tensor_size):
msg = f"The param bucket max size {self._max_size} is exceeded" \
+ f"by tensor (size {tensor_size})"
raise RuntimeError(msg)
self._bucket.append(tensor)
self._current_size += tensor_size
def will_exceed_max_size(self, tensor_size):
expected_size = self._current_size + tensor_size
return expected_size > self._max_size
def get_bucket(self):
return self._bucket
def empty(self):
self._bucket = []
self._size = 0
def flatten(self):
return _flatten_dense_tensors(self._bucket)
def unflatten_and_copy(self, flat_tensor):
unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)
This diff is collapsed.
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer
def check_equal(a, b):
"""
This function checks if two tensors are equal within tolerance
"""
assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}'
def check_completely_equal(a, b):
"""
This function checks if two tensors are completely equal
"""
assert torch.all(a == b), f'a = {a}, b = {b}'
def check_sharded_param_consistency():
"""
In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication
pattern
we use these prefixes to differentiate the zero stage
oss: partition optimizer states
pg: partition gradients and optimizer states
"""
# create layers
oss_linear1 = nn.Linear(128, 256)
oss_linear2 = nn.Linear(256, 512)
# create model
oss_model = nn.Sequential(oss_linear1, oss_linear2)
pg_model = copy.deepcopy(oss_model)
oss_model = oss_model.cuda().half()
pg_model = pg_model.cuda().half()
# create optimizer
oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001)
pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001)
oss_optimizer = LowLevelZeroOptimizer(oss_optimizer,
overlap_communication=True,
initial_scale=1,
clip_grad_norm=0.0)
pg_optimizer = LowLevelZeroOptimizer(pg_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=1,
clip_grad_norm=0.0)
# create
input_data = torch.rand(32, 128).cuda().half()
# forward
oss_output = oss_model(input_data)
pg_output = pg_model(input_data)
check_completely_equal(oss_output, pg_output)
# backward
oss_optimizer.backward(oss_output.mean().float())
pg_optimizer.backward(pg_output.mean().float())
# check grad
# as this param is small, the backward reduction
# will not be fired
oss_linear1_grad = oss_model[0].weight.grad
oss_linear2_grad = oss_model[1].weight.grad
pg_linear1_grad = pg_model[0].weight.grad
pg_linear2_grad = pg_model[1].weight.grad
check_completely_equal(oss_linear1_grad, pg_linear1_grad)
check_completely_equal(oss_linear2_grad, pg_linear2_grad)
# step
oss_optimizer.sync_grad()
pg_optimizer.sync_grad()
# step
oss_optimizer.step()
pg_optimizer.step()
# check updated param
check_completely_equal(oss_model[0].weight, pg_model[0].weight)
check_completely_equal(oss_model[1].weight, pg_model[1].weight)
def check_sharded_optim_against_torch_ddp():
"""
In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters
2. torch: use torch DDP and fp32 parameters
We feed these two sets of models with the same input and check if the
differences in model output and updated parameters are within tolerance.
"""
# create layer
zero_linear1 = nn.Linear(128, 256)
zero_linear2 = nn.Linear(256, 512)
# create model
zero_model = nn.Sequential(zero_linear1, zero_linear2)
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda().half()
torch_model = DDP(torch_model.cuda())
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
clip_grad_norm=0.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
# create
input_data = torch.rand(32, 128).cuda()
# zero-dp forward
zero_output = zero_model(input_data.half())
# torch-ddp forward
torch_output = torch_model(input_data)
check_equal(zero_output, torch_output)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float())
# torch-ddp backward
torch_output.mean().backward()
# check grad
zero_linear1_grad = zero_model[0].weight.grad
zero_linear2_grad = zero_model[1].weight.grad
torch_linear1_grad = torch_model.module[0].weight.grad
torch_linear2_grad = torch_model.module[1].weight.grad
check_equal(zero_linear1_grad, torch_linear1_grad)
check_equal(zero_linear2_grad, torch_linear2_grad)
# zero-dp step
zero_optimizer.sync_grad()
zero_optimizer.step()
# torch ddp step
torch_optimizer.step()
# check updated param
check_equal(zero_model[0].weight, torch_model.module[0].weight)
check_equal(zero_model[1].weight, torch_model.module[1].weight)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_sharded_optim_against_torch_ddp()
check_sharded_param_consistency()
@pytest.mark.dist
def test_sharded_optim():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment