Commit 795210dd authored by ver217's avatar ver217 Committed by Frank Lee
Browse files

add fp32 master params in sharded adam

parent a109225b
from enum import Enum from enum import Enum
from typing import Optional, Union from typing import Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -11,6 +11,7 @@ from colossalai.nn.optimizer import ColossalaiOptimizer ...@@ -11,6 +11,7 @@ from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from ._utils import has_inf_or_nan from ._utils import has_inf_or_nan
...@@ -39,7 +40,7 @@ class ShardedAdam(ColossalaiOptimizer): ...@@ -39,7 +40,7 @@ class ShardedAdam(ColossalaiOptimizer):
super().__init__(adam_optim) super().__init__(adam_optim)
self.model: Union[nn.Module, ShardedModelV2] = sharded_model self.model: Union[nn.Module, ShardedModelV2] = sharded_model
self.model_is_sharded = isinstance(sharded_model, ShardedModelV2) self.model_is_sharded = isinstance(sharded_model, ShardedModelV2)
self.state_device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu') self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
self.optim_state: OptimState = OptimState.UNSCALED self.optim_state: OptimState = OptimState.UNSCALED
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL) self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
...@@ -51,35 +52,18 @@ class ShardedAdam(ColossalaiOptimizer): ...@@ -51,35 +52,18 @@ class ShardedAdam(ColossalaiOptimizer):
growth_interval=growth_interval, growth_interval=growth_interval,
hysteresis=hysteresis, hysteresis=hysteresis,
max_scale=max_scale) max_scale=max_scale)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.state_device) self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.device)
# Store fp32 params
self.master_params: Dict[Parameter, Tensor] = {}
# Early state initialization
for group in adam_optim.param_groups: for group in adam_optim.param_groups:
for p in group['params']: for p in group['params']:
state_shape = p.shape
if hasattr(p, 'ca_attr'): if hasattr(p, 'ca_attr'):
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model' assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
# TODO: use payload shape self.master_params[p] = p.ca_attr.payload(self.device).to(torch.float)
state_shape = p.ca_attr.payload(self.state_device) else:
state = adam_optim.state[p] self.master_params[p] = p.data.to(torch.float)
assert len(state) == 0, 'adam optimizer initialized'
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros(state_shape,
memory_format=torch.preserve_format,
dtype=torch.float,
device=self.state_device)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros(state_shape,
memory_format=torch.preserve_format,
dtype=torch.float,
device=self.state_device)
if group['amsgrad']:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros(state_shape,
memory_format=torch.preserve_format,
dtype=torch.float,
device=self.state_device)
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
# unscale grads if scaled # unscale grads if scaled
...@@ -93,19 +77,15 @@ class ShardedAdam(ColossalaiOptimizer): ...@@ -93,19 +77,15 @@ class ShardedAdam(ColossalaiOptimizer):
self.zero_grad() self.zero_grad()
return return
# Write payload back to p.data # Write master param to p.data
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
data = p.data p.data = self.master_params[p]
if hasattr(p, 'ca_attr'):
data = p.ca_attr.payload(self.state_device)
if torch.is_floating_point(data) and data.dtype != torch.float:
data = data.to(torch.float)
p.data = data
ret = self.optim.step(*args, **kwargs) ret = self.optim.step(*args, **kwargs)
# Set p.data to None # Write master param to payload and set p.data to None
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
# TODO: update payload
p.data = None p.data = None
return ret return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment