Unverified Commit e96b76b0 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

Merge pull request #1 from hpcaitech/refactor_kernel

refactor kernel implementation
parents 1c0a3d39 f206de08
......@@ -79,10 +79,6 @@ If you want to benchmark with [OpenFold](https://github.com/aqlaboratory/openfol
torchrun --nproc_per_node=1 perf.py --msa-length 128 --res-length 256 --openfold
```
## Acknowledge
The CUDA implementations of the LayerNorm and Softmax are modified from [OneFlow](https://github.com/Oneflow-Inc/oneflow). Thanks to OneFlow for the high performance CUDA implementation, we mainly add support of Bfloat16 precision.
## Cite us
Cite this paper, if you use FastFold in your research publication.
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/*
This code is modeified from https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/layer_norm.cuh
*/
#ifndef FASTFOLD_LAYER_NORM_H_
#define FASTFOLD_LAYER_NORM_H_
#include <assert.h>
#include <math_constants.h>
#include <cub/cub.cuh>
namespace fastfold {
namespace layer_norm {
constexpr int kWarpSize = 32;
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; }
};
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); }
};
template <template <typename> class ReductionOp, typename T, int thread_group_width = kWarpSize>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <template <typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
if (threadIdx.x == 0) {
result_broadcast = result;
}
__syncthreads();
return result_broadcast;
}
template <typename T>
__inline__ __device__ T Div(T a, T b);
template <>
__inline__ __device__ at::BFloat16 Div<at::BFloat16>(at::BFloat16 a, at::BFloat16 b) {
return a / b;
}
template <>
__inline__ __device__ float Div<float>(float a, float b) {
return __fdividef(a, b);
}
template <>
__inline__ __device__ double Div<double>(double a, double b) {
return a / b;
}
template <typename T>
__inline__ __device__ T Rsqrt(T x);
template <>
__inline__ __device__ at::BFloat16 Rsqrt<at::BFloat16>(at::BFloat16 x) {
return rsqrt(x);
}
template <>
__inline__ __device__ float Rsqrt<float>(float x) {
return rsqrt(x);
}
template <>
__inline__ __device__ double Rsqrt<double>(double x) {
return rsqrt(x);
}
template <class Func>
inline cudaError_t GetNumBlocks(Func func, int64_t block_size, size_t dynamic_smem_size,
int64_t max_blocks, int64_t waves, int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
}
int sm_count;
{
cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
}
int max_active_blocks;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, func, block_size, dynamic_smem_size);
}
*num_blocks =
std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * max_active_blocks * waves));
return cudaSuccess;
}
template <typename T>
struct DefaultComputeType {
using type = T;
};
template <>
struct DefaultComputeType<half> {
using type = float;
};
#if CUDA_VERSION >= 11000
template <>
struct DefaultComputeType<nv_bfloat16> {
using type = float;
};
#endif // CUDA_VERSION >= 11000
template <typename T, int N>
struct GetPackType {
using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type;
};
template <typename T, int N>
using PackType = typename GetPackType<T, N>::type;
template <typename T, int N>
union Pack {
static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, "");
__device__ Pack() {
// do nothing
}
PackType<T, N> storage;
T elem[N];
};
template <typename SRC, typename DST>
struct DirectLoad {
DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {}
template <int N>
__device__ void load(DST* dst, int64_t row, int64_t col) const {
Pack<SRC, N> pack;
const int64_t offset = (row * row_size + col) / N;
pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(src) + offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = static_cast<DST>(pack.elem[i]);
}
}
const SRC* src;
int64_t row_size;
};
template <typename SRC, typename DST>
struct DirectStore {
DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
Pack<DST, N> pack;
const int64_t offset = (row * row_size + col) / N;
#pragma unroll
for (int i = 0; i < N; ++i) {
pack.elem[i] = static_cast<DST>(src[i]);
}
*(reinterpret_cast<PackType<DST, N>*>(dst) + offset) = pack.storage;
}
DST* dst;
int64_t row_size;
};
template <typename SRC, typename DST, bool do_scale, bool do_center>
struct AffineStore {
AffineStore(DST* normalized, DST* y, int64_t row_size, const DST* gamma, const DST* beta)
: normalized(normalized), y(y), row_size(row_size), gamma(gamma), beta(beta) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
Pack<DST, N> y_pack;
Pack<DST, N> normalized_pack;
Pack<DST, N> gamma_pack;
Pack<DST, N> beta_pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t gamma_offset = col / N;
if (do_scale) {
gamma_pack.storage = *(reinterpret_cast<const PackType<DST, N>*>(gamma) + gamma_offset);
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
gamma_pack.elem[i] = 1;
}
}
if (do_center) {
beta_pack.storage = *(reinterpret_cast<const PackType<DST, N>*>(beta) + gamma_offset);
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
beta_pack.elem[i] = 0;
}
}
#pragma unroll
for (int i = 0; i < N; ++i) {
DST normalized_i = static_cast<DST>(src[i]);
if (do_scale) {
normalized_pack.elem[i] = normalized_i;
}
if (do_scale || do_center) {
y_pack.elem[i] = normalized_i * gamma_pack.elem[i] + beta_pack.elem[i];
} else {
y_pack.elem[i] = normalized_i;
}
}
*(reinterpret_cast<PackType<DST, N>*>(y) + offset) = y_pack.storage;
if (do_scale) {
*(reinterpret_cast<PackType<DST, N>*>(normalized) + offset) = normalized_pack.storage;
}
}
DST* normalized;
DST* y;
int64_t row_size;
const DST* gamma;
const DST* beta;
};
template <typename SRC, typename DST, bool do_scale>
struct ScaleLoad {
ScaleLoad(const SRC* src, const SRC* gamma, int64_t row_size)
: src(src), gamma(gamma), row_size(row_size) {}
template <int N>
__device__ void load(DST* dst, int64_t row, int64_t col) const {
Pack<SRC, N> src_pack;
Pack<SRC, N> gamma_pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t gamma_offset = col / N;
src_pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(src) + offset);
if (do_scale) {
gamma_pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(gamma) + gamma_offset);
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
gamma_pack.elem[i] = static_cast<SRC>(1);
}
}
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = static_cast<DST>(src_pack.elem[i] * gamma_pack.elem[i]);
}
}
const SRC* src;
const SRC* gamma;
int64_t row_size;
};
template <typename SRC, typename DST, bool do_add>
struct AddStore {
AddStore(const DST* add_to_output, DST* dst, int64_t row_size)
: add_to_output(add_to_output), dst(dst), row_size(row_size) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
Pack<DST, N> add_to_output_pack;
Pack<DST, N> dst_pack;
const int64_t offset = (row * row_size + col) / N;
if (do_add) {
add_to_output_pack.storage =
*(reinterpret_cast<const PackType<DST, N>*>(add_to_output) + offset);
}
#pragma unroll
for (int i = 0; i < N; ++i) {
if (do_add) {
dst_pack.elem[i] = static_cast<DST>(src[i]) + add_to_output_pack.elem[i];
} else {
dst_pack.elem[i] = static_cast<DST>(src[i]);
}
}
*(reinterpret_cast<PackType<DST, N>*>(dst) + offset) = dst_pack.storage;
}
const DST* add_to_output;
DST* dst;
int64_t row_size;
};
template <typename T>
inline __device__ void WelfordCombine(T val, T* mean, T* m2, T* count) {
// Use Welford Online algorithem to compute mean and variance
// For more details you can refer to:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
*count += 1;
T delta1 = val - *mean;
*mean += Div(delta1, *count);
T delta2 = val - *mean;
*m2 += delta1 * delta2;
}
template <typename T>
inline __device__ void WelfordCombine(T b_mean, T b_m2, T b_count, T* mean, T* m2, T* count) {
if (b_count == 0) {
return;
}
T new_count = *count + b_count;
T nb_over_n = Div(b_count, new_count);
T delta = b_mean - *mean;
*mean += delta * nb_over_n;
*m2 += b_m2 + delta * delta * (*count) * nb_over_n;
*count = new_count;
}
template <typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(T thread_mean, T thread_m2, T thread_count, T* mean,
T* m2, T* count) {
*mean = thread_mean;
*m2 = thread_m2;
*count = thread_count;
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
T b_mean = __shfl_down_sync(0xffffffff, *mean, mask);
T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask);
T b_count = __shfl_down_sync(0xffffffff, *count, mask);
WelfordCombine(b_mean, b_m2, b_count, mean, m2, count);
}
}
template <typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(T thread_mean, T thread_m2, T thread_count, T* mean,
T* m2, T* count) {
WelfordWarpReduce<T, thread_group_width>(thread_mean, thread_m2, thread_count, mean, m2, count);
*mean = __shfl_sync(0xffffffff, *mean, 0, thread_group_width);
*m2 = __shfl_sync(0xffffffff, *m2, 0, thread_group_width);
*count = __shfl_sync(0xffffffff, *count, 0, thread_group_width);
}
template <typename T>
__inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T thread_count,
T* result_mean, T* result_m2, T* result_count) {
__shared__ T mean_shared[kWarpSize];
__shared__ T m2_shared[kWarpSize];
__shared__ T count_shared[kWarpSize];
__shared__ T mean_result_broadcast;
__shared__ T m2_result_broadcast;
__shared__ T count_result_broadcast;
const int lid = threadIdx.x % kWarpSize;
const int wid = threadIdx.x / kWarpSize;
T warp_mean = 0;
T warp_m2 = 0;
T warp_count = 0;
WelfordWarpReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count);
__syncthreads();
if (lid == 0) {
mean_shared[wid] = warp_mean;
m2_shared[wid] = warp_m2;
count_shared[wid] = warp_count;
}
__syncthreads();
if (wid == 0) {
if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid];
warp_m2 = m2_shared[lid];
warp_count = count_shared[lid];
} else {
warp_mean = static_cast<T>(0);
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
__syncwarp();
T block_mean = 0;
T block_m2 = 0;
T block_count = 0;
WelfordWarpReduce(warp_mean, warp_m2, warp_count, &block_mean, &block_m2, &block_count);
if (lid == 0) {
mean_result_broadcast = block_mean;
m2_result_broadcast = block_m2;
count_result_broadcast = block_count;
}
}
__syncthreads();
*result_mean = mean_result_broadcast;
*result_m2 = m2_result_broadcast;
*result_count = count_result_broadcast;
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, bool padding>
__global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
static_assert(cols_per_thread % pack_size == 0, "");
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
constexpr int num_packs = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * thread_group_width);
ComputeType buf[rows_per_access][cols_per_thread];
const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int64_t num_global_thread_group = gridDim.x * blockDim.y;
const int64_t lane_id = threadIdx.x;
const int64_t step = num_global_thread_group * rows_per_access;
for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {
ComputeType thread_mean[rows_per_access];
ComputeType thread_m2[rows_per_access];
ComputeType thread_count[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
thread_mean[row_id] = 0;
thread_m2[row_id] = 0;
thread_count[row_id] = 0;
ComputeType* row_buf = buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < num_packs; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
const int pack_offset = pack_id * pack_size;
if (!padding || col < cols) {
load.template load<pack_size>(row_buf + pack_offset, row + row_id, col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id,
thread_m2 + row_id, thread_count + row_id);
}
} else {
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
row_buf[pack_offset + i] = 0;
}
}
}
}
ComputeType warp_mean[rows_per_access];
ComputeType warp_m2[rows_per_access];
ComputeType warp_count[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
int global_row_id = row + row_id;
ComputeType* row_buf = buf[row_id];
WelfordWarpAllReduce<ComputeType, thread_group_width>(
thread_mean[row_id], thread_m2[row_id], thread_count[row_id], warp_mean + row_id,
warp_m2 + row_id, warp_count + row_id);
ComputeType row_mean = warp_mean[row_id];
ComputeType row_variance =
max(Div(warp_m2[row_id], warp_count[row_id]), static_cast<ComputeType>(0.0));
ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));
if (lane_id == 0) {
mean[global_row_id] = row_mean;
inv_variance[global_row_id] = row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_buf[i] = (row_buf[i] - row_mean) * row_inv_var;
}
#pragma unroll
for (int i = 0; i < num_packs; ++i) {
const int col = (i * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
store.template store<pack_size>(row_buf + i * pack_size, global_row_id, col);
}
}
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, bool padding>
inline cudaError_t LaunchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
const int64_t num_blocks =
(rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
cudaError_t err =
GetNumBlocks(LayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, padding>,
block_size, 0, num_blocks, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread, thread_group_width,
rows_per_access, padding><<<grid_dim_x, block_dim, 0, stream>>>(
load, store, rows, cols, epsilon, mean, inv_variance);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access>
inline cudaError_t DispatchLayerNormWarpImplPadding(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
if (cols == cols_per_thread * thread_group_width) {
return LaunchLayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, false>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else {
return LaunchLayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, true>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size>
typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchLayerNormWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, \
pack_size, thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, \
pack_size, thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, \
kWarpSize, 1>(stream, load, store, rows, cols, \
epsilon, mean, inv_variance); \
}
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(3)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(5)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(7)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(9)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(11)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(13)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(15)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(17)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(19)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(21)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(23)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(25)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(27)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(29)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(31)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size>
typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchLayerNormWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, \
pack_size, thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, \
pack_size, thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, \
kWarpSize, 1>(stream, load, store, rows, cols, \
epsilon, mean, inv_variance); \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size>
typename std::enable_if<pack_size == 4, cudaError_t>::type DispatchLayerNormWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, \
pack_size, thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, \
pack_size, thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, \
kWarpSize, 1>(stream, load, store, rows, cols, \
epsilon, mean, inv_variance); \
}
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD, typename STORE, typename ComputeType>
struct DispatchLayerNormWarpImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
if (cols % 4 == 0) {
return DispatchLayerNormWarpImplCols<LOAD, STORE, ComputeType, 4>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else if (cols % 2 == 0) {
return DispatchLayerNormWarpImplCols<LOAD, STORE, ComputeType, 2>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else {
return DispatchLayerNormWarpImplCols<LOAD, STORE, ComputeType, 1>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
return DispatchLayerNormWarpImplPackSize<LOAD, STORE, ComputeType>()(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
__global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto* buf = reinterpret_cast<ComputeType*>(shared_buf);
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_mean = 0;
ComputeType thread_m2 = 0;
ComputeType thread_count = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
buf[i * num_packs + pack_id] = pack[i];
WelfordCombine(pack[i], &thread_mean, &thread_m2, &thread_count);
}
}
ComputeType row_mean = 0;
ComputeType row_m2 = 0;
ComputeType row_count = 0;
WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,
&row_count);
ComputeType row_variance = max(Div(row_m2, row_count), static_cast<ComputeType>(0.0));
ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));
if (threadIdx.x == 0) {
mean[row] = row_mean;
inv_variance[row] = row_inv_var;
}
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack[i] = (buf[i * num_packs + pack_id] - row_mean) * row_inv_var;
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
inline cudaError_t LaunchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,
int smem, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err =
GetNumBlocks(LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>,
block_size, smem, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, smem, stream>>>(load, store, rows, cols, epsilon, mean,
inv_variance);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size>
inline cudaError_t TryDispatchLayerNormBlockSMemImplBlockSize(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance, bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem = cols * sizeof(ComputeType);
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_1,
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1>,
block_size_conf_1, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_4,
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4>,
block_size_conf_4, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4>(
stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_3,
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3>,
block_size_conf_3, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3>(
stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_2,
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2>,
block_size_conf_2, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2>(
stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);
}
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1>(
stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);
}
template <typename LOAD, typename STORE, typename ComputeType>
struct TryDispatchLayerNormBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance, bool* success) {
if (cols % 4 == 0) {
return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 4>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, success);
} else if (cols % 2 == 0) {
return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 2>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, success);
} else {
return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 1>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, success);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t TryDispatchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance, bool* success) {
return TryDispatchLayerNormBlockSMemImplPackSize<LOAD, STORE, ComputeType>()(
stream, load, store, rows, cols, epsilon, mean, inv_variance, success);
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
__global__ void LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon,
ComputeType* mean, ComputeType* inv_variance) {
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_mean = 0;
ComputeType thread_m2 = 0;
ComputeType thread_count = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
WelfordCombine(pack[i], &thread_mean, &thread_m2, &thread_count);
}
}
ComputeType row_mean = 0;
ComputeType row_m2 = 0;
ComputeType row_count = 0;
WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,
&row_count);
ComputeType row_variance = max(Div(row_m2, row_count), static_cast<ComputeType>(0.0));
ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));
if (threadIdx.x == 0) {
mean[row] = row_mean;
inv_variance[row] = row_inv_var;
}
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
const int pack_offset = pack_id * pack_size;
load.template load<pack_size>(pack, row, pack_offset);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack[i] = (pack[i] - row_mean) * row_inv_var;
}
store.template store<pack_size>(pack, row, pack_offset);
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size>
inline cudaError_t LaunchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(
LayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>, block_size,
0, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, 0, stream>>>(load, store, rows, cols, epsilon, mean,
inv_variance);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType>
struct DispatchLayerNormBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
if (cols % 4 == 0) {
return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 4>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else if (cols % 2 == 0) {
return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 2>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else {
return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 1>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
return DispatchLayerNormBlockUncachedImplPackSize<LOAD, STORE, ComputeType>()(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
if (cols <= 1024) {
return DispatchLayerNormWarpImpl<LOAD, STORE, ComputeType>(stream, load, store, rows, cols,
epsilon, mean, inv_variance);
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err = TryDispatchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType>(
stream, load, store, rows, cols, epsilon, mean, inv_variance,
&dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
return cudaSuccess;
}
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
return DispatchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
/*
LayerNormGrad dx:
normalized = (x - mean) * inv_var
sum_stats1 = sum(scaled_dy)
sum_stats2 = sum(scaled_dy * normalized)
dx = cols * dy - sum_stats1 - normalized * sum_stats2
dx *= inv_var / cols
*/
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int cols_per_thread, int thread_group_width, int rows_per_access,
bool padding>
__global__ void LayerNormGradWarpImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
static_assert(cols_per_thread % pack_size == 0, "");
constexpr int pack_per_thread = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * thread_group_width);
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
ComputeType normalized_buf[rows_per_access][cols_per_thread];
ComputeType dy_buf[rows_per_access][cols_per_thread];
const ComputeType one_over_cols =
static_cast<ComputeType>(1.0) / static_cast<ComputeType>(cols);
const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int64_t num_global_thread_group = gridDim.x * blockDim.y;
const int lane_id = threadIdx.x;
const int64_t step = num_global_thread_group * rows_per_access;
for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {
ComputeType sum_stats1[rows_per_access];
ComputeType sum_stats2[rows_per_access];
ComputeType inv_variance_buf[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
const int global_row_id = row + row_id;
ComputeType mean_val = mean[global_row_id];
inv_variance_buf[row_id] = inv_variance[global_row_id];
sum_stats1[row_id] = 0;
sum_stats2[row_id] = 0;
ComputeType* row_normalized_buf = normalized_buf[row_id];
ComputeType* row_dy_buf = dy_buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
const int pack_offset = pack_id * pack_size;
if (!padding || col < cols) {
load_x.template load<pack_size>(row_normalized_buf + pack_offset, global_row_id,
col);
load_scaled_dy.template load<pack_size>(row_dy_buf + pack_offset, global_row_id,
col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
const int col_id = pack_offset + i;
// row_normalized_buf store x
row_normalized_buf[col_id] =
(row_normalized_buf[col_id] - mean_val) * inv_variance_buf[row_id];
sum_stats1[row_id] += row_dy_buf[col_id];
sum_stats2[row_id] += row_dy_buf[col_id] * row_normalized_buf[col_id];
}
}
}
}
ComputeType warp_sum_stats1[rows_per_access];
ComputeType warp_sum_stats2[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
warp_sum_stats1[row_id] =
WarpAllReduce<SumOp, ComputeType, thread_group_width>(sum_stats1[row_id]);
warp_sum_stats2[row_id] =
WarpAllReduce<SumOp, ComputeType, thread_group_width>(sum_stats2[row_id]);
}
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
const int global_row_id = row + row_id;
ComputeType* row_normalized_buf = normalized_buf[row_id];
ComputeType* row_dy_buf = dy_buf[row_id];
const ComputeType inv_variance_over_cols = inv_variance_buf[row_id] * one_over_cols;
#pragma unroll
for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
for (int i = 0; i < pack_size; ++i) {
const int col_id = pack_id * pack_size + i;
row_dy_buf[col_id] =
(cols * row_dy_buf[col_id] - warp_sum_stats1[row_id] -
row_normalized_buf[col_id] * warp_sum_stats2[row_id]) *
inv_variance_over_cols;
}
store.template store<pack_size>(row_dy_buf + pack_id * pack_size, global_row_id,
col);
}
}
}
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int cols_per_thread, int thread_group_width, int rows_per_access,
bool padding>
inline cudaError_t LaunchLayerNormGradWarpImpl(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
const int64_t num_blocks =
(rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(
LayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access, padding>,
block_size, 0, num_blocks, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, padding>
<<<grid_dim_x, block_dim, 0, stream>>>(load_x, load_scaled_dy, store, mean, inv_variance,
rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int cols_per_thread, int thread_group_width, int rows_per_access>
inline cudaError_t DispatchLayerNormGradWarpImplPadding(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
if (cols == cols_per_thread * thread_group_width) {
return LaunchLayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access,
false>(stream, load_x, load_scaled_dy, store, mean,
inv_variance, rows, cols);
} else {
return LaunchLayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access,
true>(stream, load_x, load_scaled_dy, store, mean,
inv_variance, rows, cols);
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchLayerNormGradWarpImplCols(
cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, \
ComputeType, pack_size, pack_size, \
thread_group_width, 2>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} else { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, \
ComputeType, pack_size, pack_size, \
thread_group_width, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, col, kWarpSize, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
}
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(3)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(5)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(7)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(9)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(11)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(13)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(15)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(17)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(19)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(21)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(23)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(25)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(27)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(29)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(31)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchLayerNormGradWarpImplCols(
cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, \
ComputeType, pack_size, pack_size, \
thread_group_width, 2>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} else { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, \
ComputeType, pack_size, pack_size, \
thread_group_width, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, col, kWarpSize, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
struct DispatchLayerNormGradWarpImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
if (cols % 2 == 0) {
return DispatchLayerNormGradWarpImplCols<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, 2>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
} else {
return DispatchLayerNormGradWarpImplCols<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, 1>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
}
};
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormGradWarpImpl(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
return DispatchLayerNormGradWarpImplPackSize<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>()(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int block_size>
__global__ void LayerNormGradBlockSMemImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean,
const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
extern __shared__ __align__(sizeof(double)) unsigned char grad_shared_buf[];
auto* normalized_buf = reinterpret_cast<ComputeType*>(grad_shared_buf);
auto* dy_buf = normalized_buf + cols;
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;
const ComputeType one_over_cols =
static_cast<ComputeType>(1.0) / static_cast<ComputeType>(cols);
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType sum_stats1 = 0;
ComputeType sum_stats2 = 0;
const ComputeType mean_val = mean[row];
const ComputeType inv_variance_val = inv_variance[row];
const ComputeType inv_variance_over_cols = inv_variance_val * one_over_cols;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType x_pack[pack_size];
ComputeType dy_pack[pack_size];
load_x.template load<pack_size>(x_pack, row, pack_id * pack_size);
load_scaled_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
const int buf_offset = i * num_packs + pack_id;
ComputeType normalized = (x_pack[i] - mean_val) * inv_variance_val;
normalized_buf[buf_offset] = normalized;
dy_buf[buf_offset] = dy_pack[i];
sum_stats1 += dy_pack[i];
sum_stats2 += dy_pack[i] * normalized;
}
}
const ComputeType row_sum_stats1 =
BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats1);
const ComputeType row_sum_stats2 =
BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats2);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
const int buf_offset = i * num_packs + pack_id;
pack[i] = (cols * dy_buf[buf_offset] - row_sum_stats1 -
normalized_buf[buf_offset] * row_sum_stats2) *
inv_variance_over_cols;
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int block_size>
inline cudaError_t LaunchLayerNormGradBlockSMemImpl(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance, int smem,
const int64_t rows, const int64_t cols) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err =
GetNumBlocks(LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size>,
block_size, smem, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, smem, stream>>>(load_x, load_scaled_dy, store, mean,
inv_variance, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
inline cudaError_t TryDispatchLayerNormGradBlockSMemImplBlockSize(
cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,
const int64_t cols, bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem = cols * sizeof(ComputeType) * 2;
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_1,
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_1>,
block_size_conf_1, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_4,
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_4>,
block_size_conf_4, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {
*success = true;
return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size_conf_4>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_3,
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_3>,
block_size_conf_3, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {
*success = true;
return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size_conf_3>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_2,
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_2>,
block_size_conf_2, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
*success = true;
return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size_conf_2>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);
}
*success = true;
return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_1>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
struct TryDispatchLayerNormGradBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols, bool* success) {
if (cols % 2 == 0) {
return TryDispatchLayerNormGradBlockSMemImplBlockSize<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType, 2>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success);
} else {
return TryDispatchLayerNormGradBlockSMemImplBlockSize<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType, 1>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success);
}
}
};
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline cudaError_t TryDispatchLayerNormGradBlockSMemImpl(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance,
const int64_t rows, const int64_t cols,
bool* success) {
return TryDispatchLayerNormGradBlockSMemImplPackSize<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType>()(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success);
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int block_size>
__global__ void LayerNormGradBlockUncachedImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean,
const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;
const ComputeType one_over_cols =
static_cast<ComputeType>(1.0) / static_cast<ComputeType>(cols);
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
const ComputeType mean_val = mean[row];
const ComputeType inv_variance_val = inv_variance[row];
const ComputeType inv_variance_over_cols = inv_variance_val * one_over_cols;
ComputeType sum_stats1 = 0;
ComputeType sum_stats2 = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType x_pack[pack_size];
ComputeType dy_pack[pack_size];
load_x.template load<pack_size>(x_pack, row, pack_id * pack_size);
load_scaled_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
sum_stats1 += dy_pack[i];
sum_stats2 += dy_pack[i] * (x_pack[i] - mean_val) * inv_variance_val;
}
}
const ComputeType row_sum_stats1 =
BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats1);
const ComputeType row_sum_stats2 =
BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats2);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType x_pack[pack_size];
ComputeType dy_pack[pack_size];
load_x.template load<pack_size>(x_pack, row, pack_id * pack_size);
load_scaled_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
dy_pack[i] = (cols * dy_pack[i] - row_sum_stats1 -
(x_pack[i] - mean_val) * inv_variance_val * row_sum_stats2) *
inv_variance_over_cols;
}
store.template store<pack_size>(dy_pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
inline cudaError_t LaunchLayerNormGradBlockUncachedImpl(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err =
GetNumBlocks(LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size>,
block_size, 0, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size><<<grid_dim_x, block_size, 0, stream>>>(
load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
struct DispatchLayerNormGradBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
if (cols % 2 == 0 && cols > kWarpSize) {
return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
2>(stream, load_x, load_scaled_dy, store,
mean, inv_variance, rows, cols);
} else {
return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
1>(stream, load_x, load_scaled_dy, store,
mean, inv_variance, rows, cols);
}
}
};
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormGradBlockUncachedImpl(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean,
const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
return DispatchLayerNormGradBlockUncachedImplPackSize<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType>()(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNormGrad(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
if (cols <= 1024) {
return DispatchLayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err =
TryDispatchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols,
&dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
return cudaSuccess;
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNormGrad(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
return DispatchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
} // namespace layer_norm
} // namespace fastfold
#endif // FASTFOLD_LAYER_NORM_H_
\ No newline at end of file
// part of code modified from https://github.com/NVIDIA/apex
#include <cooperative_groups.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <THC/THCDeviceUtils.cuh>
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#include "layer_norm.cuh"
#include "type_shim.h"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
......@@ -12,19 +19,175 @@
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
inline __device__ void WelfordOnline(float val, float* mean, float* m2, float* count) {
*count += 1;
float delta1 = val - *mean;
*mean += delta1 / (*count);
float delta2 = val - *mean;
*m2 += delta1 * delta2;
}
inline __device__ void WelfordOnline(float b_mean, float b_m2, float b_count, float* mean,
float* m2, float* count) {
if (b_count == 0) {
return;
}
float new_count = *count + b_count;
float nb_n = b_count / new_count;
float delta = b_mean - *mean;
*mean += delta * nb_n;
*m2 += b_m2 + delta * delta * (*count) * nb_n;
*count = new_count;
}
__inline__ __device__ void WelfordWarpAllReduce(float thread_mean, float thread_m2,
float thread_count, float* mean, float* m2,
float* count) {
*mean = thread_mean;
*m2 = thread_m2;
*count = thread_count;
for (int mask = 1; mask < 32; mask *= 2) {
float b_mean = __shfl_down_sync(0xffffffff, *mean, mask);
float b_m2 = __shfl_down_sync(0xffffffff, *m2, mask);
float b_count = __shfl_down_sync(0xffffffff, *count, mask);
WelfordOnline(b_mean, b_m2, b_count, mean, m2, count);
}
*mean = __shfl_sync(0xffffffff, *mean, 0, 32);
*m2 = __shfl_sync(0xffffffff, *m2, 0, 32);
*count = __shfl_sync(0xffffffff, *count, 0, 32);
}
__global__ void fastfold_layernorm_fp32(float* input, float* output, float* gamma, float* beta,
float* mean, float* invvar, int rows, int cols,
double epsilon) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int lane_id = threadidx_y;
float buf[32];
float thread_mean;
float thread_m2;
float thread_count;
float warp_mean;
float warp_m2;
float warp_count;
float* row_input = input + row_offset * cols;
float* row_output = output + row_offset * cols;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
buf[i] = row_input[lane_id * cols_per_thread + i];
}
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
WelfordOnline(buf[i], &thread_mean, &thread_m2, &thread_count);
}
WelfordWarpAllReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count);
float row_mean = warp_mean;
float row_variance = max(warp_m2 / warp_count, 0.f);
float row_inv_var = rsqrt(row_variance + epsilon);
if (lane_id == 0) {
mean[row_offset] = row_mean;
invvar[row_offset] = row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = (buf[i] - row_mean) * row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
buf[i] * gamma[lane_id * cols_per_thread + i] + beta[lane_id * cols_per_thread + i];
}
}
__global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* output,
at::BFloat16* gamma, at::BFloat16* beta, float* mean,
float* invvar, int rows, int cols, double epsilon) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int lane_id = threadidx_y;
float buf[32];
float thread_mean;
float thread_m2;
float thread_count;
float warp_mean;
float warp_m2;
float warp_count;
at::BFloat16* row_input = input + row_offset * cols;
at::BFloat16* row_output = output + row_offset * cols;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]);
}
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
WelfordOnline(buf[i], &thread_mean, &thread_m2, &thread_count);
}
WelfordWarpAllReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count);
float row_mean = warp_mean;
float row_variance = max(warp_m2 / warp_count, 0.f);
float row_inv_var = rsqrt(row_variance + epsilon);
if (lane_id == 0) {
mean[row_offset] = row_mean;
invvar[row_offset] = row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = (buf[i] - row_mean) * row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(buf[i]) * gamma[lane_id * cols_per_thread + i] +
beta[lane_id * cols_per_thread + i];
}
}
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, at::Tensor* input,
int n1, int n2, at::IntArrayRef normalized_shape, at::Tensor* gamma,
int rows, int cols, at::IntArrayRef normalized_shape, at::Tensor* gamma,
at::Tensor* beta, double epsilon) {
at::Tensor normalized = at::empty_like(*output);
fastfold::layer_norm::DirectLoad<at::BFloat16, float> load((at::BFloat16*)input->data_ptr(),
n2);
fastfold::layer_norm::AffineStore<float, at::BFloat16, true, true> store(
(at::BFloat16*)normalized.data_ptr(), (at::BFloat16*)output->data_ptr(), n2,
(at::BFloat16*)gamma->data_ptr(), (at::BFloat16*)beta->data_ptr());
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::layer_norm::DispatchLayerNorm<decltype(load), decltype(store), float>(
cuda_stream, load, store, n1, n2, epsilon, (float*)mean->data_ptr(),
(float*)invvar->data_ptr());
int grid = rows / 4;
dim3 block(128);
if (output->dtype() == torch::kFloat32) {
fastfold_layernorm_fp32<<<grid, block>>>(
(float*)input->data_ptr(), (float*)output->data_ptr(), (float*)gamma->data_ptr(),
(float*)beta->data_ptr(), (float*)mean->data_ptr(), (float*)invvar->data_ptr(), rows,
cols, epsilon);
} else {
fastfold_layernorm_bfp16<<<grid, block>>>(
(at::BFloat16*)input->data_ptr(), (at::BFloat16*)output->data_ptr(),
(at::BFloat16*)gamma->data_ptr(), (at::BFloat16*)beta->data_ptr(),
(float*)mean->data_ptr(), (float*)invvar->data_ptr(), rows, cols, epsilon);
}
}
template <typename T>
......@@ -208,6 +371,116 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, const U* part_g
}
}
template <typename T, typename U, typename V>
__global__ void cuComputeGradInput(const V* __restrict__ dout, const T* __restrict__ input,
const int n1, const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar, U epsilon, const V* gamma,
T* grad_input) {
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = invvar[i1];
const T* k_input = input + i1 * n2;
const V* k_dout = dout + i1 * n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * gamma[l + k];
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2 * wrt_i] = sum_loss1;
buf[2 * wrt_i + 1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2 * read_i];
sum_loss2 += buf[2 * read_i + 1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2 * threadIdx.x] = sum_loss1;
buf[2 * threadIdx.x + 1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y != 0) {
sum_loss1 = buf[2 * threadIdx.x];
sum_loss2 = buf[2 * threadIdx.x + 1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
}
}
template <typename T, typename U, typename V>
void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, at::Tensor* input, int n1,
int n2, const V* gamma, const V* beta, double epsilon, T* grad_input,
......@@ -236,6 +509,14 @@ void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, at::Te
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>(), part_size, n1, n2,
grad_gamma, grad_beta);
}
// compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32, 4, 1);
int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon), gamma, grad_input);
}
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar,
......@@ -243,20 +524,6 @@ void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, at::Tensor* in
at::Tensor* gamma, at::Tensor* beta, double epsilon,
at::Tensor* grad_input, at::Tensor* grad_gamma,
at::Tensor* grad_beta) {
at::Tensor add_to_output = at::empty_like(*grad_input);
fastfold::layer_norm::DirectLoad<at::BFloat16, float> load_x((at::BFloat16*)input->data_ptr(),
n2);
fastfold::layer_norm::ScaleLoad<at::BFloat16, float, true> load_scaled_dy(
(at::BFloat16*)dout->data_ptr(), (at::BFloat16*)gamma->data_ptr(), n2);
fastfold::layer_norm::AddStore<float, at::BFloat16, true> store(
(at::BFloat16*)add_to_output.data_ptr(), (at::BFloat16*)grad_input->data_ptr(), n2);
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::layer_norm::DispatchLayerNormGrad<decltype(load_x), decltype(load_scaled_dy),
decltype(store), float>(
cuda_stream, load_x, load_scaled_dy, store, (float*)mean->data_ptr(),
(float*)invvar->data_ptr(), n1, n2);
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma->scalar_type(), "cuda_layer_norm_gradient_kernel",
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/*
This code is modeified from https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/softmax.cuh
*/
#ifndef FASTFOLD_SOFTMAX_H_
#define FASTFOLD_SOFTMAX_H_
#include <assert.h>
#include <cuda.h>
#include <math_constants.h>
#include <cub/cub.cuh>
#include "ATen/ATen.h"
#if CUDA_VERSION >= 11000
#include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000
namespace fastfold {
namespace softmax {
constexpr int kWarpSize = 32;
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; }
};
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); }
};
template <template <typename> class ReductionOp, typename T, int thread_group_width = kWarpSize>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <template <typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
if (threadIdx.x == 0) {
result_broadcast = result;
}
__syncthreads();
return result_broadcast;
}
template <typename T>
__inline__ __device__ T Inf();
template <>
__inline__ __device__ at::BFloat16 Inf<at::BFloat16>() {
return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
}
template <>
__inline__ __device__ float Inf<float>() {
return CUDART_INF_F;
}
template <>
__inline__ __device__ double Inf<double>() {
return CUDART_INF;
}
template <typename T>
__inline__ __device__ T Exp(T x);
template <>
__inline__ __device__ at::BFloat16 Exp<at::BFloat16>(at::BFloat16 x) {
return exp(x);
}
template <>
__inline__ __device__ float Exp<float>(float x) {
return __expf(x);
}
template <>
__inline__ __device__ double Exp<double>(double x) {
return exp(x);
}
template <typename T>
__inline__ __device__ T Div(T a, T b);
template <>
__inline__ __device__ at::BFloat16 Div<at::BFloat16>(at::BFloat16 a, at::BFloat16 b) {
return a / b;
}
template <>
__inline__ __device__ float Div<float>(float a, float b) {
return __fdividef(a, b);
}
template <>
__inline__ __device__ double Div<double>(double a, double b) {
return a / b;
}
template <typename T>
__inline__ __device__ T Log(T x);
template <>
__inline__ __device__ at::BFloat16 Log<at::BFloat16>(at::BFloat16 x) {
return log(x);
}
template <>
__inline__ __device__ float Log<float>(float x) {
return __logf(x);
}
template <>
__inline__ __device__ double Log<double>(double x) {
return log(x);
}
inline cudaError_t GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves,
int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
}
int sm_count;
{
cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
}
int tpm;
{
cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
if (err != cudaSuccess) {
return err;
}
}
*num_blocks =
std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * tpm / block_size * waves));
return cudaSuccess;
}
template <typename T>
struct DefaultComputeType {
using type = T;
};
template <>
struct DefaultComputeType<half> {
using type = float;
};
#if CUDA_VERSION >= 11000
template <>
struct DefaultComputeType<nv_bfloat16> {
using type = float;
};
#endif // CUDA_VERSION >= 11000
template <typename T, int N>
struct GetPackType {
using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type;
};
template <typename T, int N>
using PackType = typename GetPackType<T, N>::type;
template <typename T, int N>
union Pack {
static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, "");
__device__ Pack() {
// do nothing
}
PackType<T, N> storage;
T elem[N];
};
template <typename SRC, typename DST>
struct DirectLoad {
DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {}
template <int N>
__device__ void load(DST* dst, int64_t row, int64_t col) const {
Pack<SRC, N> pack;
const int64_t offset = (row * row_size + col) / N;
pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(src) + offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = static_cast<DST>(pack.elem[i]);
}
}
const SRC* src;
int64_t row_size;
};
template <typename SRC, typename DST>
struct DirectStore {
DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
Pack<DST, N> pack;
const int64_t offset = (row * row_size + col) / N;
#pragma unroll
for (int i = 0; i < N; ++i) {
pack.elem[i] = static_cast<DST>(src[i]);
}
*(reinterpret_cast<PackType<DST, N>*>(dst) + offset) = pack.storage;
}
DST* dst;
int64_t row_size;
};
template <typename SRC, typename DST>
struct ScaleMaskLoad {
ScaleMaskLoad(const SRC* src, const SRC* mask, int64_t row_size, int64_t head, SRC scale)
: src(src), mask(mask), row_size(row_size), head(head), scale(scale) {}
template <int N>
__device__ void load(DST* dst, int64_t row, int64_t col) {
softmax::Pack<SRC, N> pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t mask_offset = ((row / (head * row_size)) * row_size + col) / N;
pack.storage = *(reinterpret_cast<const softmax::PackType<SRC, N>*>(src) + offset);
softmax::Pack<SRC, N> mask_pack;
mask_pack.storage =
*(reinterpret_cast<const softmax::PackType<SRC, N>*>(mask) + mask_offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
if (mask_pack.elem[i] == 0) {
dst[i] = static_cast<DST>(c10::BFloat16(0x7F7F, c10::BFloat16::from_bits()));
} else {
dst[i] = static_cast<DST>(pack.elem[i]) * static_cast<DST>(scale);
}
}
}
const SRC* src;
const SRC* mask;
int64_t row_size;
int64_t head;
SRC fill;
SRC scale;
};
template <typename SRC, typename DST>
struct ScaleMaskStore {
ScaleMaskStore(DST* dst, const DST* mask, int64_t row_size, int64_t head, DST scale)
: dst(dst), mask(mask), row_size(row_size), head(head), scale(scale) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
softmax::Pack<DST, N> pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t mask_offset = ((row / (head * row_size)) * row_size + col) / N;
softmax::Pack<DST, N> mask_pack;
mask_pack.storage =
*(reinterpret_cast<const softmax::PackType<DST, N>*>(mask) + mask_offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
if (mask_pack.elem[i] == 0) {
pack.elem[i] = c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
} else {
pack.elem[i] = static_cast<DST>(src[i]) * static_cast<DST>(scale);
}
}
*(reinterpret_cast<softmax::PackType<DST, N>*>(dst) + offset) = pack.storage;
}
DST* dst;
const DST* mask;
int64_t row_size;
int64_t head;
DST fill;
DST scale;
};
template <typename SRC, typename DST>
struct ScaleMaskBiasLoad {
ScaleMaskBiasLoad(const SRC* src, const SRC* mask, const SRC* bias, int64_t row_size,
int64_t head, SRC scale)
: src(src), mask(mask), bias(bias), row_size(row_size), head(head), scale(scale) {}
template <int N>
__device__ void load(DST* dst, int64_t row, int64_t col) {
softmax::Pack<SRC, N> pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t mask_offset = ((row / (head * row_size)) * row_size + col) / N;
const int64_t bias_offset = ((row % (head * row_size)) * row_size + col) / N;
pack.storage = *(reinterpret_cast<const softmax::PackType<SRC, N>*>(src) + offset);
softmax::Pack<SRC, N> mask_pack;
softmax::Pack<SRC, N> bias_pack;
mask_pack.storage =
*(reinterpret_cast<const softmax::PackType<SRC, N>*>(mask) + mask_offset);
bias_pack.storage =
*(reinterpret_cast<const softmax::PackType<SRC, N>*>(bias) + bias_offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
if (mask_pack.elem[i] == 0) {
dst[i] = static_cast<DST>(c10::BFloat16(0x7F7F, c10::BFloat16::from_bits()));
} else {
dst[i] = static_cast<DST>(pack.elem[i]) * static_cast<DST>(scale);
dst[i] += static_cast<DST>(bias_pack.elem[i]);
}
}
}
const SRC* src;
const SRC* mask;
const SRC* bias;
int64_t row_size;
int64_t head;
SRC fill;
SRC scale;
};
enum class Algorithm {
kSoftmax = 0,
kLogSoftmax = 1,
};
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, bool padding, Algorithm algorithm>
__global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols) {
static_assert(cols_per_thread % pack_size == 0, "");
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
constexpr int num_packs = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * thread_group_width);
ComputeType buf[rows_per_access][cols_per_thread];
const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int num_global_thread_group = gridDim.x * blockDim.y;
const int lane_id = threadIdx.x;
const int64_t step = num_global_thread_group * rows_per_access;
for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {
ComputeType thread_max[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
thread_max[row_id] = -Inf<ComputeType>();
ComputeType* row_buf = buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < num_packs; ++pack_id) {
const int pack_offset = pack_id * pack_size;
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
load.template load<pack_size>(row_buf + pack_offset, row + row_id, col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_max[row_id] = max(thread_max[row_id], row_buf[pack_offset + i]);
}
} else {
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
row_buf[pack_offset + i] = -Inf<ComputeType>();
}
}
}
}
ComputeType warp_max[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
warp_max[row_id] =
WarpAllReduce<MaxOp, ComputeType, thread_group_width>(thread_max[row_id]);
}
ComputeType thread_sum[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
thread_sum[row_id] = 0;
ComputeType* row_buf = buf[row_id];
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
if (algorithm == Algorithm::kSoftmax) {
row_buf[i] = Exp(row_buf[i] - warp_max[row_id]);
thread_sum[row_id] += row_buf[i];
} else if (algorithm == Algorithm::kLogSoftmax) {
row_buf[i] -= warp_max[row_id];
thread_sum[row_id] += Exp(row_buf[i]);
} else {
__trap();
}
}
}
ComputeType warp_sum[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
warp_sum[row_id] =
WarpAllReduce<SumOp, ComputeType, thread_group_width>(thread_sum[row_id]);
}
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
ComputeType* row_buf = buf[row_id];
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
if (algorithm == Algorithm::kSoftmax) {
row_buf[i] = Div(row_buf[i], warp_sum[row_id]);
} else if (algorithm == Algorithm::kLogSoftmax) {
row_buf[i] -= Log(warp_sum[row_id]);
} else {
__trap();
}
}
#pragma unroll
for (int i = 0; i < num_packs; ++i) {
const int col = (i * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
store.template store<pack_size>(row_buf + i * pack_size, row + row_id, col);
}
}
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, bool padding, Algorithm algorithm>
inline cudaError_t LaunchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
const int64_t num_blocks =
(rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
SoftmaxWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread, thread_group_width,
rows_per_access, padding, algorithm>
<<<grid_dim_x, block_dim, 0, stream>>>(load, store, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, Algorithm algorithm>
inline cudaError_t DispatchSoftmaxWarpImplPadding(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
if (cols == cols_per_thread * thread_group_width) {
return LaunchSoftmaxWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, false, algorithm>(
stream, load, store, rows, cols);
} else {
return LaunchSoftmaxWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, true, algorithm>(
stream, load, store, rows, cols);
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>
typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchSoftmaxWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 2, algorithm>( \
stream, load, store, rows, cols); \
} else { \
return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 1, algorithm>( \
stream, load, store, rows, cols); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, kWarpSize, \
1, algorithm>(stream, load, store, rows, cols); \
}
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(3)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(5)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(7)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(9)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(11)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(13)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(15)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(17)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(19)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(21)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(23)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(25)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(27)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(29)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(31)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>
typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchSoftmaxWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 2, algorithm>( \
stream, load, store, rows, cols); \
} else { \
return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 1, algorithm>( \
stream, load, store, rows, cols); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, kWarpSize, \
1, algorithm>(stream, load, store, rows, cols); \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
struct DispatchSoftmaxWarpImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols % 2 == 0) {
return DispatchSoftmaxWarpImplCols<LOAD, STORE, ComputeType, 2, algorithm>(
stream, load, store, rows, cols);
} else {
return DispatchSoftmaxWarpImplCols<LOAD, STORE, ComputeType, 1, algorithm>(
stream, load, store, rows, cols);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
inline cudaError_t DispatchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxWarpImplPackSize<LOAD, STORE, ComputeType, algorithm>()(
stream, load, store, rows, cols);
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,
Algorithm algorithm>
__global__ void SoftmaxBlockSMemImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto* buf = reinterpret_cast<ComputeType*>(shared_buf);
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = cols / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_max = -Inf<ComputeType>();
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
buf[i * num_packs + pack_id] = pack[i];
thread_max = max(thread_max, pack[i]);
}
}
const ComputeType row_max = BlockAllReduce<MaxOp, ComputeType, block_size>(thread_max);
ComputeType thread_sum = 0;
for (int col = tid; col < cols; col += block_size) {
if (algorithm == Algorithm::kSoftmax) {
const ComputeType exp_x = Exp(buf[col] - row_max);
buf[col] = exp_x;
thread_sum += exp_x;
} else {
const ComputeType x = buf[col] - row_max;
buf[col] = x;
thread_sum += Exp(x);
}
}
const ComputeType row_sum = BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
pack[i] = Div(buf[i * num_packs + pack_id], row_sum);
} else if (algorithm == Algorithm::kLogSoftmax) {
pack[i] = buf[i * num_packs + pack_id] - Log(row_sum);
} else {
__trap();
}
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,
Algorithm algorithm>
inline cudaError_t LaunchSoftmaxBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, int smem,
const int64_t rows, const int64_t cols) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size, algorithm>
<<<grid_dim_x, block_size, smem, stream>>>(load, store, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>
inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream, LOAD load,
STORE store, const int64_t rows,
const int64_t cols, bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem = cols * sizeof(ComputeType);
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_1,
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1, algorithm>,
block_size_conf_1, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_4,
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4, algorithm>,
block_size_conf_4, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {
*success = true;
return LaunchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4,
algorithm>(stream, load, store, smem, rows, cols);
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_3,
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3, algorithm>,
block_size_conf_3, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {
*success = true;
return LaunchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3,
algorithm>(stream, load, store, smem, rows, cols);
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_2,
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2, algorithm>,
block_size_conf_2, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
*success = true;
return LaunchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2,
algorithm>(stream, load, store, smem, rows, cols);
}
*success = true;
return LaunchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1,
algorithm>(stream, load, store, smem, rows, cols);
}
template <typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
struct TryDispatchSoftmaxBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, bool* success) {
if (cols % 2 == 0) {
return TryDispatchSoftmaxBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 2, algorithm>(
stream, load, store, rows, cols, success);
} else {
return TryDispatchSoftmaxBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 1, algorithm>(
stream, load, store, rows, cols, success);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
inline cudaError_t TryDispatchSoftmaxBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
bool* success) {
return TryDispatchSoftmaxBlockSMemImplPackSize<LOAD, STORE, ComputeType, algorithm>()(
stream, load, store, rows, cols, success);
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,
Algorithm algorithm>
__global__ void SoftmaxBlockUncachedImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = cols / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_max = -Inf<ComputeType>();
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_max = max(thread_max, pack[i]);
}
}
const ComputeType row_max = BlockAllReduce<MaxOp, ComputeType, block_size>(thread_max);
ComputeType thread_sum = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_sum += Exp(pack[i] - row_max);
}
}
const ComputeType row_sum = BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
pack[i] = Div(Exp(pack[i] - row_max), row_sum);
} else if (algorithm == Algorithm::kLogSoftmax) {
pack[i] = (pack[i] - row_max) - Log(row_sum);
} else {
__trap();
}
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>
inline cudaError_t LaunchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
SoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size, algorithm>
<<<grid_dim_x, block_size, 0, stream>>>(load, store, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
struct DispatchSoftmaxBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols % 2 == 0) {
return LaunchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, 2, algorithm>(
stream, load, store, rows, cols);
} else {
return LaunchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, 1, algorithm>(
stream, load, store, rows, cols);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
inline cudaError_t DispatchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxBlockUncachedImplPackSize<LOAD, STORE, ComputeType, algorithm>()(
stream, load, store, rows, cols);
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols <= 1024) {
return DispatchSoftmaxWarpImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load, store, rows, cols);
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err =
TryDispatchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load, store, rows, cols, &dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load, store, rows, cols);
}
return cudaSuccess;
}
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load, store, rows, cols);
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols <= 1024) {
return DispatchSoftmaxWarpImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(
stream, load, store, rows, cols);
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err =
TryDispatchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(
stream, load, store, rows, cols, &dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType,
Algorithm::kLogSoftmax>(stream, load, store,
rows, cols);
}
return cudaSuccess;
}
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(
stream, load, store, rows, cols);
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int cols_per_thread, int thread_group_width, int rows_per_access, bool padding,
Algorithm algorithm>
__global__ void SoftmaxGradWarpImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,
const int64_t cols) {
static_assert(cols_per_thread % pack_size == 0, "");
constexpr int pack_per_thread = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * thread_group_width);
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
ComputeType y_buf[rows_per_access][cols_per_thread];
ComputeType dy_buf[rows_per_access][cols_per_thread];
const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int num_global_thread_group = gridDim.x * blockDim.y;
const int lane_id = threadIdx.x;
const int64_t step = num_global_thread_group * rows_per_access;
for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {
ComputeType thread_sum[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
thread_sum[row_id] = 0;
ComputeType* row_y_buf = y_buf[row_id];
ComputeType* row_dy_buf = dy_buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) {
const int pack_offset = pack_id * pack_size;
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
load_y.template load<pack_size>(row_y_buf + pack_offset, row + row_id, col);
load_dy.template load<pack_size>(row_dy_buf + pack_offset, row + row_id, col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
thread_sum[row_id] +=
row_y_buf[pack_offset + i] * row_dy_buf[pack_offset + i];
} else if (algorithm == Algorithm::kLogSoftmax) {
thread_sum[row_id] += row_dy_buf[pack_offset + i];
} else {
__trap();
}
}
}
}
}
ComputeType warp_sum[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
warp_sum[row_id] =
WarpAllReduce<SumOp, ComputeType, thread_group_width>(thread_sum[row_id]);
}
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
ComputeType* row_y_buf = y_buf[row_id];
ComputeType* row_dy_buf = dy_buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) {
const int pack_offset = pack_id * pack_size;
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
row_dy_buf[pack_offset + i] =
(row_dy_buf[pack_offset + i] - warp_sum[row_id]) *
row_y_buf[pack_offset + i];
} else if (algorithm == Algorithm::kLogSoftmax) {
row_dy_buf[pack_offset + i] -=
Exp(row_y_buf[pack_offset + i]) * warp_sum[row_id];
} else {
__trap();
}
}
store.template store<pack_size>(row_dy_buf + pack_offset, row + row_id, col);
}
}
}
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int cols_per_thread, int thread_group_width, int rows_per_access, bool padding,
Algorithm algorithm>
inline cudaError_t LaunchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy,
STORE store, const int64_t rows, const int64_t cols) {
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
const int64_t num_blocks =
(rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
SoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, padding, algorithm>
<<<grid_dim_x, block_dim, 0, stream>>>(load_y, load_dy, store, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int cols_per_thread, int thread_group_width, int rows_per_access, Algorithm algorithm>
inline cudaError_t DispatchSoftmaxGradWarpImplPadding(cudaStream_t stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols == cols_per_thread * thread_group_width) {
return LaunchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access,
false, algorithm>(stream, load_y, load_dy, store, rows,
cols);
} else {
return LaunchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access, true,
algorithm>(stream, load_y, load_dy, store, rows, cols);
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
Algorithm algorithm>
typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchSoftmaxGradWarpImplCols(
cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,
const int64_t cols) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 2, \
algorithm>(stream, load_y, load_dy, store, \
rows, cols); \
} else { \
return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 1, \
algorithm>(stream, load_y, load_dy, store, \
rows, cols); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, \
col, kWarpSize, 1, algorithm>( \
stream, load_y, load_dy, store, rows, cols); \
}
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(3)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(5)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(7)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(9)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(11)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(13)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(15)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(17)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(19)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(21)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(23)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(25)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(27)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(29)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(31)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
Algorithm algorithm>
typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchSoftmaxGradWarpImplCols(
cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,
const int64_t cols) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 2, \
algorithm>(stream, load_y, load_dy, store, \
rows, cols); \
} else { \
return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 1, \
algorithm>(stream, load_y, load_dy, store, \
rows, cols); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, \
col, kWarpSize, 1, algorithm>( \
stream, load_y, load_dy, store, rows, cols); \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
struct DispatchSoftmaxGradWarpImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols % 2 == 0) {
return DispatchSoftmaxGradWarpImplCols<LOAD_Y, LOAD_DY, STORE, ComputeType, 2,
algorithm>(stream, load_y, load_dy, store, rows,
cols);
} else {
return DispatchSoftmaxGradWarpImplCols<LOAD_Y, LOAD_DY, STORE, ComputeType, 1,
algorithm>(stream, load_y, load_dy, store, rows,
cols);
}
}
};
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
inline cudaError_t DispatchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy,
STORE store, const int64_t rows,
const int64_t cols) {
return DispatchSoftmaxGradWarpImplPackSize<LOAD_Y, LOAD_DY, STORE, ComputeType, algorithm>()(
stream, load_y, load_dy, store, rows, cols);
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int block_size, Algorithm algorithm>
__global__ void SoftmaxGradBlockSMemImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
extern __shared__ __align__(sizeof(double)) unsigned char grad_shared_buf[];
auto* y_buf = reinterpret_cast<ComputeType*>(grad_shared_buf);
auto* dy_buf = y_buf + cols;
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = cols / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_sum = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType y_pack[pack_size];
ComputeType dy_pack[pack_size];
load_y.template load<pack_size>(y_pack, row, pack_id * pack_size);
load_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
y_buf[i * num_packs + pack_id] = y_pack[i];
dy_buf[i * num_packs + pack_id] = dy_pack[i];
if (algorithm == Algorithm::kSoftmax) {
thread_sum += y_pack[i] * dy_pack[i];
} else if (algorithm == Algorithm::kLogSoftmax) {
thread_sum += dy_pack[i];
} else {
__trap();
}
}
}
const ComputeType row_sum = BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
pack[i] = (dy_buf[i * num_packs + pack_id] - row_sum) *
y_buf[i * num_packs + pack_id];
} else if (algorithm == Algorithm::kLogSoftmax) {
pack[i] = dy_buf[i * num_packs + pack_id] -
Exp(y_buf[i * num_packs + pack_id]) * row_sum;
} else {
__trap();
}
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int block_size, Algorithm algorithm>
inline cudaError_t LaunchSoftmaxGradBlockSMemImpl(cudaStream_t stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store, int smem,
const int64_t rows, const int64_t cols) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size, algorithm>
<<<grid_dim_x, block_size, smem, stream>>>(load_y, load_dy, store, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
Algorithm algorithm>
inline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows,
const int64_t cols, bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem = cols * sizeof(ComputeType) * 2;
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_1,
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_1, algorithm>,
block_size_conf_1, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_4,
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_4, algorithm>,
block_size_conf_4, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {
*success = true;
return LaunchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_4, algorithm>(
stream, load_y, load_dy, store, smem, rows, cols);
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_3,
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_3, algorithm>,
block_size_conf_3, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {
*success = true;
return LaunchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_3, algorithm>(
stream, load_y, load_dy, store, smem, rows, cols);
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_2,
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_2, algorithm>,
block_size_conf_2, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
*success = true;
return LaunchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_2, algorithm>(
stream, load_y, load_dy, store, smem, rows, cols);
}
*success = true;
return LaunchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_1, algorithm>(stream, load_y, load_dy,
store, smem, rows, cols);
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
struct TryDispatchSoftmaxGradBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols, bool* success) {
if (cols % 2 == 0) {
return TryDispatchSoftmaxGradBlockSMemImplBlockSize<LOAD_Y, LOAD_DY, STORE, ComputeType,
2, algorithm>(
stream, load_y, load_dy, store, rows, cols, success);
} else {
return TryDispatchSoftmaxGradBlockSMemImplBlockSize<LOAD_Y, LOAD_DY, STORE, ComputeType,
1, algorithm>(
stream, load_y, load_dy, store, rows, cols, success);
}
}
};
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
inline cudaError_t TryDispatchSoftmaxGradBlockSMemImpl(cudaStream_t stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols,
bool* success) {
return TryDispatchSoftmaxGradBlockSMemImplPackSize<LOAD_Y, LOAD_DY, STORE, ComputeType,
algorithm>()(stream, load_y, load_dy, store,
rows, cols, success);
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int block_size, Algorithm algorithm>
__global__ void SoftmaxGradBlockUncachedImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = cols / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_sum = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType y_pack[pack_size];
ComputeType dy_pack[pack_size];
load_y.template load<pack_size>(y_pack, row, pack_id * pack_size);
load_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
thread_sum += y_pack[i] * dy_pack[i];
} else if (algorithm == Algorithm::kLogSoftmax) {
thread_sum += dy_pack[i];
} else {
__trap();
}
}
}
const ComputeType row_sum = BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType y_pack[pack_size];
ComputeType dy_pack[pack_size];
load_y.template load<pack_size>(y_pack, row, pack_id * pack_size);
load_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
dy_pack[i] = (dy_pack[i] - row_sum) * y_pack[i];
} else if (algorithm == Algorithm::kLogSoftmax) {
dy_pack[i] -= Exp(y_pack[i]) * row_sum;
} else {
__trap();
}
}
store.template store<pack_size>(dy_pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
Algorithm algorithm>
inline cudaError_t LaunchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
SoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size,
algorithm>
<<<grid_dim_x, block_size, 0, stream>>>(load_y, load_dy, store, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
struct DispatchSoftmaxGradBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols % 2 == 0 && cols > kWarpSize) {
return LaunchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, 2,
algorithm>(stream, load_y, load_dy, store,
rows, cols);
} else {
return LaunchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, 1,
algorithm>(stream, load_y, load_dy, store,
rows, cols);
}
}
};
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
inline cudaError_t DispatchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxGradBlockUncachedImplPackSize<LOAD_Y, LOAD_DY, STORE, ComputeType,
algorithm>()(stream, load_y, load_dy, store,
rows, cols);
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols <= 1024) {
return DispatchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kSoftmax>(stream, load_y, load_dy, store,
rows, cols);
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err = TryDispatchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE,
ComputeType, Algorithm::kSoftmax>(
stream, load_y, load_dy, store, rows, cols, &dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kSoftmax>(
stream, load_y, load_dy, store, rows, cols);
}
return cudaSuccess;
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kSoftmax>(stream, load_y, load_dy, store,
rows, cols);
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols <= 1024) {
return DispatchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kLogSoftmax>(stream, load_y, load_dy, store,
rows, cols);
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err =
TryDispatchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kLogSoftmax>(
stream, load_y, load_dy, store, rows, cols, &dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kLogSoftmax>(
stream, load_y, load_dy, store, rows, cols);
}
return cudaSuccess;
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kLogSoftmax>(stream, load_y, load_dy,
store, rows, cols);
}
} // namespace softmax
} // namespace fastfold
#endif // FASTFOLD_SOFTMAX_H_
\ No newline at end of file
#include <torch/extension.h>
at::Tensor softmax(at::Tensor input, int rows, int cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor input, int rows, int cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, int rows, int cols);
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, int rows, int cols,
float scale);
......@@ -15,8 +15,8 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
int cols, float scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &softmax, "Softmax forward (CUDA)");
m.def("backward_affine", &softmax_gradient, "Softmax backward (CUDA)");
m.def("forward", &softmax, "Softmax forward (CUDA)");
m.def("backward", &softmax_gradient, "Softmax backward (CUDA)");
m.def("fused_scale_mask_softmax_forward", &fused_scale_mask_softmax_forward,
"Softmax forward (CUDA)");
......
#include <math_constants.h>
#include <torch/extension.h>
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#include "softmax.cuh"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
......@@ -11,80 +13,559 @@
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
__inline__ __device__ float WarpAllReduceMax(float val) {
for (int mask = 1; mask < 32; mask *= 2) {
val = max(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
__inline__ __device__ float WarpAllReduceSum(float val) {
for (int mask = 1; mask < 32; mask *= 2) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
////////////////
__global__ void fastfold_softmax_fp32(float *input, float *output, int rows, int cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
float buf[32];
int lane_id = threadidx_y;
float *row_input = input + row_offset * cols;
float *row_output = output + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
buf[i] = row_input[lane_id * cols_per_thread + i];
}
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_output[lane_id * cols_per_thread + i] = __fdividef(buf[i], warp_sum);
}
}
__global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output, int rows,
int cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
float buf[32];
int lane_id = threadidx_y;
at::BFloat16 *row_input = input + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]);
}
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(__fdividef(buf[i], warp_sum));
}
}
__global__ void fastfold_softmax_grad_fp32(float *d_output, float *output, float *d_input, int rows,
int cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
float *row_d_output = d_output + row_offset * cols;
float *row_output = output + row_offset * cols;
float *row_d_input = d_input + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
y_buf[i] = row_output[lane_id * cols_per_thread + i];
dy_buf[i] = row_d_output[lane_id * cols_per_thread + i];
}
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_d_input[lane_id * cols_per_thread + i] = (dy_buf[i] - warp_sum) * y_buf[i];
}
}
__global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16 *output,
at::BFloat16 *d_input, int rows, int cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
at::BFloat16 *row_d_output = d_output + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
at::BFloat16 *row_d_input = d_input + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]);
dy_buf[i] = static_cast<float>(row_d_output[lane_id * cols_per_thread + i]);
}
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_d_input[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>((dy_buf[i] - warp_sum) * y_buf[i]);
}
}
at::Tensor softmax(at::Tensor input, int rows, int cols) {
CHECK_INPUT(input);
at::Tensor output = at::empty_like(input);
fastfold::softmax::DirectLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectStore<float, at::BFloat16> store((at::BFloat16 *)output.data_ptr(),
int64_t(cols));
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmax<decltype(load), decltype(store), float>(cuda_stream, load,
store, rows, cols);
int grid = rows / 4;
dim3 block(128);
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_fp32<<<grid, block>>>((float *)input.data_ptr(),
(float *)output.data_ptr(), rows, cols);
} else {
fastfold_softmax_bfp16<<<grid, block>>>((at::BFloat16 *)input.data_ptr(),
(at::BFloat16 *)output.data_ptr(), rows, cols);
}
return output;
}
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor input, int rows, int cols) {
CHECK_INPUT(input);
at::Tensor grad_input = at::empty_like(input);
fastfold::softmax::DirectLoad<at::BFloat16, float> load_d((at::BFloat16 *)d_output.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectStore<float, at::BFloat16> store((at::BFloat16 *)grad_input.data_ptr(),
int64_t(cols));
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmaxGrad<decltype(load), decltype(load_d), decltype(store),
float>(cuda_stream, load, load_d, store, rows, cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, int rows, int cols) {
CHECK_INPUT(output);
at::Tensor grad_input = at::empty_like(output);
int grid = rows / 4;
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
fastfold_softmax_grad_fp32<<<grid, block>>>((float *)d_output.data_ptr(),
(float *)output.data_ptr(),
(float *)grad_input.data_ptr(), rows, cols);
} else {
fastfold_softmax_grad_bfp16<<<grid, block>>>(
(at::BFloat16 *)d_output.data_ptr(), (at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)grad_input.data_ptr(), rows, cols);
}
return grad_input;
}
////////////////
__global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, float *output, int rows,
int cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
float buf[32];
int lane_id = threadidx_y;
float *row_input = input + row_offset * cols;
float *row_output = output + row_offset * cols;
float *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * CUDART_INF_F;
} else {
buf[i] = row_input[lane_id * cols_per_thread + i] * scale;
}
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_output[lane_id * cols_per_thread + i] = __fdividef(buf[i], warp_sum);
}
}
__global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloat16 *mask,
at::BFloat16 *output, int rows, int cols,
float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
float buf[32];
int lane_id = threadidx_y;
at::BFloat16 *row_input = input + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
at::BFloat16 *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * CUDART_INF_F;
} else {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]) * scale;
}
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(__fdividef(buf[i], warp_sum));
}
}
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, int rows, int cols,
float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
int head = input.sizes()[2];
at::Tensor output = at::empty_like(input);
// (const SRC* src, const int8_t* mask, int64_t row_size, SRC scale)
fastfold::softmax::ScaleMaskLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
(at::BFloat16 *)mask.data_ptr(),
int64_t(cols), int64_t(head), scale);
fastfold::softmax::DirectStore<float, at::BFloat16> store((at::BFloat16 *)output.data_ptr(),
int64_t(cols));
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmax<decltype(load), decltype(store), float>(cuda_stream, load,
store, rows, cols);
int grid = rows / 4;
dim3 block(128);
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_fp32<<<grid, block>>>(
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)output.data_ptr(), rows,
cols, scale, head);
} else {
fastfold_softmax_scale_mask_bfp16<<<grid, block>>>(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)output.data_ptr(), rows, cols, scale, head);
}
return output;
}
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
int rows, int cols, float scale) {
CHECK_INPUT(input);
__global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *output,
float *d_input, float *mask, int rows,
int cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
float *row_d_output = d_output + row_offset * cols;
float *row_output = output + row_offset * cols;
float *row_d_input = d_input + row_offset * cols;
float *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
y_buf[i] = row_output[lane_id * cols_per_thread + i];
dy_buf[i] = row_d_output[lane_id * cols_per_thread + i];
}
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
if (mask_ptr[lane_id * cols_per_thread + i] != 0) {
row_d_input[lane_id * cols_per_thread + i] =
scale * ((dy_buf[i] - warp_sum) * y_buf[i]);
} else {
row_d_input = 0;
}
}
}
__global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, at::BFloat16 *output,
at::BFloat16 *d_input, at::BFloat16 *mask,
int rows, int cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
at::BFloat16 *row_d_output = d_output + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
at::BFloat16 *row_d_input = d_input + row_offset * cols;
at::BFloat16 *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]);
dy_buf[i] = static_cast<float>(row_d_output[lane_id * cols_per_thread + i]);
}
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
if (mask_ptr[lane_id * cols_per_thread + i] != 0) {
row_d_input[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(scale * ((dy_buf[i] - warp_sum) * y_buf[i]));
} else {
row_d_input = 0;
}
}
}
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor output,
at::Tensor mask, int rows, int cols, float scale) {
CHECK_INPUT(output);
CHECK_INPUT(mask);
int head = input.sizes()[2];
at::Tensor grad_input = at::empty_like(input);
fastfold::softmax::DirectLoad<at::BFloat16, float> load_d((at::BFloat16 *)d_output.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
int64_t(cols));
// (DST* dst, const int8_t* mask, int64_t row_size, DST scale)
fastfold::softmax::ScaleMaskStore<float, at::BFloat16> store(
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), int64_t(cols),
int64_t(head), scale);
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmaxGrad<decltype(load), decltype(load_d), decltype(store),
float>(cuda_stream, load, load_d, store, rows, cols);
int head = output.sizes()[2];
at::Tensor grad_input = at::empty_like(output);
int grid = rows / 4;
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_grad_fp32<<<grid, block>>>(
(float *)d_output.data_ptr(), (float *)output.data_ptr(),
(float *)grad_input.data_ptr(), (float *)mask.data_ptr(), rows, cols, scale, head);
} else {
fastfold_softmax_scale_mask_grad_bfp16<<<grid, block>>>(
(at::BFloat16 *)d_output.data_ptr(), (at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), rows, cols,
scale, head);
}
return grad_input;
}
////////////////
__global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask, float *bias,
float *output, int rows, int cols,
float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
float buf[32];
int lane_id = threadidx_y;
float *row_input = input + row_offset * cols;
float *row_output = output + row_offset * cols;
float *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
float *bias_ptr = bias + ((row_offset % (head * cols)) * cols);
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * CUDART_INF_F;
} else {
buf[i] = row_input[lane_id * cols_per_thread + i] * scale +
bias_ptr[lane_id * cols_per_thread + i];
}
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_output[lane_id * cols_per_thread + i] = __fdividef(buf[i], warp_sum);
}
}
__global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::BFloat16 *mask,
at::BFloat16 *bias, at::BFloat16 *output,
int rows, int cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
float buf[32];
int lane_id = threadidx_y;
at::BFloat16 *row_input = input + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
at::BFloat16 *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
at::BFloat16 *bias_ptr = bias + ((row_offset % (head * cols)) * cols);
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * CUDART_INF_F;
} else {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]) * scale;
buf[i] += static_cast<float>(bias_ptr[lane_id * cols_per_thread + i]);
}
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(__fdividef(buf[i], warp_sum));
}
}
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
int rows, int cols, float scale) {
CHECK_INPUT(input);
......@@ -92,40 +573,45 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
CHECK_INPUT(bias);
int head = input.sizes()[2];
at::Tensor output = at::empty_like(input);
// (const SRC* src, const int8_t* mask, int64_t row_size, SRC scale)
fastfold::softmax::ScaleMaskBiasLoad<at::BFloat16, float> load(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)bias.data_ptr(), int64_t(cols), int64_t(head), scale);
fastfold::softmax::DirectStore<float, at::BFloat16> store((at::BFloat16 *)output.data_ptr(),
int64_t(cols));
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmax<decltype(load), decltype(store), float>(cuda_stream, load,
store, rows, cols);
int grid = rows / 4;
dim3 block(128);
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_bias_fp32<<<grid, block>>>(
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(),
(float *)output.data_ptr(), rows, cols, scale, head);
} else {
fastfold_softmax_scale_mask_bias_bfp16<<<grid, block>>>(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)output.data_ptr(), rows, cols, scale,
head);
}
return output;
}
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input,
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor output,
at::Tensor mask, at::Tensor bias, int rows,
int cols, float scale) {
CHECK_INPUT(input);
CHECK_INPUT(output);
CHECK_INPUT(mask);
int head = input.sizes()[2];
// CHECK_INPUT(bias);
at::Tensor grad_input = at::empty_like(input);
fastfold::softmax::DirectLoad<at::BFloat16, float> load_d((at::BFloat16 *)d_output.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
int64_t(cols));
// (DST* dst, const int8_t* mask, int64_t row_size, DST scale)
fastfold::softmax::ScaleMaskStore<float, at::BFloat16> store(
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), int64_t(cols),
int64_t(head), scale);
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmaxGrad<decltype(load), decltype(load_d), decltype(store),
float>(cuda_stream, load, load_d, store, rows, cols);
int head = output.sizes()[2];
at::Tensor grad_input = at::empty_like(output);
int grid = rows / 4;
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_grad_fp32<<<grid, block>>>(
(float *)d_output.data_ptr(), (float *)output.data_ptr(),
(float *)grad_input.data_ptr(), (float *)mask.data_ptr(), rows, cols, scale, head);
} else {
fastfold_softmax_scale_mask_grad_bfp16<<<grid, block>>>(
(at::BFloat16 *)d_output.data_ptr(), (at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), rows, cols,
scale, head);
}
return grad_input;
}
\ No newline at end of file
// modified from https://github.com/NVIDIA/apex
#include <ATen/ATen.h>
#include "compat.h"
......
......@@ -14,7 +14,7 @@ class SoftmaxAffineFunction(torch.autograd.Function):
input_ = input.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.forward_affine(input_, ctx.rows, ctx.cols)
output = fastfold_softmax_cuda.forward(input_, ctx.rows, ctx.cols)
ctx.save_for_backward(output)
return output
......@@ -25,7 +25,7 @@ class SoftmaxAffineFunction(torch.autograd.Function):
output = ctx.saved_tensors[0]
grad_input = None
grad_input = fastfold_softmax_cuda.backward_affine(grad_output.contiguous(), output,
grad_input = fastfold_softmax_cuda.backward(grad_output.contiguous(), output,
ctx.rows, ctx.cols)
return grad_input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment