Unverified Commit a651e2c2 authored by Kexin Yu's avatar Kexin Yu Committed by GitHub
Browse files

sync-free Distributed LAMB + parameter reordering (#1055)



* sync free Distributed LAMB

* init lr with provided value

* wait l2 norm strem

* reorder param

* fix indent
Co-authored-by: default avatarKexin Yu <kexiny@nvidia.com>
parent d86d1b09
...@@ -8,11 +8,13 @@ void multi_tensor_lamb_compute_update_term_cuda( ...@@ -8,11 +8,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_beta2, at::Tensor per_tensor_beta2,
at::Tensor per_tensor_beta3, at::Tensor per_tensor_beta3,
at::Tensor per_tensor_bias_correction, at::Tensor per_tensor_bias_correction,
const int step, at::Tensor step,
at::Tensor per_tensor_epsilon, at::Tensor per_tensor_epsilon,
const int mode, const int mode,
at::Tensor per_tensor_decay, at::Tensor per_tensor_decay,
const float grad_scale); at::Tensor global_scale,
at::Tensor global_grad_norm,
const float max_grad_norm);
void multi_tensor_lamb_update_weights_cuda( void multi_tensor_lamb_update_weights_cuda(
int chunk_size, int chunk_size,
...@@ -20,8 +22,10 @@ void multi_tensor_lamb_update_weights_cuda( ...@@ -20,8 +22,10 @@ void multi_tensor_lamb_update_weights_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm, at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm, at::Tensor per_tensor_update_norm,
const float learning_rate, at::Tensor update_norm_offset,
at::Tensor learning_rate,
at::Tensor per_tensor_decay, at::Tensor per_tensor_decay,
at::Tensor global_grad_norm,
bool use_nvlamb); bool use_nvlamb);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -116,28 +116,36 @@ struct DistOptLAMBStage1Functor ...@@ -116,28 +116,36 @@ struct DistOptLAMBStage1Functor
const MATH_T* per_tensor_beta2, const MATH_T* per_tensor_beta2,
const MATH_T* per_tensor_beta3, const MATH_T* per_tensor_beta3,
const int* per_tensor_bias_correction, const int* per_tensor_bias_correction,
const int step, const int* step,
const MATH_T* per_tensor_epsilon, const MATH_T* per_tensor_epsilon,
adamMode_t mode, adamMode_t mode,
const MATH_T* per_tensor_decay, const MATH_T* per_tensor_decay,
const float grad_scale) const MATH_T* global_scale,
const MATH_T* global_grad_norm,
const float max_grad_norm)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) if (*noop_gmem == 1)
// return; return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc; int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
float combined_scale = *global_scale;
if (max_grad_norm > 0) {
combined_scale = max_grad_norm / (*global_grad_norm / *global_scale + 1e-6);
combined_scale = *global_scale / std::min((float) 1.0, combined_scale);
}
MATH_T beta1 = per_tensor_beta1[tensor_num]; MATH_T beta1 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta2[tensor_num]; MATH_T beta2 = per_tensor_beta2[tensor_num];
MATH_T beta3 = 1 - beta1; MATH_T beta3 = 1 - beta1;
MATH_T beta1_correction, beta2_correction; MATH_T beta1_correction, beta2_correction;
if (per_tensor_bias_correction[tensor_num] == 1) { if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - pow(beta1, step); beta1_correction = 1 - pow(beta1, *step);
beta2_correction = 1 - pow(beta2, step); beta2_correction = 1 - pow(beta2, *step);
} else { } else {
beta1_correction = (MATH_T) 1.0; beta1_correction = (MATH_T) 1.0;
beta2_correction = (MATH_T) 1.0; beta2_correction = (MATH_T) 1.0;
...@@ -204,7 +212,7 @@ struct DistOptLAMBStage1Functor ...@@ -204,7 +212,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
if (mode == MOMENT_MODE_0) { if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / grad_scale; MATH_T scaled_grad = r_g[ii] / combined_scale;
// L2 on scaled grad // L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii]; scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
...@@ -215,7 +223,7 @@ struct DistOptLAMBStage1Functor ...@@ -215,7 +223,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom; r_p[ii] = next_m_unbiased / denom;
} }
else { else {
MATH_T scaled_grad = r_g[ii] / grad_scale; MATH_T scaled_grad = r_g[ii] / combined_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
...@@ -274,7 +282,7 @@ struct DistOptLAMBStage1Functor ...@@ -274,7 +282,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
if (mode == MOMENT_MODE_0) { if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / grad_scale; MATH_T scaled_grad = r_g[ii] / combined_scale;
// L2 on scaled grad // L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii]; scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
...@@ -285,7 +293,7 @@ struct DistOptLAMBStage1Functor ...@@ -285,7 +293,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom; r_p[ii] = next_m_unbiased / denom;
} }
else { else {
MATH_T scaled_grad = r_g[ii] / grad_scale; MATH_T scaled_grad = r_g[ii] / combined_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
...@@ -321,13 +329,15 @@ struct DistOptLAMBStage2Functor ...@@ -321,13 +329,15 @@ struct DistOptLAMBStage2Functor
TensorListMetadata<3>& tl, TensorListMetadata<3>& tl,
const MATH_T* per_tensor_param_norm, const MATH_T* per_tensor_param_norm,
const MATH_T* per_tensor_update_norm, const MATH_T* per_tensor_update_norm,
const MATH_T learning_rate, const long* update_norm_offset,
const MATH_T* learning_rate,
const MATH_T* per_tensor_decay, const MATH_T* per_tensor_decay,
const MATH_T* global_grad_norm,
bool use_nvlamb) bool use_nvlamb)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) if (*noop_gmem == 1)
// return; return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc; int tensor_num = tl.start_tensor_this_launch + tensor_loc;
...@@ -336,14 +346,14 @@ struct DistOptLAMBStage2Functor ...@@ -336,14 +346,14 @@ struct DistOptLAMBStage2Functor
MATH_T decay = per_tensor_decay[tensor_num]; MATH_T decay = per_tensor_decay[tensor_num];
MATH_T ratio = learning_rate; MATH_T ratio = *learning_rate;
// nvlamb: apply adaptive learning rate to all parameters // nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay // otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != (MATH_T) 0.0)) if (use_nvlamb || (decay != (MATH_T) 0.0))
{ {
MATH_T param_norm = per_tensor_param_norm[tensor_num]; MATH_T param_norm = per_tensor_param_norm[tensor_num];
MATH_T update_norm = per_tensor_update_norm[tensor_num]; MATH_T update_norm = per_tensor_update_norm[update_norm_offset[tensor_num]];
ratio = (update_norm != 0.0 && param_norm != 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate; ratio = (update_norm != 0.0 && param_norm != 0.0) ? (*learning_rate) * (param_norm / update_norm) : (*learning_rate);
} }
MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc]; MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];
...@@ -427,11 +437,13 @@ void multi_tensor_lamb_compute_update_term_cuda( ...@@ -427,11 +437,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_beta2, at::Tensor per_tensor_beta2,
at::Tensor per_tensor_beta3, at::Tensor per_tensor_beta3,
at::Tensor per_tensor_bias_correction, at::Tensor per_tensor_bias_correction,
const int step, at::Tensor step,
at::Tensor per_tensor_epsilon, at::Tensor per_tensor_epsilon,
const int mode, const int mode,
at::Tensor per_tensor_decay, at::Tensor per_tensor_decay,
const float grad_scale) at::Tensor global_scale,
at::Tensor global_grad_norm,
const float max_grad_norm)
{ {
using namespace at; using namespace at;
...@@ -448,11 +460,13 @@ void multi_tensor_lamb_compute_update_term_cuda( ...@@ -448,11 +460,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
per_tensor_beta2.DATA_PTR<scalar_t_2>(), per_tensor_beta2.DATA_PTR<scalar_t_2>(),
per_tensor_beta3.DATA_PTR<scalar_t_2>(), per_tensor_beta3.DATA_PTR<scalar_t_2>(),
per_tensor_bias_correction.DATA_PTR<int>(), per_tensor_bias_correction.DATA_PTR<int>(),
step, step.DATA_PTR<int>(),
per_tensor_epsilon.DATA_PTR<scalar_t_2>(), per_tensor_epsilon.DATA_PTR<scalar_t_2>(),
(adamMode_t) mode, (adamMode_t) mode,
per_tensor_decay.DATA_PTR<scalar_t_2>(), per_tensor_decay.DATA_PTR<scalar_t_2>(),
grad_scale); ))) global_scale.DATA_PTR<scalar_t_2>(),
global_grad_norm.DATA_PTR<scalar_t_2>(),
max_grad_norm); )))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }
...@@ -463,8 +477,10 @@ void multi_tensor_lamb_update_weights_cuda( ...@@ -463,8 +477,10 @@ void multi_tensor_lamb_update_weights_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm, at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm, at::Tensor per_tensor_update_norm,
const float learning_rate, at::Tensor update_norm_offset,
at::Tensor learning_rate,
at::Tensor per_tensor_decay, at::Tensor per_tensor_decay,
at::Tensor global_grad_norm,
bool use_nvlamb) bool use_nvlamb)
{ {
using namespace at; using namespace at;
...@@ -480,8 +496,10 @@ void multi_tensor_lamb_update_weights_cuda( ...@@ -480,8 +496,10 @@ void multi_tensor_lamb_update_weights_cuda(
DistOptLAMBStage2Functor<scalar_t_0, scalar_t_1, scalar_t_2>(), DistOptLAMBStage2Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),
per_tensor_param_norm.DATA_PTR<scalar_t_2>(), per_tensor_param_norm.DATA_PTR<scalar_t_2>(),
per_tensor_update_norm.DATA_PTR<scalar_t_2>(), per_tensor_update_norm.DATA_PTR<scalar_t_2>(),
(scalar_t_2) learning_rate, update_norm_offset.DATA_PTR<long>(),
learning_rate.DATA_PTR<scalar_t_2>(),
per_tensor_decay.DATA_PTR<scalar_t_2>(), per_tensor_decay.DATA_PTR<scalar_t_2>(),
global_grad_norm.DATA_PTR<scalar_t_2>(),
use_nvlamb); ))) use_nvlamb); )))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
...@@ -4,6 +4,8 @@ import importlib ...@@ -4,6 +4,8 @@ import importlib
import amp_C import amp_C
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import torch.distributed.distributed_c10d as c10d
class DistributedFusedLAMB(torch.optim.Optimizer): class DistributedFusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm. """Implements LAMB algorithm.
...@@ -56,8 +58,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -56,8 +58,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
(default: 1.0) (default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False) weight decay parameter (default: False)
clip_grad_norm (boolean, optional): whether to handle gradient clipping step_supports_amp_scaling(boolean, optional): whether to use customized
(default: True) gradient unscaling logic (default: True)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962 https://arxiv.org/abs/1904.00962
...@@ -65,12 +67,24 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -65,12 +67,24 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
class AtomicCounter(object):
def __init__(self):
self.value = 0
self.order = []
import threading
self._lock = threading.Lock()
def add(self, idx):
with self._lock:
self.value += 1
self.order.append(idx)
def __init__(self, params, def __init__(self, params,
lr=1e-3, bias_correction = True, grad_averaging=True, lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0., weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True, adam_w_mode=True, use_nvlamb=False,
amp_scale_adjustment=1.0, overlap_reductions=True, step_supports_amp_scaling=True, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False): e5m2_allgather=False):
...@@ -81,46 +95,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -81,46 +95,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
super(DistributedFusedLAMB, self).__init__(params, defaults) super(DistributedFusedLAMB, self).__init__(params, defaults)
self._init_args = {
'lr': lr,
'bias_correction': bias_correction,
'grad_averaging': grad_averaging,
'betas': betas,
'eps': eps,
'weight_decay': weight_decay,
'max_grad_norm': max_grad_norm,
'adam_w_mode': adam_w_mode,
'use_nvlamb': use_nvlamb,
'clip_grad_norm': clip_grad_norm,
'amp_scale_adjustment': amp_scale_adjustment,
'overlap_reductions': overlap_reductions,
'dwu_group_size': dwu_group_size,
'dwu_num_blocks': dwu_num_blocks,
'dwu_num_chunks': dwu_num_chunks,
'dwu_num_rs_pg': dwu_num_rs_pg,
'dwu_num_ar_pg': dwu_num_ar_pg,
'dwu_num_ag_pg': dwu_num_ag_pg,
'e5m2_allgather': e5m2_allgather}
self._init_done = False
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def __first_step_init__(self,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False):
global fused_adam_cuda, distributed_lamb_cuda global fused_adam_cuda, distributed_lamb_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda") distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False self._has_overflow = False
self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term
...@@ -128,9 +106,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -128,9 +106,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import amp_C import amp_C
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self._grad_averaging = grad_averaging
self._adam_w_mode = 1 if adam_w_mode else 0 self._adam_w_mode = 1 if adam_w_mode else 0
self._use_nvlamb = use_nvlamb self._use_nvlamb = use_nvlamb
self._clip_grad_norm = clip_grad_norm self._step_supports_amp_scaling = step_supports_amp_scaling
self._is_accumulation_step = False self._is_accumulation_step = False
self._last_step = False self._last_step = False
self._overlap_reductions = overlap_reductions self._overlap_reductions = overlap_reductions
...@@ -139,43 +118,126 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -139,43 +118,126 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._num_chunks = dwu_num_chunks self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather self._e5m2_allgather = e5m2_allgather
self._L2_grad_norm = None self._L2_grad_norm = None
self._current_process_group = c10d._get_default_group()
self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size() self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size self._rank_in_group = torch.distributed.get_rank() % self._group_size
self._lr = torch.tensor(0.0, dtype=torch.float32, device='cuda')
self._resume_from_checkpoint = False
self._step = torch.cuda.IntTensor([0])
# Master weight, moment, gradient buffers
self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
#import inspect
#assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
#for ar_pg in self._ar_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
#torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
#for rs_pg in self._rs_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
#for ag_pg in self._ag_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream()
self._completion_st = torch.cuda.Stream()
self._step.record_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
self._one = torch.cuda.IntTensor([1])
self._first_step = True
self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False
self._param_order = self.AtomicCounter()
def _lazy_init_stage1(self):
if self._lazy_init_stage1_done: return
p_offset = 0 p_offset = 0
p_i = 0 p_i = 0
self._model_params = [] self._model_params = []
self._grads_info = []
self._grad_accs = [] self._grad_accs = []
self._group_properties = [] self._group_properties = []
for group in self.param_groups: for group in self.param_groups:
prev = None prev = None
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
beta3 = 1.0 - beta1 if self._grad_averaging else 1.0
bias_correction = 1 if group['bias_correction'] else 0
eps = group['eps']
weight_decay = group['weight_decay']
for p in group['params']: for p in group['params']:
torch.distributed.broadcast(p,0) torch.distributed.broadcast(p, 0)
if not p.requires_grad: if not p.requires_grad:
continue continue
self._model_params.append(p) self._model_params.append(p)
self._group_properties.append(( self._group_properties.append((
group['weight_decay'], weight_decay,
1 if group['bias_correction'] else 0, bias_correction,
beta1, beta1,
beta2, beta2,
1.0 - beta1 if grad_averaging else 1.0, beta3,
group['eps'] eps
)) ))
p_grads_size = p.numel() p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset): def wrapper(param, param_i):
param_tmp = param.expand_as(param) param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused): def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param) if self._first_step:
# first time
self._param_order.add(param_i)
else:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
grad_acc.register_hook(allreduce_hook) grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc) self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset}) wrapper(p, p_i)
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters: # RNN is one example of consecutive parameters:
...@@ -184,7 +246,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -184,7 +246,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
p_offset = ((p_offset + 63) // 64) * 64 p_offset = ((p_offset + 63) // 64) * 64
prev = p prev = p
p_i += 1 p_i += 1
self._grads_generated = [False]*len(self._grads_info) self._grads_generated = [False]*len(self._model_params)
self._grads_fp16, self._grads_fp32 = [], [] self._grads_fp16, self._grads_fp32 = [], []
if self._overlap_reductions: if self._overlap_reductions:
self._current_block = self._num_blocks self._current_block = self._num_blocks
...@@ -196,19 +258,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -196,19 +258,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._block_size = self._total_param_size // self._num_blocks self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size self._shard_size = self._chunk_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size)) #print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda') self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda') self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
# initialize master weights, moments buffers if not loaded from checkpoint
if self._fp32_p is None:
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
...@@ -217,10 +273,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -217,10 +273,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda') self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda') self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
self._individual_flat_grads = []
for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]].view_as(p))
def _flat_split(p): def _flat_split(p):
def __blockify(p): def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)] return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
...@@ -262,6 +314,45 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -262,6 +314,45 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p) self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g) self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
self._lazy_init_stage1_done = True
def _lazy_init_stage2(self):
if self._lazy_init_stage2_done: return
self._param_order.order.reverse()
# re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
# re-collect grads info (size, offset) after ordering
prev = None
p_offset = 0
self._grads_info = []
self._individual_flat_grads = []
for i, p in enumerate(self._model_params):
p_grads_size = p.numel()
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p))
# for the first iteration
self._do_overlapped_reduction(i, p)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
#print("self._low_param_i", self._low_param_i)
# This paragraph does two things: # This paragraph does two things:
# 1) Copy model parameters into master buffer # 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather # 2) Create tensor lists for unpacking new parameter tensor after all-gather
...@@ -274,7 +365,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -274,7 +365,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_model_param_for_norm_fp16 = [] self._contrib_model_param_for_norm_fp16 = []
self._contrib_model_param_for_norm_fp32 = [] self._contrib_model_param_for_norm_fp32 = []
self._contrib_model_param_for_norm_is_fp16 = [] self._contrib_model_param_for_norm_is_fp16 = []
self._model_param_is_contrib = [False]*self._model_params_num self._model_param_is_contrib = []
self._contrib_group_properties = [] self._contrib_group_properties = []
for shard_id in range(self._group_size): for shard_id in range(self._group_size):
for block_id in range(self._num_blocks): for block_id in range(self._num_blocks):
...@@ -297,7 +388,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -297,7 +388,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
else: else:
self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) ) self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group: if shard_id == self._rank_in_group:
self._model_param_is_contrib[param_i] = True self._model_param_is_contrib.append(param_i)
# copy model parameters into master buffer # copy model parameters into master buffer
master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length] master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length] opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
...@@ -306,6 +397,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -306,6 +397,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length] opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length] opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
#print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size()))) #print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
if not self._resume_from_checkpoint:
master_param_fragment.copy_(model_param_fragment) master_param_fragment.copy_(model_param_fragment)
self._contrib_group_properties.append(group_props) self._contrib_group_properties.append(group_props)
self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_u_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, u, g, p_copy self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_u_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, u, g, p_copy
...@@ -322,7 +414,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -322,7 +414,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
if len(self._contrib_model_param_for_norm_fp32) == 0: self._contrib_model_param_for_norm_fp32 = None if len(self._contrib_model_param_for_norm_fp32) == 0: self._contrib_model_param_for_norm_fp32 = None
self._contrib_model_param_for_norm_is_fp32 = torch.tensor([not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda') self._contrib_model_param_for_norm_is_fp32 = torch.tensor([not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
self._contrib_model_param_for_norm_is_fp16 = torch.tensor([is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda') self._contrib_model_param_for_norm_is_fp16 = torch.tensor([is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
self._model_param_is_contrib = torch.tensor(self._model_param_is_contrib, dtype=torch.bool, device='cuda') self._offsets = torch.tensor(self._model_param_is_contrib, dtype=torch.int64, device='cuda')
p, m, v, u, g, p_copy = list(zip(*self._contrib_tensor_list)) p, m, v, u, g, p_copy = list(zip(*self._contrib_tensor_list))
self._contrib_compute_update_term_tensor_list = [g, p, m, v, u] self._contrib_compute_update_term_tensor_list = [g, p, m, v, u]
...@@ -340,62 +432,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -340,62 +432,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._packed_flat_to_model_params_fp16 = list(zip(*self._packed_flat_to_model_params_fp16)) if len(self._packed_flat_to_model_params_fp16) > 0 else None self._packed_flat_to_model_params_fp16 = list(zip(*self._packed_flat_to_model_params_fp16)) if len(self._packed_flat_to_model_params_fp16) > 0 else None
self._packed_flat_to_model_params_fp32 = list(zip(*self._packed_flat_to_model_params_fp32)) if len(self._packed_flat_to_model_params_fp32) > 0 else None self._packed_flat_to_model_params_fp32 = list(zip(*self._packed_flat_to_model_params_fp32)) if len(self._packed_flat_to_model_params_fp32) > 0 else None
self._num_rs_pg = dwu_num_rs_pg self._lazy_init_stage2_done = True
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream()
self._completion_st = torch.cuda.Stream()
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
def _init_everything(self): self.complete_reductions()
if not self._init_done: self._first_step = False
self.__first_step_init__(**self._init_args)
self._init_done = True
def set_is_accumulation_step(self, is_accumulation_step): def set_is_accumulation_step(self, is_accumulation_step):
self._is_accumulation_step = is_accumulation_step self._is_accumulation_step = is_accumulation_step
...@@ -431,7 +471,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -431,7 +471,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg] rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream()) rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream): with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True) works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True)#,no_copy=True)
# Reduction across nodes for each rank # Reduction across nodes for each rank
if self._num_groups > 1: if self._num_groups > 1:
...@@ -453,7 +493,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -453,7 +493,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
l2_grad_norm_sq = torch.empty([1], device='cuda') l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2 l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg) torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item() self._L2_grad_norm = l2_grad_norm_sq.sqrt()
def __compute_contrib_param_norm(self): def __compute_contrib_param_norm(self):
if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None: if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:
...@@ -471,24 +511,24 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -471,24 +511,24 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def __compute_contrib_update_norm(self): def __compute_contrib_update_norm(self):
l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda') l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2 local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
l2_norm.masked_scatter_(self._model_param_is_contrib, local_contrib_l2_norm) l2_norm.scatter_(dim=0, index=self._offsets, src=local_contrib_l2_norm)
torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0]) torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
l2_norm = torch.sqrt(l2_norm) l2_norm = torch.sqrt(l2_norm)
return l2_norm.masked_select(self._model_param_is_contrib) return l2_norm
def _pipeline_step(self): def _pipeline_step(self):
# If self._clip_grad_norm is False, we assume gradient clipping already global_scale = self.global_scale
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale = self.global_scale
max_grad_norm = self.defaults['max_grad_norm'] max_grad_norm = self.defaults['max_grad_norm']
global_grad_norm = self.L2_grad_norm global_grad_norm = self.L2_grad_norm
if self._clip_grad_norm and max_grad_norm > 0 and math.isfinite(global_grad_norm):
combined_scale = max_grad_norm / (global_grad_norm / self.global_scale + 1e-6) # check global_grad_norm and fill overflow_buf
combined_scale = self.global_scale / min(1, combined_scale) is_finite = (global_grad_norm + 1 > global_grad_norm).int()
self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1
# increment step counter if no overflow
self._step += is_finite
self._completion_st.wait_stream(torch.cuda.current_stream())
self._completion_st.wait_stream(self._l2_grad_norm_st)
# Call step kernel once per step # Call step kernel once per step
# Call all-gather once per step # Call all-gather once per step
...@@ -504,21 +544,25 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -504,21 +544,25 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_beta2, self._contrib_beta2,
self._contrib_beta3, self._contrib_beta3,
self._contrib_bias_correction, self._contrib_bias_correction,
self.param_groups[0]['step'], self._step,
self._contrib_epsilon, self._contrib_epsilon,
self._adam_w_mode, self._adam_w_mode,
self._contrib_weight_decay, self._contrib_weight_decay,
combined_scale) global_scale,
global_grad_norm,
max_grad_norm)
upd_norm = self.__compute_contrib_update_norm() upd_norm = self.__compute_contrib_update_norm()
multi_tensor_applier(self.multi_tensor_lamb_update_weights, multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf, self._overflow_buf,
self._contrib_update_weights_tensor_list, # u, p, p_copy self._contrib_update_weights_tensor_list, # u, p, p_copy
param_norm, param_norm,
upd_norm, upd_norm,
self.param_groups[0]['lr'], self._offsets,
self._lr,
self._contrib_weight_decay, self._contrib_weight_decay,
global_grad_norm,
self._use_nvlamb) self._use_nvlamb)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True) torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0])#, no_copy=True)
def _flatten_grad_mt(self, scale): def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0: if len(self._grads_fp16) > 0:
...@@ -538,8 +582,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -538,8 +582,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
scale) scale)
self._grads_fp32 = [] self._grads_fp32 = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param): def _do_overlapped_reduction(self, param_i, param):
self._init_everything()
if not self._is_accumulation_step: if not self._is_accumulation_step:
# handle overlapped reductions # handle overlapped reductions
if param.dtype == torch.float16: if param.dtype == torch.float16:
...@@ -547,7 +590,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -547,7 +590,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
else: else:
self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) ) self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) )
self._grads_generated[param_i]=True self._grads_generated[param_i]=True
if self._overlap_reductions and not self._last_step: if not self._first_step and not self._last_step:
if self._overlap_reductions:
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
while flush_block: while flush_block:
block_id = flush_block[0] // self._block_size block_id = flush_block[0] // self._block_size
...@@ -571,8 +615,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -571,8 +615,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def complete_reductions(self): def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed. """Complete reductions if full pipeline is not selected or overlap is not allowed.
""" """
self._init_everything()
if self._last_step: if self._last_step:
# zero out gradients that have not been completed yet # zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated): for param_i, grad_generated in enumerate(self._grads_generated):
...@@ -583,7 +625,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -583,7 +625,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._flat_grads[param_offset:param_offset+param_size].zero_() self._flat_grads[param_offset:param_offset+param_size].zero_()
self._grads_generated[param_i] = True self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions: if self._first_step or self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions # nothing done so far, run full pipeline after reductions
for block_id in range(self._num_blocks-1,-1,-1): for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_reductions(block_id) self._pipeline_block_reductions(block_id)
...@@ -593,24 +635,23 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -593,24 +635,23 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._current_block = self._num_blocks self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info) self._grads_generated = [False]*len(self._grads_info)
def step(self, closure=None): def step(self, closure=None, grad_scaler=None):
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
for param_group in self.param_groups:
if 'step' in param_group:
param_group['step'] += 1
else:
param_group['step'] = 1
self._pipeline_step() self._pipeline_step()
if grad_scaler is not None:
found_inf = self._overflow_buf.float()
optimizer_state = grad_scaler._per_optimizer_states[id(self)]
current_device = torch.device('cuda', torch.cuda.current_device())
optimizer_state["found_inf_per_device"][current_device] = found_inf
self._completion_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
# Copy self._new_params to model params # Copy self._new_params to model params
self._overflow_buf.zero_()
with torch.no_grad(): with torch.no_grad():
if self._packed_flat_to_model_params_fp16 is not None: if self._packed_flat_to_model_params_fp16 is not None:
multi_tensor_applier( multi_tensor_applier(
...@@ -630,4 +671,42 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -630,4 +671,42 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return loss return loss
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
# save step, master weights and first/second moments
state_dict = {}
state_dict['step'] = self._step
state_dict['fp32_p'] = self._fp32_p
state_dict['fp32_m'] = self._fp32_m
state_dict['fp32_v'] = self._fp32_v
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``optimizer.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# restore step, master weights and first/second moments
self._step = state_dict['step']
self._fp32_p = state_dict['fp32_p'].to(device="cuda")
self._fp32_m = state_dict['fp32_m'].to(device="cuda")
self._fp32_v = state_dict['fp32_v'].to(device="cuda")
self._resume_from_checkpoint = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment