Commit c2b62b7f authored by JR_ZZU's avatar JR_ZZU 🌴
Browse files

delete origin files

parent 2a4864d5
/**
* From PyTorch:
*
* Copyright (c) 2016- Facebook, Inc (Adam Paszke)
* Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
* Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
* Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
* Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
* Copyright (c) 2011-2013 NYU (Clement Farabet)
* Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
* Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
* Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
*
* From Caffe2:
*
* Copyright (c) 2016-present, Facebook Inc. All rights reserved.
*
* All contributions by Facebook:
* Copyright (c) 2016 Facebook Inc.
*
* All contributions by Google:
* Copyright (c) 2015 Google Inc.
* All rights reserved.
*
* All contributions by Yangqing Jia:
* Copyright (c) 2015 Yangqing Jia
* All rights reserved.
*
* All contributions from Caffe:
* Copyright(c) 2013, 2014, 2015, the respective contributors
* All rights reserved.
*
* All other contributions:
* Copyright(c) 2015, 2016 the respective contributors
* All rights reserved.
*
* Caffe2 uses a copyright model similar to Caffe: each contributor holds
* copyright over their contributions to Caffe2. The project versioning records
* all such contribution and copyright details. If a contributor wants to further
* mark their specific copyright on a particular contribution, they should
* indicate their copyright solely in the commit message of the change when it is
* committed.
*
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
* and IDIAP Research Institute nor the names of its contributors may be
* used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/NumericLimits.cuh>
#include "type_shim.h"
#include "compat.h"
#define ALIGN_BYTES 16
#ifdef __HIP_PLATFORM_HCC__
#define WARP_SIZE 64
#define SYNCWARP(mask)
#else
#define WARP_SIZE 32
#define SYNCWARP(mask) __syncwarp(mask)
#endif
using Tensor = at::Tensor;
using TensorList = at::TensorList;
using ScalarType = at::ScalarType;
using at::acc_type;
template<typename T, typename AccumT, typename OutT>
struct LogSoftMaxForwardEpilogue {
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
: logsum(max_input + std::log(sum)) {}
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp)
: logsum(max_log_sum_exp) {}
__device__ __forceinline__ OutT operator()(T input) const {
return static_cast<OutT>(input - logsum);
}
const AccumT logsum;
};
template<typename T, typename AccumT, typename OutT>
struct LogSoftMaxBackwardEpilogue {
__device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)
: sum(sum) {}
__device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);
}
const AccumT sum;
};
const int max_threads = 1024;
inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
uint64_t block_size = 1;
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
while (block_size < (max_block_size/2)) block_size *= 2;
// Launch at least a single warp - the kernel assumes that.
block_size = std::max(block_size, static_cast<uint64_t>(WARP_SIZE));
return dim3(block_size);
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
////////////////////////////////////////////////////////////////////////////////
// Regular kernel (fast when dim_size is large; requires inner_size == 1)
////////////////////////////////////////////////////////////////////////////////
template <typename T, typename AccumT>
struct MaxFloat
{
__device__ __forceinline__ AccumT operator()(AccumT max, T v) const {
return ::max(max, (AccumT)v);
}
};
template<typename T, typename AccumT>
struct AddFloat
{
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + v;
}
};
template<typename T, typename AccumT>
struct SumExpFloat
{
__device__ __forceinline__ SumExpFloat(AccumT v)
: max_k(v) {}
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + std::exp(v - max_k);
}
const AccumT max_k;
};
template <template<typename> class Reduction, typename AccumT>
__device__ __forceinline__ AccumT
blockReduce(AccumT* smem, AccumT val,
const Reduction<AccumT>& r,
AccumT defaultVal)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads();
smem[threadIdx.x] = val;
__syncthreads();
AccumT warpVal = defaultVal;
// First warp will perform per-warp reductions for the remaining warps
uint32_t mask = (((uint64_t)1) << (blockDim.x / WARP_SIZE)) - 1;
if (threadIdx.x < WARP_SIZE) {
int lane = threadIdx.x % WARP_SIZE;
if (lane < blockDim.x / WARP_SIZE) {
#pragma unroll
for (int i = 0; i < WARP_SIZE; ++i) {
warpVal = r(warpVal, smem[lane * WARP_SIZE + i]);
}
SYNCWARP(mask);
smem[lane] = warpVal;
}
}
__syncthreads();
// First thread will perform a reduction of the above per-warp reductions
AccumT blockVal = defaultVal;
if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / WARP_SIZE; ++i) {
blockVal = r(blockVal, smem[i]);
}
smem[0] = blockVal;
}
// Sync and broadcast
__syncthreads();
return smem[0];
}
template <template<typename> class Reduction1, template<typename> class Reduction2, typename AccumT>
__device__ __forceinline__ void
blockReduce(AccumT* smem,
AccumT* reducVal1,
AccumT val1,
const Reduction1<AccumT>& r1,
AccumT defaultVal1,
AccumT* reducVal2,
AccumT val2,
const Reduction2<AccumT>& r2,
AccumT defaultVal2)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads();
smem[threadIdx.x] = val1;
smem[blockDim.x + threadIdx.x] = val2;
__syncthreads();
AccumT warpVal1 = defaultVal1;
AccumT warpVal2 = defaultVal2;
// First warp will perform per-warp reductions for the remaining warps
uint32_t mask = (((uint64_t)1) << (blockDim.x / WARP_SIZE)) - 1;
if (threadIdx.x < WARP_SIZE) {
int lane = threadIdx.x % WARP_SIZE;
if (lane < blockDim.x / WARP_SIZE) {
#pragma unroll
for (int i = 0; i < WARP_SIZE; ++i) {
warpVal1 = r1(warpVal1, smem[lane * WARP_SIZE + i]);
warpVal2 = r2(warpVal2, smem[lane * WARP_SIZE + i + blockDim.x]);
}
SYNCWARP(mask);
smem[lane] = warpVal1;
smem[lane + blockDim.x] = warpVal2;
}
}
__syncthreads();
// First thread will perform a reduction of the above per-warp reductions
AccumT blockVal1 = defaultVal1;
AccumT blockVal2 = defaultVal2;
if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / WARP_SIZE; ++i) {
blockVal1 = r1(blockVal1, smem[i]);
blockVal2 = r2(blockVal2, smem[i + blockDim.x]);
}
smem[0] = blockVal1;
smem[blockDim.x] = blockVal2;
}
// Sync and broadcast
__syncthreads();
*reducVal1 = smem[0];
*reducVal2 = smem[blockDim.x];
__syncthreads();
}
template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>
__device__ __forceinline__ AccumT
ilpReduce(int shift,
T* data,
int size,
const Reduction<T, AccumT>& r,
AccumT defaultVal)
{
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
AccumT threadVal = defaultVal;
int offset = threadIdx.x;
// shift and do 1
if(shift > 0){
data -= shift;
size += shift;
if(threadIdx.x >= shift){
threadVal = r(threadVal, data[offset]);
}
size -= blockDim.x;
data += blockDim.x;
}
int last = size % (ILP * blockDim.x);
T v[ILP];
LoadT* value = reinterpret_cast<LoadT*>(&v);
for (; offset * ILP < (size - last); offset += blockDim.x) {
*value = reinterpret_cast<LoadT*>(data)[offset];
for (int j = 0; j < ILP; ++j) {
threadVal = r(threadVal, v[j]);
}
}
offset = size - last + threadIdx.x;
// Epilogue
for (; offset < size; offset += blockDim.x)
threadVal = r(threadVal, data[offset]);
return threadVal;
}
template <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT>
__device__ __forceinline__ void
ilpReduce(int shift,
T* data,
int size,
AccumT* reducVal1,
const Reduction1<T, AccumT>& r1,
AccumT defaultVal1,
AccumT* reducVal2,
const Reduction2<T, AccumT>& r2,
AccumT defaultVal2)
{
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
AccumT threadVal1 = defaultVal1;
AccumT threadVal2 = defaultVal2;
int offset = threadIdx.x;
// shift and do 1
if(shift > 0){
data -= shift;
size += shift;
if(threadIdx.x >= shift){
threadVal1 = r1(threadVal1, data[offset]);
threadVal2 = r2(threadVal2, data[offset]);
}
size -= blockDim.x;
data += blockDim.x;
}
int last = size % (ILP * blockDim.x);
T v[ILP];
LoadT* value = reinterpret_cast<LoadT*>(&v);
for (; offset * ILP < (size - last); offset += blockDim.x) {
*value = reinterpret_cast<LoadT*>(data)[offset];
for (int j = 0; j < ILP; ++j) {
threadVal1 = r1(threadVal1, v[j]);
threadVal2 = r2(threadVal2, v[j]);
}
}
offset = size - last + threadIdx.x;
// Epilogue
for (; offset < size; offset += blockDim.x) {
threadVal1 = r1(threadVal1, data[offset]);
threadVal2 = r2(threadVal2, data[offset]);
}
*reducVal1 = threadVal1;
*reducVal2 = threadVal2;
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
__global__ void
cunn_SoftMaxXEntropyForward(
accscalar_t *losses,
outscalar_t *max_log_sum_exp,
scalar_t *input,
int64_t *labels,
int64_t classes,
const float smoothing)
{
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);
// forward pointers to batch[blockIdx.x]
// each block handles a sample in the mini-batch
input += blockIdx.x * classes;
//output += blockIdx.x * classes;
const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);
int64_t label = labels[blockIdx.x];
// find the max and sum
accscalar_t threadMax, threadSum, max_k, sum_k;
ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>(
shift, input, classes,
&threadMax, MaxFloat<scalar_t, accscalar_t>(),
-at::numeric_limits<accscalar_t>::max(),
&threadSum, AddFloat<scalar_t, accscalar_t>(),
static_cast<accscalar_t>(0));
blockReduce<Max, Add, accscalar_t>(
sdata,
&max_k, threadMax, Max<accscalar_t>(),
-at::numeric_limits<accscalar_t>::max(),
&sum_k, threadSum, Add<accscalar_t>(),
static_cast<accscalar_t>(0));
accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
accscalar_t sumAll = blockReduce<Add, accscalar_t>(
sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
// calculate per element loss with label smoothing
// reserve max + log_sum_exp for bprop
if (threadIdx.x == 0) {
accscalar_t log_prob = epilogue(static_cast<accscalar_t>(input[label]));
losses[blockIdx.x] = (max_k + std::log(sumAll) - sum_k / classes) \
* smoothing - log_prob * (1 - smoothing);
max_log_sum_exp[blockIdx.x] = max_k + std::log(sumAll);
}
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
__device__ __forceinline__ void
apply(scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
{
accscalar_t smooth_positives = 1.0 - smoothing;
accscalar_t smooth_negatives = smoothing / classes;
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
int64_t label = labels[blockIdx.x];
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
int offset = threadIdx.x;
int last = classes % (ILP * blockDim.x);
for (; offset < classes - last; offset += blockDim.x * ILP) {
accscalar_t tmpLogits[ILP];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
tmpLogits[j] = static_cast<accscalar_t>(logits[offset + j * blockDim.x]);
}
#pragma unroll
for (int j = 0; j < ILP; ++j)
gradInput[offset + j * blockDim.x] = tmpGradOutput * (
std::exp(tmpLogits[j] - coeff) - static_cast<accscalar_t>(
(offset + j * blockDim.x == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
for (; offset < classes; offset += blockDim.x)
gradInput[offset] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>((offset == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
__device__ __forceinline__ void
aligned_apply(int shift,
scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
{
accscalar_t smooth_positives = 1.0 - smoothing;
accscalar_t smooth_negatives = smoothing / classes;
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
int64_t label = labels[blockIdx.x];
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
int offset = threadIdx.x;
// shift and do 1
if(shift > 0){
logits -= shift;
gradInput -= shift;
classes += shift;
if(threadIdx.x >= shift){
gradInput[offset] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
classes -= blockDim.x;
gradInput += blockDim.x;
logits += blockDim.x;
shift -= blockDim.x;
}
int last = classes % (ILP * blockDim.x);
typedef typename std::aligned_storage<ILP*sizeof(scalar_t), ILP*alignof(scalar_t)>::type LoadT;
// input
scalar_t v[ILP];
LoadT* value = reinterpret_cast<LoadT*>(&v);
// output
scalar_t r[ILP];
LoadT* result = reinterpret_cast<LoadT*>(&r);
for (; offset * ILP < (classes - last); offset += blockDim.x) {
*value = reinterpret_cast<LoadT*>(logits)[offset];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
r[j] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(v[j]) - coeff) -
static_cast<accscalar_t>(((ILP * offset + j - shift) == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
reinterpret_cast<LoadT*>(gradInput)[offset] = *result;
}
offset = classes - last + threadIdx.x;
for (; offset < classes; offset += blockDim.x)
gradInput[offset] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
__global__ void
cunn_SoftMaxXEntropyBackward(
scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
{
gradInput += blockIdx.x * classes;
logits += blockIdx.x * classes;
// Do vectorized load/store when input/output have same alignment
const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);
const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
if (shift == shift_){
aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
}
else {
apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
}
}
template<template<typename, typename, typename> class Epilogue>
std::vector<Tensor> host_softmax_xentropy(
const Tensor & input_,
const Tensor & labels_,
const float smoothing,
const bool half_to_float){
if (half_to_float) AT_ASSERTM(input_.type().scalarType() == ScalarType::Half || input_.type().scalarType() == ScalarType::BFloat16,"conversion is supported for Half and BFloat16 type only");
AT_ASSERTM(labels_.type().scalarType() == ScalarType::Long,"Label type should be CUDA Long");
auto input = input_.contiguous();
Tensor max_log_sum_exp = at::empty_like(labels_, half_to_float ? input.options().dtype(ScalarType::Float) : input.options());
Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float));
static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
std::is_same<acc_type<at::Half, true>, double>::value,
"accscalar_t for half should be float or double");
AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported");
AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional");
AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples");
AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0");
const int64_t dim = 1;
int64_t outer_size = 1;
int64_t dim_size = input.size(dim);
int64_t inner_size = 1;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
for (int64_t i = 0; i < dim; ++i)
outer_size *= input.size(i);
for (int64_t i = dim + 1; i < input.dim(); ++i)
inner_size *= input.size(i);
// This kernel spawns a block per each element in the batch.
// XXX: it assumes that inner_size == 1
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
dim3 grid(outer_size);
using namespace at;
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(input.scalar_type(), 0, "host_softmax_xentropy",
using accscalar_t = at::acc_type<scalar_t_0, true>;
const int ILP = sizeof(float4)/sizeof(scalar_t_0);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
if (!half_to_float) {
cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
losses.DATA_PTR<accscalar_t>(), max_log_sum_exp.DATA_PTR<scalar_t_0>(),
input.DATA_PTR<scalar_t_0>(), labels_.DATA_PTR<int64_t>(),
dim_size, smoothing
);
} else {
cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
losses.DATA_PTR<accscalar_t>(), max_log_sum_exp.DATA_PTR<accscalar_t>(),
input.DATA_PTR<scalar_t_0>(), labels_.DATA_PTR<int64_t>(),
dim_size, smoothing
);
}
);
C10_CUDA_CHECK(cudaGetLastError());
std::vector<at::Tensor> ret = {losses, max_log_sum_exp};
return ret;
}
template<template<typename, typename, typename> class Epilogue>
Tensor host_softmax_xentropy_backward(
const at::Tensor &grad_loss,
const at::Tensor &logits_,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing,
bool half_to_float) {
const int64_t dim = 1;
Tensor gI = at::empty_like(logits_);
if (grad_loss.numel() == 0) {
return gI;
}
auto grad = grad_loss.contiguous();
auto logits = logits_.contiguous();
static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
std::is_same<acc_type<at::Half, true>, double>::value,
"accscalar_t for half should be float or double");
if (grad.dim() == 0) grad = grad.view(1);
AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported");
AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional");
AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0");
AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples");
AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples");
int64_t outer_size = 1;
int64_t dim_size = logits.size(dim);
int64_t inner_size = 1;
for (int64_t i = 0; i < dim; ++i)
outer_size *= logits.size(i);
for (int64_t i = dim + 1; i < logits.dim(); ++i)
inner_size *= logits.size(i);
// See descriptions of kernels above.
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
dim3 grid(outer_size);
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(gI.scalar_type(), 0, "host_softmax_xentropy_backward",
using accscalar_t = acc_type<scalar_t_0, true>;
const int ILP = sizeof(float4)/sizeof(scalar_t_0);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
if (!half_to_float) {
cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
gI.DATA_PTR<scalar_t_0>(), logits.DATA_PTR<scalar_t_0>(),
max_log_sum_exp.DATA_PTR<scalar_t_0>(),
grad.DATA_PTR<scalar_t_0>(), labels.DATA_PTR<int64_t>(),
smoothing, dim_size
);
} else {
cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
gI.DATA_PTR<scalar_t_0>(), logits.DATA_PTR<scalar_t_0>(),
max_log_sum_exp.DATA_PTR<accscalar_t>(),
grad.DATA_PTR<accscalar_t>(), labels.DATA_PTR<int64_t>(),
smoothing, dim_size
);
}
);
C10_CUDA_CHECK(cudaGetLastError());
return gI;
}
std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const bool half_to_float){
return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, half_to_float);
}
at::Tensor softmax_xentropy_backward_cuda(
const at::Tensor &grad_loss,
const at::Tensor &logits,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing) {
bool half_to_float = grad_loss.type().scalarType() != logits.type().scalarType();
if (half_to_float) {
AT_ASSERTM((grad_loss.type().scalarType() == ScalarType::Float && (logits.type().scalarType() == ScalarType::Half || logits.type().scalarType() == ScalarType::BFloat16)), "expected input and grad types to match, or input to be at::Half or at::Bfloat16 and grad to be at::Float");
}
return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, half_to_float);
}
import torch
import torch.nn.functional as F
import argparse
from apex.contrib.multihead_attn import SelfMultiheadAttn
from apex.contrib.multihead_attn import EncdecMultiheadAttn
parser = argparse.ArgumentParser(description='Multihead Attention Standalone Test')
parser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input')
parser.add_argument('--num-seqs-start', default=5, type=int, help='Start Range of Number of Sequences')
parser.add_argument('--num-seqs-stop', default=80, type=int, help='Stop Range of Number of Sequences')
parser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences')
parser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute')
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
parser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--seed-start', default=1, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--seed-end', default=100, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension')
parser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads')
parser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.')
parser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.')
parser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.')
parser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.')
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
parser.add_argument('--eval', action='store_true', help='Inference only, no backward pass.')
args = parser.parse_args()
assert args.seq_length % 64 == 0, "Sequence Length should be a multiple of 64!"
if not torch.cuda.is_available():
raise NotImplementedError('Running on CPU is not supported')
torch.cuda.set_device(0)
dropout_prob = 0.1
for seed in range(args.seed_start, args.seed_end+1) :
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
ref_layer = None
if args.encdec_attn :
ref_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default')
else :
ref_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default')
ref_layer.cuda()
ref_layer.half()
ref_layer.reset_parameters()
ref_inputs = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
ref_inputs_kv = None
if args.encdec_attn :
ref_inputs_kv = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
ref_grads = torch.randn_like(ref_inputs)
ref_outputs,_ = ref_layer.forward(ref_inputs,
ref_inputs_kv,
ref_inputs_kv,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=(not args.eval))
ref_outputs.backward(ref_grads)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
tst_layer = None
if args.encdec_attn :
tst_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast')
else:
tst_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast')
tst_layer.cuda()
tst_layer.half()
tst_layer.reset_parameters()
tst_inputs = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
tst_inputs_kv = None
if args.encdec_attn :
tst_inputs_kv = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
assert torch.equal(ref_inputs,tst_inputs), "ERROR: Inputs are different!"
tst_grads = torch.randn_like(tst_inputs)
tst_outputs,_ = tst_layer.forward(tst_inputs,
tst_inputs_kv,
tst_inputs_kv,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=(not args.eval))
tst_outputs.backward(tst_grads)
fwd_close = torch.equal(ref_outputs, tst_outputs)
bwd_close = torch.equal(ref_inputs.grad, tst_inputs.grad)
diff_fwd = ref_outputs - tst_outputs
diff_cnt_fwd = diff_fwd.ne(0.0).sum()
diff_accum_fwd = diff_fwd.abs().sum()
diff_bwd = ref_inputs.grad - tst_inputs.grad
diff_cnt_bwd = diff_bwd.ne(0.0).sum()
diff_accum_bwd = diff_bwd.abs().sum()
print(">>> Seed: ", seed, fwd_close, diff_cnt_fwd.item(), diff_accum_fwd.item(), bwd_close, diff_cnt_bwd.item(), diff_accum_bwd.item())
import torch
import torch.nn.functional as F
import argparse
from apex.contrib.multihead_attn import SelfMultiheadAttn
from apex.contrib.multihead_attn import EncdecMultiheadAttn
parser = argparse.ArgumentParser(description='Multihead Attention Standalone Test')
parser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input')
parser.add_argument('--num-seqs-start', default=10, type=int, help='Start Range of Number of Sequences')
parser.add_argument('--num-seqs-stop', default=120, type=int, help='Stop Range of Number of Sequences')
parser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences')
parser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute')
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
parser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension')
parser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads')
parser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.')
parser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.')
parser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.')
parser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.')
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
parser.add_argument('--biases', action='store_true', help='Execute multihead attention with Linear Biases.')
args = parser.parse_args()
if not torch.cuda.is_available():
raise NotImplementedError('Running on CPU is not supported')
torch.cuda.set_device(0)
torch.manual_seed(111)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(111)
attn_layers = []
for idx in range(0, args.layers) :
if args.encdec_attn :
if args.ref :
attn_layers.append(EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=False, impl='default'))
else :
attn_layers.append(EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='fast'))
else :
if args.native :
attn_layers.append(torch.nn.MultiheadAttention(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases))
elif args.ref :
attn_layers.append(SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='default'))
else :
attn_layers.append(SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='fast'))
attn_layers[idx].cuda()
attn_layers[idx].half()
if not args.native :
attn_layers[idx].reset_parameters()
start_evt_fwd = []
start_evt_bwd = []
stop_evt_bwd = []
for recorded_trial in range(0, args.trials) :
start_evt_fwd.append(torch.cuda.Event(enable_timing=True))
start_evt_bwd.append(torch.cuda.Event(enable_timing=True))
stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))
for sequences in range(args.num_seqs_start, args.num_seqs_stop + args.num_seqs_inc, args.num_seqs_inc) :
inputs = torch.randn(args.seq_length, sequences, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
grads = torch.randn_like(inputs)
for trial in range(0, args.trials + args.warmup_trials) :
layer_inputs = inputs
evt_idx = trial - args.warmup_trials
if evt_idx >= 0 :
start_evt_fwd[evt_idx].record()
for lyr_idx in range(0, args.layers) :
if args.native :
outputs,_ = attn_layers[lyr_idx].forward(layer_inputs,
layer_inputs,
layer_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None)
else :
outputs,_ = attn_layers[lyr_idx].forward(layer_inputs,
layer_inputs,
layer_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
layer_inputs = outputs
if evt_idx >= 0 :
start_evt_bwd[evt_idx].record()
if not args.fwd :
layer_inputs.backward(grads)
if evt_idx >= 0 :
stop_evt_bwd[evt_idx].record()
torch.cuda.synchronize()
elapsed_time_fwd = 0.0
elapsed_time_bwd = 0.0
for evt_idx in range(0, args.trials) :
elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx])
elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])
print("[ {} Attn {} ]Total Tokens: {:4d} Sequences: {:3d} Sequence Length: {:3d} Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format(
'Encdec' if args.encdec_attn else 'Self', \
'Norm&Add' if args.norm_add else '', \
sequences*args.seq_length, \
sequences, \
args.seq_length, \
elapsed_time_fwd / ( args.trials * args.layers ), \
elapsed_time_bwd / ( args.trials * args.layers )))
###############################################################################
# Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
###############################################################################
import torch
import torch.nn.functional as F
import fmhalib as mha
class FMHAFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors):
batch_size = cu_seqlens.numel() - 1
if batch_size < 4:
max_s = 512
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, True, zero_tensors, None)
else:
context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, False, zero_tensors, None)
ctx.save_for_backward(qkv, S_dmask)
ctx.cu_seqlens = cu_seqlens
ctx.p_dropout = p_dropout
ctx.max_s = max_s
ctx.zero_tensors = zero_tensors
return context
@staticmethod
def backward(ctx, dout):
qkv, S_dmask = ctx.saved_tensors
batch_size = ctx.cu_seqlens.numel() - 1
if batch_size < 4:
dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.zero_tensors)
else:
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.zero_tensors)
return dqkv, None, None, None, None, None
class FMHA(torch.nn.Module):
def __init__(self, config):
super(FMHA, self).__init__()
self.p_dropout = config.attention_probs_dropout_prob
self.h = config.num_attention_heads
self.hidden_size = config.hidden_size
self.d = self.hidden_size // self.h
assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads"
def forward(self, qkv, cu_seqlens, max_s, is_training=True, zero_tensors=False):
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training, zero_tensors)
return ctx.view(-1, self.hidden_size)
try:
import torch
import focal_loss_cuda
from .focal_loss import focal_loss
del torch
del focal_loss_cuda
del focal_loss
except ImportError as err:
print("apex was installed without --focal_loss flag, apex.contrib.focal_loss is not available")
import torch
import focal_loss_cuda
class FocalLoss(torch.autograd.Function):
@staticmethod
def forward(
ctx,
cls_output,
cls_targets_at_level,
num_positives_sum,
num_real_classes,
alpha,
gamma,
label_smoothing=0.0,
):
loss, partial_grad = focal_loss_cuda.forward(
cls_output,
cls_targets_at_level,
num_positives_sum,
num_real_classes,
alpha,
gamma,
label_smoothing,
)
ctx.save_for_backward(partial_grad, num_positives_sum)
return loss
@staticmethod
def backward(ctx, grad_loss):
partial_grad, num_positives_sum = ctx.saved_tensors
# The backward kernel is actually in-place to save memory space,
# partial_grad and grad_input are the same tensor.
grad_input = focal_loss_cuda.backward(grad_loss, partial_grad, num_positives_sum)
return grad_input, None, None, None, None, None, None
def focal_loss(
cls_output: torch.Tensor,
cls_targets_at_level: torch.Tensor,
num_positive_sum: torch.Tensor,
num_real_classes: int,
alpha: float,
gamma: float,
label_smoothing: float = 0.0,
) -> torch.Tensor:
"""Fused focal loss function."""
return FocalLoss.apply(
cls_output,
cls_targets_at_level,
num_positive_sum,
num_real_classes,
alpha,
gamma,
label_smoothing,
)
try:
import torch
import bnp
from .batch_norm import BatchNorm2d_NHWC
del torch
del bnp
del batch_norm
except ImportError as err:
print("apex was installed without --bnp flag, contrib.groupbn is not available")
import torch
import numpy as np
from torch.nn.modules.batchnorm import _BatchNorm
import bnp
def check_if_rocm_pytorch():
is_rocm_pytorch = False
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
return is_rocm_pytorch
IS_ROCM_PYTORCH = check_if_rocm_pytorch()
def check_and_convert_channels_last(tensor, torch_channels_last):
if torch_channels_last:
channels_last = tensor.is_contiguous(memory_format = torch.channels_last)
if not channels_last:
tensor = tensor.to(memory_format = torch.channels_last)
return tensor
class bn_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, torch_channels_last, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):
x = check_and_convert_channels_last(x, torch_channels_last)
if is_train:
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv)
ctx.torch_channels_last = torch_channels_last
ctx.epsilon = epsilon
ctx.momentum = mom
ctx.ret_cta = ret_cta
ctx.fuse_relu = fuse_relu
ctx.my_data = my_data
ctx.pair_data = pair_data
ctx.magic = magic
ctx.pair_data2 = pair_data2
ctx.pair_data3 = pair_data3
ctx.bn_group = bn_group
ctx.bwd_occup = bwd_occup
ctx.bwd_grid_x = bwd_grid_x
ctx.multi_stream = multi_stream
res = bnp.bn_fwd_nhwc(x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream)
return res
else:
return bnp.bn_fwd_eval_nhwc(x, s, b, rm, riv, ret_cta, bn_group, mom, epsilon, fuse_relu)
@staticmethod
def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables
grad_y = check_and_convert_channels_last(grad_y, ctx.torch_channels_last)
x = check_and_convert_channels_last(x, ctx.torch_channels_last)
epsilon = ctx.epsilon
mom = ctx.momentum
ret_cta = ctx.ret_cta
fuse_relu = ctx.fuse_relu
my_data = ctx.my_data
pair_data = ctx.pair_data
magic = ctx.magic
pair_data2 = ctx.pair_data2
pair_data3 = ctx.pair_data3
bn_group = ctx.bn_group
bwd_occup = ctx.bwd_occup
bwd_grid_x = ctx.bwd_grid_x
multi_stream = ctx.multi_stream
dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream)
return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class bn_addrelu_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, torch_channels_last, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):
x = check_and_convert_channels_last(x, torch_channels_last)
z = check_and_convert_channels_last(z, torch_channels_last)
if is_train:
if IS_ROCM_PYTORCH:
if torch_channels_last:
nhw = x.shape[0] * x.shape[2] * x.shape[3]
else:
nhw = x.shape[0] * x.shape[1] * x.shape[2]
shape = int(((nhw + 3) & ~3) * 2 * grid_dim_y)
bitmask = torch.cuda.LongTensor(shape)
else:
bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)
ctx.torch_channels_last = torch_channels_last
ctx.epsilon = epsilon
ctx.momentum = mom
ctx.ret_cta = ret_cta
ctx.my_data = my_data
ctx.pair_data = pair_data
ctx.magic = magic
ctx.pair_data2 = pair_data2
ctx.pair_data3 = pair_data3
ctx.bn_group = bn_group
ctx.bwd_occup = bwd_occup
ctx.bwd_grid_x = bwd_grid_x
ctx.multi_stream = multi_stream
res = bnp.bn_addrelu_fwd_nhwc(x, z, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream)
return res
else:
return bnp.bn_addrelu_fwd_eval_nhwc(x, z, s, b, rm, riv, ret_cta, bn_group, mom, epsilon)
@staticmethod
def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables
grad_y = check_and_convert_channels_last(grad_y, ctx.torch_channels_last)
x = check_and_convert_channels_last(x, ctx.torch_channels_last)
epsilon = ctx.epsilon
mom = ctx.momentum
ret_cta = ctx.ret_cta
my_data = ctx.my_data
pair_data = ctx.pair_data
magic = ctx.magic
pair_data2 = ctx.pair_data2
pair_data3 = ctx.pair_data3
bn_group = ctx.bn_group
bwd_occup = ctx.bwd_occup
bwd_grid_x = ctx.bwd_grid_x
multi_stream = ctx.multi_stream
dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream)
return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class BatchNorm2d_NHWC(_BatchNorm):
# if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True
def __init__(self, num_features, fuse_relu=False, bn_group=1, torch_channels_last=False,max_cta_per_sm=2, cta_launch_margin=12, multi_stream=False):
super(BatchNorm2d_NHWC, self).__init__(num_features)
self.fuse_relu = fuse_relu
self.torch_channels_last = torch_channels_last
self.multi_stream = multi_stream
self.minibatch_mean = torch.cuda.FloatTensor(num_features)
self.minibatch_riv = torch.cuda.FloatTensor(num_features)
#defaut to distributed bn disabled
self.bn_group = bn_group
self.max_cta_per_sm = max_cta_per_sm #used only in training fwd and bwd
self.cta_launch_margin = cta_launch_margin #used only in training fwd and bwd
self.my_data = None
self.pair_data = None
self.pair_data2 = None
self.pair_data3 = None
self.local_rank = 0
self.magic = torch.IntTensor([0])
#calculate cta per sm occupancies
assert(max_cta_per_sm>0) # won't be able to do much with 0 CTAs :)
self.fwd_occupancy = min(bnp.bn_fwd_nhwc_occupancy(), max_cta_per_sm)
self.bwd_occupancy = min(bnp.bn_bwd_nhwc_occupancy(), max_cta_per_sm)
self.addrelu_fwd_occupancy = min(bnp.bn_addrelu_fwd_nhwc_occupancy(), max_cta_per_sm)
self.addrelu_bwd_occupancy = min(bnp.bn_addrelu_bwd_nhwc_occupancy(), max_cta_per_sm)
#calculate grid dimentions based on occupancy numbers
mp_count = torch.cuda.get_device_properties(None).multi_processor_count
self.fwd_grid_dim_x = max(mp_count*self.fwd_occupancy - cta_launch_margin , 1)
self.bwd_grid_dim_x = max(mp_count*self.bwd_occupancy - cta_launch_margin , 1)
self.addrelu_fwd_grid_dim_x = max(mp_count*self.addrelu_fwd_occupancy - cta_launch_margin , 1)
self.addrelu_bwd_grid_dim_x = max(mp_count*self.addrelu_bwd_occupancy - cta_launch_margin , 1)
self.grid_dim_y = (num_features + 63) // 64
# allocate scratch space used by implementation
# TODO: scratch space that is not supposed to be exposed at user code. We only need one time initialization, the
# same buffer could be reused in future iterations. Currently we exposed it here instead of requesting new
# buffer from cache allocator to avoid unnecessary initialization at future iterations.
self.ret_cta = torch.cuda.ByteTensor(8192).fill_(0)
#FIXME: turn pair handles into an array
if bn_group>1:
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert(world_size >= bn_group)
assert(world_size % bn_group == 0)
bn_sync_steps = 1
if (bn_group==4):
bn_sync_steps = 2
if (bn_group==8):
bn_sync_steps = 3
self.ipc_buffer = torch.cuda.ByteTensor(bnp.get_buffer_size(bn_sync_steps))
self.my_data = bnp.get_data_ptr(self.ipc_buffer)
# we are walking on very thin ice here by utilizing internal `_share_cuda_()`
self.storage = self.ipc_buffer.storage()
self.share_cuda = self.storage._share_cuda_()
internal_cuda_mem = self.share_cuda
# internal_cuda_mem[1]: ipc_mem_handle
my_handle = torch.cuda.ByteTensor(np.frombuffer(internal_cuda_mem[1], dtype=np.uint8))
# internal_cuda_mem[3]: offset
my_offset = torch.cuda.IntTensor([internal_cuda_mem[3]])
handles_all = torch.empty(world_size, my_handle.size(0), dtype=my_handle.dtype, device=my_handle.device)
handles_l = list(handles_all.unbind(0))
torch.distributed.all_gather(handles_l, my_handle)
offsets_all = torch.empty(world_size, my_offset.size(0), dtype=my_offset.dtype, device=my_offset.device)
offsets_l = list(offsets_all.unbind(0))
torch.distributed.all_gather(offsets_l, my_offset)
#whom do I actually care about? that would be local_rank XOR 1
self.pair_handle = handles_l[local_rank ^ 1].cpu().contiguous()
pair_offset = offsets_l[local_rank ^ 1].cpu()
self.pair_data = bnp.get_remote_data_ptr(self.pair_handle, pair_offset)
if bn_group>2:
self.pair_handle2 = handles_l[local_rank ^ 2].cpu().contiguous()
pair_offset2 = offsets_l[local_rank ^ 2].cpu()
self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2)
if bn_group>4:
self.pair_handle3 = handles_l[local_rank ^ 4].cpu().contiguous()
pair_offset3 = offsets_l[local_rank ^ 4].cpu()
self.pair_data3 = bnp.get_remote_data_ptr(self.pair_handle3, pair_offset3)
#FIXME: get magic value into C code and eliminate from here
self.magic = torch.IntTensor([2])
self.local_rank = local_rank
def forward(self, x, z=None):
if z is not None:
assert(self.fuse_relu==True)
return bn_addrelu_NHWC_impl.apply(x, z,
self.weight, self.bias,
self.running_mean, self.running_var,
self.minibatch_mean, self.minibatch_riv, self.grid_dim_y, self.ret_cta,
self.momentum,
self.eps, self.training, self.torch_channels_last, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3,
self.addrelu_fwd_occupancy, self.addrelu_fwd_grid_dim_x,
self.addrelu_bwd_occupancy, self.addrelu_bwd_grid_dim_x,
self.multi_stream)
else:
return bn_NHWC_impl.apply(x,
self.weight, self.bias,
self.running_mean, self.running_var,
self.minibatch_mean, self.minibatch_riv, self.ret_cta,
self.momentum,
self.eps, self.fuse_relu, self.training, self.torch_channels_last, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3,
self.fwd_occupancy, self.fwd_grid_dim_x,
self.bwd_occupancy, self.bwd_grid_dim_x,
self.multi_stream)
def __del__(self):
if self.bn_group>1:
bnp.close_remote_data(self.pair_handle)
if self.bn_group>2:
bnp.close_remote_data(self.pair_handle2)
if self.bn_group>4:
bnp.close_remote_data(self.pair_handle3)
from .index_mul_2d import index_mul_2d
import torch
import fused_index_mul_2d
class IndexMul2d_(torch.autograd.Function):
'''
Currently only support index in dimension 0 with a 2-dimension tensor.
The shape of indexed in1 must be same with in2. Now this kernel does not support broadcast.
The datatype must be float32 or float16.
'''
@staticmethod
def forward(ctx, in1: torch.Tensor, in2: torch.Tensor, idx1: torch.Tensor) -> torch.Tensor:
assert in2.size(0) == idx1.size(0)
if ((in1.dtype != torch.float32 and in1.dtype != torch.half) or in2.dtype != in1.dtype):
raise RuntimeError("input1'dtype and input2's dtype must be fp32 or fp16. And input type must be same")
if (in1.dim() != 2 or in2.dim() != 2):
raise RuntimeError("in1 and in2 must be 2-dimension tensor.")
if (idx1.dim() != 1):
raise RuntimeError("idx1 must be 1-dimension tensor.")
if not in1.is_contiguous():
in1 = in1.contiguous()
if not in2.is_contiguous():
in2 = in2.contiguous()
if not idx1.is_contiguous():
idx1 = idx1.contiguous()
assert in1.is_contiguous()
assert in2.is_contiguous()
assert idx1.is_contiguous()
out = torch.empty_like(in2)
if (in1.dtype == torch.float32):
fused_index_mul_2d.float_forward(
out,
in1,
in2,
idx1)
elif (in1.dtype == torch.half):
fused_index_mul_2d.half_forward(
out,
in1,
in2,
idx1)
ctx.for_backwards = (in1, in2, idx1)
return out
@staticmethod
def backward(ctx, grad_out):
in1, in2, idx1 = ctx.for_backwards
grad_in1, grad_in2 = index_mul_2d_backward(in1, in2, idx1, grad_out)
return grad_in1, grad_in2, None
class IndexMul2dBackward_(torch.autograd.Function):
@staticmethod
def forward(ctx, in1: torch.Tensor, in2: torch.Tensor, idx1: torch.Tensor,
grad_out: torch.Tensor) -> torch.Tensor:
if not in1.is_contiguous():
in1 = in1.contiguous()
if not in2.is_contiguous():
in2 = in2.contiguous()
if not idx1.is_contiguous():
idx1 = idx1.contiguous()
if not grad_out.is_contiguous():
grad_out = grad_out.contiguous()
assert in1.is_contiguous()
assert in2.is_contiguous()
assert idx1.is_contiguous()
assert grad_out.is_contiguous()
grad_in1 = torch.zeros_like(in1)
grad_in2 = torch.empty_like(in2)
if (in1.dtype == torch.float32):
fused_index_mul_2d.float_backward(
grad_in1,
grad_in2,
grad_out,
in1,
in2,
idx1)
elif (in1.dtype == torch.half):
fused_index_mul_2d.half_backward(
grad_in1,
grad_in2,
grad_out,
in1,
in2,
idx1)
ctx.for_backwards = (in1, in2, idx1, grad_out)
return grad_in1, grad_in2
@staticmethod
def backward(ctx, grad_grad_in1, grad_grad_in2):
if not grad_grad_in1.is_contiguous():
grad_grad_in1 = grad_grad_in1.contiguous()
if not grad_grad_in2.is_contiguous():
grad_grad_in2 = grad_grad_in2.contiguous()
assert grad_grad_in1.is_contiguous()
assert grad_grad_in2.is_contiguous()
in1, in2, idx1, grad_out = ctx.for_backwards
grad_in1 = torch.zeros_like(in1)
grad_in2 = torch.empty_like(in2)
grad_grad_out = torch.empty_like(grad_out)
if (in1.dtype == torch.float32):
fused_index_mul_2d.float_backward_backward(
grad_grad_out,
grad_in1,
grad_in2,
grad_out,
grad_grad_in1,
grad_grad_in2,
in1,
in2,
idx1)
elif (in1.dtype == torch.half):
fused_index_mul_2d.half_backward_backward(
grad_grad_out,
grad_in1,
grad_in2,
grad_out,
grad_grad_in1,
grad_grad_in2,
in1,
in2,
idx1)
return grad_in1, grad_in2, None, grad_grad_out
index_mul_2d = IndexMul2d_.apply
index_mul_2d_backward = IndexMul2dBackward_.apply
from .layer_norm import FastLayerNorm
import torch
from torch.nn import init
from apex._autocast_utils import _cast_if_autocast_enabled
import fast_layer_norm
class FastLayerNormFN(torch.autograd.Function):
@staticmethod
def forward(ctx, x, gamma, beta, epsilon):
x = x.contiguous()
gamma = gamma.contiguous()
beta = beta.contiguous()
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon)
ctx.save_for_backward(x, gamma, mu, rsigma)
return ymat.view(x.shape)
@staticmethod
def backward(ctx, dy):
# assert dy.is_contiguous()
dy = dy.contiguous() # this happens!
x, gamma, mu, rsigma = ctx.saved_tensors
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dymat = dy.view(xmat.shape)
dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma)
dx = dxmat.view(x.shape)
return dx, dgamma, dbeta, None
def _fast_layer_norm(x, weight, bias, epsilon):
args = _cast_if_autocast_enabled(x, weight, bias, epsilon)
with torch.cuda.amp.autocast(enabled=False):
return FastLayerNormFN.apply(*args)
class FastLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5):
super().__init__()
self.epsilon = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size))
self.bias = torch.nn.Parameter(torch.empty(hidden_size))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, x):
return _fast_layer_norm(x, self.weight, self.bias, self.epsilon)
# Fast Multihead Attention
This implementation has two main features :
* A C++ implementation to avoid the CPU overheads of Pytorch found with smaller batch sizes.
* The removal of all copies and transposes found in standard implementations of Multihead Attention.
| | Python Version | C++ Version |
| :----------------------------------------- | :------------: | :---------: |
| Layer Norm and Residual Add Variant | X | X |
| Includes Linear Biases | X | |
| Reduces CPU Overheads | | X |
| Fuses masking with Softmax | | X |
| Removes Transposes and Copies | X | X |
| Includes Self and Encoder/Decoder Variants | X | X |
## How to Instantiate
`SelfMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)`
`EncdecMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)`
`impl` has two options:
* `fast` uses C++ Version
* `default` uses Python Version
## Instructions to build on Linux
```
$ git clone https://github.com/NVIDIA/apex
$ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" ./
```
## Try Performance Tests Yourself!
Perf test script is found here!
```
cd contrib/examples/multihead_attn
```
#### Fast Multihead Attention
```
python perf_test_multihead_attn.py --ref
```
#### Fast Multihead Attention with C++ Implementation
```
python perf_test_multihead_attn.py
```
#### Compare with `torch.nn.MultiheadAttn`
```
python perf_test_multihead_attn.py --native
```
#### Test your own range!
```
python perf_test_multihead_attn.py --seq-length 64 --num-seqs-start 10 --num-seqs-stop 120 --num-seqs-inc 5
```
## Performance Comparisons
* Performance was measured with 64 token sequence lengths on an NVIDIA TitanV card.
* Time is measured across multiple layers to simulate an in model scenario.
![Multihead Attention Forward](MHA_fwd.png)
![Multihead Attention Backward](MHA_bwd.png)
from .self_multihead_attn import SelfMultiheadAttn
from .encdec_multihead_attn import EncdecMultiheadAttn
from .mask_softmax_dropout_func import fast_mask_softmax_dropout_func
import math
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from .encdec_multihead_attn_func import encdec_attn_func
from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=True)
out = residual + out
return out
class EncdecMultiheadAttn(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=False, include_norm_add=False, impl="fast"):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.bias = bias
self.include_norm_add = include_norm_add
self.impl = impl
self.scaling = self.head_dim ** -0.5
self.in_proj_weight_q = Parameter(torch.empty(embed_dim, embed_dim))
self.in_proj_weight_kv = Parameter(torch.empty(2 * embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))
if self.bias:
assert impl != "fast", "ERROR! The Fast implementation does not support biases!"
self.in_proj_bias_q = Parameter(torch.empty(embed_dim))
self.in_proj_bias_kv = Parameter(torch.empty(2 * embed_dim))
self.out_proj_bias = Parameter(torch.empty(embed_dim))
else:
self.register_parameter("in_proj_bias_q", None)
self.register_parameter("in_proj_bias_kv", None)
self.in_proj_bias_q = None
self.in_proj_bias_kv = None
self.out_proj_bias = None
if self.include_norm_add:
if impl == "fast":
self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim))
self.lyr_nrm = None
else:
self.register_parameter("lyr_norm_gamma_weights", None)
self.register_parameter("lyr_norm_beta_weights", None)
self.lyr_nrm_gamma_weights = None
self.lyr_nrm_beta_weights = None
self.lyr_nrm = FusedLayerNorm(embed_dim)
self.reset_parameters()
if self.include_norm_add:
if impl == "fast":
self.attn_func = fast_encdec_attn_norm_add_func
elif impl == "default":
self.attn_func = encdec_attn_func
else:
assert False, "Unsupported impl: {} !".format(impl)
else:
if impl == "fast":
self.attn_func = fast_encdec_attn_func
elif impl == "default":
self.attn_func = encdec_attn_func
else:
assert False, "Unsupported impl: {} !".format(impl)
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight_q)
# in_proj_weight_kv has shape [2 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)
# therefore xavier_uniform gain should be set to sqrt(1.5).
nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
nn.init.constant_(self.in_proj_bias_q, 0.0)
nn.init.constant_(self.in_proj_bias_kv, 0.0)
nn.init.constant_(self.out_proj_bias, 0.0)
if self.include_norm_add:
if self.impl == "fast":
nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights)
else:
self.lyr_nrm.reset_parameters()
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
if key_padding_mask is not None:
assert attn_mask is None, "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask
elif attn_mask is not None:
mask = attn_mask
else:
mask = None
if self.include_norm_add:
if self.impl == "fast":
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
query,
key,
self.lyr_nrm_gamma_weights,
self.lyr_nrm_beta_weights,
self.in_proj_weight_q,
self.in_proj_weight_kv,
self.out_proj_weight,
mask,
self.dropout,
)
else:
lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
self.scaling,
lyr_nrm_results,
key,
self.in_proj_weight_q,
self.in_proj_weight_kv,
self.out_proj_weight,
self.in_proj_bias_q,
self.in_proj_bias_kv,
self.out_proj_bias,
mask,
self.dropout,
)
if is_training:
print('default:', outputs)
outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else:
outputs = outputs + query
else:
if self.impl == "fast":
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
query,
key,
self.in_proj_weight_q,
self.in_proj_weight_kv,
self.out_proj_weight,
mask,
self.dropout,
)
else:
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
self.scaling,
query,
key,
self.in_proj_weight_q,
self.in_proj_weight_kv,
self.out_proj_weight,
self.in_proj_bias_q,
self.in_proj_bias_kv,
self.out_proj_bias,
mask,
self.dropout,
)
return outputs, None
import torch
import torch.nn.functional as F
class EncdecAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
use_time_mask,
is_training,
heads,
scale,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
input_biases_q,
input_biases_kv,
output_biases,
mask,
dropout_prob,
):
use_biases_t = torch.tensor([input_biases_q is not None])
heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
head_dim = inputs_q.size(2) // heads
# Input Linear GEMM Q
# input1: (activations) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim (1024), embed_dim (1024)] (transpose [0,1])
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
if use_biases_t[0]:
input_lin_q_results = torch.addmm(
input_biases_q,
inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
input_weights_q.transpose(0, 1),
beta=1.0,
alpha=1.0,
)
else:
input_lin_q_results = torch.mm(
inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0, 1)
)
input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0))
# Input Linear GEMM KV
# input1: (activations) [seql_k, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1])
# output: [seql_k, seqs, embed_dim*2]
# GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2)
if use_biases_t[0]:
input_lin_kv_results = torch.addmm(
input_biases_kv,
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
input_weights_kv.transpose(0, 1),
beta=1.0,
alpha=1.0,
)
else:
input_lin_kv_results = torch.mm(
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
input_weights_kv.transpose(0, 1),
)
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0))
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
# input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]
queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1) * heads, head_dim)
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1) * heads, 2, head_dim)
keys = input_lin_kv_results[:, :, 0, :]
values = input_lin_kv_results[:, :, 1, :]
# Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
# a separate elementwise operation.
# Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)
# Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1)
# output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul1_results = torch.empty(
(queries.size(1), queries.size(0), keys.size(0)), dtype=queries.dtype, device=torch.device("cuda")
)
matmul1_results = torch.baddbmm(
matmul1_results,
queries.transpose(0, 1),
keys.transpose(0, 1).transpose(1, 2),
out=matmul1_results,
beta=0.0,
alpha=scale_t[0],
)
if mask is not None:
# Self Attention Time Mask
if use_time_mask:
assert len(mask.size()) == 2, "Timing mask is not 2D!"
assert mask.size(0) == mask.size(1), "Sequence length should match!"
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float("-inf"))
# Key Padding Mask
else:
batches, seql_q, seql_k = matmul1_results.size()
seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float("-inf"))
matmul1_results = matmul1_results.view(seqs * heads, seql_q, seql_k)
softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference
if is_training:
dropout_results, dropout_mask = torch._fused_dropout(softmax_results, p=(1.0 - dropout_prob_t[0]))
else:
dropout_results = softmax_results
dropout_mask = null_tensor
# Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output.
# Given that pytorch cannot currently perform autograd with an output tensor specified,
# this requires a backward pass specified.
# Input1: from_softmax [seqs*heads, seql_q, seql_k]
# Input2: (values) [seql_v, seqs*heads, head_dim] transpose(0,1)
# Output: [seql_q, seqs*heads, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)
matmul2_results = torch.empty(
(dropout_results.size(1), dropout_results.size(0), values.size(2)),
dtype=dropout_results.dtype,
device=torch.device("cuda"),
).transpose(1, 0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0, 1), out=matmul2_results)
matmul2_results = (
matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
)
# Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1)
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0]:
outputs = torch.addmm(
output_biases,
matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
output_weights.transpose(0, 1),
beta=1.0,
alpha=1.0,
)
else:
outputs = torch.mm(
matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
output_weights.transpose(0, 1),
)
outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0))
ctx.save_for_backward(
use_biases_t,
heads_t,
scale_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob_t,
)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
(
use_biases_t,
heads_t,
scale_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob_t,
) = ctx.saved_tensors
head_dim = inputs_q.size(2) // heads_t[0]
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
# input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]
queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1) * heads_t[0], head_dim)
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1) * heads_t[0], 2, head_dim)
keys = input_lin_kv_results[:, :, 0, :]
values = input_lin_kv_results[:, :, 1, :]
# Slice out k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads.
input_lin_kv_results_grads = torch.empty_like(input_lin_kv_results)
queries_grads = torch.empty_like(queries)
keys_grads = input_lin_kv_results_grads[:, :, 0, :]
values_grads = input_lin_kv_results_grads[:, :, 1, :]
# Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
output_lin_grads = torch.mm(
output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights
)
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))
# Output Linear GEMM - WGRAD
# Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)
# Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(
output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0, 1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)),
)
output_lin_grads = output_lin_grads.view(
output_grads.size(0), output_grads.size(1) * heads_t[0], head_dim
).transpose(0, 1)
if use_biases_t[0]:
output_bias_grads = torch.sum(
output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0
)
else:
output_bias_grads = None
# Matmul2 - DGRAD1
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2))
# Matmul2 - DGRAD2
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
values_grads = torch.bmm(dropout_results.transpose(1, 2), output_lin_grads, out=values_grads.transpose(0, 1))
# Mask and Scaling for Dropout (not a publically documented op)
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))
# Softmax Grad (not a publically documented op)
### softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) # og
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, torch.float32, grad_input=softmax_results)
# Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(
queries_grads.transpose(0, 1),
softmax_grads,
keys.transpose(0, 1),
out=queries_grads.transpose(0, 1),
beta=0.0,
alpha=scale_t[0],
)
# Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(
keys_grads.transpose(0, 1),
softmax_grads.transpose(1, 2),
queries.transpose(0, 1),
out=keys_grads.transpose(0, 1),
beta=0.0,
alpha=scale_t[0],
)
# Input Q Linear GEMM - DGRAD
# input1: (data grads) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim (1024), embed_dim (1024)]
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
queries_grads = queries_grads.transpose(0, 1).view(inputs_q.size(0) * inputs_q.size(1), heads_t[0] * head_dim)
input_q_grads = torch.mm(queries_grads, input_weights_q)
input_q_grads = input_q_grads.view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
# Input KV Linear GEMM - DGRAD
# input1: (data grads) [seql_k, seqs, 2*embed_dim(2048)]
# input2: (weights) [embed_dim*2 (2048), embed_dim (1024)]
# output: [seql_k, seqs, embed_dim]
# GEMM: ( (seql_k*seqs) x 2*embed_dim ) x ( 2*embed_dim x embed_dim ) = (seql_k*seqs x embed_dim)
input_lin_kv_results_grads = input_lin_kv_results_grads.view(
inputs_kv.size(0) * inputs_kv.size(1), heads_t[0] * 2 * head_dim
)
input_kv_grads = torch.mm(input_lin_kv_results_grads, input_weights_kv)
input_kv_grads = input_kv_grads.view(inputs_kv.size(0), inputs_kv.size(1), inputs_kv.size(2))
# Input Q Linear GEMM - WGRAD
# input1: (data grads) [seql_q*seqs, embed_dim(1024)]
# input2: (activations) [seql_q*seqs, embed_dim(1024)]
# output: [embed_dim, embed_dim]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (embed_dim x embed_dim)
input_weight_q_grads = torch.mm(
queries_grads.transpose(0, 1), inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2))
)
# Input KV Linear GEMM - WGRAD
# input1: (data grads) [seql_k*seqs, 2*embed_dim(2048)]
# input2: (activations) [seql_k*seqs, embed_dim(1024)]
# output: [2*embed_dim, embed_dim]
# GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim)
input_weight_kv_grads = torch.mm(
input_lin_kv_results_grads.transpose(0, 1),
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
)
if use_biases_t[0]:
input_bias_grads_q = torch.sum(queries_grads, 0)
input_bias_grads_kv = torch.sum(input_lin_kv_results_grads, 0)
else:
input_bias_grads_q = None
input_bias_grads_kv = None
return (
None,
None,
None,
None,
input_q_grads,
input_kv_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads,
input_bias_grads_q,
input_bias_grads_kv,
output_bias_grads,
None,
None,
)
encdec_attn_func = EncdecAttnFunc.apply
import torch
import fast_multihead_attn
class FastEncdecAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
pad_mask,
dropout_prob,
):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = pad_mask is not None
(
input_lin_q_results,
input_lin_kv_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs,
) = fast_multihead_attn.encdec_multihead_attn_forward(
use_mask,
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
ctx.save_for_backward(
heads_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob_t,
)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
(
heads_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob_t,
) = ctx.saved_tensors
(
input_q_grads,
input_kv_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads,
) = fast_multihead_attn.encdec_multihead_attn_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob_t[0],
)
return (
None,
None,
None,
input_q_grads,
input_kv_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads,
None,
None,
)
fast_encdec_attn_func = FastEncdecAttnFunc.apply
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment