Unverified Commit 6b7e77b0 authored by Kexin Yu's avatar Kexin Yu Committed by GitHub
Browse files

DistributedFusedAdam Model Parallelism Support (Megatron) (#981)



DistributedFusedAdam Model Parallelism Support (Megatron)
Co-authored-by: default avatarKexin Yu <kexiny@nvidia.com>
Co-authored-by: default avatarKexin Yu <kexinznzn›@gmail.com>
parent 8a1ed9e8
#include <torch/extension.h>
void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda,
"Multi tensor Adam optimized CUDA implementation.");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <THC/THCGeneral.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cmath>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <int DEPTH, typename T, typename GRAD_T>
struct DistAdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float* per_tensor_beta1,
const float* per_tensor_beta2,
const int* per_tensor_bias_correction,
const float* per_tensor_eps,
const float* per_tensor_weight_decay,
const float lr,
const float grad_scale,
const int step,
adamMode_t mode)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float b1 = per_tensor_beta1[tensor_num];
float b2 = per_tensor_beta2[tensor_num];
float eps = per_tensor_eps[tensor_num];
float decay = per_tensor_weight_decay[tensor_num];
float beta1_correction = 1.0f, beta2_correction = 1.0f;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - std::pow(b1, step);
beta2_correction = 1 - std::pow(b2, step);
}
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v) &&
is_aligned(g) &&
is_aligned(p_copy)) {
for (int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) {
// load
GRAD_T tmp_g[ILP];
load_store(incoming_p, p, 0, i_start);
load_store(incoming_m, m, 0, i_start);
load_store(incoming_v, v, 0, i_start);
load_store(tmp_g, g, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_g[ii] = static_cast<T>(tmp_g[ii]);
T scaled_grad = incoming_g[ii]/grad_scale;
incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
T next_m_unbiased = incoming_m[ii] / beta1_correction;
T next_v_unbiased = incoming_v[ii] / beta2_correction;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(next_v_unbiased + eps);
else // Mode 1
denom = sqrtf(next_v_unbiased) + eps;
float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);
incoming_p[ii] = incoming_p[ii] - (lr * update);
if (DEPTH == 5) tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);
}
load_store(p, incoming_p, i_start, 0);
load_store(m, incoming_m, i_start, 0);
load_store(v, incoming_v, i_start, 0);
if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);
}
} else {
for (int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if (j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
T next_m_unbiased = m[j] / beta1_correction;
T next_v_unbiased = v[j] / beta2_correction;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(next_v_unbiased + eps);
else // Mode 1
denom = sqrtf(next_v_unbiased) + eps;
float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);
p[j] = incoming_p[ii] - (lr * update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
}
};
void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode)
{
using namespace at;
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<5, accscalar_t, scalar_t_0>(),
per_tensor_beta1.DATA_PTR<float>(),
per_tensor_beta2.DATA_PTR<float>(),
per_tensor_bias_correction.DATA_PTR<int>(),
per_tensor_eps.DATA_PTR<float>(),
per_tensor_weight_decay.DATA_PTR<float>(),
lr,
grad_scale,
step,
(adamMode_t) mode);
);
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<4, accscalar_t, scalar_t_0>(),
per_tensor_beta1.DATA_PTR<float>(),
per_tensor_beta2.DATA_PTR<float>(),
per_tensor_bias_correction.DATA_PTR<int>(),
per_tensor_eps.DATA_PTR<float>(),
per_tensor_weight_decay.DATA_PTR<float>(),
lr,
grad_scale,
step,
(adamMode_t) mode);
);
}
THCudaCheck(cudaGetLastError());
}
...@@ -4,13 +4,15 @@ import importlib ...@@ -4,13 +4,15 @@ import importlib
import amp_C import amp_C
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import torch.distributed.distributed_c10d as c10d
class DistributedFusedAdam(torch.optim.Optimizer): class DistributedFusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``. ``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_. It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments: Arguments:
params (iterable): iterable of parameters to optimize or dicts defining params (iterable): iterable of parameters to optimize or dicts defining
parameter groups. parameter groups.
...@@ -19,20 +21,30 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -19,20 +21,30 @@ class DistributedFusedAdam(torch.optim.Optimizer):
running averages of gradient and its square. (default: (0.9, 0.999)) running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8) numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step, eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False) second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
latency. (default: False) amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
overlap_reductions(boolean, optional): whether to overlap reductions overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True) with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be step_supports_amp_scaling(boolean, optional): whether to use customized
reduced during first fp16 gradient reduction block. gradient unscaling logic (default: True)
num_process_groups (integer, optional): number of process groups in
the app (default: 1)
current_process_group (object, optional): the process group to work on
(default: None)
process_group_id (integer, optional): process group id (default: 0)
process_group_size (integer, optional): size of process group
(default: 0)
clip_grad_norm (boolean, optional): whether to handle gradient clipping
(default: True)
model_parallel (boolean, optional): whether model parallelism is used
(default: False)
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
...@@ -41,22 +53,28 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -41,22 +53,28 @@ class DistributedFusedAdam(torch.optim.Optimizer):
""" """
def __init__(self, params, def __init__(self, params,
lr=1e-3, bias_correction = True, lr=1e-3, bias_correction=True, betas=(0.9, 0.999),
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, eps=1e-8, eps_inside_sqrt=False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False, weight_decay=0., max_grad_norm=0.,
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True, amsgrad=False, flat_mt=False,
compute_L2_grad_norm=False, distributed_weight_update=0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4, compute_L2_grad_norm=False,
dwu_num_ag_pg=0, revert_method=1, flat_mt=False, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_chunks=4, predivide=True, e5m2_allgather=False, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
do_not_flatten_model=False): predivide=True, e5m2_allgather=False,
global fused_adam_cuda do_not_flatten_model=False,
step_supports_amp_scaling=True,
num_process_groups=1,
current_process_group=None,
process_group_id=0,
process_group_size=0,
clip_grad_norm=True,
model_parallel=False):
global fused_adam_cuda, distributed_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_adam_cuda = importlib.import_module("distributed_adam_cuda")
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self._amp_scale_adjustment = amp_scale_adjustment
if use_mt:
raise RuntimeError('DistributedFusedAdam does not support use_mt.')
if amsgrad: if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.') raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
...@@ -64,21 +82,12 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -64,21 +82,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
betas=betas, eps=eps, weight_decay=weight_decay, betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(DistributedFusedAdam, self).__init__(params, defaults) super(DistributedFusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
# Misc
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False self._has_overflow = False
self._step_supports_amp_scaling = step_supports_amp_scaling
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False self._last_step = False
self._overlap_reductions = overlap_reductions self._overlap_reductions = overlap_reductions
self._global_scale = None self._global_scale = None
...@@ -87,33 +96,64 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -87,33 +96,64 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._predivide = predivide self._predivide = predivide
self._e5m2_allgather = e5m2_allgather self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = None self._L2_grad_norm = None
self._flat_mt = flat_mt
self._init_done = False
self._resume_from_checkpoint = False
self._step = 0
# Process group related
self._clip_grad_norm = clip_grad_norm
self._model_parallel = model_parallel
self._num_process_groups = num_process_groups
self._current_process_group = current_process_group if current_process_group is not None else c10d._get_default_group()
self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())
self._process_group_id = process_group_id
self._process_group_size = torch.cuda.device_count() if process_group_size <= 0 else process_group_size
self._world_size = self._process_group_size # world: the current process group
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size self._global_rank = torch.distributed.get_rank()
self._world_rank = self._global_rank // self._num_process_groups
self._group_rank = self._world_rank % self._group_size
#print("world_size:", self._world_size, ", group_size:", self._group_size, ", num_groups:", self._num_groups, ", global_rank:", self._global_rank, ", world_rank:", self._world_rank, ", group_rank:", self._group_rank)
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
# Master weight, moment, gradient buffers
self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
def _first_step_init(self):
p_offset = 0 p_offset = 0
p_i = 0 p_i = 0
self._param_state = None
self._model_params = [] self._model_params = []
self._grads_info = [] self._grads_info = []
self._grad_accs = [] self._grad_accs = []
self._group_properties = []
for group in self.param_groups: for group in self.param_groups:
self._param_group = group self._param_group = group
prev = None prev = None
beta1, beta2 = group['betas']
bias_correction = 1 if group['bias_correction'] else 0
eps = group['eps']
weight_decay = group['weight_decay']
for p in group['params']: for p in group['params']:
torch.distributed.broadcast(p,0) # broadcast from rank 0 of current process group
torch.distributed.broadcast(p, src=self._available_ranks[0], group=self._current_process_group)
if not p.requires_grad: if not p.requires_grad:
continue continue
self._model_params.append(p) self._model_params.append(p)
state = self.state[p] # Multiple param groups support:
if len(state) == 0: # store one hyperparam item per parameter tensor
state['step'] = 0 self._group_properties.append((
if self._param_state is None: beta1,
self._param_state = state beta2,
bias_correction,
eps,
weight_decay
))
p_grads_size = p.numel() p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset): def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param) param_tmp = param.expand_as(param)
...@@ -133,7 +173,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -133,7 +173,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
prev = p prev = p
p_i += 1 p_i += 1
self._grads_generated = [False]*len(self._grads_info) self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt
self._grads = [] self._grads = []
if self._overlap_reductions: if self._overlap_reductions:
self._current_block = self._num_blocks self._current_block = self._num_blocks
...@@ -145,7 +184,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -145,7 +184,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._block_size = self._total_param_size // self._num_blocks self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size self._shard_size = self._chunk_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size)) #print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1): for block_id in range(self._num_blocks-1,-1,-1):
...@@ -153,14 +192,16 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -153,14 +192,16 @@ class DistributedFusedAdam(torch.optim.Optimizer):
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size: while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1 p_i -= 1
self._low_param_i[block_id] = p_i self._low_param_i[block_id] = p_i
print(self._low_param_i) #print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda') self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda') self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') # initialize master weights, moments buffers if not loaded from checkpoint
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') if self._fp32_p is None:
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16 # FIXME: Rethink fp16 label since it's either uint8 or fp16
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda') self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda') self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
...@@ -213,12 +254,15 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -213,12 +254,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# 1) Copy model parameters into master buffer # 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather # 2) Create tensor lists for unpacking new parameter tensor after all-gather
self._packed_flat_to_model_params = [] self._packed_flat_to_model_params = []
self._contrib_tensor_list = []
self._contrib_group_properties = []
self._non_parallel_grads = []
for shard_id in range(self._group_size): for shard_id in range(self._group_size):
for block_id in range(self._num_blocks): for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks): for chunk_id in range(self._num_chunks):
flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size
flat_shard_end = flat_shard_start + self._shard_size flat_shard_end = flat_shard_start + self._shard_size
for p, grads_info in zip(self._model_params, self._grads_info): for (p, grads_info, group_props) in zip(self._model_params, self._grads_info, self._group_properties):
flat_grad_start = grads_info["param_offset"] flat_grad_start = grads_info["param_offset"]
flat_grad_end = flat_grad_start + grads_info["param_grads_size"] flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start) clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
...@@ -230,60 +274,90 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -230,60 +274,90 @@ class DistributedFusedAdam(torch.optim.Optimizer):
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length] model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length] new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) ) self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group: if shard_id == self._group_rank:
# copy model parameters into master buffer # copy model parameters into master buffer
master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length] master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size()))) opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
master_param_fragment.copy_(model_param_fragment) opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
#print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
if not self._resume_from_checkpoint:
master_param_fragment.copy_(model_param_fragment)
self._contrib_group_properties.append(group_props)
self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, g, p_copy
if self._model_parallel and hasattr(p, 'model_parallel') and not p.model_parallel:
self._non_parallel_grads.append(opti_state_g_fragment)
p, m, v, g, p_copy = list(zip(*self._contrib_tensor_list))
self._contrib_tensor_list = [p, m, v, g, p_copy]
math_type = self._fp32_p.dtype
beta1, beta2, bias_correction, epsilon, decay = list(zip(*self._contrib_group_properties))
self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')
self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')
self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')
self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')
self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')
p_in, p_out = zip(*self._packed_flat_to_model_params) p_in, p_out = zip(*self._packed_flat_to_model_params)
self._packed_flat_to_model_params = [p_in, p_out] self._packed_flat_to_model_params = [p_in, p_out]
self._distributed_weight_update = distributed_weight_update # Is this still needed?
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1: if self._num_groups > 1:
self._ar_pg = [] self._ar_pg = []
for dev_i in range(self._group_size): for i in range(self._num_process_groups):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)] # gather global ranks of all members of the current process group
for i in range(self._num_ar_pg): ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]
grp = torch.distributed.new_group(ranks=ranks) for j in range(self._group_size):
if torch.distributed.get_rank() in ranks: ar_idx = [j+k*self._group_size for k in range(self._num_groups)]
self._ar_pg.append(grp) ar_rank = [ranks[k] for k in ar_idx]
#if self._global_rank in ar_rank:
# print("group for all reduce, ranks:", ar_rank)
for _ in range(self._num_ar_pg):
grp = torch.distributed.new_group(ranks=ar_rank)
if self._global_rank in ar_rank:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)] self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg: for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg) torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = []
for group_i in range(self._num_groups): self._rs_pg, rs_ranks = [],[]
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)]) for i in range(self._num_process_groups):
self._rs_pg = [] ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]
for group_i in range(self._num_groups): for j in range(self._num_groups):
ranks = rs_ranks[group_i] rs_idx = [j*self._group_size+k for k in range(self._group_size)]
for i in range(self._num_rs_pg): rs_rank = [ranks[k] for k in rs_idx]
grp = torch.distributed.new_group(ranks=ranks) #if self._global_rank in rs_rank:
if torch.distributed.get_rank() in ranks: # print("group for reduce scatter, ranks:", rs_rank)
self._rs_pg.append(grp) for _ in range(self._num_rs_pg):
if self._compute_L2_grad_norm: grp = torch.distributed.new_group(ranks=rs_rank)
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks) if self._global_rank in rs_rank:
if torch.distributed.get_rank() in ranks: self._rs_pg.append(grp)
self._l2_grad_norm_pg = l2_grad_norm_pg if self._compute_L2_grad_norm:
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg) l2_grad_norm_pg = torch.distributed.new_group(ranks=rs_rank)
if self._global_rank in rs_rank:
self._l2_grad_norm_pg = l2_grad_norm_pg
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)] self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg: for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg) torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0: if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg self._ag_pg = self._rs_pg
self._ag_st = self._rs_st self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg self._num_ag_pg = self._num_rs_pg
else: else:
self._ag_pg = [] self._ag_pg = []
for group_i in range(self._num_groups): for i in range(self._num_process_groups):
ranks = rs_ranks[group_i] ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]
for i in range(self._num_ag_pg): for j in range(self._num_groups):
grp = torch.distributed.new_group(ranks=ranks) ag_rank = rs_ranks[j]
if torch.distributed.get_rank() in ranks: #if self._global_rank in ag_rank:
self._ag_pg.append(grp) # print("group for all gather, ranks:", ag_rank)
for _ in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ag_rank)
if self._global_rank in ag_rank:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)] self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg: for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg) torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
...@@ -296,6 +370,10 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -296,6 +370,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
import inspect import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option" assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def _init_everything(self):
if not self._init_done:
self._first_step_init()
self._init_done = True
def set_last_step(self, last_step): def set_last_step(self, last_step):
self._last_step = last_step self._last_step = last_step
...@@ -350,46 +428,43 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -350,46 +428,43 @@ class DistributedFusedAdam(torch.optim.Optimizer):
l2_grad_norm_sq = torch.empty([1], device='cuda') l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2 l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg) torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
# for model_parallel_rank=0, keep all gradients
# for the rest, subtract non_parallel gradients
if self._model_parallel and self._process_group_id: # non zero model_parallel_rank
non_parallel_grad_norm_sq = torch.zeros([1], device='cuda')
if len(self._non_parallel_grads): # non parallel grads exit
non_parallel_grad_norm_sq = multi_tensor_applier(self.multi_tensor_l2norm,
self._overflow_buf,
[self._non_parallel_grads], False)[0]**2
torch.distributed.all_reduce(non_parallel_grad_norm_sq, group=self._l2_grad_norm_pg)
l2_grad_norm_sq = l2_grad_norm_sq - non_parallel_grad_norm_sq
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item() self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self, p, p_copy, m, v, g): def __launch_step_kernel(self):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale = self._global_scale combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm): if self._clip_grad_norm and self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6) combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale) combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas'] self._step += 1
fused_adam_cuda.reversible_adam( multi_tensor_applier(distributed_adam_cuda.multi_tensor_fused_adam,
p, p_copy, m, v, g, self._overflow_buf,
self._contrib_tensor_list, # p, m, v, g, p_copy
self._contrib_beta1,
self._contrib_beta2,
self._contrib_bias_correction,
self._contrib_epsilon,
self._contrib_weight_decay,
self._param_group['lr'], self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale, combined_scale,
self._param_state['step']+1, self._step,
self.eps_mode, self.eps_mode)
bias_correction,
self._param_group['weight_decay'])
def _pipeline_block_step(self, block_id):
# Call step kernel once per block
ag_stream = self._ag_st[block_id%self._num_ag_pg]
with torch.cuda.stream(ag_stream):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p_blocks[block_id],
self._fp16_p_blocks[block_id],
self._fp32_m_blocks[block_id],
self._fp32_v_blocks[block_id],
self._fp16_g_blocks[block_id])
# Call all-gather once per step.
# FIXME: Determine which is faster, one all-gather per block or a single all-gather at end
if block_id == 0:
for other_ag_stream in self._ag_st:
self._completion_st.wait_stream(other_ag_stream)
with torch.cuda.stream(self._completion_st):
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _pipeline_step(self): def _pipeline_step(self):
# Call step kernel once per step # Call step kernel once per step
...@@ -398,12 +473,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -398,12 +473,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
for block_id in range(self._num_blocks): for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks): for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait() self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel( self.__launch_step_kernel()
self._fp32_p,
self._fp16_p,
self._fp32_m,
self._fp32_v,
self._fp16_g)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True) torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale): def _flatten_grad_mt(self, scale):
...@@ -429,8 +499,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -429,8 +499,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
while flush_block: while flush_block:
block_id = flush_block[0] // self._block_size block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id) self._pipeline_block_reductions(block_id)
if self._full_pipeline:
self._pipeline_block_step(block_id)
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
def set_global_scale(self, global_scale): def set_global_scale(self, global_scale):
...@@ -484,7 +552,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -484,7 +552,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def complete_reductions(self): def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed. """Complete reductions if full pipeline is not selected or overlap is not allowed.
""" """
self._init_everything()
if self._last_step: if self._last_step:
# zero out gradients that have not been completed yet # zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated): for param_i, grad_generated in enumerate(self._grads_generated):
...@@ -506,53 +574,19 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -506,53 +574,19 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._current_block = self._num_blocks self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info) self._grads_generated = [False]*len(self._grads_info)
def revert_step(self): def step(self, closure=None):
"""Revert effect of previously calling partial_step.
"""
# Call undo kernel once per step
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p,
self._fp32_m,
self._fp32_v,
self._fp16_g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def step(self, closure=None, skip_overflow_check=False):
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
if self._last_step or not self._overlap_reductions or not self._full_pipeline: self._pipeline_step()
self._pipeline_step()
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
# Check for overflow # Copy self._new_params to model params
# Store state for loss scaler calculation multi_tensor_applier(
has_overflow = False if skip_overflow_check else self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size) fused_adam_cuda.maybe_cast_mt,
if has_overflow: self._overflow_buf,
self.revert_step() self._packed_flat_to_model_params)
else:
# Copy self._new_params to model params
for p in self._model_params: self.state[p]['step'] += 1
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params)
torch.cuda.current_stream().wait_stream(self._completion_st) torch.cuda.current_stream().wait_stream(self._completion_st)
...@@ -561,4 +595,42 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -561,4 +595,42 @@ class DistributedFusedAdam(torch.optim.Optimizer):
return loss return loss
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
# save step, master weights and first/second moments
state_dict = {}
state_dict['step'] = self._step
state_dict['fp32_p'] = self._fp32_p
state_dict['fp32_m'] = self._fp32_m
state_dict['fp32_v'] = self._fp32_v
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``optimizer.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# restore step, master weights and first/second moments
self._step = state_dict['step']
self._fp32_p = state_dict['fp32_p'].to(device="cuda")
self._fp32_m = state_dict['fp32_m'].to(device="cuda")
self._fp32_v = state_dict['fp32_v'].to(device="cuda")
self._resume_from_checkpoint = True
...@@ -123,6 +123,25 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): ...@@ -123,6 +123,25 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5'] version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--distributed_adam" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--distributed_adam")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='distributed_adam_cuda',
sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp',
'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
if "--distributed_lamb" in sys.argv: if "--distributed_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--distributed_lamb") sys.argv.remove("--distributed_lamb")
......
import argparse
import random
import sys
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from apex import amp
from apex.optimizers import FusedAdam
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
class TestModel(torch.nn.Module):
def __init__(self, args):
super(TestModel, self).__init__()
self.linear = torch.nn.Sequential(*[torch.nn.Linear(args.dim, args.dim, bias=args.bias) for _ in range(args.layers)])
def forward(self, x):
return self.linear(x)
def setup(args):
## Model
ref_model = TestModel(args).cuda()
dist_model = TestModel(args).cuda()
# Same weights
with torch.no_grad():
for dp, rp in zip(dist_model.parameters(), ref_model.parameters()):
dp.data.copy_(rp.data)
dist_model = dist_model.half()
## Optimizer
# same hyperparameters
ref_opt_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
ref_opt = FusedAdam(ref_model.parameters(), **ref_opt_args)
dist_opt_args = ref_opt_args.copy()
dist_opt_args.update( {'overlap_reductions' : False} )
dist_opt_args.update( {'process_group_size' : args.n_gpu} )
dist_opt_args.update( {'dwu_group_size' : args.dwu_group_size} )
dist_opt_args.update( {'dwu_num_blocks' : 1} )
dist_opt_args.update( {'dwu_num_chunks' : 1} )
dist_opt = DistributedFusedAdam(dist_model.parameters(), **dist_opt_args)
dist_opt.set_global_scale(1.)
## amp-init
amp_args = { 'loss_scale' : 'dynamic' , 'opt_level' : 'O2'}
ref_model, ref_opt = amp.initialize(ref_model, ref_opt, **amp_args)
## DDP
ref_model = DDP(ref_model, device_ids=[args.rank])
with torch.no_grad():
for dp in dist_model.parameters():
torch.distributed.broadcast(dp.data, src=0)
for rp in ref_model.parameters():
torch.distributed.broadcast(rp.data, src=0)
torch.cuda.synchronize()
torch.distributed.barrier()
if get_rank() == 0:
print(f'dist opt with {args.n_gpu} GPUs')
return ref_model, ref_opt, dist_model, dist_opt
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--steps', type=int, default=20)
parser.add_argument('--batch', type=int, default=32)
parser.add_argument('--dim', type=int, default=4)
parser.add_argument('--layers', type=int, default=2)
parser.add_argument('--bias', action='store_true')
parser.add_argument('--atol', type=float, default=1e-3)
parser.add_argument('--rtol', type=float, default=1)
parser.add_argument('--dwu_group_size', type=float, default=1)
args = parser.parse_args()
return args
def setup_env(args):
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.rank = torch.distributed.get_rank()
args.n_gpu = torch.distributed.get_world_size()
seed = 42 + get_rank()
random.seed(seed)
torch.manual_seed(seed)
return args
def get_rank():
return torch.distributed.get_rank()
def main():
args = parse_args()
args = setup_env(args)
tol_args = { 'atol' : args.atol, 'rtol' : args.rtol }
torch.set_printoptions(precision=16)
ref_model, ref_opt, dist_model, dist_opt = setup(args)
# lazy_init not called yet, initialize stash
stash = ref_opt._amp_stash
stash.all_fp16_params, stash.all_fp32_from_fp16_params = [], []
# make sure everything from _first_step_init_ is ready before training
# e.g. registering allreduce_hook
# so that gradients are copied/reduced when necessary
dist_opt._init_everything()
for i in range(args.steps):
x_ref = torch.randn(args.batch, args.dim, dtype=torch.half).cuda().requires_grad_(True)
x_dist = x_ref.clone().detach().requires_grad_(True)
if get_rank() == 0:
print(f'[{i}] Checking input')
#print("x_ref:", x_ref.flatten()[:10])
#print("x_dist:", x_dist.flatten()[:10])
assert(torch.allclose(x_ref, x_dist, **tol_args))
y_ref = ref_model(x_ref).half()
y_dist = dist_model(x_dist)
if get_rank() == 0:
print(f'[{i}] Checking output')
#print("y_ref:", y_ref.flatten()[:10])
#print("y_dist:", y_dist.flatten()[:10])
assert(torch.allclose(y_ref, y_dist, **tol_args))
dy = torch.randn_like(y_ref)
y_ref.backward(dy)
y_dist.backward(dy)
if get_rank() == 0:
print(f'[{i}] Checking gradients')
torch.distributed.barrier()
torch.cuda.synchronize()
assert(torch.allclose(x_ref.grad, x_dist.grad, **tol_args))
# gradient all-reduce within distributed optimizer
dist_opt.complete_reductions()
if get_rank() == 0:
print(f'[{i}] Stepping')
ref_opt.step()
dist_opt.step()
torch.cuda.synchronize()
torch.distributed.barrier()
print('Checking new weights')
if get_rank() == 0:
print("ref param:", ref_model.module.linear[0].weight)
print("dist param:", dist_model.linear[0].weight)
for i, (rp, dp) in enumerate(zip(ref_model.parameters(), dist_model.parameters())):
if not torch.allclose(rp, dp, **tol_args):
if get_rank() == 0:
print(f'Rank: {get_rank()}, Param: {i}')
print(f'ref: {rp.sum().item()}, dist: {dp.sum().item()}')
print(rp)
print(dp)
print(torch.abs(rp-dp) > tol_args['atol'])
sys.exit(0)
# zero grads
for rp, dp in zip(ref_model.parameters(), dist_model.parameters()):
rp.grad = None
dp.grad = None
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment