Commit 9506a8be authored by ver217's avatar ver217
Browse files

use double buffer to handle grad

parent 0f5f5dd5
from typing import Optional
import torch import torch
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from ._base_ophook import BaseOpHook from ._base_ophook import BaseOpHook
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Optional
@OPHOOKS.register_module @OPHOOKS.register_module
...@@ -62,8 +64,8 @@ class ZeroHook(BaseOpHook): ...@@ -62,8 +64,8 @@ class ZeroHook(BaseOpHook):
if param.grad is not None: if param.grad is not None:
if param.col_attr.bwd_count == 0: if param.col_attr.bwd_count == 0:
# We haven't stored local accumulated grad yet # We haven't stored local accumulated grad yet
assert param.col_attr.grad is None assert param.col_attr.fp32_grad is None
param.col_attr.grad = param.grad.data param.col_attr.fp32_grad = param.grad.data
param.grad = None param.grad = None
else: else:
# We have stored local accumulated grad # We have stored local accumulated grad
......
import functools import functools
import torch import torch
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
# Inserts _post_init_method at the end of init method # Inserts _post_init_method at the end of init method
...@@ -154,6 +155,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): ...@@ -154,6 +155,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param: if self.shard_param:
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor]) self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload) GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
if param.col_attr.grad and self.shard_grad: # if param.col_attr.grad and self.shard_grad:
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor]) # self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload) # GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
from ast import Try
import functools import functools
from ast import Try
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional from typing import Any, Optional
...@@ -12,16 +12,17 @@ from colossalai.engine.ophooks import register_ophooks_recursively ...@@ -12,16 +12,17 @@ from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.engine.ophooks.zero_hook import ZeroHook from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.commons.memory import col_cuda_memory_capacity
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
get_gradient_predivide_factor) get_gradient_predivide_factor)
from colossalai.utils.commons.memory import col_cuda_memory_capacity
class ShardedModelV2(nn.Module): class ShardedModelV2(nn.Module):
...@@ -164,8 +165,15 @@ class ShardedModelV2(nn.Module): ...@@ -164,8 +165,15 @@ class ShardedModelV2(nn.Module):
# If world size == 1 and sharded param, # If world size == 1 and sharded param,
# the shape `grad` is the same as unsharded param # the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard # So we can just use `view(-1)` to ensure grad is a flat tensor shard
p.grad.data = p.col_attr.grad.view(-1) grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
p.col_attr.grad = None if self._cpu_offload:
col_move_to_cpu(grad)
if p.col_attr.fp32_grad is not None:
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
grad = p.col_attr.fp32_grad
p.grad.data = grad.view(-1)
p.col_attr.fp16_grad = None
p.col_attr.fp32_grad = None
@torch.no_grad() @torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
...@@ -216,23 +224,7 @@ class ShardedModelV2(nn.Module): ...@@ -216,23 +224,7 @@ class ShardedModelV2(nn.Module):
# Average grad by world_size for consistency with PyTorch DDP. # Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor) reduced_grad.data.div_(self.gradient_postdivide_factor)
# Make sure we store fp32 grad param.col_attr.fp16_grad = reduced_grad.data
reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data)
# Maybe offload
# TODO() optimize GPU->CPU bandwidth utilization
if self._cpu_offload:
col_move_to_cpu(reduced_grad)
# reduced_grad.data = reduced_grad.data.cpu()
if param.col_attr.grad is None:
param.col_attr.grad = reduced_grad.data
else:
# When dp size = 1
# param.col_attr.grad is local accumulated grad shard (full but flatten)
# But reduced_grad here is full grad
# We should call `view_as`
param.col_attr.grad.add_(reduced_grad.data.view_as(param.col_attr.grad))
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()]) self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()])
......
...@@ -16,7 +16,8 @@ class ShardedParamV2(object): ...@@ -16,7 +16,8 @@ class ShardedParamV2(object):
process_group: Optional[dist.ProcessGroup] = None, process_group: Optional[dist.ProcessGroup] = None,
rm_torch_payload=False) -> None: rm_torch_payload=False) -> None:
self._data_sharded_tensor: ShardedTensor = ShardedTensor(param.data, process_group) self._data_sharded_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
self._grad_sharded_tensor: Optional[torch.Tensor] = None self.fp16_grad: Optional[torch.Tensor] = None
self.fp32_grad: Optional[torch.Tensor] = None
# make sure the shared param is the only owner of payload # make sure the shared param is the only owner of payload
# The param.data maybe used to init the other part of the model. # The param.data maybe used to init the other part of the model.
...@@ -39,14 +40,6 @@ class ShardedParamV2(object): ...@@ -39,14 +40,6 @@ class ShardedParamV2(object):
def data(self): def data(self):
return self._data_sharded_tensor return self._data_sharded_tensor
@property
def grad(self):
return self._grad_sharded_tensor
@grad.setter
def grad(self, t: torch.Tensor):
self._grad_sharded_tensor = t
@property @property
def param_is_sharded(self): def param_is_sharded(self):
return self._data_sharded_tensor.is_sharded return self._data_sharded_tensor.is_sharded
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment