Commit 795210dd authored by ver217's avatar ver217 Committed by Frank Lee
Browse files

add fp32 master params in sharded adam

parent a109225b
from enum import Enum
from typing import Optional, Union
from typing import Dict, Optional, Union
import torch
import torch.distributed as dist
......@@ -11,6 +11,7 @@ from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from ._utils import has_inf_or_nan
......@@ -39,7 +40,7 @@ class ShardedAdam(ColossalaiOptimizer):
super().__init__(adam_optim)
self.model: Union[nn.Module, ShardedModelV2] = sharded_model
self.model_is_sharded = isinstance(sharded_model, ShardedModelV2)
self.state_device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
self.optim_state: OptimState = OptimState.UNSCALED
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
......@@ -51,35 +52,18 @@ class ShardedAdam(ColossalaiOptimizer):
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.state_device)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.device)
# Store fp32 params
self.master_params: Dict[Parameter, Tensor] = {}
# Early state initialization
for group in adam_optim.param_groups:
for p in group['params']:
state_shape = p.shape
if hasattr(p, 'ca_attr'):
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
# TODO: use payload shape
state_shape = p.ca_attr.payload(self.state_device)
state = adam_optim.state[p]
assert len(state) == 0, 'adam optimizer initialized'
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros(state_shape,
memory_format=torch.preserve_format,
dtype=torch.float,
device=self.state_device)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros(state_shape,
memory_format=torch.preserve_format,
dtype=torch.float,
device=self.state_device)
if group['amsgrad']:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros(state_shape,
memory_format=torch.preserve_format,
dtype=torch.float,
device=self.state_device)
self.master_params[p] = p.ca_attr.payload(self.device).to(torch.float)
else:
self.master_params[p] = p.data.to(torch.float)
def step(self, *args, **kwargs):
# unscale grads if scaled
......@@ -93,19 +77,15 @@ class ShardedAdam(ColossalaiOptimizer):
self.zero_grad()
return
# Write payload back to p.data
# Write master param to p.data
for group in self.optim.param_groups:
for p in group['params']:
data = p.data
if hasattr(p, 'ca_attr'):
data = p.ca_attr.payload(self.state_device)
if torch.is_floating_point(data) and data.dtype != torch.float:
data = data.to(torch.float)
p.data = data
p.data = self.master_params[p]
ret = self.optim.step(*args, **kwargs)
# Set p.data to None
# Write master param to payload and set p.data to None
for group in self.optim.param_groups:
for p in group['params']:
# TODO: update payload
p.data = None
return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment