Unverified Commit e7d3afc9 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[optimizer] add div_scale for optimizers (#2117)

* [optimizer] add div_scale for optimizers

* [zero] use div_scale in zero optimizer

* fix testing error
parent e5aa8333
...@@ -11,7 +11,7 @@ def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List ...@@ -11,7 +11,7 @@ def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List
... ...
def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float) -> None: def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float, div_scale: float) -> None:
... ...
......
...@@ -17,8 +17,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -17,8 +17,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const float lr, const float beta1, const float lr, const float beta1,
const float beta2, const float epsilon, const float beta2, const float epsilon,
const int step, const int mode, const int step, const int mode,
const int bias_correction, const int bias_correction, const float weight_decay,
const float weight_decay); const float div_scale);
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
......
...@@ -28,7 +28,7 @@ struct AdamFunctor { ...@@ -28,7 +28,7 @@ struct AdamFunctor {
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta1_correction, const float beta1, const float beta2, const float beta1_correction,
const float beta2_correction, const float epsilon, const float lr, const float beta2_correction, const float epsilon, const float lr,
adamMode_t mode, const float decay) { adamMode_t mode, const float decay, const float div_scale) {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
...@@ -79,6 +79,8 @@ struct AdamFunctor { ...@@ -79,6 +79,8 @@ struct AdamFunctor {
} }
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
if (div_scale > 0) r_g[ii] /= div_scale;
if (mode == ADAM_MODE_0) { // L2 if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]); r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
...@@ -116,8 +118,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -116,8 +118,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const float lr, const float beta1, const float lr, const float beta1,
const float beta2, const float epsilon, const float beta2, const float epsilon,
const int step, const int mode, const int step, const int mode,
const int bias_correction, const int bias_correction, const float weight_decay,
const float weight_decay) { const float div_scale) {
using namespace at; using namespace at;
// Handle bias correction mode // Handle bias correction mode
...@@ -133,7 +135,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -133,7 +135,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1, AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay);) lr, (adamMode_t)mode, weight_decay, div_scale);)
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }
...@@ -117,7 +117,7 @@ class CPUAdam(NVMeOptimizer): ...@@ -117,7 +117,7 @@ class CPUAdam(NVMeOptimizer):
data.addcdiv_(exp_avg, denom, value=-step_size) data.addcdiv_(exp_avg, denom, value=-step_size)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None, div_scale: float = -1):
loss = None loss = None
if closure is not None: if closure is not None:
with torch.enable_grad(): with torch.enable_grad():
...@@ -152,9 +152,10 @@ class CPUAdam(NVMeOptimizer): ...@@ -152,9 +152,10 @@ class CPUAdam(NVMeOptimizer):
self._pre_update(p, 'exp_avg', 'exp_avg_sq') self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], -1) state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq') self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda': elif target_device.type == 'cuda':
assert div_scale == -1, "div_scale should remain default"
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
......
...@@ -81,7 +81,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -81,7 +81,7 @@ class FusedAdam(torch.optim.Optimizer):
else: else:
super(FusedAdam, self).zero_grad() super(FusedAdam, self).zero_grad()
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None): def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, div_scale: float = -1):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
...@@ -137,6 +137,6 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -137,6 +137,6 @@ class FusedAdam(torch.optim.Optimizer):
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction, beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction,
group['weight_decay']) group['weight_decay'], div_scale)
return loss return loss
...@@ -89,7 +89,7 @@ class HybridAdam(NVMeOptimizer): ...@@ -89,7 +89,7 @@ class HybridAdam(NVMeOptimizer):
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None, div_scale: float = -1):
loss = None loss = None
if closure is not None: if closure is not None:
with torch.enable_grad(): with torch.enable_grad():
...@@ -126,7 +126,7 @@ class HybridAdam(NVMeOptimizer): ...@@ -126,7 +126,7 @@ class HybridAdam(NVMeOptimizer):
self._pre_update(p, 'exp_avg', 'exp_avg_sq') self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], -1) state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq') self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda': elif target_device.type == 'cuda':
...@@ -146,6 +146,6 @@ class HybridAdam(NVMeOptimizer): ...@@ -146,6 +146,6 @@ class HybridAdam(NVMeOptimizer):
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode, group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode,
bias_correction, group['weight_decay']) bias_correction, group['weight_decay'], div_scale)
self._post_step() self._post_step()
return loss return loss
...@@ -10,10 +10,12 @@ from torch.optim import Optimizer ...@@ -10,10 +10,12 @@ from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.parallel.data_parallel import ZeroDDP from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device from colossalai.utils import disposable, get_current_device
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
class OptimState(Enum): class OptimState(Enum):
SCALED = 0 SCALED = 0
...@@ -62,6 +64,7 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -62,6 +64,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
**defaults: Any): **defaults: Any):
super().__init__(optim) super().__init__(optim)
assert isinstance(module, ZeroDDP) assert isinstance(module, ZeroDDP)
assert type(optim) in _AVAIL_OPTIM_LIST, "you should use the optimizer in the available list"
self.module = module self.module = module
self.gemini_manager = module.gemini_manager self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
...@@ -162,21 +165,24 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -162,21 +165,24 @@ class ZeroOptimizer(ColossalaiOptimizer):
global_norm = math.sqrt(norm_sqr) global_norm = math.sqrt(norm_sqr)
return global_norm return global_norm
def _unscale_and_clip_grads(self): def _get_combined_scale(self):
assert self.optim_state == OptimState.SCALED loss_scale = 1
combined_scale = self.loss_scale if self.optim_state == OptimState.SCALED:
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED
combined_scale = loss_scale
if self.clipping_flag: if self.clipping_flag:
total_norm = self._calc_global_norm() total_norm = self._calc_global_norm()
clip = ((total_norm / self.loss_scale) + 1e-6) / self.max_norm clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
if clip > 1: if clip > 1:
combined_scale = clip * self.loss_scale combined_scale = clip * loss_scale
for group in self.optim.param_groups: if combined_scale == 1:
for p in group['params']: return -1
if p.grad is not None: else:
p.grad.data.div_(combined_scale) return combined_scale
self.optim_state = OptimState.UNSCALED
@property @property
def loss_scale(self): def loss_scale(self):
...@@ -199,12 +205,12 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -199,12 +205,12 @@ class ZeroOptimizer(ColossalaiOptimizer):
self._update_fp16_params() self._update_fp16_params()
return return
# unscale grads if scaled # get combined scale. combined scale = loss scale * clipping norm
if self.optim_state == OptimState.SCALED: # so that gradient = gradient / combined scale
self._unscale_and_clip_grads() combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf) self.grad_scaler.update(found_inf)
ret = self.optim.step(*args, **kwargs) ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self._register_states() self._register_states()
self.zero_grad() self.zero_grad()
self._update_fp16_params() self._update_fp16_params()
......
...@@ -71,7 +71,7 @@ def test_adam(adamw, step, p_dtype, g_dtype): ...@@ -71,7 +71,7 @@ def test_adam(adamw, step, p_dtype, g_dtype):
weight_decay = 0 weight_decay = 0
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw, multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw,
True, weight_decay) True, weight_decay, -1)
torch_adam_update( torch_adam_update(
step, step,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment