Commit 6e14df49 authored by rohithkrn's avatar rohithkrn
Browse files

Merge remote-tracking branch 'rocm_up/master' into apex_amp_bfp16

parents c7fd532c 2d0f9cf2
...@@ -6,11 +6,17 @@ from itertools import product ...@@ -6,11 +6,17 @@ from itertools import product
def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False): def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
# Exception handling for 18.04 compatibility # Exception handling for 18.04 compatibility
if check_overflow: if check_overflow:
if model_grad.is_sparse:
cpu_sum = float(model_grad.float()._values().sum())
else:
cpu_sum = float(model_grad.float().sum()) cpu_sum = float(model_grad.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True return True
if master_grad is not model_grad: # copy_ probably internally short-circuits this if master_grad is not model_grad: # copy_ probably internally short-circuits this
if model_grad.is_sparse:
master_grad.copy_(model_grad.to_dense())
else:
master_grad.copy_(model_grad) master_grad.copy_(model_grad)
if scale != 1.0: if scale != 1.0:
master_grad.mul_(scale) master_grad.mul_(scale)
...@@ -19,6 +25,9 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F ...@@ -19,6 +25,9 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False): def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
# Exception handling for 18.04 compatibility # Exception handling for 18.04 compatibility
if check_overflow: if check_overflow:
if model_grad.is_sparse:
cpu_sum = float(model_grad.float()._values().sum())
else:
cpu_sum = float(model_grad.float().sum()) cpu_sum = float(model_grad.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True return True
......
...@@ -14,6 +14,17 @@ ...@@ -14,6 +14,17 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
#include "type_shim.h" #include "type_shim.h"
typedef enum{ typedef enum{
...@@ -99,11 +110,51 @@ struct AdamFunctor ...@@ -99,11 +110,51 @@ struct AdamFunctor
T incoming_v[ILP]; T incoming_v[ILP];
T incoming_g[ILP]; T incoming_g[ILP];
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v) &&
is_aligned(g) &&
is_aligned(p_copy))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
GRAD_T tmp_g[ILP];
load_store(incoming_p, p, 0, i_start);
load_store(incoming_m, m, 0, i_start);
load_store(incoming_v, v, 0, i_start);
load_store(tmp_g, g, 0, i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_g[ii] = static_cast<T>(tmp_g[ii]);
T scaled_grad = incoming_g[ii]/grad_scale;
incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(incoming_v[ii] + eps);
else // Mode 1
denom = sqrtf(incoming_v[ii]) + eps;
float update = (incoming_m[ii]/denom) + (decay*incoming_p[ii]);
incoming_p[ii] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);
}
load_store(p, incoming_p, i_start, 0);
load_store(m, incoming_m, i_start, 0);
load_store(v, incoming_v, i_start, 0);
if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);
}
}
else
{
for(int i_start = 0; for(int i_start = 0;
i_start < n && i_start < chunk_size; i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) { i_start += blockDim.x*ILP) {
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) { for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0; incoming_p[ii] = 0;
incoming_m[ii] = 0; incoming_m[ii] = 0;
...@@ -124,7 +175,7 @@ struct AdamFunctor ...@@ -124,7 +175,7 @@ struct AdamFunctor
// the write loop, since writes just fire off once their LDGs arrive. // the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other. // Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though. // There is still compute ILP benefit from unrolling the loop though.
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) { for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x; int j = i_start + threadIdx.x + ii*blockDim.x;
...@@ -144,6 +195,7 @@ struct AdamFunctor ...@@ -144,6 +195,7 @@ struct AdamFunctor
} }
} }
} }
}
}; };
void fused_adam_cuda( void fused_adam_cuda(
...@@ -332,4 +384,3 @@ void fused_adam_cuda_mt( ...@@ -332,4 +384,3 @@ void fused_adam_cuda_mt(
} }
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
...@@ -70,7 +70,6 @@ ...@@ -70,7 +70,6 @@
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE. * POSSIBILITY OF SUCH DAMAGE.
*/ */
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
...@@ -84,6 +83,8 @@ ...@@ -84,6 +83,8 @@
#include "type_shim.h" #include "type_shim.h"
#include "compat.h" #include "compat.h"
#define ALIGN_BYTES 16
using Tensor = at::Tensor; using Tensor = at::Tensor;
using TensorList = at::TensorList; using TensorList = at::TensorList;
using ScalarType = at::ScalarType; using ScalarType = at::ScalarType;
...@@ -123,7 +124,7 @@ const int max_threads = 1024; ...@@ -123,7 +124,7 @@ const int max_threads = 1024;
inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
uint64_t block_size = 1; uint64_t block_size = 1;
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads)); uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
while (block_size < max_block_size) block_size *= 2; while (block_size < (max_block_size/2)) block_size *= 2;
// Launch at least a single warp - the kernel assumes that. // Launch at least a single warp - the kernel assumes that.
block_size = std::max(block_size, static_cast<uint64_t>(32)); block_size = std::max(block_size, static_cast<uint64_t>(32));
return dim3(block_size); return dim3(block_size);
...@@ -287,29 +288,40 @@ blockReduce(AccumT* smem, ...@@ -287,29 +288,40 @@ blockReduce(AccumT* smem,
template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT> template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>
__device__ __forceinline__ AccumT __device__ __forceinline__ AccumT
ilpReduce(T* data, ilpReduce(int shift,
T* data,
int size, int size,
const Reduction<T, AccumT>& r, const Reduction<T, AccumT>& r,
AccumT defaultVal) AccumT defaultVal)
{ {
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
AccumT threadVal = defaultVal; AccumT threadVal = defaultVal;
int offset = threadIdx.x; int offset = threadIdx.x;
// shift and do 1
if(shift > 0){
data -= shift;
size += shift;
if(threadIdx.x >= shift){
threadVal = r(threadVal, data[offset]);
}
size -= blockDim.x;
data += blockDim.x;
}
int last = size % (ILP * blockDim.x); int last = size % (ILP * blockDim.x);
// Body (unroll by ILP times) T v[ILP];
for (; offset < size - last; offset += blockDim.x * ILP) { LoadT* value = reinterpret_cast<LoadT*>(&v);
T tmp[ILP];
#pragma unroll for (; offset * ILP < (size - last); offset += blockDim.x) {
for (int j = 0; j < ILP; ++j) *value = reinterpret_cast<LoadT*>(data)[offset];
tmp[j] = data[offset + j * blockDim.x];
#pragma unroll for (int j = 0; j < ILP; ++j) {
for (int j = 0; j < ILP; ++j) threadVal = r(threadVal, v[j]);
threadVal = r(threadVal, tmp[j]); }
} }
offset = size - last + threadIdx.x;
// Epilogue // Epilogue
for (; offset < size; offset += blockDim.x) for (; offset < size; offset += blockDim.x)
threadVal = r(threadVal, data[offset]); threadVal = r(threadVal, data[offset]);
...@@ -319,7 +331,8 @@ ilpReduce(T* data, ...@@ -319,7 +331,8 @@ ilpReduce(T* data,
template <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT> template <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT>
__device__ __forceinline__ void __device__ __forceinline__ void
ilpReduce(T* data, ilpReduce(int shift,
T* data,
int size, int size,
AccumT* reducVal1, AccumT* reducVal1,
const Reduction1<T, AccumT>& r1, const Reduction1<T, AccumT>& r1,
...@@ -328,27 +341,38 @@ ilpReduce(T* data, ...@@ -328,27 +341,38 @@ ilpReduce(T* data,
const Reduction2<T, AccumT>& r2, const Reduction2<T, AccumT>& r2,
AccumT defaultVal2) AccumT defaultVal2)
{ {
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
AccumT threadVal1 = defaultVal1; AccumT threadVal1 = defaultVal1;
AccumT threadVal2 = defaultVal2; AccumT threadVal2 = defaultVal2;
int offset = threadIdx.x; int offset = threadIdx.x;
// shift and do 1
if(shift > 0){
data -= shift;
size += shift;
if(threadIdx.x >= shift){
threadVal1 = r1(threadVal1, data[offset]);
threadVal2 = r2(threadVal2, data[offset]);
}
size -= blockDim.x;
data += blockDim.x;
}
int last = size % (ILP * blockDim.x); int last = size % (ILP * blockDim.x);
// Body (unroll by ILP times) T v[ILP];
for (; offset < size - last; offset += blockDim.x * ILP) { LoadT* value = reinterpret_cast<LoadT*>(&v);
T tmp[ILP];
#pragma unroll for (; offset * ILP < (size - last); offset += blockDim.x) {
for (int j = 0; j < ILP; ++j) *value = reinterpret_cast<LoadT*>(data)[offset];
tmp[j] = data[offset + j * blockDim.x];
#pragma unroll
for (int j = 0; j < ILP; ++j) { for (int j = 0; j < ILP; ++j) {
threadVal1 = r1(threadVal1, tmp[j]); threadVal1 = r1(threadVal1, v[j]);
threadVal2 = r2(threadVal2, tmp[j]); threadVal2 = r2(threadVal2, v[j]);
} }
} }
offset = size - last + threadIdx.x;
// Epilogue // Epilogue
for (; offset < size; offset += blockDim.x) { for (; offset < size; offset += blockDim.x) {
threadVal1 = r1(threadVal1, data[offset]); threadVal1 = r1(threadVal1, data[offset]);
...@@ -375,17 +399,19 @@ cunn_SoftMaxXEntropyForward( ...@@ -375,17 +399,19 @@ cunn_SoftMaxXEntropyForward(
// each block handles a sample in the mini-batch // each block handles a sample in the mini-batch
input += blockIdx.x * classes; input += blockIdx.x * classes;
//output += blockIdx.x * classes; //output += blockIdx.x * classes;
const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);
int64_t label = labels[blockIdx.x]; int64_t label = labels[blockIdx.x];
// find the max and sum // find the max and sum
accscalar_t threadMax, threadSum, max_k, sum_k; accscalar_t threadMax, threadSum, max_k, sum_k;
ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>( ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>(
input, classes, shift, input, classes,
&threadMax, MaxFloat<scalar_t, accscalar_t>(), &threadMax, MaxFloat<scalar_t, accscalar_t>(),
-at::numeric_limits<accscalar_t>::max(), -at::numeric_limits<accscalar_t>::max(),
&threadSum, AddFloat<scalar_t, accscalar_t>(), &threadSum, AddFloat<scalar_t, accscalar_t>(),
static_cast<accscalar_t>(0)); static_cast<accscalar_t>(0));
blockReduce<Max, Add, accscalar_t>( blockReduce<Max, Add, accscalar_t>(
sdata, sdata,
&max_k, threadMax, Max<accscalar_t>(), &max_k, threadMax, Max<accscalar_t>(),
...@@ -393,9 +419,7 @@ cunn_SoftMaxXEntropyForward( ...@@ -393,9 +419,7 @@ cunn_SoftMaxXEntropyForward(
&sum_k, threadSum, Add<accscalar_t>(), &sum_k, threadSum, Add<accscalar_t>(),
static_cast<accscalar_t>(0)); static_cast<accscalar_t>(0));
// reduce all values accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(
input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
accscalar_t sumAll = blockReduce<Add, accscalar_t>( accscalar_t sumAll = blockReduce<Add, accscalar_t>(
sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0)); sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));
...@@ -411,10 +435,9 @@ cunn_SoftMaxXEntropyForward( ...@@ -411,10 +435,9 @@ cunn_SoftMaxXEntropyForward(
} }
} }
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue> template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
__global__ void __device__ __forceinline__ void
cunn_SoftMaxXEntropyBackward( apply(scalar_t *gradInput,
scalar_t *gradInput,
scalar_t *logits, scalar_t *logits,
outscalar_t *max_log_sum_exp, outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput, outscalar_t *gradOutput,
...@@ -422,9 +445,6 @@ cunn_SoftMaxXEntropyBackward( ...@@ -422,9 +445,6 @@ cunn_SoftMaxXEntropyBackward(
const float smoothing, const float smoothing,
int classes) int classes)
{ {
gradInput += blockIdx.x * classes;
logits += blockIdx.x * classes;
accscalar_t smooth_positives = 1.0 - smoothing; accscalar_t smooth_positives = 1.0 - smoothing;
accscalar_t smooth_negatives = smoothing / classes; accscalar_t smooth_negatives = smoothing / classes;
accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
...@@ -433,6 +453,7 @@ cunn_SoftMaxXEntropyBackward( ...@@ -433,6 +453,7 @@ cunn_SoftMaxXEntropyBackward(
int offset = threadIdx.x; int offset = threadIdx.x;
int last = classes % (ILP * blockDim.x); int last = classes % (ILP * blockDim.x);
for (; offset < classes - last; offset += blockDim.x * ILP) { for (; offset < classes - last; offset += blockDim.x * ILP) {
accscalar_t tmpLogits[ILP]; accscalar_t tmpLogits[ILP];
...@@ -457,9 +478,99 @@ cunn_SoftMaxXEntropyBackward( ...@@ -457,9 +478,99 @@ cunn_SoftMaxXEntropyBackward(
} }
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
__device__ __forceinline__ void
aligned_apply(int shift,
scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
{
accscalar_t smooth_positives = 1.0 - smoothing;
accscalar_t smooth_negatives = smoothing / classes;
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
int64_t label = labels[blockIdx.x];
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
int offset = threadIdx.x;
// shift and do 1
if(shift > 0){
logits -= shift;
gradInput -= shift;
classes += shift;
if(threadIdx.x >= shift){
gradInput[offset] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
classes -= blockDim.x;
gradInput += blockDim.x;
logits += blockDim.x;
shift -= blockDim.x;
}
int last = classes % (ILP * blockDim.x);
typedef typename std::aligned_storage<ILP*sizeof(scalar_t), ILP*alignof(scalar_t)>::type LoadT;
// input
scalar_t v[ILP];
LoadT* value = reinterpret_cast<LoadT*>(&v);
// output
scalar_t r[ILP];
LoadT* result = reinterpret_cast<LoadT*>(&r);
for (; offset * ILP < (classes - last); offset += blockDim.x) {
*value = reinterpret_cast<LoadT*>(logits)[offset];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
r[j] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(v[j]) - coeff) -
static_cast<accscalar_t>(((ILP * offset + j - shift) == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
reinterpret_cast<LoadT*>(gradInput)[offset] = *result;
}
offset = classes - last + threadIdx.x;
for (; offset < classes; offset += blockDim.x)
gradInput[offset] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
__global__ void
cunn_SoftMaxXEntropyBackward(
scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
{
gradInput += blockIdx.x * classes;
logits += blockIdx.x * classes;
// Do vectorized load/store when input/output have same alignment
const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);
const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
if (shift == shift_){
aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
}
else {
apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
}
}
template<template<typename, typename, typename> class Epilogue> template<template<typename, typename, typename> class Epilogue>
std::vector<Tensor> host_softmax_xentropy( std::vector<Tensor> host_softmax_xentropy(
...@@ -495,13 +606,13 @@ std::vector<Tensor> host_softmax_xentropy( ...@@ -495,13 +606,13 @@ std::vector<Tensor> host_softmax_xentropy(
// XXX: it assumes that inner_size == 1 // XXX: it assumes that inner_size == 1
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
const int ILP = 2;
dim3 grid(outer_size); dim3 grid(outer_size);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "host_softmax_xentropy", DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "host_softmax_xentropy",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
const int ILP = sizeof(float4)/sizeof(scalar_t_0);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
if (!half_to_float) { if (!half_to_float) {
cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue> cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>( <<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
...@@ -564,12 +675,12 @@ Tensor host_softmax_xentropy_backward( ...@@ -564,12 +675,12 @@ Tensor host_softmax_xentropy_backward(
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
const int ILP = 2;
dim3 grid(outer_size); dim3 grid(outer_size);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
DISPATCH_FLOAT_AND_HALF(gI.scalar_type(), 0, "host_softmax_xentropy_backward", DISPATCH_FLOAT_AND_HALF(gI.scalar_type(), 0, "host_softmax_xentropy_backward",
using accscalar_t = acc_type<scalar_t_0, true>; using accscalar_t = acc_type<scalar_t_0, true>;
const int ILP = sizeof(float4)/sizeof(scalar_t_0);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
if (!half_to_float) { if (!half_to_float) {
cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue> cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>( <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
......
...@@ -183,7 +183,7 @@ class SelfAttnFunc(torch.autograd.Function): ...@@ -183,7 +183,7 @@ class SelfAttnFunc(torch.autograd.Function):
values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1)) values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))
# Mask and Scaling for Dropout (not a publically documented op) # Mask and Scaling for Dropout (not a publically documented op)
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, dropout_prob_t[0]) dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0/(1.0-dropout_prob_t[0]))
# Softmax Grad (not a publically documented op) # Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
......
...@@ -7,17 +7,19 @@ from .. import amp ...@@ -7,17 +7,19 @@ from .. import amp
class MlpFunction(torch.autograd.Function): class MlpFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, *args): def forward(ctx, bias, activation, *args):
output = mlp_cuda.forward(args) output = mlp_cuda.forward(bias, activation, args)
ctx.save_for_backward(*args) ctx.save_for_backward(*args)
ctx.outputs = output ctx.outputs = output
ctx.bias = bias
ctx.activation = activation
return output[0] return output[0]
@staticmethod @staticmethod
def backward(ctx, grad_o): def backward(ctx, grad_o):
grads = mlp_cuda.backward(grad_o, ctx.outputs, ctx.saved_tensors) grads = mlp_cuda.backward(ctx.bias, ctx.activation, grad_o, ctx.outputs, ctx.saved_tensors)
del ctx.outputs del ctx.outputs
return tuple(grads) return (None, None, *grads)
mlp_function = amp.half_function(MlpFunction.apply) mlp_function = amp.half_function(MlpFunction.apply)
...@@ -29,16 +31,21 @@ class MLP(torch.nn.Module): ...@@ -29,16 +31,21 @@ class MLP(torch.nn.Module):
bias (bool): Default True: bias (bool): Default True:
relu (bool): Default True relu (bool): Default True
""" """
def __init__(self, mlp_sizes, bias=True, relu=True): def __init__(self, mlp_sizes, bias=True, activation='relu'):
if not (bias and relu):
raise TypeError("bias and relu must be both true.")
super(MLP, self).__init__() super(MLP, self).__init__()
self.num_layers = len(mlp_sizes) - 1 self.num_layers = len(mlp_sizes) - 1
self.mlp_sizes = copy(mlp_sizes) self.mlp_sizes = copy(mlp_sizes)
self.bias = bias self.bias = 1 if bias else 0
self.relu= relu
if activation is 'none':
self.activation = 0
elif activation is 'relu':
self.activation = 1
elif activation is 'sigmoid':
self.activation = 2
else:
raise TypeError("activation must be relu or none.")
# ignoring bias = False now
self.weights = [] self.weights = []
self.biases = [] self.biases = []
for i in range(self.num_layers): for i in range(self.num_layers):
...@@ -46,6 +53,7 @@ class MLP(torch.nn.Module): ...@@ -46,6 +53,7 @@ class MLP(torch.nn.Module):
self.weights.append(w) self.weights.append(w)
name = 'weight_{}'.format(i) name = 'weight_{}'.format(i)
setattr(self, name, w) setattr(self, name, w)
if self.bias:
b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1])) b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1]))
self.biases.append(b) self.biases.append(b)
name = 'bias_{}'.format(i) name = 'bias_{}'.format(i)
...@@ -58,13 +66,14 @@ class MLP(torch.nn.Module): ...@@ -58,13 +66,14 @@ class MLP(torch.nn.Module):
dimsum = weight.size(0) + weight.size(1) dimsum = weight.size(0) + weight.size(1)
std = math.sqrt(2. / float(dimsum)) std = math.sqrt(2. / float(dimsum))
nn.init.normal_(weight, 0., std) nn.init.normal_(weight, 0., std)
if self.bias:
for bias in self.biases: for bias in self.biases:
std = math.sqrt(1. / float(bias.size(0))) std = math.sqrt(1. / float(bias.size(0)))
nn.init.normal_(bias, 0., std) nn.init.normal_(bias, 0., std)
def forward(self, input): def forward(self, input):
return mlp_function(input, *self.weights, *self.biases) return mlp_function(self.bias, self.activation, input, *self.weights, *self.biases)
def extra_repr(self): def extra_repr(self):
s = F"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, ReLU={self.relu}" s = F"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, activation={self.activation}"
return s return s
...@@ -172,8 +172,8 @@ void cuWelfordMuSigma2( ...@@ -172,8 +172,8 @@ void cuWelfordMuSigma2(
for (; l+7 < n2; l+=8*numx) { for (; l+7 < n2; l+=8*numx) {
for (int k = 0; k < 8; k+=2) { for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k))); float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count); cuWelfordOnlineSum<float>(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count); cuWelfordOnlineSum<float>(curr.y,mu,sigma2,count);
} }
} }
for (; l < n2; ++l) { for (; l < n2; ++l) {
...@@ -230,9 +230,15 @@ void cuWelfordMuSigma2( ...@@ -230,9 +230,15 @@ void cuWelfordMuSigma2(
template<typename U> U rsqrt(U v) { template<typename U> U rsqrt(U v) {
return U(1) / sqrt(v); return U(1) / sqrt(v);
} }
#if defined __HIP_PLATFORM_HCC__
__device__ float rsqrt(float v) {
return rsqrtf(v);
}
#else
template<> float rsqrt(float v) { template<> float rsqrt(float v) {
return rsqrtf(v); return rsqrtf(v);
} }
#endif
template<> double rsqrt(double v) { template<> double rsqrt(double v) {
return rsqrt(v); return rsqrt(v);
} }
...@@ -293,7 +299,7 @@ void cuApplyLayerNorm( ...@@ -293,7 +299,7 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
// 2) Tensors are contiguous // 2) Tensors are contiguous
// //
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared; SharedMemory<U> shared;
U* buf = shared.getPointer(); U* buf = shared.getPointer();
U mu,sigma2; U mu,sigma2;
...@@ -531,7 +537,7 @@ void cuComputeGradInput( ...@@ -531,7 +537,7 @@ void cuComputeGradInput(
const T* gamma, const T* gamma,
T* grad_input) T* grad_input)
{ {
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0); U sum_loss1 = U(0);
U sum_loss2 = U(0); U sum_loss2 = U(0);
const U c_mean = mean[i1]; const U c_mean = mean[i1];
......
...@@ -19,7 +19,9 @@ int mlp_fp( ...@@ -19,7 +19,9 @@ int mlp_fp(
int* output_features, int* output_features,
T** BPtr, T** BPtr,
T* Y, T* Y,
T* reserved_space); T* reserved_space,
int use_bias,
int activation);
template <typename T> template <typename T>
int mlp_bp( int mlp_bp(
...@@ -35,11 +37,18 @@ int mlp_bp( ...@@ -35,11 +37,18 @@ int mlp_bp(
T* work_space, T* work_space,
T* dX, T* dX,
T** dwPtr, T** dwPtr,
T** dbPtr); T** dbPtr,
bool requires_grad,
int use_bias,
int activation);
std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at::Tensor> inputs) {
std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) { auto num_layers = inputs.size() - 1;
if (use_bias) {
// inputs contains (input, weights, biases) // inputs contains (input, weights, biases)
auto num_layers = (inputs.size() - 1) / 2; num_layers /= 2;
}
auto batch_size = inputs[0].size(0); auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1); auto input_features = inputs[0].size(1);
...@@ -60,8 +69,10 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) { ...@@ -60,8 +69,10 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) {
std::vector<scalar_t*> b_ptr; std::vector<scalar_t*> b_ptr;
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>()); w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
if (use_bias) {
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>()); b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
} }
}
auto result = mlp_fp<scalar_t>( auto result = mlp_fp<scalar_t>(
inputs[0].data_ptr<scalar_t>(), inputs[0].data_ptr<scalar_t>(),
input_features, input_features,
...@@ -71,37 +82,48 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) { ...@@ -71,37 +82,48 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) {
output_features.data(), output_features.data(),
b_ptr.data(), b_ptr.data(),
out.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(),
reserved_space.data_ptr<scalar_t>()); reserved_space.data_ptr<scalar_t>(),
use_bias,
activation);
}); });
return {out, reserved_space}; return {out, reserved_space};
} }
std::vector<at::Tensor> mlp_backward( std::vector<at::Tensor> mlp_backward(
int use_bias,
int activation,
at::Tensor grad_o, at::Tensor grad_o,
std::vector<at::Tensor> fprop_outputs, std::vector<at::Tensor> fprop_outputs,
std::vector<at::Tensor> inputs) { std::vector<at::Tensor> inputs) {
// same code to get sizes and W pointers
auto num_layers = (inputs.size() - 1) / 2; auto num_layers = inputs.size() - 1;
if (use_bias) {
// inputs contains (input, weights, biases)
num_layers /= 2;
}
auto batch_size = inputs[0].size(0); auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1); auto input_features = inputs[0].size(1);
// TODO: not creating empty tensor for it?
bool requires_grad = inputs[0].requires_grad();
std::vector<int> output_features; std::vector<int> output_features;
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
output_features.push_back(inputs[i + 1].size(0)); output_features.push_back(inputs[i + 1].size(0));
} }
// create outputs, length of inputs // create outputs, length of inputs
// TODO: not create bias if not needed
std::vector<at::Tensor> outputs; std::vector<at::Tensor> outputs;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now
} }
AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_backward", [&] {
std::vector<scalar_t*> w_ptr; std::vector<scalar_t*> w_ptr;
std::vector<scalar_t*> b_ptr;
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>()); w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
} }
std::vector<scalar_t*> outputs_ptr; std::vector<scalar_t*> outputs_ptr;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
...@@ -127,7 +149,10 @@ std::vector<at::Tensor> mlp_backward( ...@@ -127,7 +149,10 @@ std::vector<at::Tensor> mlp_backward(
work_space.data_ptr<scalar_t>(), work_space.data_ptr<scalar_t>(),
outputs_ptr[0], outputs_ptr[0],
outputs_ptr.data() + 1, outputs_ptr.data() + 1,
outputs_ptr.data() + 1 + num_layers); outputs_ptr.data() + 1 + num_layers,
requires_grad,
use_bias,
activation);
}); });
return outputs; return outputs;
......
...@@ -10,8 +10,11 @@ ...@@ -10,8 +10,11 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#define BIASADDRELU_FPROP_NUM_THREADS 128 // constants for fused bias+relu kernel
#define BIASADDRELU_BPROP_NUM_THREADS 128 #define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block
#define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim
#define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim
#define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread
// move to a header later on // move to a header later on
#define ILP 4 #define ILP 4
...@@ -42,6 +45,12 @@ __device__ __inline__ float relu(float a) { ...@@ -42,6 +45,12 @@ __device__ __inline__ float relu(float a) {
return (retf); return (retf);
} }
// Keep Sigmoid in float only. When using half, cast to float before calling.
__device__ __inline__ float sigmoid(float a) {
float retf = 1.f / (1.f + expf(-a));
return (retf);
}
// FP64 Wrapper around cublas GEMMEx // FP64 Wrapper around cublas GEMMEx
cublasStatus_t mlp_gemm( cublasStatus_t mlp_gemm(
cublasHandle_t handle, cublasHandle_t handle,
...@@ -156,9 +165,55 @@ cublasStatus_t mlp_gemm( ...@@ -156,9 +165,55 @@ cublasStatus_t mlp_gemm(
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} }
// Bias ADD + ReLU. Assume input X is [features x batch size], assume column major. // Bias ADD. Assume input X is [features x batch size], column major.
// Bias is one 'features' long vector, with implicit broadcast. // Bias is one 'features' long vector, with implicit broadcast.
// Currently, activation support fuesed ReLU. Safe to call in-place. template <typename T>
__global__ void biasAdd_fprop(T *X, T *b, uint batch_size, uint features) {
T r_x[ILP];
T r_b[ILP];
if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
int row = tid % (features / ILP);
load_store(r_x, X, 0 , tid);
load_store(r_b, b, 0 , row);
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
r_x[ii] = bias_sum;
}
load_store(X, r_x, tid , 0);
}
} else {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
int row = tid % features;
r_x[ii] = X[idx];
r_b[ii] = b[row];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
r_x[ii] = bias_sum;
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
X[idx] = r_x[ii];
}
}
}
}
}
// Bias ADD + ReLU. Assume input X is [features x batch size], column major.
// Activation support fuesed ReLU. Safe to call in-place.
template <typename T> template <typename T>
__global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) { __global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) {
T r_x[ILP]; T r_x[ILP];
...@@ -204,32 +259,308 @@ __global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) { ...@@ -204,32 +259,308 @@ __global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) {
} }
} }
// ReLU. Assume input X is [features x batch size], column major.
// Safe to call in-place.
template <typename T>
__global__ void Relu_fprop(T *X, uint batch_size, uint features) {
T r_x[ILP];
if(is_aligned(X) && features % ILP ==0) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
load_store(r_x, X, 0 , tid);
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
r_x[ii] = relu(static_cast<float>(r_x[ii]));
}
load_store(X, r_x, tid , 0);
}
} else {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
r_x[ii] = X[idx];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
r_x[ii] = relu(static_cast<float>(r_x[ii]));
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
X[idx] = r_x[ii];
}
}
}
}
}
// Sigmoid. Assume input X is [features x batch size], column major.
// Safe to call in-place.
template <typename T>
__global__ void Sigmoid_fprop(T *X, uint batch_size, uint features) {
T r_x[ILP];
if(is_aligned(X) && features % ILP ==0) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
load_store(r_x, X, 0 , tid);
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));
}
load_store(X, r_x, tid , 0);
}
} else {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
r_x[ii] = X[idx];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
X[idx] = r_x[ii];
}
}
}
}
}
// ReLU. Assume input X is [features x batch size], column major.
// Safe to call in-place.
template <typename T>
__global__ void Relu_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
T r_dy[ILP];
T r_y[ILP];
if(is_aligned(dY) &&
is_aligned(Y) &&
is_aligned(dX) &&
features % ILP ==0) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
load_store(r_dy, dY, 0 , tid);
load_store(r_y, Y, 0 , tid);
#pragma unroll
for(int ii=0;ii<ILP;ii++){
if ((float)r_y[ii] <= 0.f)
r_dy[ii] = 0;
}
load_store(dX, r_dy, tid, 0);
}
} else {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
r_dy[ii] = dY[idx];
r_y[ii] = Y[idx];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
if ((float)r_y[ii] <= 0.f)
r_dy[ii] = 0;
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
dX[idx] = r_dy[ii];
}
}
}
}
}
// Sigmoid. Assume input X is [features x batch size], column major.
// Safe to call in-place.
template <typename T>
__global__ void Sigmoid_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
T r_dy[ILP];
T r_y[ILP];
if(is_aligned(dY) &&
is_aligned(Y) &&
is_aligned(dX) &&
features % ILP ==0) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
load_store(r_dy, dY, 0 , tid);
load_store(r_y, Y, 0 , tid);
#pragma unroll
for(int ii=0;ii<ILP;ii++){
float grad_out = r_dy[ii];
float out = r_y[ii];
float grad_i = out * ( 1.f - out) * grad_out;
r_dy[ii] = grad_i;
}
load_store(dX, r_dy, tid, 0);
}
} else {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
r_dy[ii] = dY[idx];
r_y[ii] = Y[idx];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
float grad_out = r_dy[ii];
float out = r_y[ii];
float grad_i = out * ( 1.f - out) * grad_out;
r_dy[ii] = grad_i;
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
dX[idx] = r_dy[ii];
}
}
}
}
}
// Compute grid size for pointwise backward kernel. // Compute grid size for pointwise backward kernel.
// Some intelligence needed to determine number of splits for reduction. // block_x/y is total elment being handled per block, not number of threads
void get_biasAddRelu_bprop_grid_size( void get_biasAddRelu_bprop_grid_size(
int yfeat, int yfeat,
int threadsPerBlock,
int batch_size, int batch_size,
int block_x,
int block_y,
int* grid_x, int* grid_x,
int* grid_y) { int* grid_y) {
*grid_x = (yfeat + block_x - 1) / block_x;
// Get number of SMs for efficient reduction. // Get number of SMs for efficient reduction.
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
// First preference, whole reduction in 1 CTA // can switch to occupancy calculation. use 4 below now for sm_70
int nBlocks = (yfeat + threadsPerBlock - 1) / threadsPerBlock; int max_blocks_y = num_SMs * 4 / (*grid_x);
// block_y should be from minimal work per thread
// Figure out how many splits to divide reduction into. At least 32 elements per CTA. int nRedSplits = (batch_size + block_y - 1) / block_y;
// we want grid_y as close to sqrt(batchsize)? // increase number of elem per thread redcution to not launch more than enough
int nRedSplits = std::sqrt(batch_size); // kernel adjust work, so here we just launch max block
// for batchsize <=64, just use 1 block *grid_y = std::min(nRedSplits, max_blocks_y);
if(batch_size < 64) nRedSplits = 1;
// no need to go over occupancy
nRedSplits = min((8*num_SMs)/nBlocks, nRedSplits);
*grid_x = nBlocks;
*grid_y = nRedSplits;
return; return;
} }
// Addition done deterministically via a 2-pass approach. Each CTA writes out partial
// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
template <typename T, int UNROLL_FACTOR>
__global__ void biasAdd_bprop(
T* dY,
int features,
int batch_size,
volatile float* intermediate,
int* semaphores,
T* db) {
// The feature that this thread is responsible for
int f = blockIdx.x * blockDim.x + threadIdx.x;
// Compute the span this thread is responsible for
// For this block
int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
int b_nStart = blockIdx.y * b_chunkSize;
int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
// For this thread
int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
int nStart = threadIdx.y * chunkSize + b_nStart;
int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
volatile float* out = intermediate + blockIdx.y * features;
// Flag to trigger last reduction.
__shared__ bool isLastBlock;
// we know block size for now
__shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];
// Accumulate db in FP32 always
float db_local = 0;
if (f < features) {
int nidx = 0;
// Handle non-multiple of UNROLL_FACTOR residue
for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
int row, col, flat_idx;
row = f;
col = nStart + nidx;
flat_idx = col * features + row;
db_local += (float)dY[flat_idx];
}
// Handle meat of work
for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
int row, col, flat_idx;
row = f;
col = nStart + nidx;
flat_idx = col * features + row;
#pragma unroll 4
for (int u = 0; u < UNROLL_FACTOR; u++) {
db_local += (float)dY[flat_idx];
flat_idx += features;
}
}
// naive block reduction on y-dim
int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
smem[linear_idx] = db_local;
}
__syncthreads();
if (f < features) {
if(threadIdx.y == 0) {
for(int yidx = 1; yidx < blockDim.y; yidx++){
db_local += smem[yidx * blockDim.x + threadIdx.x];
}
// block result is in db_local now for all threadIdx.y == 0
// Write out partial result
out[f] = db_local;
}
}
__threadfence();
__syncthreads();
// Increment semaphore and check if this is the last CTA in the grid_y dimension.
// Only thread (0,0) calls this
if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {
unsigned int sum_idx;
sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
isLastBlock = (sum_idx == (gridDim.y - 1));
}
__syncthreads();
db_local = 0;
// No block reduction for now, only thread (*,0) do grid reduction
if (isLastBlock && f < features) {
if(threadIdx.y == 0) {
for (int n = 0; n < gridDim.y; n++) {
int row, col;
row = f;
col = n;
db_local += (float)(intermediate[col * features + row]);
}
db[f] = (T)db_local;
}
}
}
// Addition done deterministically via a 2-pass approach. Each CTA writes out partial // Addition done deterministically via a 2-pass approach. Each CTA writes out partial
// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result. // sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
template <typename T, int UNROLL_FACTOR> template <typename T, int UNROLL_FACTOR>
...@@ -245,14 +576,22 @@ __global__ void biasAddRelu_bprop( ...@@ -245,14 +576,22 @@ __global__ void biasAddRelu_bprop(
// The feature that this thread is responsible for // The feature that this thread is responsible for
int f = blockIdx.x * blockDim.x + threadIdx.x; int f = blockIdx.x * blockDim.x + threadIdx.x;
// Compute the batch span this thread is responsible for // Compute the span this thread is responsible for
int chunkSize = (batch_size + gridDim.y - 1) / gridDim.y; // For this block
int nStart = blockIdx.y * chunkSize; int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
int nSpan = min(batch_size, nStart + chunkSize) - nStart; int b_nStart = blockIdx.y * b_chunkSize;
int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
// For this thread
int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
int nStart = threadIdx.y * chunkSize + b_nStart;
int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
volatile float* out = intermediate + blockIdx.y * features; volatile float* out = intermediate + blockIdx.y * features;
// Flag to trigger last reduction. // Flag to trigger last reduction.
__shared__ bool isLastBlock; __shared__ bool isLastBlock;
// we know block size for now
__shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];
// Accumulate db in FP32 always // Accumulate db in FP32 always
float db_local = 0; float db_local = 0;
...@@ -296,15 +635,28 @@ __global__ void biasAddRelu_bprop( ...@@ -296,15 +635,28 @@ __global__ void biasAddRelu_bprop(
} }
} }
// naive block reduction on y-dim
int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
smem[linear_idx] = db_local;
}
__syncthreads();
if (f < features) {
if(threadIdx.y == 0) {
for(int yidx = 1; yidx < blockDim.y; yidx++){
db_local += smem[yidx * blockDim.x + threadIdx.x];
}
// block result is in db_local now for all threadIdx.y == 0
// Write out partial result // Write out partial result
out[f] = db_local; out[f] = db_local;
} }
}
__threadfence(); __threadfence();
__syncthreads(); __syncthreads();
// Increment semaphore and check if this is the last CTA in // Increment semaphore and check if this is the last CTA in the grid_y dimension.
// the grid_y dimension. // Only thread (0,0) calls this
if (threadIdx.x == 0 && f < features) { if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {
unsigned int sum_idx; unsigned int sum_idx;
sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1); sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
isLastBlock = (sum_idx == (gridDim.y - 1)); isLastBlock = (sum_idx == (gridDim.y - 1));
...@@ -312,7 +664,9 @@ __global__ void biasAddRelu_bprop( ...@@ -312,7 +664,9 @@ __global__ void biasAddRelu_bprop(
__syncthreads(); __syncthreads();
db_local = 0; db_local = 0;
// No block reduction for now, only thread (*,0) do grid reduction
if (isLastBlock && f < features) { if (isLastBlock && f < features) {
if(threadIdx.y == 0) {
for (int n = 0; n < gridDim.y; n++) { for (int n = 0; n < gridDim.y; n++) {
int row, col; int row, col;
row = f; row = f;
...@@ -321,6 +675,7 @@ __global__ void biasAddRelu_bprop( ...@@ -321,6 +675,7 @@ __global__ void biasAddRelu_bprop(
} }
db[f] = (T)db_local; db[f] = (T)db_local;
} }
}
} }
// Addition done deterministically via a 2-pass approach. Each CTA writes out partial // Addition done deterministically via a 2-pass approach. Each CTA writes out partial
...@@ -338,10 +693,16 @@ __global__ void biasAddRelu_bprop_aligned( ...@@ -338,10 +693,16 @@ __global__ void biasAddRelu_bprop_aligned(
// The feature that this thread is responsible for // The feature that this thread is responsible for
int f = blockIdx.x * blockDim.x + threadIdx.x; int f = blockIdx.x * blockDim.x + threadIdx.x;
// Compute the batch span this thread is responsible for // Compute the span this thread is responsible for
int chunkSize = (batch_size + gridDim.y - 1) / gridDim.y; // For this block
int nStart = blockIdx.y * chunkSize; int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
int nSpan = min(batch_size, nStart + chunkSize) - nStart; int b_nStart = blockIdx.y * b_chunkSize;
int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
// For this thread
int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
int nStart = threadIdx.y * chunkSize + b_nStart;
int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
volatile float* out = intermediate + blockIdx.y * features; volatile float* out = intermediate + blockIdx.y * features;
// Flag to trigger last reduction. // Flag to trigger last reduction.
...@@ -399,6 +760,27 @@ __global__ void biasAddRelu_bprop_aligned( ...@@ -399,6 +760,27 @@ __global__ void biasAddRelu_bprop_aligned(
} }
} }
// we know block size for now
__shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y*ILP];
// naive block reduction on y-dim
int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
float* smem_out = smem + ILP * linear_idx;
#pragma unroll
for(int ii=0;ii<ILP;ii++){
smem_out[ii] = db_local[ii]; // reuse local dy buffer
}
__syncthreads();
if(threadIdx.y == 0) {
for(int yidx = 1; yidx < blockDim.y; yidx++){
float* smem_in = smem + ILP * (yidx * blockDim.x + threadIdx.x);
#pragma unroll
for(int ii=0;ii<ILP;ii++){
db_local[ii] += smem_in[ii]; // reuse local dy buffer
}
}
// block result is in db_local now for all threadIdx.y == 0
// TODO: maybe not useful early exit here
if(gridDim.y == 1) { if(gridDim.y == 1) {
#pragma unroll #pragma unroll
for(int ii=0;ii<ILP;ii++){ for(int ii=0;ii<ILP;ii++){
...@@ -410,13 +792,13 @@ __global__ void biasAddRelu_bprop_aligned( ...@@ -410,13 +792,13 @@ __global__ void biasAddRelu_bprop_aligned(
// Write out partial result // Write out partial result
load_store(out, db_local, f, 0); load_store(out, db_local, f, 0);
}
__threadfence(); __threadfence();
__syncthreads(); __syncthreads();
// Increment semaphore and check if this is the last CTA in // Increment semaphore and check if this is the last CTA in the grid_y dimension.
// the grid_y dimension. // Only thread (0,0) calls this
if (threadIdx.x == 0) { if (threadIdx.x == 0 && threadIdx.y == 0) {
unsigned int sum_idx; unsigned int sum_idx;
sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1); sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
isLastBlock = (sum_idx == (gridDim.y - 1)); isLastBlock = (sum_idx == (gridDim.y - 1));
...@@ -428,7 +810,10 @@ __global__ void biasAddRelu_bprop_aligned( ...@@ -428,7 +810,10 @@ __global__ void biasAddRelu_bprop_aligned(
db_local[ii] = 0.f; db_local[ii] = 0.f;
} }
float r_db[ILP]; float r_db[ILP];
// No block reduction for now, only thread (*,0) do grid reduction
if (isLastBlock) { if (isLastBlock) {
if(threadIdx.y == 0){
for (int n = 0; n < gridDim.y; n++) { for (int n = 0; n < gridDim.y; n++) {
int row, col; int row, col;
row = f; row = f;
...@@ -445,6 +830,7 @@ __global__ void biasAddRelu_bprop_aligned( ...@@ -445,6 +830,7 @@ __global__ void biasAddRelu_bprop_aligned(
} }
load_store(db, r_dy, f, 0); load_store(db, r_dy, f, 0);
} }
}
} }
// Lists where the num_layers-1 intermediate Y buffers start in reserved space on fprop, starting // Lists where the num_layers-1 intermediate Y buffers start in reserved space on fprop, starting
...@@ -502,10 +888,20 @@ size_t get_reduction_scratch_space(int batch_size, int num_layers, const int* ou ...@@ -502,10 +888,20 @@ size_t get_reduction_scratch_space(int batch_size, int num_layers, const int* ou
size_t max_scratch_space = 0; size_t max_scratch_space = 0;
// Loop over all layers to see which one needs the max scratch space // Loop over all layers to see which one needs the max scratch space
for (int l = 0; l < num_layers; l++) { for (int l = 0; l < num_layers; l++) {
int tmp, num_splits; // need to find max(aligned, not_aligned)
int tmp, res0, res1;
int block_x = BIAS_RELU_BW_NTHREADS_X;
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
get_biasAddRelu_bprop_grid_size(
output_features[l], batch_size, block_x, block_y, &tmp, &res0);
block_x = ILP * BIAS_RELU_BW_NTHREADS_X;
get_biasAddRelu_bprop_grid_size( get_biasAddRelu_bprop_grid_size(
output_features[l], BIASADDRELU_BPROP_NUM_THREADS, batch_size, &tmp, &num_splits); output_features[l], batch_size, block_x, block_y, &tmp, &res1);
max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * num_splits));
max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res0));
max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res1));
} }
return max_scratch_space; return max_scratch_space;
...@@ -581,7 +977,9 @@ int mlp_fp( ...@@ -581,7 +977,9 @@ int mlp_fp(
int* output_features, int* output_features,
T** BPtr, T** BPtr,
T* Y, T* Y,
T* reserved_space) { T* reserved_space,
int use_bias,
int activation) {
T *weight, *input, *output, *bias; T *weight, *input, *output, *bias;
T *reserved_space_x, *reserved_space_y; T *reserved_space_x, *reserved_space_y;
reserved_space_x = NULL; reserved_space_x = NULL;
...@@ -597,7 +995,9 @@ int mlp_fp( ...@@ -597,7 +995,9 @@ int mlp_fp(
weight = WPtr[layer]; weight = WPtr[layer];
input = (layer == 0) ? X : reserved_space_x; input = (layer == 0) ? X : reserved_space_x;
output = (layer == num_layers - 1) ? Y : reserved_space_y; output = (layer == num_layers - 1) ? Y : reserved_space_y;
if (use_bias) {
bias = BPtr[layer]; bias = BPtr[layer];
}
int ifeat = (layer == 0) ? input_features : output_features[layer - 1]; int ifeat = (layer == 0) ? input_features : output_features[layer - 1];
int ofeat = output_features[layer]; int ofeat = output_features[layer];
...@@ -627,12 +1027,33 @@ int mlp_fp( ...@@ -627,12 +1027,33 @@ int mlp_fp(
return 1; return 1;
} }
// Call biasReLU
const uint &input_size = ofeat; const uint &input_size = ofeat;
int num_blocks = 0; int num_blocks = 0;
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIASADDRELU_FPROP_NUM_THREADS, 0); // Call biasReLU
biasAddRelu_fprop<<<num_SMs*num_blocks, BIASADDRELU_FPROP_NUM_THREADS, 0, stream>>>(output, bias, batch_size, input_size); if(use_bias == 1) {
if (activation == 0) { // no activation
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
} else if (activation == 1) { // relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAddRelu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
} else if (activation == 2) { // sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
}
} else {
// don't need to do anything in case of no activation and no bias
if (activation == 1) { // relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Relu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
} else if (activation == 2) { // sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
}
}
// Set current output as next layer input // Set current output as next layer input
reserved_space_x = reserved_space_y; reserved_space_x = reserved_space_y;
...@@ -660,7 +1081,10 @@ int mlp_bp( ...@@ -660,7 +1081,10 @@ int mlp_bp(
T* work_space, T* work_space,
T* dX, T* dX,
T** dwPtr, T** dwPtr,
T** dbPtr) { T** dbPtr,
bool requires_grad,
int use_bias,
int activation) {
T* weight; T* weight;
T *dweight, *dx, *dy, *dbias; T *dweight, *dx, *dy, *dbias;
T *x, *y; T *x, *y;
...@@ -719,31 +1143,84 @@ int mlp_bp( ...@@ -719,31 +1143,84 @@ int mlp_bp(
float one = 1.f; float one = 1.f;
float zero = 0.f; float zero = 0.f;
// Call bias ReLU backprop - first implementation, 1 thread per bias element if (use_bias == 1) {
int threadsPerBlock = BIASADDRELU_BPROP_NUM_THREADS; if (activation == 0) { // no acitvation
// bgrad
dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
int grid_x, grid_y; int grid_x, grid_y;
get_biasAddRelu_bprop_grid_size(yfeat, threadsPerBlock, batch_size, &grid_x, &grid_y); cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
dim3 block(threadsPerBlock);
int block_x = BIAS_RELU_BW_NTHREADS_X;
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
dim3 grid(grid_x, grid_y);
biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(
dy, yfeat, batch_size, db_scratch, semaphores, dbias);
// bypass dgrad through reset pointer
dy_gemm = dy;
} else if (activation == 1) { // relu
dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
int grid_x, grid_y;
cudaMemsetAsync(semaphores, 0, semaphore_size, stream); cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
if(yfeat % (ILP * threadsPerBlock) == 0 && if(yfeat % (ILP * BIAS_RELU_BW_NTHREADS_X) == 0 &&
is_aligned(y) && is_aligned(y) &&
is_aligned(dy) && is_aligned(dy) &&
is_aligned(dy_gemm) && is_aligned(dy_gemm) &&
is_aligned(dbias)){ is_aligned(dbias)){
dim3 grid(grid_x/ILP, grid_y); int block_x = ILP * BIAS_RELU_BW_NTHREADS_X;
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
dim3 grid(grid_x, grid_y);
biasAddRelu_bprop_aligned<T, 4><<<grid, block, 0, stream>>>( biasAddRelu_bprop_aligned<T, 4><<<grid, block, 0, stream>>>(
y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias); y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);
} else { } else {
int block_x = BIAS_RELU_BW_NTHREADS_X;
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
dim3 grid(grid_x, grid_y); dim3 grid(grid_x, grid_y);
biasAddRelu_bprop<T, 4><<<grid, block, 0, stream>>>( biasAddRelu_bprop<T, 4><<<grid, block, 0, stream>>>(
y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias); y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);
} }
} else if (activation == 2) { // sigmoid
// activation backward
int num_blocks = 0;
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
// bgrad, from dy_gemm
dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
int grid_x, grid_y;
cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
int block_x = BIAS_RELU_BW_NTHREADS_X;
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
dim3 grid(grid_x, grid_y);
biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(
dy_gemm, yfeat, batch_size, db_scratch, semaphores, dbias);
}
} else { // no bias below
if (activation == 0) {
// bypass dgrad through reset pointer
dy_gemm = dy;
} else if (activation == 1) { // relu
int num_blocks = 0;
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Relu_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
} else if (activation == 2) { // sigmoid
int num_blocks = 0;
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
}
}
cublasStatus_t cublas_status; cublasStatus_t cublas_status;
// Call GEMM dgrad // Call GEMM dgrad
if (layer > 0 || requires_grad == 1) {
cublas_status = mlp_gemm( cublas_status = mlp_gemm(
handle, handle,
CUBLAS_OP_N, CUBLAS_OP_N,
...@@ -764,6 +1241,7 @@ int mlp_bp( ...@@ -764,6 +1241,7 @@ int mlp_bp(
printf("GEMM dgrad failed with %d\n", cublas_status); printf("GEMM dgrad failed with %d\n", cublas_status);
return 1; return 1;
} }
}
// Call GEMM wgrad // Call GEMM wgrad
cublas_status = mlp_gemm( cublas_status = mlp_gemm(
...@@ -801,7 +1279,9 @@ template int mlp_fp<float>( ...@@ -801,7 +1279,9 @@ template int mlp_fp<float>(
int* output_features, int* output_features,
float** BPtr, float** BPtr,
float* Y, float* Y,
float* reserved_space); float* reserved_space,
int use_bias,
int activation);
template int mlp_bp<float>( template int mlp_bp<float>(
float* X, float* X,
...@@ -816,7 +1296,10 @@ template int mlp_bp<float>( ...@@ -816,7 +1296,10 @@ template int mlp_bp<float>(
float* work_space, float* work_space,
float* dX, float* dX,
float** dwPtr, float** dwPtr,
float** dbPtr); float** dbPtr,
bool requires_grad,
int use_bias,
int activation);
template int mlp_fp<at::Half>( template int mlp_fp<at::Half>(
at::Half* X, at::Half* X,
...@@ -827,7 +1310,9 @@ template int mlp_fp<at::Half>( ...@@ -827,7 +1310,9 @@ template int mlp_fp<at::Half>(
int* output_features, int* output_features,
at::Half** BPtr, at::Half** BPtr,
at::Half* Y, at::Half* Y,
at::Half* reserved_space); at::Half* reserved_space,
int use_bias,
int activation);
template int mlp_bp<at::Half>( template int mlp_bp<at::Half>(
at::Half* X, at::Half* X,
...@@ -842,7 +1327,10 @@ template int mlp_bp<at::Half>( ...@@ -842,7 +1327,10 @@ template int mlp_bp<at::Half>(
at::Half* work_space, at::Half* work_space,
at::Half* dX, at::Half* dX,
at::Half** dwPtr, at::Half** dwPtr,
at::Half** dbPtr); at::Half** dbPtr,
bool requires_grad,
int use_bias,
int activation);
template int mlp_fp<double>( template int mlp_fp<double>(
double* X, double* X,
...@@ -853,7 +1341,9 @@ template int mlp_fp<double>( ...@@ -853,7 +1341,9 @@ template int mlp_fp<double>(
int* output_features, int* output_features,
double** BPtr, double** BPtr,
double* Y, double* Y,
double* reserved_space); double* reserved_space,
int use_bias,
int activation);
template int mlp_bp<double>( template int mlp_bp<double>(
double* X, double* X,
...@@ -868,7 +1358,10 @@ template int mlp_bp<double>( ...@@ -868,7 +1358,10 @@ template int mlp_bp<double>(
double* work_space, double* work_space,
double* dX, double* dX,
double** dwPtr, double** dwPtr,
double** dbPtr); double** dbPtr,
bool requires_grad,
int use_bias,
int activation);
template size_t get_mlp_bp_workspace_in_bytes<float>( template size_t get_mlp_bp_workspace_in_bytes<float>(
int batch_size, int batch_size,
......
...@@ -13,6 +13,17 @@ ...@@ -13,6 +13,17 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template<typename x_t, typename y_t, typename out_t> template<typename x_t, typename y_t, typename out_t>
struct AxpbyFunctor struct AxpbyFunctor
{ {
...@@ -43,46 +54,74 @@ struct AxpbyFunctor ...@@ -43,46 +54,74 @@ struct AxpbyFunctor
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
bool finite = true;
x_t r_x[ILP];
y_t r_y[ILP];
out_t r_out[ILP];
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(out))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_x, x, 0 , i_start);
load_store(r_y, y, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
if(arg_to_check == -1)
finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
if(arg_to_check == 0)
finite = finite && isfinite(r_x[ii]);
if(arg_to_check == 1)
finite = finite && isfinite(r_y[ii]);
}
// store
load_store(out, r_out, i_start , 0);
}
}
else
{
// Non-divergent exit condition for __syncthreads, not necessary here // Non-divergent exit condition for __syncthreads, not necessary here
float xs[ILP]; for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
float ys[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{ {
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
xs[ii] = 0; r_x[ii] = 0;
ys[ii] = 0; r_y[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
{ {
xs[ii] = static_cast<float>(x[i]); r_x[ii] = x[i];
ys[ii] = static_cast<float>(y[i]); r_y[ii] = y[i];
} }
} }
#pragma unroll
// see note in multi_tensor_scale_kernel.cu
#pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
if(i < n && i < chunk_size)
{
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]);
bool finite = true;
if(arg_to_check == -1) if(arg_to_check == -1)
finite = (isfinite(xs[ii]) && isfinite(ys[ii])); finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
if(arg_to_check == 0) if(arg_to_check == 0)
finite = isfinite(xs[ii]); finite = finite && isfinite(r_x[ii]);
if(arg_to_check == 1) if(arg_to_check == 1)
finite = isfinite(ys[ii]); finite = finite && isfinite(r_y[ii]);
if(!finite) }
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. // see note in multi_tensor_scale_kernel.cu
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
out[i] = r_out[ii];
} }
} }
} }
if(!finite)
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
} }
}; };
......
...@@ -13,6 +13,17 @@ ...@@ -13,6 +13,17 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template<typename x_t> template<typename x_t>
struct L2NormFunctor struct L2NormFunctor
{ {
...@@ -41,12 +52,33 @@ struct L2NormFunctor ...@@ -41,12 +52,33 @@ struct L2NormFunctor
__shared__ float s_vals[512]; __shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be sure... float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for(int i = 0; i < ILP; i++) for(int i = 0; i < ILP; i++)
{
vals[i] = 0.f; vals[i] = 0.f;
r_x[i] = 0;
}
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_x, x, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
float next = static_cast<float>(r_x[ii]);
vals[ii] += next*next;
}
}
}
else
{
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{ {
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
...@@ -57,6 +89,7 @@ struct L2NormFunctor ...@@ -57,6 +89,7 @@ struct L2NormFunctor
} }
} }
} }
}
float val = 0.f; float val = 0.f;
for(int i = 0; i < ILP; i++) for(int i = 0; i < ILP; i++)
...@@ -104,12 +137,33 @@ struct MaxNormFunctor ...@@ -104,12 +137,33 @@ struct MaxNormFunctor
__shared__ float s_vals[512]; __shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be sure... float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for(int i = 0; i < ILP; i++) for(int i = 0; i < ILP; i++)
{
vals[i] = 0.f; vals[i] = 0.f;
r_x[i] = 0;
}
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_x, x, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
float next = static_cast<float>(r_x[ii]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
}
else
{
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{ {
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
...@@ -120,6 +174,7 @@ struct MaxNormFunctor ...@@ -120,6 +174,7 @@ struct MaxNormFunctor
} }
} }
} }
}
float val = 0.f; float val = 0.f;
for(int i = 0; i < ILP; i++) for(int i = 0; i < ILP; i++)
......
...@@ -13,6 +13,17 @@ ...@@ -13,6 +13,17 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
typedef enum{ typedef enum{
MOMENT_MODE_0 =0, // L2 regularization mode MOMENT_MODE_0 =0, // L2 regularization mode
MOMENT_MODE_1 =1 // Decoupled weight decay mode MOMENT_MODE_1 =1 // Decoupled weight decay mode
...@@ -68,6 +79,83 @@ struct LAMBStage1Functor ...@@ -68,6 +79,83 @@ struct LAMBStage1Functor
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(g) &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v))
{
T l_g[ILP];
T l_p[ILP];
T l_m[ILP];
T l_v[ILP];
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(l_g, g, 0, i_start);
if (decay != 0)
load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start);
// unpack
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_g[ii] = l_g[ii];
if (decay == 0) {
r_p[ii] = MATH_T(0);
}
else {
r_p[ii] = l_p[ii];
}
r_m[ii] = l_m[ii];
r_v[ii] = l_v[ii];
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
l_p[ii] = r_p[ii];
l_m[ii] = r_m[ii];
l_v[ii] = r_v[ii];
}
// store
load_store(g, l_p, i_start, 0);
load_store(m, l_m, i_start, 0);
load_store(v, l_v, i_start, 0);
}
}
else
{
// see note in multi_tensor_scale_kernel.cu // see note in multi_tensor_scale_kernel.cu
for(int i_start = 0; for(int i_start = 0;
i_start < n && i_start < chunk_size; i_start < n && i_start < chunk_size;
...@@ -137,6 +225,7 @@ struct LAMBStage1Functor ...@@ -137,6 +225,7 @@ struct LAMBStage1Functor
} }
} }
} }
}
}; };
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
...@@ -173,6 +262,29 @@ struct LAMBStage2Functor ...@@ -173,6 +262,29 @@ struct LAMBStage2Functor
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(update))
{
T r_p[ILP];
T r_update[ILP];
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_p, p, 0, i_start);
load_store(r_update, update, 0, i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * static_cast<MATH_T>(r_update[ii]));
}
load_store(p, r_p, i_start, 0);
}
}
else
{
for(int i_start = 0; for(int i_start = 0;
i_start < n && i_start < chunk_size; i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) i_start += blockDim.x*ILP)
...@@ -205,6 +317,7 @@ struct LAMBStage2Functor ...@@ -205,6 +317,7 @@ struct LAMBStage2Functor
} }
} }
} }
}
}; };
......
...@@ -15,6 +15,17 @@ ...@@ -15,6 +15,17 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template<typename in_t, typename out_t> template<typename in_t, typename out_t>
struct ScaleFunctor struct ScaleFunctor
{ {
...@@ -40,38 +51,62 @@ struct ScaleFunctor ...@@ -40,38 +51,62 @@ struct ScaleFunctor
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
bool finite = true;
in_t r_in[ILP];
out_t r_out[ILP];
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_in, in, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
}
// store
load_store(out, r_out, i_start, 0);
}
}
else
{
// Non-divergent exit condition for __syncthreads, not necessary here // Non-divergent exit condition for __syncthreads, not necessary here
float incoming_vals[ILP]; for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{ {
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
incoming_vals[ii] = 0; r_in[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
incoming_vals[ii] = static_cast<float>(in[i]); r_in[ii] = in[i];
} }
// note for clarification to future michael: // note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling // From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive. // the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other. // Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though. // There is still compute ILP benefit from unrolling the loop though.
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
{ out[i] = r_out[ii];
out[i] = static_cast<out_t>(incoming_vals[ii]*scale);
if(!isfinite(incoming_vals[ii]))
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
} }
} }
} }
if(!finite)
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
} }
}; };
......
...@@ -177,7 +177,13 @@ if "--cuda_ext" in sys.argv: ...@@ -177,7 +177,13 @@ if "--cuda_ext" in sys.argv:
'-O3', '-O3',
'--use_fast_math'] + version_dependent_macros})) '--use_fast_math'] + version_dependent_macros}))
else: else:
print ("INFO: Skipping FusedLayerNorm extension.") print ("INFO: Building FusedLayerNorm extension.")
ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda',
sources=['csrc/layer_norm_cuda.cpp',
'csrc/hip/layer_norm_hip_kernel.hip'],
extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros,
'nvcc' : []}))
if not is_rocm_pytorch: if not is_rocm_pytorch:
ext_modules.append( ext_modules.append(
......
...@@ -51,6 +51,116 @@ class TestMLP(unittest.TestCase): ...@@ -51,6 +51,116 @@ class TestMLP(unittest.TestCase):
ref_mlp[0].bias.grad.detach().cpu().numpy(), ref_mlp[0].bias.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5) atol=1e-7, rtol=1e-5)
def test_no_bias(self):
for use_activation in ['none', 'relu', 'sigmoid']:
mlp = MLP(mlp_sizes, bias=False, activation=use_activation).cuda()
mlp_layers = []
for i in range(mlp.num_layers):
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=False)
mlp.weights[i].data.copy_(linear.weight)
mlp_layers.append(linear)
if use_activation == 'relu':
mlp_layers.append(nn.ReLU(inplace=True))
if use_activation == 'sigmoid':
mlp_layers.append(nn.Sigmoid())
ref_mlp = nn.Sequential(*mlp_layers).cuda()
test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.).requires_grad_()
ref_input = test_input.clone().detach().requires_grad_()
mlp_out = mlp(test_input)
ref_out = ref_mlp(ref_input)
np.testing.assert_allclose(
mlp_out.detach().cpu().numpy(),
ref_out.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
mlp_out.mean().mul(10.).backward()
ref_out.mean().mul(10.).backward()
np.testing.assert_allclose(
test_input.grad.detach().cpu().numpy(),
ref_input.grad.detach().cpu().numpy(),
atol=0, rtol=100)
np.testing.assert_allclose(
mlp.weights[0].grad.detach().cpu().numpy(),
ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=100)
def test_with_bias(self):
for use_activation in ['none', 'relu', 'sigmoid']:
mlp = MLP(mlp_sizes, bias=True, activation=use_activation).cuda()
mlp_layers = []
for i in range(mlp.num_layers):
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=True)
mlp.weights[i].data.copy_(linear.weight)
mlp.biases[i].data.copy_(linear.bias)
mlp_layers.append(linear)
if use_activation == 'relu':
mlp_layers.append(nn.ReLU(inplace=True))
if use_activation == 'sigmoid':
mlp_layers.append(nn.Sigmoid())
ref_mlp = nn.Sequential(*mlp_layers).cuda()
test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.).requires_grad_()
ref_input = test_input.clone().detach().requires_grad_()
mlp_out = mlp(test_input)
ref_out = ref_mlp(ref_input)
np.testing.assert_allclose(
mlp_out.detach().cpu().numpy(),
ref_out.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
mlp_out.mean().mul(10.).backward()
ref_out.mean().mul(10.).backward()
np.testing.assert_allclose(
test_input.grad.detach().cpu().numpy(),
ref_input.grad.detach().cpu().numpy(),
atol=0, rtol=1)
np.testing.assert_allclose(
mlp.weights[0].grad.detach().cpu().numpy(),
ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1)
np.testing.assert_allclose(
mlp.biases[0].grad.detach().cpu().numpy(),
ref_mlp[0].bias.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
def test_no_grad(self):
mlp = MLP(mlp_sizes).cuda()
mlp_layers = []
for i in range(mlp.num_layers):
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])
mlp.weights[i].data.copy_(linear.weight)
mlp.biases[i].data.copy_(linear.bias)
mlp_layers.append(linear)
mlp_layers.append(nn.ReLU(inplace=True))
ref_mlp = nn.Sequential(*mlp_layers).cuda()
test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.)
ref_input = test_input.clone().detach()
mlp_out = mlp(test_input)
ref_out = ref_mlp(ref_input)
np.testing.assert_allclose(
mlp_out.detach().cpu().numpy(),
ref_out.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
mlp_out.mean().mul(10.).backward()
ref_out.mean().mul(10.).backward()
np.testing.assert_allclose(
mlp.weights[0].grad.detach().cpu().numpy(),
ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
def test_performance_half(self): def test_performance_half(self):
mlp = MLP(mlp_sizes).cuda().half() mlp = MLP(mlp_sizes).cuda().half()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment