Unverified Commit 6b7e77b0 authored by Kexin Yu's avatar Kexin Yu Committed by GitHub
Browse files

DistributedFusedAdam Model Parallelism Support (Megatron) (#981)



DistributedFusedAdam Model Parallelism Support (Megatron)
Co-authored-by: default avatarKexin Yu <kexiny@nvidia.com>
Co-authored-by: default avatarKexin Yu <kexinznzn›@gmail.com>
parent 8a1ed9e8
#include <torch/extension.h>
void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda,
"Multi tensor Adam optimized CUDA implementation.");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <THC/THCGeneral.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cmath>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <int DEPTH, typename T, typename GRAD_T>
struct DistAdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float* per_tensor_beta1,
const float* per_tensor_beta2,
const int* per_tensor_bias_correction,
const float* per_tensor_eps,
const float* per_tensor_weight_decay,
const float lr,
const float grad_scale,
const int step,
adamMode_t mode)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float b1 = per_tensor_beta1[tensor_num];
float b2 = per_tensor_beta2[tensor_num];
float eps = per_tensor_eps[tensor_num];
float decay = per_tensor_weight_decay[tensor_num];
float beta1_correction = 1.0f, beta2_correction = 1.0f;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - std::pow(b1, step);
beta2_correction = 1 - std::pow(b2, step);
}
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v) &&
is_aligned(g) &&
is_aligned(p_copy)) {
for (int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) {
// load
GRAD_T tmp_g[ILP];
load_store(incoming_p, p, 0, i_start);
load_store(incoming_m, m, 0, i_start);
load_store(incoming_v, v, 0, i_start);
load_store(tmp_g, g, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_g[ii] = static_cast<T>(tmp_g[ii]);
T scaled_grad = incoming_g[ii]/grad_scale;
incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
T next_m_unbiased = incoming_m[ii] / beta1_correction;
T next_v_unbiased = incoming_v[ii] / beta2_correction;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(next_v_unbiased + eps);
else // Mode 1
denom = sqrtf(next_v_unbiased) + eps;
float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);
incoming_p[ii] = incoming_p[ii] - (lr * update);
if (DEPTH == 5) tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);
}
load_store(p, incoming_p, i_start, 0);
load_store(m, incoming_m, i_start, 0);
load_store(v, incoming_v, i_start, 0);
if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);
}
} else {
for (int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if (j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
T next_m_unbiased = m[j] / beta1_correction;
T next_v_unbiased = v[j] / beta2_correction;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(next_v_unbiased + eps);
else // Mode 1
denom = sqrtf(next_v_unbiased) + eps;
float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);
p[j] = incoming_p[ii] - (lr * update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
}
};
void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode)
{
using namespace at;
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<5, accscalar_t, scalar_t_0>(),
per_tensor_beta1.DATA_PTR<float>(),
per_tensor_beta2.DATA_PTR<float>(),
per_tensor_bias_correction.DATA_PTR<int>(),
per_tensor_eps.DATA_PTR<float>(),
per_tensor_weight_decay.DATA_PTR<float>(),
lr,
grad_scale,
step,
(adamMode_t) mode);
);
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<4, accscalar_t, scalar_t_0>(),
per_tensor_beta1.DATA_PTR<float>(),
per_tensor_beta2.DATA_PTR<float>(),
per_tensor_bias_correction.DATA_PTR<int>(),
per_tensor_eps.DATA_PTR<float>(),
per_tensor_weight_decay.DATA_PTR<float>(),
lr,
grad_scale,
step,
(adamMode_t) mode);
);
}
THCudaCheck(cudaGetLastError());
}
...@@ -123,6 +123,25 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): ...@@ -123,6 +123,25 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5'] version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--distributed_adam" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--distributed_adam")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='distributed_adam_cuda',
sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp',
'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
if "--distributed_lamb" in sys.argv: if "--distributed_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--distributed_lamb") sys.argv.remove("--distributed_lamb")
......
import argparse
import random
import sys
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from apex import amp
from apex.optimizers import FusedAdam
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
class TestModel(torch.nn.Module):
def __init__(self, args):
super(TestModel, self).__init__()
self.linear = torch.nn.Sequential(*[torch.nn.Linear(args.dim, args.dim, bias=args.bias) for _ in range(args.layers)])
def forward(self, x):
return self.linear(x)
def setup(args):
## Model
ref_model = TestModel(args).cuda()
dist_model = TestModel(args).cuda()
# Same weights
with torch.no_grad():
for dp, rp in zip(dist_model.parameters(), ref_model.parameters()):
dp.data.copy_(rp.data)
dist_model = dist_model.half()
## Optimizer
# same hyperparameters
ref_opt_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
ref_opt = FusedAdam(ref_model.parameters(), **ref_opt_args)
dist_opt_args = ref_opt_args.copy()
dist_opt_args.update( {'overlap_reductions' : False} )
dist_opt_args.update( {'process_group_size' : args.n_gpu} )
dist_opt_args.update( {'dwu_group_size' : args.dwu_group_size} )
dist_opt_args.update( {'dwu_num_blocks' : 1} )
dist_opt_args.update( {'dwu_num_chunks' : 1} )
dist_opt = DistributedFusedAdam(dist_model.parameters(), **dist_opt_args)
dist_opt.set_global_scale(1.)
## amp-init
amp_args = { 'loss_scale' : 'dynamic' , 'opt_level' : 'O2'}
ref_model, ref_opt = amp.initialize(ref_model, ref_opt, **amp_args)
## DDP
ref_model = DDP(ref_model, device_ids=[args.rank])
with torch.no_grad():
for dp in dist_model.parameters():
torch.distributed.broadcast(dp.data, src=0)
for rp in ref_model.parameters():
torch.distributed.broadcast(rp.data, src=0)
torch.cuda.synchronize()
torch.distributed.barrier()
if get_rank() == 0:
print(f'dist opt with {args.n_gpu} GPUs')
return ref_model, ref_opt, dist_model, dist_opt
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--steps', type=int, default=20)
parser.add_argument('--batch', type=int, default=32)
parser.add_argument('--dim', type=int, default=4)
parser.add_argument('--layers', type=int, default=2)
parser.add_argument('--bias', action='store_true')
parser.add_argument('--atol', type=float, default=1e-3)
parser.add_argument('--rtol', type=float, default=1)
parser.add_argument('--dwu_group_size', type=float, default=1)
args = parser.parse_args()
return args
def setup_env(args):
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.rank = torch.distributed.get_rank()
args.n_gpu = torch.distributed.get_world_size()
seed = 42 + get_rank()
random.seed(seed)
torch.manual_seed(seed)
return args
def get_rank():
return torch.distributed.get_rank()
def main():
args = parse_args()
args = setup_env(args)
tol_args = { 'atol' : args.atol, 'rtol' : args.rtol }
torch.set_printoptions(precision=16)
ref_model, ref_opt, dist_model, dist_opt = setup(args)
# lazy_init not called yet, initialize stash
stash = ref_opt._amp_stash
stash.all_fp16_params, stash.all_fp32_from_fp16_params = [], []
# make sure everything from _first_step_init_ is ready before training
# e.g. registering allreduce_hook
# so that gradients are copied/reduced when necessary
dist_opt._init_everything()
for i in range(args.steps):
x_ref = torch.randn(args.batch, args.dim, dtype=torch.half).cuda().requires_grad_(True)
x_dist = x_ref.clone().detach().requires_grad_(True)
if get_rank() == 0:
print(f'[{i}] Checking input')
#print("x_ref:", x_ref.flatten()[:10])
#print("x_dist:", x_dist.flatten()[:10])
assert(torch.allclose(x_ref, x_dist, **tol_args))
y_ref = ref_model(x_ref).half()
y_dist = dist_model(x_dist)
if get_rank() == 0:
print(f'[{i}] Checking output')
#print("y_ref:", y_ref.flatten()[:10])
#print("y_dist:", y_dist.flatten()[:10])
assert(torch.allclose(y_ref, y_dist, **tol_args))
dy = torch.randn_like(y_ref)
y_ref.backward(dy)
y_dist.backward(dy)
if get_rank() == 0:
print(f'[{i}] Checking gradients')
torch.distributed.barrier()
torch.cuda.synchronize()
assert(torch.allclose(x_ref.grad, x_dist.grad, **tol_args))
# gradient all-reduce within distributed optimizer
dist_opt.complete_reductions()
if get_rank() == 0:
print(f'[{i}] Stepping')
ref_opt.step()
dist_opt.step()
torch.cuda.synchronize()
torch.distributed.barrier()
print('Checking new weights')
if get_rank() == 0:
print("ref param:", ref_model.module.linear[0].weight)
print("dist param:", dist_model.linear[0].weight)
for i, (rp, dp) in enumerate(zip(ref_model.parameters(), dist_model.parameters())):
if not torch.allclose(rp, dp, **tol_args):
if get_rank() == 0:
print(f'Rank: {get_rank()}, Param: {i}')
print(f'ref: {rp.sum().item()}, dist: {dp.sum().item()}')
print(rp)
print(dp)
print(torch.abs(rp-dp) > tol_args['atol'])
sys.exit(0)
# zero grads
for rp, dp in zip(ref_model.parameters(), dist_model.parameters()):
rp.grad = None
dp.grad = None
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment