Commit 6f7a8b39 authored by lcskrishna's avatar lcskrishna
Browse files

Merge remote-tracking branch 'rocm_upstream/master' into ifu_07272020

parents 459de22d 9c80f6d3
...@@ -91,7 +91,7 @@ class FusedAdagrad(torch.optim.Optimizer): ...@@ -91,7 +91,7 @@ class FusedAdagrad(torch.optim.Optimizer):
if len(state) == 0: if len(state) == 0:
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['sum'] = torch.zeros_like(p.data) state['sum'] = torch.zeros_like(p.data)
if p.dtype == torch.float16: if p.dtype in {torch.float16, torch.bfloat16}:
g_16.append(p.grad.data) g_16.append(p.grad.data)
p_16.append(p.data) p_16.append(p.data)
h_16.append(state['sum']) h_16.append(state['sum'])
...@@ -100,7 +100,7 @@ class FusedAdagrad(torch.optim.Optimizer): ...@@ -100,7 +100,7 @@ class FusedAdagrad(torch.optim.Optimizer):
p_32.append(p.data) p_32.append(p.data)
h_32.append(state['sum']) h_32.append(state['sum'])
else: else:
raise RuntimeError('FusedAdagrad only support fp16 and fp32.') raise RuntimeError('FusedAdagrad only support fp16, bfloat16 and fp32.')
if(len(g_16) > 0): if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_adagrad, multi_tensor_applier(self.multi_tensor_adagrad,
......
...@@ -130,7 +130,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -130,7 +130,7 @@ class FusedAdam(torch.optim.Optimizer):
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p.data)
if p.dtype == torch.float16: if p.dtype in {torch.float16, torch.bfloat16}:
g_16.append(p.grad.data) g_16.append(p.grad.data)
p_16.append(p.data) p_16.append(p.data)
m_16.append(state['exp_avg']) m_16.append(state['exp_avg'])
...@@ -141,7 +141,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -141,7 +141,7 @@ class FusedAdam(torch.optim.Optimizer):
m_32.append(state['exp_avg']) m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq']) v_32.append(state['exp_avg_sq'])
else: else:
raise RuntimeError('FusedAdam only support fp16 and fp32.') raise RuntimeError('FusedAdam only support fp16, bfloat16 and fp32.')
if(len(g_16) > 0): if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_adam, multi_tensor_applier(self.multi_tensor_adam,
......
...@@ -165,7 +165,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -165,7 +165,7 @@ class FusedLAMB(torch.optim.Optimizer):
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p.data)
if p.dtype == torch.float16: if p.dtype in {torch.float16, torch.bfloat16}:
g_16.append(p.grad.data) g_16.append(p.grad.data)
p_16.append(p.data) p_16.append(p.data)
m_16.append(state['exp_avg']) m_16.append(state['exp_avg'])
...@@ -176,7 +176,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -176,7 +176,7 @@ class FusedLAMB(torch.optim.Optimizer):
m_32.append(state['exp_avg']) m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq']) v_32.append(state['exp_avg_sq'])
else: else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.') raise RuntimeError('FusedLAMB only support fp16, bfloat16 and fp32.')
if(len(g_16) > 0): if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_lamb, multi_tensor_applier(self.multi_tensor_lamb,
......
...@@ -142,7 +142,7 @@ class FusedNovoGrad(torch.optim.Optimizer): ...@@ -142,7 +142,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg'] = torch.zeros_like(p.data)
if p.dtype == torch.float16: if p.dtype in {torch.float16, torch.bfloat16}:
g_16.append(p.grad.data) g_16.append(p.grad.data)
p_16.append(p.data) p_16.append(p.data)
m_16.append(state['exp_avg']) m_16.append(state['exp_avg'])
...@@ -151,7 +151,7 @@ class FusedNovoGrad(torch.optim.Optimizer): ...@@ -151,7 +151,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
p_32.append(p.data) p_32.append(p.data)
m_32.append(state['exp_avg']) m_32.append(state['exp_avg'])
else: else:
raise RuntimeError('FusedNovoGrad only support fp16 and fp32.') raise RuntimeError('FusedNovoGrad only support fp16, bfloat16 and fp32.')
# we store per weight norm as one tensor for one group/precision combination # we store per weight norm as one tensor for one group/precision combination
# different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types # different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types
......
...@@ -48,8 +48,8 @@ def apply_flat_dist_call(bucket, call, extra_args=None): ...@@ -48,8 +48,8 @@ def apply_flat_dist_call(bucket, call, extra_args=None):
for buf, synced in zip(bucket, unflatten(coalesced, bucket)): for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
buf.copy_(synced) buf.copy_(synced)
def split_half_float_double(tensors): def split_half_float_double_bfloat16(tensors):
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor"] dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
buckets = [] buckets = []
for i, dtype in enumerate(dtypes): for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype] bucket = [t for t in tensors if t.type() == dtype]
...@@ -240,7 +240,8 @@ class DistributedDataParallel(Module): ...@@ -240,7 +240,8 @@ class DistributedDataParallel(Module):
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0, self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1, "torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2} "torch.cuda.DoubleTensor" : 2,
"torch.cuda.BFloat16Tensor" : 3}
if multi_tensor_applier.available: if multi_tensor_applier.available:
# TODO: I really need to centralize the C++ backed imports # TODO: I really need to centralize the C++ backed imports
...@@ -498,7 +499,7 @@ class DistributedDataParallel(Module): ...@@ -498,7 +499,7 @@ class DistributedDataParallel(Module):
else: else:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
split_buckets = split_half_float_double(grads) split_buckets = split_half_float_double_bfloat16(grads)
# If retain_allreduce_buffers is True and delay_allreduce is False, # If retain_allreduce_buffers is True and delay_allreduce is False,
# this will only be done during the first backward pass, ignored by the # this will only be done during the first backward pass, ignored by the
...@@ -578,8 +579,8 @@ class DistributedDataParallel(Module): ...@@ -578,8 +579,8 @@ class DistributedDataParallel(Module):
if self.needs_refresh: if self.needs_refresh:
self.active_i_buckets = [] self.active_i_buckets = []
self.buckets = [] self.buckets = []
self.tmp_buckets = [[], [], []] # [running half, float, double buckets] self.tmp_buckets = [[], [], [], []] # [running half, float, double, bfloat16 buckets]
self.tmp_numels = [0, 0, 0] self.tmp_numels = [0, 0, 0, 0]
self.bucket_sizes = [] self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)} self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {} self.param_id_to_bucket = {}
......
'''
This file contains common utility functions for running the unit tests on ROCM.
'''
import torch
import os
import sys
from functools import wraps
import unittest
TEST_WITH_ROCM = os.getenv('APEX_TEST_WITH_ROCM', '0') == '1'
## Wrapper to skip the unit tests.
def skipIfRocm(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if TEST_WITH_ROCM:
raise unittest.SkipTest("test doesn't currently work on ROCm stack.")
else:
fn(*args, **kwargs)
return wrapper
...@@ -130,7 +130,8 @@ std::vector<at::Tensor> layer_norm( ...@@ -130,7 +130,8 @@ std::vector<at::Tensor> layer_norm(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,n1,n2); check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor mean = at::empty({n1}, input.options().dtype((input.scalar_type()==at::ScalarType::Half ||
input.scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon); normalized_shape,NULL,NULL,epsilon);
...@@ -152,7 +153,8 @@ std::vector<at::Tensor> layer_norm_affine( ...@@ -152,7 +153,8 @@ std::vector<at::Tensor> layer_norm_affine(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor mean = at::empty({n1}, input.options().dtype((input.scalar_type()==at::ScalarType::Half ||
input.scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon); normalized_shape,&gamma,&beta,epsilon);
......
...@@ -172,8 +172,8 @@ void cuWelfordMuSigma2( ...@@ -172,8 +172,8 @@ void cuWelfordMuSigma2(
for (; l+7 < n2; l+=8*numx) { for (; l+7 < n2; l+=8*numx) {
for (int k = 0; k < 8; k+=2) { for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k))); float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count); cuWelfordOnlineSum<float>(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count); cuWelfordOnlineSum<float>(curr.y,mu,sigma2,count);
} }
} }
for (; l < n2; ++l) { for (; l < n2; ++l) {
...@@ -230,9 +230,15 @@ void cuWelfordMuSigma2( ...@@ -230,9 +230,15 @@ void cuWelfordMuSigma2(
template<typename U> U rsqrt(U v) { template<typename U> U rsqrt(U v) {
return U(1) / sqrt(v); return U(1) / sqrt(v);
} }
#if defined __HIP_PLATFORM_HCC__
__device__ float rsqrt(float v) {
return rsqrtf(v);
}
#else
template<> float rsqrt(float v) { template<> float rsqrt(float v) {
return rsqrtf(v); return rsqrtf(v);
} }
#endif
template<> double rsqrt(double v) { template<> double rsqrt(double v) {
return rsqrt(v); return rsqrt(v);
} }
...@@ -293,7 +299,7 @@ void cuApplyLayerNorm( ...@@ -293,7 +299,7 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
// 2) Tensors are contiguous // 2) Tensors are contiguous
// //
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared; SharedMemory<U> shared;
U* buf = shared.getPointer(); U* buf = shared.getPointer();
U mu,sigma2; U mu,sigma2;
...@@ -531,7 +537,7 @@ void cuComputeGradInput( ...@@ -531,7 +537,7 @@ void cuComputeGradInput(
const T* gamma, const T* gamma,
T* grad_input) T* grad_input)
{ {
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0); U sum_loss1 = U(0);
U sum_loss2 = U(0); U sum_loss2 = U(0);
const U c_mean = mean[i1]; const U c_mean = mean[i1];
...@@ -684,7 +690,7 @@ void cuda_layer_norm( ...@@ -684,7 +690,7 @@ void cuda_layer_norm(
double epsilon) double epsilon)
{ {
using namespace at; using namespace at;
DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel", DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(input->scalar_type(), 0, "layer_norm_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
HostApplyLayerNorm( HostApplyLayerNorm(
output->DATA_PTR<scalar_t_0>(), output->DATA_PTR<scalar_t_0>(),
...@@ -724,7 +730,8 @@ void HostLayerNormGradient( ...@@ -724,7 +730,8 @@ void HostLayerNormGradient(
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type())); at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype((input->scalar_type()==at::ScalarType::Half ||
input->scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input->scalar_type()));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, dout,
...@@ -787,7 +794,7 @@ void cuda_layer_norm_gradient( ...@@ -787,7 +794,7 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta) at::Tensor* grad_beta)
{ {
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(input->scalar_type(), 0, "cuComputeGradInput",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
HostLayerNormGradient( HostLayerNormGradient(
dout->DATA_PTR<scalar_t_0>(), dout->DATA_PTR<scalar_t_0>(),
......
...@@ -23,20 +23,20 @@ using MATH_T = float; ...@@ -23,20 +23,20 @@ using MATH_T = float;
template <typename T> struct AdagradFunctor { template <typename T> struct AdagradFunctor {
__device__ __forceinline__ void __device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl, operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> *tl,
const float epsilon, const float lr, adagradMode_t mode, const float epsilon, const float lr, adagradMode_t mode,
const float weight_decay) { const float weight_decay) {
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl->sizes[tensor_loc];
T *g = (T *)tl.addresses[0][tensor_loc]; T *g = (T *)tl->addresses[0][tensor_loc];
g += chunk_idx * chunk_size; g += chunk_idx * chunk_size;
T *p = (T *)tl.addresses[1][tensor_loc]; T *p = (T *)tl->addresses[1][tensor_loc];
p += chunk_idx * chunk_size; p += chunk_idx * chunk_size;
T *h = (T *)tl.addresses[2][tensor_loc]; T *h = (T *)tl->addresses[2][tensor_loc];
h += chunk_idx * chunk_size; h += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size; n -= chunk_idx * chunk_size;
...@@ -90,7 +90,7 @@ void multi_tensor_adagrad_cuda( ...@@ -90,7 +90,7 @@ void multi_tensor_adagrad_cuda(
using namespace at; using namespace at;
// Assume single type across p,g,h now // Assume single type across p,g,h now
DISPATCH_DOUBLE_FLOAT_AND_HALF( DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "adagrad", tensor_lists[0][0].scalar_type(), 0, "adagrad",
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdagradFunctor<scalar_t_0>(), epsilon, lr, AdagradFunctor<scalar_t_0>(), epsilon, lr,
......
...@@ -26,7 +26,7 @@ struct AdamFunctor ...@@ -26,7 +26,7 @@ struct AdamFunctor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<4>& tl, TensorListMetadata<4>* tl,
const float beta1, const float beta1,
const float beta2, const float beta2,
const float beta1_correction, const float beta1_correction,
...@@ -40,24 +40,24 @@ struct AdamFunctor ...@@ -40,24 +40,24 @@ struct AdamFunctor
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl->block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar // potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc; // int tensor_num = tl->start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl->sizes[tensor_loc];
T* g = (T*)tl.addresses[0][tensor_loc]; T* g = (T*)tl->addresses[0][tensor_loc];
g += chunk_idx*chunk_size; g += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc]; T* p = (T*)tl->addresses[1][tensor_loc];
p += chunk_idx*chunk_size; p += chunk_idx*chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc]; T* m = (T*)tl->addresses[2][tensor_loc];
m += chunk_idx*chunk_size; m += chunk_idx*chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc]; T* v = (T*)tl->addresses[3][tensor_loc];
v += chunk_idx*chunk_size; v += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda( ...@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda(
} }
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF( DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "adam", tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <THC/THC.h>
#include "compat.h" #include "compat.h"
#include <assert.h> #include <assert.h>
...@@ -29,7 +30,7 @@ template<typename T, typename U, typename... ArgTypes> ...@@ -29,7 +30,7 @@ template<typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel( __global__ void multi_tensor_apply_kernel(
int chunk_size, int chunk_size,
volatile int* noop_flag, volatile int* noop_flag,
T tl, T* tl,
U callable, U callable,
ArgTypes... args) ArgTypes... args)
{ {
...@@ -56,7 +57,7 @@ void multi_tensor_apply( ...@@ -56,7 +57,7 @@ void multi_tensor_apply(
for(int t = 0; t < tensor_lists[l].size(); t++) for(int t = 0; t < tensor_lists[l].size(); t++)
{ {
// TODO: Print which tensor fails. // TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous(); bool contiguous_memory = (tensor_lists[l][t].is_sparse()) ? tensor_lists[l][t]._values().is_contiguous() : tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5 #ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif #endif
...@@ -78,8 +79,15 @@ void multi_tensor_apply( ...@@ -78,8 +79,15 @@ void multi_tensor_apply(
for(int t = 0; t < ntensors; t++) for(int t = 0; t < ntensors; t++)
{ {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for(int d = 0; d < depth; d++) for(int d = 0; d < depth; d++) {
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); if (tensor_lists[d][t].is_sparse()) {
at::Tensor dst = at::zeros(tensor_lists[d][t].sizes(), tensor_lists[d][t].options().layout(at::kStrided));
dst.add_(tensor_lists[d][t]);
tl.addresses[d][loc_tensor_info] = dst.data_ptr();
} else {
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
}
}
loc_tensor_info++; loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size; int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
...@@ -97,11 +105,15 @@ void multi_tensor_apply( ...@@ -97,11 +105,15 @@ void multi_tensor_apply(
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if(tensors_full || blocks_full || last_chunk) if(tensors_full || blocks_full || last_chunk)
{ {
auto storage = at::empty(sizeof(tl), c10::TensorOptions(at::kStrided).dtype(at::kByte).device(at::kCPU).pinned_memory(true));
auto tl_as_host_pinned_ptr = static_cast<decltype(tl)*>(storage.data_ptr());
memcpy(tl_as_host_pinned_ptr, &tl, sizeof(tl));
AT_CUDA_CHECK(THCCachingHostAllocator_recordEvent(tl_as_host_pinned_ptr, stream));
// using accscalar_t = acc_type<scalar_t, true>; // using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>( multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, chunk_size,
noop_flag.DATA_PTR<int>(), noop_flag.DATA_PTR<int>(),
tl, tl_as_host_pinned_ptr,
callable, callable,
args...); args...);
......
...@@ -30,7 +30,7 @@ struct AxpbyFunctor ...@@ -30,7 +30,7 @@ struct AxpbyFunctor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<3>& tl, TensorListMetadata<3>* tl,
float a, float a,
float b, float b,
int arg_to_check) int arg_to_check)
...@@ -39,17 +39,17 @@ struct AxpbyFunctor ...@@ -39,17 +39,17 @@ struct AxpbyFunctor
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl->sizes[tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc]; x_t* x = (x_t*)tl->addresses[0][tensor_loc];
x += chunk_idx*chunk_size; x += chunk_idx*chunk_size;
y_t* y = (y_t*)tl.addresses[1][tensor_loc]; y_t* y = (y_t*)tl->addresses[1][tensor_loc];
y += chunk_idx*chunk_size; y += chunk_idx*chunk_size;
out_t* out = (out_t*)tl.addresses[2][tensor_loc]; out_t* out = (out_t*)tl->addresses[2][tensor_loc];
out += chunk_idx*chunk_size; out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda( ...@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda(
// If build times suffer, think about where to put this dispatch, // If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply. // and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda",
multi_tensor_apply<3>( multi_tensor_apply<3>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
......
...@@ -30,7 +30,7 @@ struct L2NormFunctor ...@@ -30,7 +30,7 @@ struct L2NormFunctor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<1>& tl, TensorListMetadata<1>* tl,
float* output, float* output,
float* output_per_tensor, float* output_per_tensor,
bool per_tensor, bool per_tensor,
...@@ -40,11 +40,11 @@ struct L2NormFunctor ...@@ -40,11 +40,11 @@ struct L2NormFunctor
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl->sizes[tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc]; x_t* x = (x_t*)tl->addresses[0][tensor_loc];
x += chunk_idx*chunk_size; x += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -103,7 +103,7 @@ struct L2NormFunctor ...@@ -103,7 +103,7 @@ struct L2NormFunctor
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final; output[blockIdx.x] += final;
if(per_tensor) if(per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final; output_per_tensor[(tl->start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
} }
} }
}; };
...@@ -115,7 +115,7 @@ struct MaxNormFunctor ...@@ -115,7 +115,7 @@ struct MaxNormFunctor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<1>& tl, TensorListMetadata<1>* tl,
float* output, float* output,
float* output_per_tensor, float* output_per_tensor,
bool per_tensor, bool per_tensor,
...@@ -125,11 +125,11 @@ struct MaxNormFunctor ...@@ -125,11 +125,11 @@ struct MaxNormFunctor
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl->sizes[tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc]; x_t* x = (x_t*)tl->addresses[0][tensor_loc];
x += chunk_idx*chunk_size; x += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -188,13 +188,17 @@ struct MaxNormFunctor ...@@ -188,13 +188,17 @@ struct MaxNormFunctor
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final)); output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
if(per_tensor) if(per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final; output_per_tensor[(tl->start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
} }
} }
}; };
__global__ void cleanup( __global__ void
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(1024)
#endif
cleanup(
float* output, float* output,
float* output_per_tensor, float* output_per_tensor,
float* ret, float* ret,
...@@ -231,7 +235,11 @@ __global__ void cleanup( ...@@ -231,7 +235,11 @@ __global__ void cleanup(
} }
} }
__global__ void cleanup_v2( __global__ void
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(1024)
#endif
cleanup_v2(
float* output, float* output,
float* output_per_tensor, float* output_per_tensor,
float* ret, float* ret,
...@@ -322,7 +330,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( ...@@ -322,7 +330,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
ret_per_tensor = at::empty({0}, float_options); ret_per_tensor = at::empty({0}, float_options);
} }
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>( multi_tensor_apply<1>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
...@@ -391,7 +399,7 @@ void multi_tensor_norm_out_cuda( ...@@ -391,7 +399,7 @@ void multi_tensor_norm_out_cuda(
output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options); output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
if (norm_type == 0) { if (norm_type == 0) {
DISPATCH_FLOAT_AND_HALF( DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
multi_tensor_apply<1>( multi_tensor_apply<1>(
BLOCK_SIZE, BLOCK_SIZE,
...@@ -405,7 +413,7 @@ void multi_tensor_norm_out_cuda( ...@@ -405,7 +413,7 @@ void multi_tensor_norm_out_cuda(
max_chunks_per_tensor);) max_chunks_per_tensor);)
} }
else { else {
DISPATCH_FLOAT_AND_HALF( DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>( multi_tensor_apply<1>(
BLOCK_SIZE, BLOCK_SIZE,
......
...@@ -43,7 +43,7 @@ struct LAMBStage1Functor ...@@ -43,7 +43,7 @@ struct LAMBStage1Functor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<4>& tl, TensorListMetadata<4>* tl,
const float beta1, const float beta1,
const float beta2, const float beta2,
const float beta3, const float beta3,
...@@ -59,22 +59,22 @@ struct LAMBStage1Functor ...@@ -59,22 +59,22 @@ struct LAMBStage1Functor
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl->sizes[tensor_loc];
float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f; float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;
T* g = (T*)tl.addresses[0][tensor_loc]; T* g = (T*)tl->addresses[0][tensor_loc];
g += chunk_idx*chunk_size; g += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc]; T* p = (T*)tl->addresses[1][tensor_loc];
p += chunk_idx*chunk_size; p += chunk_idx*chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc]; T* m = (T*)tl->addresses[2][tensor_loc];
m += chunk_idx*chunk_size; m += chunk_idx*chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc]; T* v = (T*)tl->addresses[3][tensor_loc];
v += chunk_idx*chunk_size; v += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -236,7 +236,7 @@ struct LAMBStage2Functor ...@@ -236,7 +236,7 @@ struct LAMBStage2Functor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<2>& tl, TensorListMetadata<2>* tl,
const float* per_tensor_param_norm, const float* per_tensor_param_norm,
const float* per_tensor_update_norm, const float* per_tensor_update_norm,
const float learning_rate, const float learning_rate,
...@@ -247,10 +247,10 @@ struct LAMBStage2Functor ...@@ -247,10 +247,10 @@ struct LAMBStage2Functor
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl->block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc; int tensor_num = tl->start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl->sizes[tensor_loc];
MATH_T ratio = learning_rate; MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters // nvlamb: apply adaptive learning rate to all parameters
...@@ -262,10 +262,10 @@ struct LAMBStage2Functor ...@@ -262,10 +262,10 @@ struct LAMBStage2Functor
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate; ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
} }
T* update = (T*)tl.addresses[0][tensor_loc]; T* update = (T*)tl->addresses[0][tensor_loc];
update += chunk_idx*chunk_size; update += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc]; T* p = (T*)tl->addresses[1][tensor_loc];
p += chunk_idx*chunk_size; p += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -372,7 +372,7 @@ void multi_tensor_lamb_cuda( ...@@ -372,7 +372,7 @@ void multi_tensor_lamb_cuda(
// We now in-place modify grad to store update before compute its norm // We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time // Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
...@@ -395,7 +395,7 @@ void multi_tensor_lamb_cuda( ...@@ -395,7 +395,7 @@ void multi_tensor_lamb_cuda(
std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2); std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
multi_tensor_apply<2>( multi_tensor_apply<2>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
......
...@@ -20,7 +20,7 @@ struct LAMBStage1Functor ...@@ -20,7 +20,7 @@ struct LAMBStage1Functor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<5>& tl, TensorListMetadata<5>* tl,
const float* per_tensor_decay, const float* per_tensor_decay,
const float beta1, const float beta1,
const float beta2, const float beta2,
...@@ -33,26 +33,26 @@ struct LAMBStage1Functor ...@@ -33,26 +33,26 @@ struct LAMBStage1Functor
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl->block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc; int tensor_num = tl->start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl->sizes[tensor_loc];
float decay = per_tensor_decay[tensor_num]; float decay = per_tensor_decay[tensor_num];
GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc]; GRAD_T* g = (GRAD_T*)tl->addresses[0][tensor_loc];
g += chunk_idx*chunk_size; g += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc]; T* p = (T*)tl->addresses[1][tensor_loc];
p += chunk_idx*chunk_size; p += chunk_idx*chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc]; T* m = (T*)tl->addresses[2][tensor_loc];
m += chunk_idx*chunk_size; m += chunk_idx*chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc]; T* v = (T*)tl->addresses[3][tensor_loc];
v += chunk_idx*chunk_size; v += chunk_idx*chunk_size;
UPD_T* update = (UPD_T*)tl.addresses[4][tensor_loc]; UPD_T* update = (UPD_T*)tl->addresses[4][tensor_loc];
update += chunk_idx*chunk_size; update += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -128,9 +128,9 @@ void multi_tensor_lamb_stage1_cuda( ...@@ -128,9 +128,9 @@ void multi_tensor_lamb_stage1_cuda(
float next_step = float(step+1); float next_step = float(step+1);
float beta1_correction = 1.0f - std::pow(beta1, next_step); float beta1_correction = 1.0f - std::pow(beta1, next_step);
float beta2_correction = 1.0f - std::pow(beta2, next_step); float beta2_correction = 1.0f - std::pow(beta2, next_step);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1",
multi_tensor_apply<5>( multi_tensor_apply<5>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
......
...@@ -23,7 +23,7 @@ struct LAMBStage2Functor ...@@ -23,7 +23,7 @@ struct LAMBStage2Functor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<2>& tl, TensorListMetadata<2>* tl,
const float* per_tensor_param_norm, const float* per_tensor_param_norm,
const float* per_tensor_update_norm, const float* per_tensor_update_norm,
const float learning_rate, const float learning_rate,
...@@ -34,10 +34,10 @@ struct LAMBStage2Functor ...@@ -34,10 +34,10 @@ struct LAMBStage2Functor
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl->block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc; int tensor_num = tl->start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl->sizes[tensor_loc];
MATH_T ratio = learning_rate; MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters // nvlamb: apply adaptive learning rate to all parameters
...@@ -49,10 +49,10 @@ struct LAMBStage2Functor ...@@ -49,10 +49,10 @@ struct LAMBStage2Functor
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate; ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
} }
T* p = (T*)tl.addresses[0][tensor_loc]; T* p = (T*)tl->addresses[0][tensor_loc];
p += chunk_idx*chunk_size; p += chunk_idx*chunk_size;
UPD_T* update = (UPD_T*)tl.addresses[1][tensor_loc]; UPD_T* update = (UPD_T*)tl->addresses[1][tensor_loc];
update += chunk_idx*chunk_size; update += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -105,8 +105,8 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -105,8 +105,8 @@ void multi_tensor_lamb_stage2_cuda(
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2",
multi_tensor_apply<2>( multi_tensor_apply<2>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
......
...@@ -35,7 +35,7 @@ struct NovoGradFunctor ...@@ -35,7 +35,7 @@ struct NovoGradFunctor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<3>& tl, TensorListMetadata<3>* tl,
const float beta1, const float beta1,
const float beta2, const float beta2,
const float beta3, const float beta3,
...@@ -51,20 +51,20 @@ struct NovoGradFunctor ...@@ -51,20 +51,20 @@ struct NovoGradFunctor
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl->block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc; int tensor_num = tl->start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl->sizes[tensor_loc];
float grad_norm = per_tensor_grad_norm[tensor_num]; float grad_norm = per_tensor_grad_norm[tensor_num];
T* g = (T*)tl.addresses[0][tensor_loc]; T* g = (T*)tl->addresses[0][tensor_loc];
g += chunk_idx*chunk_size; g += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc]; T* p = (T*)tl->addresses[1][tensor_loc];
p += chunk_idx*chunk_size; p += chunk_idx*chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc]; T* m = (T*)tl->addresses[2][tensor_loc];
m += chunk_idx*chunk_size; m += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda( ...@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda(
multi_tensor_norm_out_cuda(chunk_size, noop_flag, grad_list, grad_norms, beta2, (1.0f - beta2), norm_type); multi_tensor_norm_out_cuda(chunk_size, noop_flag, grad_list, grad_norms, beta2, (1.0f - beta2), norm_type);
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF( DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "novograd", tensor_lists[0][0].scalar_type(), 0, "novograd",
multi_tensor_apply<3>( multi_tensor_apply<3>(
BLOCK_SIZE, BLOCK_SIZE,
......
...@@ -32,21 +32,21 @@ struct ScaleFunctor ...@@ -32,21 +32,21 @@ struct ScaleFunctor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<2>& tl, TensorListMetadata<2>* tl,
float scale) float scale)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl->sizes[tensor_loc];
in_t* in = (in_t*)tl.addresses[0][tensor_loc]; in_t* in = (in_t*)tl->addresses[0][tensor_loc];
in += chunk_idx*chunk_size; in += chunk_idx*chunk_size;
out_t* out = (out_t*)tl.addresses[1][tensor_loc]; out_t* out = (out_t*)tl->addresses[1][tensor_loc];
out += chunk_idx*chunk_size; out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda( ...@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch, // If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply. // and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
multi_tensor_apply<2>( multi_tensor_apply<2>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
......
...@@ -32,7 +32,7 @@ struct SGDFunctor ...@@ -32,7 +32,7 @@ struct SGDFunctor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<N>& tl, TensorListMetadata<N>* tl,
float wd, float wd,
float momentum, float momentum,
float dampening, float dampening,
...@@ -45,23 +45,23 @@ struct SGDFunctor ...@@ -45,23 +45,23 @@ struct SGDFunctor
// Early exit if we don't need to do anything // Early exit if we don't need to do anything
if (*noop_gmem) return; if (*noop_gmem) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl->sizes[tensor_loc];
T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc]; T_grad* grad_in = (T_grad*)tl->addresses[0][tensor_loc];
grad_in += chunk_idx*chunk_size; grad_in += chunk_idx*chunk_size;
T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc]; T_weight* weight_in = (T_weight*)tl->addresses[1][tensor_loc];
weight_in += chunk_idx*chunk_size; weight_in += chunk_idx*chunk_size;
T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc]; T_weight* mom_in = (T_weight*)tl->addresses[2][tensor_loc];
mom_in += chunk_idx*chunk_size; mom_in += chunk_idx*chunk_size;
at::Half *model_weights_out = nullptr; at::Half *model_weights_out = nullptr;
if(N == 4) if(N == 4)
{ {
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc]; model_weights_out = (at::Half*)tl->addresses[3][tensor_loc];
model_weights_out += chunk_idx*chunk_size; model_weights_out += chunk_idx*chunk_size;
} }
...@@ -166,6 +166,8 @@ void multi_tensor_sgd_cuda( ...@@ -166,6 +166,8 @@ void multi_tensor_sgd_cuda(
// 2. fp32, fp32, fp32, No // 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes // 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// 5. bfp16, bfp16, bfp16, No
// 6. bfp16, fp32, fp32, Yes
// It's easier to hardcode these possibilities than to use // It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where // switches etc. to handle the cross-product of cases where
// we don't want the majority of them. // we don't want the majority of them.
...@@ -268,6 +270,46 @@ void multi_tensor_sgd_cuda( ...@@ -268,6 +270,46 @@ void multi_tensor_sgd_cuda(
wd_after_momentum, wd_after_momentum,
scale); scale);
} }
// Case 5. bfp16, bfp16, bfp16, No
else if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::BFloat16 &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, at::BFloat16, at::BFloat16>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 6. bfp16, fp32, fp32, Yes
else if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, at::BFloat16, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
else else
{ {
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment