Commit c2b62b7f authored by JR_ZZU's avatar JR_ZZU 🌴
Browse files

delete origin files

parent 2a4864d5
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
#include "ln.h"
/*
Supported Type combinations:
input compute weights output
=======================================
fp32 fp32 fp32 fp32
fp16 fp32 fp16 fp16
bf16 fp32 bf16 bf16
fp32 fp32 fp16 fp16
fp32 fp32 bf16 bf16
Remarks:
Output type = Weight type
Compute always in FP32
*/
namespace layer_norm {
// Create registries and provide runtime versions of config hash functions.
// FwdRegistry FWD_FUNCS;
// BwdRegistry BWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
uint32_t get_type_id(torch::Dtype dtype){
if( dtype == torch::kFloat16 ) {
return TypeId<fp16>::Value;
} else if( dtype == torch::kBFloat16 ) {
return TypeId<bf16>::Value;
} else if( dtype == torch::kFloat32 ) {
return TypeId<fp32>::Value;
} else {
TORCH_CHECK(false, "Type not supported: ", dtype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
using namespace layer_norm;
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | (get_type_id(ctype) << 6);
uint64_t launcher_key = (type_key << 32) | hidden_size;
return launcher_key;
}
} // namespace layer_norm
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size));
if( iter != layer_norm::FWD_FUNCS.end() ) {
return iter->second;
} else {
TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size));
if( iter != layer_norm::BWD_FUNCS.end() ) {
return iter->second;
} else {
TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size
const float epsilon
) {
auto itype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = wtype;
auto ctype = torch::kFloat32;
TORCH_CHECK(beta.scalar_type() == wtype);
TORCH_CHECK(x.is_cuda())
TORCH_CHECK(gamma.is_cuda())
TORCH_CHECK(beta.is_cuda())
TORCH_CHECK(x.is_contiguous());
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
const int rows = sizes[0];
const int cols = sizes[1];
auto hidden_size = gamma.numel();
TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(hidden_size == cols);
TORCH_CHECK(epsilon >= 0.f);
auto opts = x.options();
auto z = torch::empty(sizes, opts.dtype(otype));
auto mu = torch::empty({ rows }, opts.dtype(ctype));
auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
launch_params.props = at::cuda::getCurrentDeviceProperties();
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
// Request the kernel launcher.
auto launcher = get_fwd_launcher(wtype, itype, otype, ctype, hidden_size);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
at::Tensor workspace, barrier;
// Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.beta = beta.data_ptr();
params.z = z.data_ptr();
params.epsilon = epsilon;
if( launch_params.barrier_size > 0 ) {
auto options = x.options();
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
// Launch the kernel.
launcher(launch_params, false);
return { z, mu, rsigma };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> ln_bwd(const at::Tensor &dz, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size
const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma // hidden_size
) {
auto itype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = wtype;
auto ctype = torch::kFloat32;
TORCH_CHECK(dz.dtype() == otype);
TORCH_CHECK(mu.dtype() == ctype);
TORCH_CHECK(rsigma.dtype() == ctype);
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(dz.is_cuda());
TORCH_CHECK(mu.is_cuda());
TORCH_CHECK(rsigma.is_cuda());
TORCH_CHECK(gamma.is_cuda());
TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(dz.is_contiguous());
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dz.sizes() == sizes);
auto rows = sizes[0];
auto cols = sizes[1];
auto hidden_size = gamma.numel();
TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(mu.sizes() == rsigma.sizes());
TORCH_CHECK(gamma.numel() == cols);
auto options = x.options();
auto dx = torch::empty_like(x);
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
launch_params.props = at::cuda::getCurrentDeviceProperties();
auto launcher = get_bwd_launcher(wtype, itype, otype, ctype, hidden_size);
launcher(launch_params, true);
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype));
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype));
at::Tensor workspace, barrier;
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.dz = dz.data_ptr();
params.dx = dx.data_ptr();
params.dbeta = dbeta.data_ptr();
params.dgamma = dgamma.data_ptr();
params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr();
if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this?
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
launcher(launch_params, false);
return { dx, dgamma, dbeta, dgamma_part, dbeta_part };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA LayerNorm";
m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel");
m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel");
}
#pragma once
namespace layer_norm {
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_kernel(layer_norm::BwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum;
constexpr float rn = 1.f / float(COLS);
Wvec gamma[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
dz_sum[it].data.elt[jt] += dz_tmp;
}
}
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
dx[it].data.elt[jt] = dx_tmp;
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
if( WARPS_M == 1 ) {
idx = r * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dz_sum[NUM_RES];
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
for( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
__syncthreads();
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dzy_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
for( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for( int jt = 0; jt < NUM_RES; jt++ ) {
*dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
for( int jt = 0; jt < NUM_RES; jt++ ) {
*dbeta_part = cta_dz_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA;
}
}
}
template<typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
void ln_bwd_finalize_kernel(BwdParams params)
{
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type;
Sum<reduce_t> sum;
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local));
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
dbeta_part.load_from(params.dbeta_part, idx);
dgamma_part.load_from(params.dgamma_part, idx);
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
}
}
void * smem_gamma = smem_;
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
const int write_row = warp;
const int write_col = lane ^ write_row;
const int write_idx = write_row * THREADS_PER_WARP + write_col;
dgamma_local.store_to(smem_gamma, write_idx);
dbeta_local.store_to(smem_beta, write_idx);
__syncthreads();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
// More than one iter iff ROWS_PER_CTA < 32.
for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
const int read_row = lane;
const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col;
memset(&dbeta_local, 0, sizeof(dbeta_local));
memset(&dgamma_local, 0, sizeof(dgamma_local));
// Load beta and gamma transposed
if(read_row < Kernel_traits::ROWS_PER_CTA){
dbeta_local.load_from(smem_beta, read_idx);
dgamma_local.load_from(smem_gamma, read_idx);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
compute_t b_i = dbeta_local.data.elt[it];
compute_t g_i = dgamma_local.data.elt[it];
b_i = reducer.allreduce(b_i, sum);
g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i;
dbeta_local.data.elt[it] = b_i;
}
// Leader stores the result at the current column.
if(lane == 0){
dgamma_local.store_to(smem_gamma_out, w);
dbeta_local.store_to(smem_beta_out, w);
}
}
// All writes done.
__syncthreads();
// Pack and store: 2-wide stores with half the threads.
if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
dgamma_vec2.load_from(smem_gamma_out, lane);
dbeta_vec2.load_from(smem_beta_out, lane);
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
}
dgamma_out2.store_to(params.dgamma, col_out);
dbeta_out2.store_to(params.dbeta, col_out);
}
}
}
} // namespace layer_norm
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_bwd_kernels.cuh"
using namespace layer_norm;
BwdRegistry layer_norm::BWD_FUNCS;
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL
>
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG_MAIN
>;
auto kernel = &ln_bwd_kernel<Kernel_traits>;
if( configure_params ) {
int ctas_per_sm;
cudaError_t status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
#if defined(__HIP_PLATFORM_HCC__)
CHECK_CUDA(hipFuncSetAttribute((const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
#else
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
#endif
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
#if defined(__HIP_PLATFORM_HCC__)
hipLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
#else
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
#endif
}
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
output_t,
compute_t,
index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
}
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
// REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
// REGISTER_BWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
// REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
// REGISTER_BWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4);
REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4);
REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4);
// REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);
// REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);
REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_fwd_kernels.cuh"
using namespace layer_norm;
FwdRegistry layer_norm::FWD_FUNCS;
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG
>
void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
auto kernel = &ln_fwd_kernel<Kernel_traits>;
if( configure_params ) {
int ctas_per_sm;
cudaError_t status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
#if defined(__HIP_PLATFORM_HCC__)
CHECK_CUDA(hipFuncSetAttribute((const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
#else
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
#endif
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
#if defined(__HIP_PLATFORM_HCC__)
hipLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
#else
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
#endif
}
}
REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);
#pragma once
#if defined(__HIP_PLATFORM_HCC__)
#include "ln_utils.cuh"
#else
#include "ln.h"
#endif
namespace layer_norm {
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_kernel(FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
Wvec gamma[LDGS];
Wvec beta[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx);
idx += VEC_COLS_PER_LDG;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
x[it].load_from(params.x, idx);
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
}
stats_t s = stats.compute(xf, rn);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
mu_ptr[row] = mu;
}
compute_t rs = rsqrtf(rn * m2 + params.epsilon);
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
rs_ptr[row] = rs;
}
Ovec z[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu));
output_t g_ij = gamma[it].data.elt[jt];
output_t b_ij = beta[it].data.elt[jt];
z[it].data.elt[jt] = (g_ij * y_ij + b_ij);
}
z[it].store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
}
}
} // namespace layer_norm
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace layer_norm {
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_
>
struct Kernel_traits_base {
using weight_t = weight_t_;
using input_t = input_t_;
using output_t = output_t_;
using compute_t = compute_t_;
using index_t = index_t_;
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
enum { THREADS_PER_WARP = 32 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
THREADS_PER_CTA_>
>
struct Kernel_traits_finalize : public Base {
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
// Bytes per global load from the input.
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
// Number of elements fetched by a global load.
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
// Bytes per global store of the weights.
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
// Shared memory size to transpose the CTA result.
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
// Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA.
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
// The type of the reducer.
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
// Condition for the whole CTA to participate in syncthreads.
static_assert(COLS % Base::THREADS_PER_WARP == 0);
enum { CTAS = COLS / Base::THREADS_PER_WARP };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t HIDDEN_SIZE_,
uint32_t CTAS_PER_ROW_,
uint32_t WARPS_M_,
uint32_t WARPS_N_,
uint32_t BYTES_PER_LDG_ = 16,
typename Base = Kernel_traits_base<
HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
WARPS_M_*WARPS_N_*THREADS_PER_WARP
>
>
struct Kernel_traits : public Base {
using input_t = typename Base::input_t;
using weight_t = typename Base::weight_t;
using compute_t = typename Base::compute_t;
using output_t = typename Base::output_t;
using index_t = typename Base::index_t;
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = HIDDEN_SIZE_ };
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
// Assume that each thread can handle the same number of elements in the output and weights as in the input.
static_assert(sizeof(input_t) >= sizeof(output_t));
static_assert(sizeof(input_t) >= sizeof(weight_t));
// The number of columns fetched per load from input: one per thread.
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
// The total number of vectorized loads/stores per hidden vector.
enum { VEC_COLS = COLS / ELTS_PER_LDG };
// The number of loads per thread for the input.
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
//static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
#pragma once
#include <cassert>
#if defined(__HIP_PLATFORM_HCC__)
#include "hip/hip_fp16.h"
#include "hip/hip_bfloat16.h"
#else
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#endif
#include "ln.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr uint32_t THREADS_PER_WARP = 32;
////////////////////////////////////////////////////////////////////////////////////////////////////
inline void check_cuda_(cudaError_t status, const char *file, int line) {
if( status != cudaSuccess ) {
fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line);
exit(status);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(ans) \
{ check_cuda_((ans), __FILE__, __LINE__); }
////////////////////////////////////////////////////////////////////////////////////////////////////
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
const bool configure_params) { \
launch_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 operator+(const float2 & a, const float2 & b){
return {a.x + b.x, a.y + b.y};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void operator+=(float2 & a, const float2 & b){
a.x += b.x;
a.y += b.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct Sum {
inline __device__ Sum(){}
inline __device__ T operator()(const T &a, const T &b){
return a + b;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){
#if defined(__HIP_PLATFORM_HCC__)
return __shfl_xor(x, idx);
#else
return __shfl_xor_sync(uint32_t(-1), x, idx);
#endif
}
template<>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx){
return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) };
}
template<typename T>
inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){
#if defined(__HIP_PLATFORM_HCC__)
return __shfl_down(x, idx);
#else
return __shfl_down_sync(uint32_t(-1), x, idx);
#endif
}
template<>
inline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx){
return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint16 {
uint4 u;
uint4 v;
uint4 s;
uint4 t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint8 {
uint4 u;
uint4 v;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES>
struct BytesToType {};
template<>
struct BytesToType<64> {
using Type = uint16;
static_assert(sizeof(Type) == 64);
};
template<>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<>
struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<>
struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<>
struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<>
struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<>
struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeToVec2 {};
template<>
struct TypeToVec2<float> {
using Type = float2;
};
template<>
struct TypeToVec2<half> {
using Type = half2;
};
#if 0
template<>
struct TypeToVec2<nv_bfloat16> {
using Type = nv_bfloat162;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int INDEX>
struct Get {
template<typename T, typename R>
static inline __device__ R of(const T &vec);
};
template<>
template<typename T, typename R>
inline __device__ R Get<0>::of(const T &vec) {
return vec.x;
}
template<>
template<typename T, typename R>
inline __device__ R Get<1>::of(const T &vec) {
return vec.y;
}
template<>
template<typename T, typename R>
inline __device__ R Get<2>::of(const T &vec) {
return vec.z;
}
template<>
template<typename T, typename R>
inline __device__ R Get<3>::of(const T &vec) {
return vec.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Src, typename Dst>
struct Converter{
static inline __device__ Dst convert(const Src &from) {
return Dst(from);
}
};
template<>
struct Converter<float2, half2>{
static inline __device__ half2 convert(const float2 &x) {
return __float22half2_rn(x);
}
};
#if defined(__HIP_PLATFORM_HCC__)
template<>
struct Converter<float, half>{
static inline __device__ half convert(const float &x) {
return __float2half(x);
}
};
template<>
struct Converter<half, float>{
static inline __device__ float convert(const half &x) {
return __half2float(x);
}
};
template<>
struct Converter<float, hip_bfloat16>{
static inline __device__ hip_bfloat16 convert(const float &x) {
return hip_bfloat16::round_to_bfloat16(x);
}
};
template<>
struct Converter<hip_bfloat16, float>{
static inline __device__ float convert(const hip_bfloat16 &x) {
return float(x);
}
};
#endif
#if 0
template<>
struct Converter<float2, nv_bfloat162>{
static inline __device__ nv_bfloat162 convert(const float2 &x) {
#if __CUDA_ARCH__ >= 800
return __float22bfloat162_rn(x);
#else
union {
nv_bfloat162 raw;
nv_bfloat16 x;
nv_bfloat16 y;
} tmp;
tmp.x = __float2bfloat16_rn(x.x);
tmp.y = __float2bfloat16_rn(x.y);
return tmp.raw;
#endif
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct Zeros{
static inline __device__ T get() {
return T(0.f);
}
};
template<>
struct Zeros<float2>{
static inline __device__ float2 get() {
return make_float2(0.f, 0.f);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Elt_type, uint32_t NUM_ELT>
struct Vec {
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
using Vec_type = typename BytesToType<BYTES>::Type;
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
Alias_type data;
template<typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
other.data.elt[it] = S(this->data.elt[it]);
}
}
template<typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = op(it);
}
}
inline __device__ void load_from(const void *base_ptr, const size_t idx) {
this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
}
inline __device__ void store_to(void *base_ptr, const size_t idx) {
static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<uint32_t CTAS_PER_ROW>
struct InterCTASync {
template<typename Params>
inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn)
: phase_counter_(0)
, b0_(params.barrier + bidm) // The barrier for this group of CTAs.
, b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs.
{
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
}
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
#if defined(__HIP_PLATFORM_HCC__)
atomicAdd(barrier, step);
for( int found = -1; found != expected; ) {
// asm volatile("global_load_dword %0, %1, off;" : "=v"(found) : "v"(barrier));
found = atomicCAS(barrier, expected, expected);
}
#else
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
for( int found = -1; found != expected; ) {
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
}
#endif
}
inline __device__ void sync(){
// ALL THREADS MUST ENTER!
// We switch barrier every iteration.
int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
// We decrement every other iteration.
bool dec = phase_counter_ & 0x2;
int step = dec ? -1 : 1;
int expected = dec ? 0 : CTAS_PER_ROW;
// There are only 4 phases: up/down for b0/b1.
phase_counter_ = (phase_counter_ + 1) & 0x3;
if( threadIdx.x == 0 ) {
spin_wait_(barrier, step, expected);
}
// CTA waits for thread 0
__syncthreads();
}
int phase_counter_;
int * b0_;
int * b1_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
using InterCTASync = InterCTASync<CTAS_PER_ROW>;
using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
using Type = typename Base::Type;
enum { SMEM_BYTES = Base::SMEM_BYTES };
enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
// size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES };
template<typename Params>
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, inter_cta_(params, bidm, bidn)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
{
}
template<typename Op>
inline __device__ T allreduce(T data, Op &op) {
data = Base::reduce(data, op);
// We switch workspace every iteration.
T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
// Warp leaders 0 hold the CTA-local results.
if( this->warp_n_ == 0 && this->lane_ == 0 ) {
workspace[bidn_] = data;
}
inter_cta_.sync();
static_assert(CTAS_PER_ROW <= 32);
T total = Zeros<T>::get();
if(this->lane_ < CTAS_PER_ROW){
total = workspace[this->lane_];
}
total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
return total;
}
InterCTASync inter_cta_;
T *w0_;
T *w1_;
int bidn_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M>
struct Reducer<T, 1, WARPS_M, 1> {
using Type = T;
enum { SMEM_BYTES = 0 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: warp_n_(warp_n)
, lane_(lane)
{
}
template<typename Op>
static inline __device__ T allreduce_(T data, Op &op) {
#pragma unroll
for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {
data = op(data, warp_shuffle_xor(data, it));
}
return data;
}
template<typename Op>
inline __device__ T allreduce(T data, Op &op) {
return allreduce_(data, op);
}
template<typename Op>
inline __device__ T reduce(T data, Op &op){
// only lane 0 holds the result!
#pragma unroll
for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {
data = op(data, warp_shuffle_down(data, it));
}
return data;
}
int warp_n_;
int lane_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
using Base = Reducer<T, 1, WARPS_M, 1>;
using Type = T;
enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, use0_(true)
{
smem0_ = &static_cast<T *>(smem)[warp_m * WARPS_N];
smem1_ = smem0_ + WARPS_M * WARPS_N;
}
template<typename Op>
inline __device__ T allreduce(T data, Op & op) {
T * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
data = Base::reduce(data, op);
if( this->lane_ == 0 ) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
#pragma unroll
for( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
return out;
}
template<typename Op>
inline __device__ T reduce(T data, Op &op) {
T * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// only intra-CTA group leader holds the result!
data = Base::reduce(data, op);
if( this->lane_ == 0 ) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
if( this->warp_n_ == 0 && this->lane_ == 0 ) {
#pragma unroll
for( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
}
return out;
}
T * smem0_;
T * smem1_;
bool use0_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active){
//Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
#pragma unroll
for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
// Exchange
T n_b = warp_shuffle_down(n_a, step);
T m_b = warp_shuffle_down(m_a, step);
T m2_b = warp_shuffle_down(m2_a, step);
// Update
const T n_ab = n_a + n_b; // We can handle one of them being 0, not both.
const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(
const T delta = m_a - m_b;
const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
n_a = n_ab;
m_a = m_ab;
m2_a = m2_ab;
}
// Intra-warp broadcast (only lane 0 has valid stats).
#if defined(__HIP_PLATFORM_HCC__)
m_a = __shfl(m_a, 0);
m2_a = __shfl(m2_a, 0);
#else
m_a = __shfl_sync(uint32_t(-1), m_a, 0);
m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats {
// This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.
using InterCTASync = InterCTASync<CTAS_PER_ROW>;
using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
using stats_t = typename BlockStats::stats_t;
enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
template<typename Params>
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: inter_cta_(params, bidm, bidn)
, block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
, warp_n_(warp_n)
, lane_(lane)
{
}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
// TODO rn is not really needed here..
constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
stats_t block_stats = block_stats_.compute(elts, block_rn);
stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
if( warp_n_ == 0 && lane_ == 0 ) {
workspace[bidn_] = block_stats;
}
// Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
inter_cta_.sync();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume CTA group size in N less than 32, such that we can finalize with a single warp.
static_assert(CTAS_PER_ROW <= 32);
// Every warp does the final reduction locally.
if( lane_ < CTAS_PER_ROW ) {
stats_t result = workspace[lane_];
n = ELTS_PER_ROW_PER_CTA;
m = layer_norm::Get<0>::of<stats_t, T>(result);
m2 = layer_norm::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
return { m, m2 };
}
InterCTASync inter_cta_;
BlockStats block_stats_;
stats_t *w0_;
stats_t *w1_;
int bidn_;
int warp_n_;
int lane_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats<T, 1, WARPS_M, WARPS_N> {
using WarpStats = Stats<T, 1, WARPS_M, 1>;
using stats_t = typename WarpStats::stats_t;
enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
template<typename Params>
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, use0_(true)
{
smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;
smem1_ = smem0_ + WARPS_M * WARPS_N;
}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
stats_t * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// Compute warp local for all WARPS_N
constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
stats_t warp_stats = warp_stats_.compute(elts, warp_rn);
//Each warp warp leader stores its stats
const auto warp_n = warp_stats_.reducer_.warp_n_;
const auto lane = warp_stats_.reducer_.lane_;
if( lane == 0 ) {
smem[warp_n] = warp_stats;
}
__syncthreads();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume that there are less than 32 warps, such that we can finalize with a single warp
static_assert(WARPS_N <= 32);
if(lane < WARPS_N){
stats_t result = smem[lane];
n = N * THREADS_PER_WARP;
m = layer_norm::Get<0>::of<stats_t, T>(result);
m2 = layer_norm::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, WARPS_N);
return { m, m2 };
}
WarpStats warp_stats_;
stats_t * smem0_;
stats_t * smem1_;
bool use0_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M>
struct Stats<T, 1, WARPS_M, 1> {
using stats_t = typename TypeToVec2<T>::Type;
// The simple Warp reducer.
using Reducer = Reducer<T, 1, WARPS_M, 1>;
enum { SMEM_BYTES = 0 };
template<typename Params>
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem)
{
}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
auto sum = Sum<T>();
T m = Zeros<T>::get();
#pragma unroll
for( int it = 0; it < N; it++ ) {
m += elts[it];
}
m = reducer_.allreduce(m, sum) * rn;
T m2 = Zeros<T>::get();
#pragma unroll
for( int it = 0; it < N; it++ ) {
T diff = (elts[it] - m);
m2 += diff * diff;
}
m2 = reducer_.allreduce(m2, sum);
return {m, m2};
}
Reducer reducer_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "softmax.cuh"
// symbol to be automatically resolved by PyTorch libs
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const half *pad_mask, float dropout_prob) {
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
// const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *input_ptr = static_cast<void *>(input.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
attn_batches * q_seq_len);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
if (is_training) {
// use at:: function so that C++ version generates the same random mask as
// python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {dropout_results, dropout_mask, softmax_results};
}
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
// const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(output_grads.data_ptr()),
static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
// backward pass is completely in-place
return output_grads;
}
} // namespace additive_mask_softmax_dropout
} // namespace fused_softmax
} // namespace multihead_attn
#pragma once
#include <ATen/ATen.h>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h>
namespace {
constexpr int UNROLL = 4;
} // namespace
template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void
apex_fused_dropout_kernel(scalar_t const *inputs, scalar_t *outputs,
uint8_t *mask, IndexType totalElements, accscalar_t p,
std::pair<uint64_t, uint64_t> seeds) {
accscalar_t pinv = accscalar_t(1) / p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, idx, seeds.second, &state);
IndexType rounded_size =
((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL];
rand.x = rand.x <= p;
rand.y = rand.y <= p;
rand.z = rand.z <= p;
rand.w = rand.w <= p;
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
src[ii] = inputs[li];
}
}
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
outputs[li] = src[ii] * (&rand.x)[ii] * pinv;
mask[li] = (uint8_t)(&rand.x)[ii];
}
}
__syncthreads();
}
}
template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void apex_dropout_add_kernel(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs, uint8_t *mask,
IndexType totalElements, accscalar_t p,
std::pair<uint64_t, uint64_t> seeds) {
accscalar_t pinv = accscalar_t(1) / p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, idx, seeds.second, &state);
IndexType rounded_size =
((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL];
scalar_t add_src[UNROLL];
rand.x = rand.x <= p;
rand.y = rand.y <= p;
rand.z = rand.z <= p;
rand.w = rand.w <= p;
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
src[ii] = inputs[li];
add_src[ii] = add_inputs[li];
}
}
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv;
outputs[li] =
static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);
mask[li] = (uint8_t)(&rand.x)[ii];
}
}
__syncthreads();
}
}
template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void apex_add_kernel(scalar_t const *inputs,
scalar_t const *add_inputs, scalar_t *outputs,
IndexType totalElements) {
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size =
((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t src[UNROLL];
scalar_t add_src[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
src[ii] = inputs[li];
add_src[ii] = add_inputs[li];
}
}
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
outputs[li] = src[ii] + add_src[ii];
}
}
__syncthreads();
}
}
template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void apex_masked_scale_kernel(scalar_t const *inputs,
scalar_t *outputs, uint8_t const *mask,
IndexType totalElements,
accscalar_t scale) {
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size =
((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t src[UNROLL];
scalar_t msk[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
src[ii] = static_cast<scalar_t>(inputs[li]);
msk[ii] = static_cast<scalar_t>(mask[li]);
}
}
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
outputs[li] = static_cast<accscalar_t>(src[ii]) * scale *
static_cast<accscalar_t>(msk[ii]);
}
}
}
}
template <typename scalar_t, typename accscalar_t, typename IndexType>
void apex_fused_dropout_cuda(scalar_t const *inputs, scalar_t *outputs,
uint8_t *mask, IndexType totalElements,
accscalar_t p) {
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
// number of times random will be generated per thread, to offset philox
// counter in the random state
int64_t counter_offset =
((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
counter_offset);
}
apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, outputs, mask, totalElements, p, rng_engine_inputs);
C10_CUDA_CHECK(cudaGetLastError());
}
template <typename scalar_t, typename accscalar_t, typename IndexType>
void apex_dropout_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs,
scalar_t *outputs, uint8_t *mask,
IndexType totalElements, accscalar_t p) {
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
// number of times random will be generated per thread, to offset philox
// counter in the random state
int64_t counter_offset =
((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
counter_offset);
}
apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, add_inputs, outputs, mask, totalElements, p,
rng_engine_inputs);
C10_CUDA_CHECK(cudaGetLastError());
}
template <typename scalar_t, typename accscalar_t, typename IndexType>
void apex_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs,
scalar_t *outputs, IndexType totalElements) {
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
apex_add_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, add_inputs, outputs, totalElements);
C10_CUDA_CHECK(cudaGetLastError());
}
template <typename scalar_t, typename accscalar_t, typename IndexType>
void apex_masked_scale_cuda(scalar_t const *inputs, scalar_t *outputs,
uint8_t const *mask, IndexType totalElements,
accscalar_t scale) {
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, outputs, mask, totalElements, scale);
C10_CUDA_CHECK(cudaGetLastError());
}
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace multihead_attn {
namespace encdec {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob) {
const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0);
const int k_seq_len = inputs_kv.size(0);
const int batches_q = sequences * q_seq_len;
const int batches_kv = sequences * k_seq_len;
const int head_dim = embed_dim / heads;
const int output_lin_q_dim = embed_dim;
const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs_q.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_q_results =
torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor input_lin_kv_results =
torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void *q_lin_results_ptr = static_cast<void *>(input_lin_q_results.data_ptr());
void *k_lin_results_ptr =
static_cast<void *>(input_lin_kv_results.data_ptr());
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
char b_layout_n{'n'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_q_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_kv_dim,
batches_kv,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(q_lin_results_ptr),
lead_dim_q,
batch_stride_q,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
assert(softmax_success);
if (is_training) {
apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
(1.0f - dropout_prob));
}
// Matmul2
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_q_results,
input_lin_kv_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0);
const int k_seq_len = inputs_kv.size(0);
const int batches_q = sequences * q_seq_len;
const int batches_kv = sequences * k_seq_len;
const int head_dim = embed_dim / heads;
const int output_lin_q_dim = embed_dim;
const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
torch::Tensor input_q_grads = torch::empty_like(inputs_q);
torch::Tensor input_kv_grads = torch::empty_like(inputs_kv);
torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q);
torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv);
torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
at::Tensor input_lin_kv_output_grads =
torch::empty_like(input_lin_kv_results);
auto q_lin_results_ptr = static_cast<half *>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half *>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr =
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim;
auto q_lin_grads_ptr =
static_cast<half *>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr()) + head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef BACKWARD_PASS_GUARD
flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(matmul2_grads.data_ptr()),
static_cast<at::Half*>(matmul2_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
dropout_elems,
(1.0 / (1.0 - dropout_prob)));
// Softmax Grad
bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim_kv,
batch_stride_kv,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim_q,
batch_stride_q,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_q,
output_lin_q_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_q_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_kv,
output_lin_kv_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_kv_dim,
batches_kv,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_q_grads,
input_kv_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads
};
}
} // end namespace rocblas_gemmex
} // end namespace encdec
} // end namespace multihead_attn
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "layer_norm.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace multihead_attn {
namespace encdec_norm_add {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
)
{
const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0);
const int k_seq_len = inputs_kv.size(0);
const int batches_q = sequences * q_seq_len;
const int batches_kv = sequences * k_seq_len;
const int total_tokens_q = batches_q * embed_dim;
const int head_dim = embed_dim / heads;
const int output_lin_q_dim = embed_dim;
const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs_q.options().requires_grad(false);
auto lyr_nrm_options = act_options.dtype(torch::kFloat32);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor lyr_nrm_mean = torch::empty({batches_q}, lyr_nrm_options);
torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options);
torch::Tensor input_lin_q_results =
torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor input_lin_kv_results =
torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options);
torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void *q_lin_results_ptr = static_cast<void *>(input_lin_q_results.data_ptr());
void *k_lin_results_ptr =
static_cast<void *>(input_lin_kv_results.data_ptr());
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
char b_layout_n{'n'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
HostApplyLayerNorm<at::Half, float>(
static_cast<at::Half *>(lyr_nrm_results.data_ptr()),
static_cast<float *>(lyr_nrm_mean.data_ptr()),
static_cast<float *>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half *>(inputs_q.data_ptr()),
static_cast<int>(batches_q), // n1
static_cast<int>(embed_dim), // n2
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_q_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_q_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_kv_dim,
batches_kv,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
k_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_kv_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(q_lin_results_ptr),
lead_dim_q,
batch_stride_q,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
assert(softmax_success);
if (is_training) {
apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
(1.0f - dropout_prob));
}
// Matmul2
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()),
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// End-of-block Dropout-Add
if (is_training) {
apex_dropout_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const *>(inputs_q.data_ptr()),
static_cast<at::Half *>(outputs.data_ptr()),
static_cast<uint8_t *>(dropout_add_mask.data_ptr()), total_tokens_q,
(1.0f - dropout_prob));
} else {
apex_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const *>(inputs_q.data_ptr()),
static_cast<at::Half *>(outputs.data_ptr()), total_tokens_q);
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
input_lin_q_results,
input_lin_kv_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
dropout_add_mask,
outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const &dropout_add_mask, float dropout_prob) {
const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0);
const int k_seq_len = inputs_kv.size(0);
const int batches_q = sequences * q_seq_len;
const int batches_kv = sequences * k_seq_len;
const int total_tokens_q = batches_q * embed_dim;
const int head_dim = embed_dim / heads;
const int output_lin_q_dim = embed_dim;
const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
torch::Tensor input_q_grads = torch::empty_like(inputs_q);
torch::Tensor input_kv_grads = torch::empty_like(inputs_kv);
torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights);
torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights);
torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q);
torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv);
torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations
at::Tensor dropout_add_grads = torch::empty_like(output_grads);
at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
at::Tensor input_lin_kv_output_grads =
torch::empty_like(input_lin_kv_results);
at::Tensor input_lin_q_grads = torch::empty_like(inputs_q);
auto q_lin_results_ptr = static_cast<half *>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half *>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr =
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim;
auto q_lin_grads_ptr =
static_cast<half *>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr()) + head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef BACKWARD_PASS_GUARD
flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()),
static_cast<at::Half*>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_add_mask.data_ptr()),
total_tokens_q,
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(matmul2_grads.data_ptr()),
static_cast<at::Half*>(matmul2_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
dropout_elems,
(1.0 / (1.0 - dropout_prob)));
// Softmax Grad
bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim_kv,
batch_stride_kv,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim_q,
batch_stride_q,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_q,
output_lin_q_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&beta),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_q_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_kv,
output_lin_kv_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_kv_dim,
batches_kv,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>(
static_cast<const half*>(input_lin_q_grads.data_ptr()),
static_cast<half const*>(output_grads.data_ptr()),
static_cast<const float*>(lyr_nrm_mean.data_ptr()),
static_cast<const float*>(lyr_nrm_invvar.data_ptr()),
inputs_q,
static_cast<int>(batches_q), // n1
static_cast<int>(embed_dim), // n2
static_cast<const half*>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const half*>(lyr_nrm_beta_weights.data_ptr()),
1.0e-5,
static_cast<half*>(input_q_grads.data_ptr()),
static_cast<half*>(lyr_nrm_gamma_grads.data_ptr()),
static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_q_grads, input_kv_grads, lyr_nrm_gamma_grads,
lyr_nrm_beta_grads, input_weight_q_grads, input_weight_kv_grads,
output_weight_grads};
}
} // end namespace rocblas_gemmex
} // end namespace encdec_norm_add
} // end namespace multihead_attn
\ No newline at end of file
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/DeviceUtils.cuh>
namespace {
template <typename U>
__device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) {
count = count + U(1);
U delta = curr - mu;
U lmean = mu + delta / count;
mu = lmean;
U delta2 = curr - lmean;
sigma2 = sigma2 + delta * delta2;
}
template <typename U>
__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB,
U &mu, U &sigma2, U &count) {
U delta = muB - mu;
U nA = count;
U nB = countB;
count = count + countB;
U nX = count;
if (nX > U(0)) {
nA = nA / nX;
nB = nB / nX;
mu = nA * mu + nB * muB;
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
} else {
mu = U(0);
sigma2 = U(0);
}
}
template <typename T, typename U>
__device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
const int n2, const int i1, U &mu, U &sigma2,
U *buf) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U count = U(0);
mu = U(0);
sigma2 = U(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const T *lvals = vals + i1 * n2;
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l + k]);
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
}
}
for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]);
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
U muB = WARP_SHFL(mu, srcLaneB, 32);
U countB = WARP_SHFL(count, srcLaneB, 32);
U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
U *ubuf = (U *)buf;
U *ibuf = (U *)(ubuf + blockDim.y);
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2 * wrt_y] = mu;
ubuf[2 * wrt_y + 1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
U muB = ubuf[2 * threadIdx.y];
U sigma2B = ubuf[2 * threadIdx.y + 1];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1] / U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2 / U(n2), 0, 32);
}
}
}
template <>
__device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals,
const int n1, const int n2, const int i1,
float &mu, float &sigma2, float *buf) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float count = 0.0f;
mu = float(0);
sigma2 = float(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const at::Half *lvals = vals + i1 * n2;
int l = 8 * thrx;
if ((((size_t)lvals) & 3) != 0) {
// 16 bit alignment
// first thread consumes first point
if (thrx == 0) {
float curr = static_cast<float>(lvals[0]);
cuWelfordOnlineSum(curr, mu, sigma2, count);
}
++l;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for (; l + 7 < n2; l += 8 * numx) {
for (int k = 0; k < 8; k += 2) {
float2 curr = __half22float2(*((__half2 *)(lvals + l + k)));
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
}
}
for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]);
cuWelfordOnlineSum(curr, mu, sigma2, count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
float muB = WARP_SHFL(mu, srcLaneB, 32);
float countB = WARP_SHFL(count, srcLaneB, 32);
float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
float *ubuf = (float *)buf;
float *ibuf = (float *)(ubuf + blockDim.y);
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2 * wrt_y] = mu;
ubuf[2 * wrt_y + 1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
float muB = ubuf[2 * threadIdx.y];
float sigma2B = ubuf[2 * threadIdx.y + 1];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1] / float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2 / float(n2), 0, 32);
}
}
}
template<typename U> U rsqrt(U v) {
return U(1) / sqrt(v);
}
//template<> float rsqrt(float v) {
// return rsqrtf(v);
//}
#if defined __HIP_PLATFORM_HCC__
__device__ float rsqrt(float v) { return rsqrtf(v); }
#else
template<> float rsqrt(float v) { return rsqrtf(v); }
#endif
template<> double rsqrt(double v) { return rsqrt(v); }
// template <typename U> __device__ U rsqrt(U v) { return U(1) / sqrt(v); }
// template <> __device__ float rsqrt(float v) { return rsqrtf(v); }
// template <> __device__ double rsqrt(double v) { return rsqrt(v); }
// This is the un-specialized struct. Note that we prevent instantiation of
// this struct by putting an undefined symbol in the function body so it won't
// compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T> struct SharedMemory;
template <> struct SharedMemory<float> {
__device__ float *getPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
template <> struct SharedMemory<double> {
__device__ double *getPointer() {
extern __shared__ double s_double[];
return s_double;
}
};
template <typename T, typename U>
__global__ void
cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean,
U *__restrict__ invvar, const T *__restrict__ vals,
const int n1, const int n2, const U epsilon,
const T *__restrict__ gamma, const T *__restrict__ beta) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U *buf = shared.getPointer();
U mu, sigma2;
cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf);
const T *lvals = vals + i1 * n2;
T *ovals = output_vals + i1 * n2;
U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = gamma[i] * static_cast<T>(c_invvar * (curr - mu)) + beta[i];
}
} else {
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<T>(c_invvar * (curr - mu));
}
}
if (threadIdx.x == 0 && threadIdx.y == 0) {
mean[i1] = mu;
invvar[i1] = c_invvar;
}
}
}
template <typename T, typename U>
__device__ void cuLoadWriteStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
const T *input, const T *dout, const int i1_end, const int n2,
const U *__restrict__ mean, const U *__restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] =
curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template <typename T, typename U>
__device__ void cuLoadAddStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
const T *input, const T *dout, const int i1_end, const int n2,
const U *__restrict__ mean, const U *__restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] +=
curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template <typename T, typename U>
__global__ void cuComputePartGradGammaBeta(
const T *__restrict__ dout, const T *__restrict__ input, const int n1,
const int n2, const U *__restrict__ mean, const U *__restrict__ invvar,
U epsilon, U *part_grad_gamma, U *part_grad_beta) {
const int numsegs_n1 =
(n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
const int i1_beg_plus_one =
(blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x + 1;
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
const int thr_load_row_off =
(threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U *buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y *
// blockDim.y + (blockDim.y -
// 1)*(blockDim.x/blockDim.y) elements
U *warp_buf1 = (U *)buf;
U *warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off,
row_stride, warp_buf1, warp_buf2, input, dout,
i1_end, n2, mean, invvar);
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
i1_block += blockDim.y * blockDim.y) {
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off,
row_stride, warp_buf1, warp_buf2, input, dout,
i1_end, n2, mean, invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k * blockDim.y;
int idx1 = row1 * row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template <typename T, typename U>
__global__ void
cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta,
const int part_size, const int n1, const int n2,
T *grad_gamma, T *grad_beta) {
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U *buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U *part_grad_gamma_ptr =
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U *part_grad_beta_ptr =
part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions;
++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
sum_beta += part_grad_beta_ptr[warp_offset * n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx + nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx + nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template <typename T, typename U>
__global__ void
cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid,
const T *__restrict__ input, const int n1, const int n2,
const U *__restrict__ mean, const U *__restrict__ invvar,
U epsilon, const T *gamma, T *grad_input) {
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = invvar[i1];
const T *k_input = input + i1 * n2;
const T *k_dout = dout + i1 * n2;
const T *k_dout_resid = dout_resid + i1 * n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * static_cast<U>(gamma[l + k]);
sum_loss2 +=
c_loss * static_cast<U>(gamma[l + k]) * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * static_cast<U>(gamma[l]);
sum_loss2 +=
c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32);
}
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U *buf = shared.getPointer();
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2 * wrt_i] = sum_loss1;
buf[2 * wrt_i + 1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2 * read_i];
sum_loss2 += buf[2 * read_i + 1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2 * threadIdx.x] = sum_loss1;
buf[2 * threadIdx.x + 1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y != 0) {
sum_loss1 = buf[2 * threadIdx.x];
sum_loss2 = buf[2 * threadIdx.x + 1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T *k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
const T c_resid = static_cast<T>(k_dout_resid[l]);
U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input) + c_resid;
}
} else {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
const T c_resid = static_cast<T>(k_dout_resid[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input) + c_resid;
}
}
}
}
template <typename T, typename U>
void HostApplyLayerNorm(T *output, U *mean, U *invvar, const T *input, int n1,
int n2, double epsilon, const T *gamma, const T *beta) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32, 4, 1);
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
}
template <typename T, typename U>
void HostLayerNormGradient(const T *dout, const T *dout_resid, const U *mean,
const U *invvar, const at::Tensor &input, int n1,
int n2, const T *gamma, const T *beta,
double epsilon, T *grad_input, T *grad_gamma,
T *grad_beta) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32, 4, 1);
const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);
const int nshared2_a =
2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty(
{part_size, n2},
input.options().dtype(input.scalar_type() == at::ScalarType::Half
? at::ScalarType::Float
: input.scalar_type()));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, static_cast<T *>(input.data_ptr()), n1, n2, mean, invvar,
U(epsilon), static_cast<U *>(part_grad_gamma.data_ptr()),
static_cast<U *>(part_grad_beta.data_ptr()));
const dim3 threads3(32, 8, 1);
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
static_cast<U *>(part_grad_gamma.data_ptr()),
static_cast<U *>(part_grad_beta.data_ptr()), part_size, n1, n2,
grad_gamma, grad_beta);
}
// compute grad_input
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32, 4, 1);
int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout, dout_resid, static_cast<T *>(input.data_ptr()), n1, n2, mean,
invvar, U(epsilon), gamma, grad_input);
}
} // namespace
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "softmax.cuh"
namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const uint8_t *pad_mask,
float dropout_prob) {
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *input_ptr = static_cast<void *>(input.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
attn_batches * q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
if (is_training) {
// use at:: function so that C++ version generates the same random mask as
// python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {dropout_results, dropout_mask, softmax_results};
}
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
const uint8_t *padding_mask, float dropout_prob) {
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(output_grads.data_ptr()),
static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
} else {
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,
false>(
static_cast<half *>(output_grads.data_ptr()),
static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
static_cast<uint8_t const *>(padding_mask), 1.0 / (1.0 - dropout_prob),
k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream);
}
// backward pass is completely in-place
return output_grads;
}
} // namespace mask_softmax_dropout
} // namespace fused_softmax
} // namespace multihead_attn
#include <vector>
#include <cuda_fp16.h>
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const half *pad_mask, float dropout_prob);
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob);
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only BYTE is supported");
}
return fwd_cuda(is_training, heads, input,
use_mask ? static_cast<const half *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
dropout_prob);
}
} // namespace additive_mask_softmax_dropout
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const uint8_t *pad_mask,
float dropout_prob);
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
const uint8_t *padding_mask, float dropout_prob);
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(is_training, heads, input,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
torch::Tensor const &padding_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
use_mask
? static_cast<const uint8_t *>(padding_mask.data_ptr())
: nullptr,
dropout_prob);
}
} // end namespace mask_softmax_dropout
} // end namespace fused_softmax
namespace encdec {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob);
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
input_weights_q, input_weights_kv, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_q_results, input_lin_kv_results,
inputs_q, inputs_kv, input_weights_q, input_weights_kv,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemmex
} // end namespace encdec
namespace encdec_norm_add {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const &dropout_add_mask, float dropout_prob);
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q,
input_weights_kv, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const &dropout_add_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor");
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_q_results, input_lin_kv_results,
lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q,
inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights,
input_weights_q, input_weights_kv, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
}
} // end namespace rocblas_gemmex
} // end namespace encdec_norm_add
namespace self {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob);
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask, is_training, heads, inputs, input_weights, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
dropout_prob);
}
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, inputs, input_weights,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemmex
} // end namespace self
namespace self_bias {
namespace rocblas_gemmex {
std::vector<torch::Tensor>
fwd_cuda(bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
// torch::Tensor const& input_biases,
// torch::Tensor const& output_biases,
torch::Tensor const &dropout_mask, float dropout_prob);
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
output_weights, input_biases, output_biases,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, inputs, input_weights,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemmex
} // namespace self_bias
namespace self_bias_additive_mask {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
torch::Tensor const &input_biases,
torch::Tensor const &output_biases,
const half *pad_mask, float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
// torch::Tensor const& softmax_results,
torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
// torch::Tensor const& input_biases,
// torch::Tensor const& output_biases,
torch::Tensor const &dropout_mask, float dropout_prob);
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(use_mask, "no mask is not supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only Half is supported");
}
return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights,
output_weights, input_biases, output_biases,
use_mask ? static_cast<const half *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
bmm1_results, pad_mask, input_lin_results, inputs,
input_weights, output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemmex
} // namespace self_bias_additive_mask
namespace self_norm_add {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
float dropout_prob);
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &pad_mask, float dropout_prob) {
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights,
lyr_nrm_beta_weights, input_weights, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr()) : nullptr,
dropout_prob);
}
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_results, lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights,
lyr_nrm_beta_weights, input_weights, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
}
} // end namespace rocblas_gemmex
} // end namespace self_norm_add
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("additive_mask_softmax_dropout_forward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("additive_mask_softmax_dropout_backward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
m.def("mask_softmax_dropout_forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("mask_softmax_dropout_backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
m.def("encdec_multihead_attn_forward", &multihead_attn::encdec::rocblas_gemmex::fwd,
"Encdec Multihead Attention Forward.");
m.def("encdec_multihead_attn_backward", &multihead_attn::encdec::rocblas_gemmex::bwd,
"Encdec Multihead Attention Backward.");
m.def("encdec_multihead_attn_norm_add_forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def(
"encdec_multihead_attn_norm_add_backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.");
m.def("self_attn_forward", &multihead_attn::self::rocblas_gemmex::fwd,
"Self Multihead Attention Forward.");
m.def("self_attn_backward", &multihead_attn::self::rocblas_gemmex::bwd,
"Self Multihead Attention Backward.");
m.def("self_attn_bias_forward", &multihead_attn::self_bias::rocblas_gemmex::fwd,
"Self Multihead Attention with Bias -- Forward.");
m.def("self_attn_bias_backward", &multihead_attn::self_bias::rocblas_gemmex::bwd,
"Self Multihead Attention with Bias -- Backward.");
m.def("self_attn_bias_additive_mask_forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd,
"Self Multihead Attention with Bias -- Forward.");
m.def("self_attn_bias_additive_mask_backward",
&multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd,
"Self Multihead Attention with Bias -- Backward.");
m.def("self_attn_norm_add_forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd,
"Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("self_attn_norm_add_backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd,
"Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
}
#undef CHECK_CUDA
#undef CHECK_CONTIGUOUS
#undef CHECK_INPUT
#pragma once
// Philox CUDA.
namespace {
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
key.x = (unsigned int)seed;
key.y = (unsigned int)(seed >> 32);
counter = make_uint4(0, 0, 0, 0);
counter.z = (unsigned int)(subsequence);
counter.w = (unsigned int)(subsequence >> 32);
STATE = 0;
incr_n(offset / 4);
}
__device__ inline uint4 operator()() {
if (STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
// 7-round philox
for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
}
output = single_round(counter_, key_);
incr();
}
// return a float4 directly
// unsigned long ret;
// switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
// STATE = (STATE + 1) % 4;
return output;
}
private:
uint4 counter;
uint4 output;
uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ inline void incr() {
if (++counter.x)
return;
if (++counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a * b;
}
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
unsigned int hi1;
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
// Inverse of 2^32.
constexpr float M_RAN_INVM32 = 2.3283064e-10f;
__device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32);
}
} // namespace
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace multihead_attn {
namespace self_bias_additive_mask {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask, float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
const int k_seq_len = q_seq_len;
const int batches = sequences * q_seq_len;
const int head_dim = embed_dim / heads;
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor bmm1_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *bmm1_results_ptr = static_cast<void *>(bmm1_results.data_ptr());
void *dropout_results_ptr = static_cast<void *>(dropout_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
char b_layout_n{'n'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results.copy_(input_biases);
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta_zero,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Padded Softmax
bool softmax_success = false;
if (is_training) {
softmax_success =
dispatch_additive_masked_softmax_dropout<half, half, float>(
reinterpret_cast<half *>(dropout_results_ptr),
(is_training)
? reinterpret_cast<uint8_t *>(dropout_mask.data_ptr<uint8_t>())
: nullptr,
reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask,
attn_batches * q_seq_len * q_seq_len, k_seq_len, k_seq_len,
attn_batches * q_seq_len, attn_batches * q_seq_len / sequences,
1.0f - dropout_prob, stream);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half *>(
dropout_results_ptr), // this is actually softmax results, but
// making it consistent for the next function
reinterpret_cast<const half *>(bmm1_results_ptr), pad_mask, k_seq_len,
k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
// Matmul2
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, bmm1_results, dropout_results,
dropout_mask, matmul2_results, outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
const int k_seq_len = q_seq_len;
const int batches = sequences * q_seq_len;
const int head_dim = embed_dim / heads;
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef BACKWARD_PASS_GUARD
flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half* const>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
} // end namespace rocblas_gemmex
} // end namespace self_bias_additive_mask
} // end namespace multihead_attn
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace multihead_attn {
namespace self_bias {
namespace rocblas_gemmex {
std::vector<torch::Tensor>
fwd_cuda(bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs, torch::Tensor const &input_weights,
torch::Tensor const &output_weights, torch::Tensor const &input_biases,
torch::Tensor const &output_biases, const uint8_t *pad_mask,
float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
const int k_seq_len = q_seq_len;
const int batches = sequences * q_seq_len;
const int head_dim = embed_dim / heads;
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
char b_layout_n{'n'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results.copy_(input_biases);
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta_zero,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
if (is_training) {
// use at:: function so that C++ version generates the same random mask as
// python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results,
dropout_mask, matmul2_results, outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
const int k_seq_len = q_seq_len;
const int batches = sequences * q_seq_len;
const int head_dim = embed_dim / heads;
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef BACKWARD_PASS_GUARD
flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
\ No newline at end of file
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace multihead_attn {
namespace self {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
const int k_seq_len = q_seq_len;
const int batches = sequences * q_seq_len;
const int head_dim = embed_dim / heads;
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
char b_layout_n{'n'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
assert(softmax_success);
if (is_training) {
apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
(1.0f - dropout_prob));
}
// Matmul2
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results,
dropout_mask, matmul2_results, outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results, torch::Tensor const &inputs,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
const int k_seq_len = q_seq_len;
const int batches = sequences * q_seq_len;
const int head_dim = embed_dim / heads;
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef BACKWARD_PASS_GUARD
flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(matmul2_grads.data_ptr()),
static_cast<at::Half*>(matmul2_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
dropout_elems,
(1.0 / (1.0 - dropout_prob)));
// Softmax Grad
bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
input_weight_grads,
output_weight_grads
};
}
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "layer_norm.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace multihead_attn {
namespace self_norm_add {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
const int k_seq_len = q_seq_len;
const int batches = sequences * q_seq_len;
const int total_tokens = batches * embed_dim;
const int head_dim = embed_dim / heads;
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto lyr_nrm_options = act_options.dtype(torch::kFloat32);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor lyr_nrm_mean = torch::empty({batches}, lyr_nrm_options);
torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results = torch::empty_like(inputs, act_options);
torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
char b_layout_n{'n'};
rocblas_int flags = 0;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
HostApplyLayerNorm<at::Half, float>(
static_cast<at::Half *>(lyr_nrm_results.data_ptr()),
static_cast<float *>(lyr_nrm_mean.data_ptr()),
static_cast<float *>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half *>(inputs.data_ptr()),
static_cast<int>(batches), // n1
static_cast<int>(embed_dim), // n2
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
assert(softmax_success);
if (is_training) {
apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
(1.0f - dropout_prob));
}
// Matmul2
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// End-of-block Dropout-Add
if (is_training) {
apex_dropout_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const *>(inputs.data_ptr()),
static_cast<at::Half *>(outputs.data_ptr()),
static_cast<uint8_t *>(dropout_add_mask.data_ptr()), total_tokens,
(1.0f - dropout_prob));
} else {
apex_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const *>(inputs.data_ptr()),
static_cast<at::Half *>(outputs.data_ptr()), total_tokens);
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_results,
softmax_results, dropout_results, dropout_mask, matmul2_results,
dropout_add_mask, outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask,
float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
const int k_seq_len = q_seq_len;
const int batches = sequences * q_seq_len;
const int total_tokens = batches * embed_dim;
const int head_dim = embed_dim / heads;
const int output_lin_dim = 3 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
torch::Tensor input_grads = torch::empty_like(inputs);
torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights);
torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights);
torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations
torch::Tensor dropout_add_grads = torch::empty_like(output_grads);
torch::Tensor output_lin_grads = torch::empty_like(matmul2_results);
torch::Tensor matmul2_grads = torch::empty_like(dropout_results);
torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
torch::Tensor input_lin_grads = torch::empty_like(inputs);
auto q_lin_results_ptr = static_cast<half *>(input_lin_results.data_ptr());
auto k_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + head_dim;
auto v_lin_results_ptr =
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim;
auto q_lin_grads_ptr = static_cast<half *>(input_lin_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + head_dim;
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_output_grads.data_ptr()) + 2 * head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
char b_layout_n{'n'};
char b_layout_t{'t'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef BACKWARD_PASS_GUARD
flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Dropout Add Backward
apex_masked_scale_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_grads.data_ptr()),
static_cast<at::Half *>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const *>(dropout_add_mask.data_ptr()), total_tokens,
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(matmul2_grads.data_ptr()),
static_cast<at::Half*>(matmul2_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
dropout_elems,
(1.0 / (1.0 - dropout_prob)));
// Softmax Grad
bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_dim,
static_cast<const void*>(&beta),
//static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half, float>(
static_cast<const half *>(input_lin_grads.data_ptr()),
static_cast<const half *>(output_grads.data_ptr()),
static_cast<const float *>(lyr_nrm_mean.data_ptr()),
static_cast<const float *>(lyr_nrm_invvar.data_ptr()), inputs,
static_cast<int>(batches), // n1
static_cast<int>(embed_dim), // n2
static_cast<const half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const half *>(lyr_nrm_beta_weights.data_ptr()), 1.0e-5,
static_cast<half *>(input_grads.data_ptr()),
static_cast<half *>(lyr_nrm_gamma_grads.data_ptr()),
static_cast<half *>(lyr_nrm_beta_grads.data_ptr()));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads,
input_weight_grads, output_weight_grads};
}
} // end namespace rocblas_gemmex
} // end namespace self_norm_add
} // end namespace multihead_attn
\ No newline at end of file
#pragma once
#include "philox.cuh"
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <assert.h>
#include <cfloat>
#include <cmath>
#include <cuda_fp16.h>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <cmath>
#ifdef __HIP_PLATFORM_HCC__
#define APEX_WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width)
#else
#define APEX_WARP_SHFL_XOR __shfl_xor_sync
#endif
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_mask(Datatype *dst, Datatype value,
const uint8_t *src);
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_additive_mask(Datatype *dst,
const Datatype *additive_mask);
template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst,
const __half *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst,
const __half *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
const uint8_t *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
const uint8_t *src) {
*((half2 *)dst) = *((half2 *)src);
}
template <>
__device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value,
const uint8_t *src) {
if (*src == 1) {
*dst = value;
}
}
template <>
__device__ __inline__ void
apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) {
*dst += *additive_mask;
}
template <>
__device__ __inline__ void
apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) {
*dst += *additive_mask;
*(dst + 1) += *(additive_mask + 1);
*(dst + 2) += *(additive_mask + 2);
*(dst + 3) += *(additive_mask + 3);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp Softmax forward
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG = 1>
__global__ void softmax_warp_forward(input_t *dst, const output_t *src,
int batch_size, int stride,
int element_count) {
assert(ELEMENTS_PER_LDG_STG == 1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements_input[i][it + element] =
-std::numeric_limits<float>::infinity();
}
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
&elements_input[i][it], src + i * element_count + it * WARP_SIZE);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using softmax_forward_func = void (*)(input_t *dst, const output_t *src,
int batch_size, int stride,
int element_count);
template <typename input_t, typename output_t, typename acc_t>
bool warp_softmax_kernel(int log2_elements, int &warp_size,
int &batches_per_warp,
softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
switch (log2_elements) {
case 0: // 1
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 1, 1>;
break;
case 1: // 2
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 2, 1>;
break;
case 2: // 4
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 4, 1>;
break;
case 3: // 8
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 8, 1>;
break;
case 4: // 16
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 16, 1>;
break;
case 5: // 32
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 32, 1>;
break;
case 6: // 64
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 2, 32, 1>;
break;
case 7: // 128
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 4, 32, 1>;
break;
case 8: // 256
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1, 8, 32, 1>;
break;
case 9: // 512
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1, 16, 32, 1>;
break;
case 10: // 1024
kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1, 32, 32, 1>;
break;
default:
return false;
}
return true;
}
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements,
int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
softmax_forward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_softmax_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
// compute launch size
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, batch_count, softmax_elements_stride, softmax_elements);
return true;
}
return false;
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>
__global__ void additive_masked_softmax_dropout_warp_forward_vec4(
output_t *dst, uint8_t *dropout_mask, const input_t *src,
const input_t *pad_mask, int batch_size, int stride, int element_count,
int pad_batch_stride, at::PhiloxCudaState philox_args, float p) {
assert(ELEMENTS_PER_LDG_STG == 4);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
threadIdx.x;
acc_t pinv = acc_t(1) / p;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
// vectorize if element_count is multiple of 4, else don't vectorize
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
src += thread_offset;
dst += thread_offset;
dropout_mask += thread_offset;
// load data from global memory
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride +
ELEMENTS_PER_LDG_STG * local_idx;
const half *curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
// masking_value is a large negative value
elements_input[i][it + element] = -10000;
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
src + itr_idx);
apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(
&elements_input[i][it],
curr_mask +
itr_jmp); //(__half)-std::numeric_limits<float>::infinity()
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
auto seeds = at::cuda::philox::unpack(philox_args);
Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds));
uint8_t rands[WARP_BATCH][WARP_ITERATIONS];
float4 rand_num;
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
rand_num = uniform4(ph());
rands[i][it] = (rand_num.x <= p) > 0.5;
rands[i][it + 1] = (rand_num.y <= p) > 0.5;
rands[i][it + 2] = (rand_num.z <= p) > 0.5;
rands[i][it + 3] = (rand_num.w <= p) > 0.5;
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(
dropout_mask + i * element_count + it * WARP_SIZE, &rands[i][it]);
}
}
}
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = rands[i][it + element] *
(pinv * (elements[i][it + element] / sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>
__global__ void additive_masked_softmax_dropout_warp_forward(
output_t *dst, uint8_t *dropout_mask, const input_t *src,
const input_t *pad_mask, int batch_size, int stride, int element_count,
int pad_batch_stride, at::PhiloxCudaState philox_args, float p) {
assert(ELEMENTS_PER_LDG_STG == 1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
threadIdx.x;
acc_t pinv = acc_t(1) / p;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
// vectorize if element_count is multiple of 4, else don't vectorize
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
int thread_offset = first_batch * stride + local_idx;
src += thread_offset;
dst += thread_offset;
dropout_mask += thread_offset;
// load data from global memory
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset =
((first_batch + i) / pad_batch_stride) * stride + local_idx;
const half *curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += 1) {
int element_index = local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0; element < 1; ++element) {
// masking_value is a large negative value
elements_input[i][it + element] = -10000;
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, 1>(&elements_input[i][it], src + itr_idx);
apply_additive_mask<input_t, 1>(&elements_input[i][it],
curr_mask + itr_jmp);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
curandStatePhilox4_32_10_t state;
auto seeds = at::cuda::philox::unpack(philox_args);
curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += 1) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
output_t out[1];
acc_t softmax_out[1];
uint8_t dropout_mask_temp[1];
// generate a vector of random numbers here
float rand = curand_uniform(&state);
float *rand_ptr = (float *)(&rand);
#pragma unroll
for (int element = 0; element < 1; ++element) {
softmax_out[element] = (elements[i][it + element] / sum[i]);
rand_ptr[element] = rand_ptr[element] <= p;
out[element] = rand_ptr[element] * pinv * softmax_out[element];
dropout_mask_temp[element] =
rand_ptr[element] > 0.5; // just to distinguish 0.0f and 1.0f
}
copy_vector<output_t, 1>(dst + i * element_count + it * WARP_SIZE, out);
copy_vector<uint8_t, 1>(dropout_mask + i * element_count +
it * WARP_SIZE,
dropout_mask_temp);
} else {
break;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t>
using additive_masked_softmax_dropout_forward_func = void (*)(
output_t *dst, uint8_t *dropout_mask, const input_t *src,
const input_t *pad_mask, int batch_size, int stride, int element_count,
int pad_batch_stride, at::PhiloxCudaState philox_args, float p);
template <typename input_t, typename output_t, typename acc_t>
bool warp_additive_masked_softmax_dropout_kernel(
int element_count, int log2_elements, int &warp_size, int &batches_per_warp,
additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t>
&kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
bool flag_vec4 = (element_count % 4 == 0);
switch (log2_elements) {
case 0: // 1
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 1, 1, 1>;
break;
case 1: // 2
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 1, 2, 1>;
break;
case 2: // 4
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 1, 4, 1>;
break;
case 3: // 8
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 1, 8, 1>;
break;
case 4: // 16
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 1, 16, 1>;
break;
case 5: // 32
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 1, 32, 1>;
break;
case 6: // 64
kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 2, 32, 1>;
break;
case 7: // 128
if (flag_vec4)
kernel = &additive_masked_softmax_dropout_warp_forward_vec4<
input_t, output_t, acc_t, 2, 4, 32, 4>;
else
kernel =
&additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 2, 4, 32, 1>;
break;
case 8: // 256
if (flag_vec4)
kernel = &additive_masked_softmax_dropout_warp_forward_vec4<
input_t, output_t, acc_t, 1, 8, 32, 4>;
else
kernel =
&additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 1, 8, 32, 1>;
break;
case 9: // 512
if (flag_vec4)
kernel = &additive_masked_softmax_dropout_warp_forward_vec4<
input_t, output_t, acc_t, 1, 16, 32, 4>;
else
kernel =
&additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 1, 16, 32, 1>;
break;
case 10: // 1024
if (flag_vec4)
kernel = &additive_masked_softmax_dropout_warp_forward_vec4<
input_t, output_t, acc_t, 1, 32, 32, 4>;
else
kernel =
&additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 1, 32, 32, 1>;
break;
case 11: // 2048
if (flag_vec4)
kernel = &additive_masked_softmax_dropout_warp_forward_vec4<
input_t, output_t, acc_t, 1, 64, 32, 4>;
else
kernel =
&additive_masked_softmax_dropout_warp_forward<input_t, output_t,
acc_t, 1, 64, 32, 1>;
break;
default:
return false;
}
return true;
}
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax_dropout(
output_t *dst, uint8_t *dropout_mask, const input_t *src,
const input_t *pad_mask, int totalElements, int softmax_elements,
int softmax_elements_stride, int batch_count, int pad_batch_stride, float p,
cudaStream_t streamid) // p is the probability to keep, not drop
{
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 2048) {
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t>
kernel;
int warp_size, batches_per_warp;
if (!warp_additive_masked_softmax_dropout_kernel<input_t, output_t, acc_t>(
softmax_elements, log2_elements, warp_size, batches_per_warp,
kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
c10::optional<at::Generator> gen_;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
int64_t counter_offset = (totalElements / (blocks * threads_per_block) + 1);
at::PhiloxCudaState rng_engine_inputs;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(counter_offset);
}
// compute launch size
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, streamid>>>(
dst, dropout_mask, src, pad_mask, batch_count, softmax_elements_stride,
softmax_elements, pad_batch_stride, rng_engine_inputs, p);
return true;
}
return false;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG = 1>
__global__ void additive_masked_softmax_warp_forward(
input_t *dst, const output_t *src, const input_t *pad_mask, int batch_size,
int stride, int element_count, int pad_batch_stride) {
assert(ELEMENTS_PER_LDG_STG == 1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
src += thread_offset;
dst += thread_offset;
// load data from global memory
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride +
ELEMENTS_PER_LDG_STG * local_idx;
const half *curr_mask = pad_mask + pad_thread_offset;
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
// masking_value is a large negative value
elements_input[i][it + element] = -10000;
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
src + itr_idx);
// apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
// (__half)-std::numeric_limits<float>::infinity(),
// curr_mask + itr_jmp);
elements_input[i][it] += *(curr_mask + itr_jmp);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using additive_masked_softmax_forward_func = void (*)(
input_t *dst, const output_t *src, const half *pad_mask, int batch_size,
int stride, int element_count, int pad_batch_stride);
template <typename input_t, typename output_t, typename acc_t>
bool warp_additive_masked_softmax_kernel(
int log2_elements, int &warp_size, int &batches_per_warp,
additive_masked_softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
switch (log2_elements) {
case 0: // 1
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
1, 1, 1>;
break;
case 1: // 2
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
1, 2, 1>;
break;
case 2: // 4
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
1, 4, 1>;
break;
case 3: // 8
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
1, 8, 1>;
break;
case 4: // 16
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
1, 16, 1>;
break;
case 5: // 32
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
1, 32, 1>;
break;
case 6: // 64
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
2, 32, 1>;
break;
case 7: // 128
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2,
4, 32, 1>;
break;
case 8: // 256
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,
8, 32, 1>;
break;
case 9: // 512
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,
16, 32, 1>;
break;
case 10: // 1024
kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1,
32, 32, 1>;
break;
default:
return false;
}
return true;
}
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax(output_t *dst, const input_t *src,
const input_t *pad_mask,
int softmax_elements,
int softmax_elements_stride,
int batch_count, int pad_batch_stride) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
additive_masked_softmax_forward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
// compute launch size
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, pad_mask, batch_count, softmax_elements_stride,
softmax_elements, pad_batch_stride);
return true;
}
return false;
}
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_additive_masked_softmax_stream(
output_t *dst, const input_t *src, const input_t *pad_mask,
int softmax_elements, int softmax_elements_stride, int batch_count,
int pad_batch_stride, cudaStream_t streamid) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
additive_masked_softmax_forward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
// compute launch size
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, streamid>>>(
dst, src, pad_mask, batch_count, softmax_elements_stride,
softmax_elements, pad_batch_stride);
return true;
}
return false;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG = 1>
__global__ void
masked_softmax_warp_forward(input_t *dst, const output_t *src,
const uint8_t *pad_mask, int batch_size, int stride,
int element_count, int pad_batch_stride) {
assert(ELEMENTS_PER_LDG_STG == 1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
src += thread_offset;
dst += thread_offset;
// load data from global memory
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride +
ELEMENTS_PER_LDG_STG * local_idx;
const uint8_t *curr_mask = pad_mask + pad_thread_offset;
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements_input[i][it + element] =
-std::numeric_limits<float>::infinity();
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
src + itr_idx);
apply_mask<input_t, ELEMENTS_PER_LDG_STG>(
&elements_input[i][it],
(__half)-std::numeric_limits<float>::infinity(),
curr_mask + itr_jmp);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using masked_softmax_forward_func = void (*)(input_t *dst, const output_t *src,
const uint8_t *pad_mask,
int batch_size, int stride,
int element_count,
int pad_batch_stride);
template <typename input_t, typename output_t, typename acc_t>
bool warp_masked_softmax_kernel(
int log2_elements, int &warp_size, int &batches_per_warp,
masked_softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
switch (log2_elements) {
case 0: // 1
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 1, 1>;
break;
case 1: // 2
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 2, 1>;
break;
case 2: // 4
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 4, 1>;
break;
case 3: // 8
kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 8, 1>;
break;
case 4: // 16
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 16, 1>;
break;
case 5: // 32
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 32, 1>;
break;
case 6: // 64
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 2, 32, 1>;
break;
case 7: // 128
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 4, 32, 1>;
break;
case 8: // 256
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 8, 32, 1>;
break;
case 9: // 512
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 16, 32, 1>;
break;
case 10: // 1024
kernel =
&masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 32, 32, 1>;
break;
default:
return false;
}
return true;
}
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_masked_softmax(output_t *dst, const input_t *src,
const uint8_t *pad_mask, int softmax_elements,
int softmax_elements_stride, int batch_count,
int pad_batch_stride) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
masked_softmax_forward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_masked_softmax_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
// compute launch size
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, pad_mask, batch_count, softmax_elements_stride,
softmax_elements, pad_batch_stride);
return true;
}
return false;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG = 1>
__global__ void time_masked_softmax_warp_forward(
input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size,
int stride, int element_count, int mod_seq_len) {
assert(ELEMENTS_PER_LDG_STG == 1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
src += thread_offset;
dst += thread_offset;
// load data from global memory
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ((first_batch + i) % mod_seq_len) * stride +
ELEMENTS_PER_LDG_STG * local_idx;
const uint8_t *curr_mask = pad_mask + pad_thread_offset;
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements_input[i][it + element] =
-std::numeric_limits<float>::infinity();
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
src + itr_idx);
apply_mask<input_t, ELEMENTS_PER_LDG_STG>(
&elements_input[i][it],
(__half)-std::numeric_limits<float>::infinity(),
curr_mask + itr_jmp);
}
}
}
// convert input_t to acc_t
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using time_masked_softmax_forward_func =
void (*)(input_t *dst, const output_t *src, const uint8_t *pad_mask,
int batch_size, int stride, int element_count, int mod_seq_len);
template <typename input_t, typename output_t, typename acc_t>
bool warp_time_masked_softmax_kernel(
int log2_elements, int &warp_size, int &batches_per_warp,
time_masked_softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
switch (log2_elements) {
case 0: // 1
kernel =
&time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 1, 1>;
break;
case 1: // 2
kernel =
&time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 2, 1>;
break;
case 2: // 4
kernel =
&time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 4, 1>;
break;
case 3: // 8
kernel =
&time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 8, 1>;
break;
case 4: // 16
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1,
16, 1>;
break;
case 5: // 32
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1,
32, 1>;
break;
case 6: // 64
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 2,
32, 1>;
break;
case 7: // 128
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 4,
32, 1>;
break;
case 8: // 256
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 8,
32, 1>;
break;
case 9: // 512
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 16,
32, 1>;
break;
case 10: // 1024
kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 32,
32, 1>;
break;
default:
return false;
}
return true;
}
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_time_masked_softmax(output_t *dst, const input_t *src,
const uint8_t *pad_mask, int softmax_elements,
int softmax_elements_stride, int batch_count,
int mod_seq_len) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
time_masked_softmax_forward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_time_masked_softmax_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
// compute launch size
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, pad_mask, batch_count, softmax_elements_stride,
softmax_elements, mod_seq_len);
return true;
}
return false;
}
int log2_ceil_native(int value) {
int log2_value = 0;
while ((1 << log2_value) < value)
++log2_value;
return log2_value;
}
template <typename T>
__device__ __forceinline__ T
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE>
__device__ __forceinline__ void warp_reduce_sum(acc_t *sum) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = sum[i] + b;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp softmax backward functions as fused variants of
// at::softmax_backward_data function
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// softmax backward data function is taken from native pytorch, elementwise mul
// is fused in the epolog, as well as masking and scaling for fusing dropout
template <typename input_t, typename output_t, typename acc_t,
int log2_elements, bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward_masked_dgrad(
output_t *gradInput, const input_t *grad, const input_t *output,
const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int batch_size,
int stride, int element_count, int heads) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x % WARP_SIZE;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
mask += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified
// to one loop, but I think doing so would obfuscate the logic of the
// algorithm, thus I chose to keep the nested loops. This should have no
// impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] =
(input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] *
(acc_t)grad[i * element_count + it * WARP_SIZE] *
(acc_t)scale) *
output[i * element_count + it * WARP_SIZE];
output_reg[i][it] = output[i * element_count + it * WARP_SIZE];
} else {
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
int total_ind = thread_offset + i * element_count + it * WARP_SIZE;
int pad_mask_ind =
element_count *
(total_ind / (heads * element_count * element_count)) +
total_ind % element_count;
uint8_t pad_mask_element = 1 - pad_mask[pad_mask_ind];
if (pad_mask_element == 0)
gradInput[i * element_count + it * WARP_SIZE] = 0;
else {
if (is_log_softmax) {
gradInput[i * element_count + it * WARP_SIZE] =
(grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
} else {
gradInput[i * element_count + it * WARP_SIZE] =
(grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
}
}
}
}
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_masked_out(
output_t *grad_input, const input_t *grad, const input_t *output,
const uint8_t *mask, const uint8_t *pad_mask, acc_t scale,
int softmax_elements, int softmax_elements_stride, int batch_count,
int heads) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
0, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 1: // 2
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
1, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 2: // 4
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
2, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 3: // 8
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
3, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 4: // 16
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
4, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 5: // 32
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
5, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 6: // 64
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
6, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 7: // 128
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
7, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 8: // 256
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
8, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 9: // 512
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
9, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 10: // 1024
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
10, is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
default:
break;
}
}
}
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_masked_out_stream(
output_t *grad_input, const input_t *grad, const input_t *output,
const uint8_t *mask, const uint8_t *pad_mask, acc_t scale,
int softmax_elements, int softmax_elements_stride, int batch_count,
int heads, cudaStream_t streamid) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
0, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 1: // 2
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
1, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 2: // 4
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
2, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 3: // 8
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
3, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 4: // 16
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
4, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 5: // 32
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
5, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 6: // 64
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
6, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 7: // 128
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
7, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 8: // 256
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
8, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 9: // 512
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
9, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
case 10: // 1024
masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t,
10, is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, pad_mask, scale, batch_count,
softmax_elements_stride, softmax_elements, heads);
break;
default:
break;
}
}
}
template <typename input_t, typename output_t, typename acc_t,
int log2_elements, bool is_log_softmax>
__global__ void
masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad,
const input_t *output, const uint8_t *mask,
acc_t scale, int batch_size, int stride,
int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x % WARP_SIZE;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
mask += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified
// to one loop, but I think doing so would obfuscate the logic of the
// algorithm, thus I chose to keep the nested loops. This should have no
// impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] =
(input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] *
(acc_t)grad[i * element_count + it * WARP_SIZE] *
(acc_t)scale) *
output[i * element_count + it * WARP_SIZE];
output_reg[i][it] = output[i * element_count + it * WARP_SIZE];
} else {
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
if (is_log_softmax) {
gradInput[i * element_count + it * WARP_SIZE] =
(grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
} else {
gradInput[i * element_count + it * WARP_SIZE] =
(grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG,
bool is_log_softmax>
__global__ void masked_scale_softmax_warp_backward_recompute(
output_t *gradInput, const input_t *grad, const input_t *softmax_input,
const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size,
int stride, int pad_batch_stride, int element_count) {
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x % WARP_SIZE;
// vectorize if a row length is multiple of 4
int flag_vec4 = element_count & 3 == 0;
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
input_t elements_input[WARP_BATCH][WARP_ITERATIONS];
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
softmax_input += thread_offset;
gradInput += thread_offset;
mask += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified
// to one loop, but I think doing so would obfuscate the logic of the
// algorithm, thus I chose to keep the nested loops. This should have no
// impact on performance because the loops are unrolled anyway.
// load data from global memory
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride +
ELEMENTS_PER_LDG_STG * local_idx;
const input_t *curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
// masking_value is a large negative value
elements_input[i][it + element] = -10000;
grad_reg[i][it + element] = acc_t(0);
}
if (element_index < batch_element_count) {
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
softmax_input + itr_idx);
apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(
&elements_input[i][it],
curr_mask +
itr_jmp); //(__half)-std::numeric_limits<float>::infinity()
uint8_t mask_temp[ELEMENTS_PER_LDG_STG];
input_t grad_temp[ELEMENTS_PER_LDG_STG];
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(&mask_temp[0],
mask + itr_idx);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_temp[0],
grad + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
grad_reg[i][it + element] =
((acc_t)mask_temp[element] * (acc_t)grad_temp[element] *
(acc_t)scale);
}
}
}
}
// load data from global memory
// convert input_t to acc_t
// TODO : remove this, input is already acc_t type in register
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = elements_input[i][it];
}
}
constexpr uint32_t FULL_MASK = 0xffffffff;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
float val[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it++) {
elements[i][it] = elements[i][it] / sum[i];
grad_reg[i][it] = grad_reg[i][it] * elements[i][it];
}
}
acc_t grad_sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
grad_sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
grad_sum[i] += grad_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(grad_sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t grad_input_reg[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; element++) {
if (is_log_softmax) {
grad_input_reg[element] =
(grad_reg[i][it + element] -
std::exp(elements[i][it + element]) * grad_sum[i]);
} else {
grad_input_reg[element] = (grad_reg[i][it + element] -
elements[i][it + element] * grad_sum[i]);
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
gradInput + i * element_count + it * WARP_SIZE, grad_input_reg);
}
}
}
}
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
using masked_scale_softmax_warp_backward_recompute_func = void (*)(
output_t *gradInput, const input_t *grad, const input_t *softmax_input,
const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size,
int stride, int pad_batch_stride, int element_count);
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
bool masked_scale_softmax_warp_backward_recompute_kernel(
int element_count, int log2_elements, int &warp_size, int &batches_per_warp,
masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t,
is_log_softmax> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
bool flag_vec4 = (element_count % 4 == 0);
switch (log2_elements) {
case 0: // 1
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 1, 1, 1, is_log_softmax>;
break;
case 1: // 2
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 1, 2, 1, is_log_softmax>;
break;
case 2: // 4
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 1, 4, 1, is_log_softmax>;
break;
case 3: // 8
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 1, 8, 1, is_log_softmax>;
break;
case 4: // 16
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 1, 16, 1, is_log_softmax>;
break;
case 5: // 32
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 1, 32, 1, is_log_softmax>;
break;
case 6: // 64
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 2, 32, 1, is_log_softmax>;
break;
case 7: // 128
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 2, 4, 32, 1, is_log_softmax>;
break;
case 8: // 256
if (flag_vec4)
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 8, 32, 4, is_log_softmax>;
else
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 8, 32, 1, is_log_softmax>;
break;
case 9: // 512
if (flag_vec4)
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 16, 32, 4, is_log_softmax>;
else
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 16, 32, 1, is_log_softmax>;
break;
case 10: // 1024
if (flag_vec4)
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 32, 32, 4, is_log_softmax>;
else
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 32, 32, 1, is_log_softmax>;
break;
case 11: // 2048
if (flag_vec4)
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 64, 32, 4, is_log_softmax>;
else
kernel = &masked_scale_softmax_warp_backward_recompute<
input_t, output_t, acc_t, 1, 64, 32, 1, is_log_softmax>;
break;
default:
return false;
}
return true;
}
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
bool dispatch_masked_scale_softmax_backward_recompute(
output_t *grad_input, const input_t *grad, const input_t *softmax_input,
const input_t *pad_mask, const uint8_t *mask, acc_t scale,
int softmax_elements, int softmax_elements_stride, int pad_batch_stride,
int batch_count, cudaStream_t streamid) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 2048) {
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t,
is_log_softmax>
kernel;
int warp_size, batches_per_warp;
if (!masked_scale_softmax_warp_backward_recompute_kernel<
input_t, output_t, acc_t, is_log_softmax>(
softmax_elements, log2_elements, warp_size, batches_per_warp,
kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
// compute launch size
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, streamid>>>(
grad_input, grad, softmax_input, pad_mask, mask, scale, batch_count,
softmax_elements_stride, pad_batch_stride, softmax_elements);
return true;
}
return false;
}
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
void dispatch_masked_scale_softmax_backward_stream(
output_t *grad_input, const input_t *grad, const input_t *output,
const uint8_t *mask, acc_t scale, int softmax_elements,
int softmax_elements_stride, int batch_count, cudaStream_t streamid) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 0,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 1: // 2
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 1,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 2: // 4
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 2,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 3: // 8
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 3,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 4: // 16
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 4,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 5: // 32
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 5,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 6: // 64
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 6,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 7: // 128
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 7,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 8: // 256
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 8,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 9: // 512
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 9,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 10,
is_log_softmax>
<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, mask, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
// elementwise multiplication called in at::softmax_backward_data is fused
// inside softmax dgrad kernel as a result of fusion, intermediate
// multiplication result is stored in fp32 in registers, instead of fp16
template <typename input_t, typename output_t, typename acc_t,
int log2_elements, bool is_log_softmax>
__global__ void
softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad,
const input_t *output, int batch_size,
int stride, int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x % WARP_SIZE;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified
// to one loop, but I think doing so would obfuscate the logic of the
// algorithm, thus I chose to keep the nested loops. This should have no
// impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = grad[i * element_count + it * WARP_SIZE] *
output[i * element_count + it * WARP_SIZE];
output_reg[i][it] = output[i * element_count + it * WARP_SIZE];
} else {
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0]; //* output_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it]; // * output_reg[i][it];
}
}
warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
if (is_log_softmax) {
gradInput[i * element_count + it * WARP_SIZE] =
(grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
} else {
gradInput[i * element_count + it * WARP_SIZE] =
(grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
}
}
}
template <typename input_t, typename output_t, typename acc_t,
bool is_log_softmax>
void dispatch_softmax_backward_fused_native(
output_t *grad_input, const input_t *grad, const input_t *output,
int softmax_elements, int softmax_elements_stride, int batch_count) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil_native(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 0,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 1: // 2
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 1,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 2: // 4
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 2,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 3: // 8
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 3,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 4: // 16
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 4,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 5: // 32
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 5,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 6: // 64
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 6,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 7: // 128
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 7,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 8: // 256
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 8,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 9: // 512
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 9,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 10: // 1024
softmax_warp_backward_fused_native<input_t, output_t, acc_t, 10,
is_log_softmax>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
break;
default:
break;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp softmax backward
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG = 1>
__global__ void softmax_warp_backward(__half *gradInput, const __half *grad,
const __half *output, int batch_size,
int stride, int element_count) {
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};
input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&output_reg_input[i][it],
output + i * element_count +
it * WARP_SIZE);
}
}
}
// convert half to floating point
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
grad_reg[i][it] = grad_reg_input[i][it];
output_reg[i][it] = output_reg_input[i][it];
}
}
// compute thread local sum
acc_t sum[WARP_BATCH] = {0};
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += grad_reg[i][it] * output_reg[i][it];
}
}
// reduction sum
constexpr uint32_t FULL_MASK = 0xffffffff;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_reg[i][it + element] *
(grad_reg[i][it + element] - sum[i]));
}
// store them in global memory
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
gradInput + i * element_count + it * WARP_SIZE, out);
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using softmax_backward_func = void (*)(output_t *gradInput, const input_t *grad,
const input_t *output, int batch_size,
int stride, int element_count);
template <typename input_t, typename output_t, typename acc_t>
bool warp_softmax_backward_kernel(
int log2_elements, int &warp_size, int &batches_per_warp,
softmax_backward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
switch (log2_elements) {
case 0: // 1
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 1, 1>;
break;
case 1: // 2
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 2, 1>;
break;
case 2: // 4
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 4, 1>;
break;
case 3: // 8
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 8, 1>;
break;
case 4: // 16
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 16, 1>;
break;
case 5: // 32
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 32, 1>;
break;
case 6: // 64
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 2, 32, 1>;
break;
case 7: // 128
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 4, 32, 1>;
break;
case 8: // 256
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1, 8, 32, 1>;
break;
case 9: // 512
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1, 16, 32, 1>;
break;
case 10: // 1024
kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1, 32, 32, 1>;
break;
default:
return false;
}
return true;
}
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad,
const input_t *output, int softmax_elements,
int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
softmax_backward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
// compute launch size
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
return true;
}
return false;
}
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_softmax_backward_stream(output_t *grad_input, const input_t *grad,
const input_t *output,
int softmax_elements,
int softmax_elements_stride,
int batch_count, cudaStream_t streamid) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
softmax_backward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
// compute launch size
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, streamid>>>(
grad_input, grad, output, batch_count, softmax_elements_stride,
softmax_elements);
return true;
}
return false;
}
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH,
int WARP_ITERATIONS, int WARP_SIZE = 32, int ELEMENTS_PER_LDG_STG = 1>
__global__ void
masked_softmax_warp_backward(__half *gradInput, const __half *grad,
const __half *output, const uint8_t *pad_mask,
int batch_size, int stride, int element_count,
int pad_batch_stride) {
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};
input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&output_reg_input[i][it],
output + i * element_count +
it * WARP_SIZE);
}
}
}
// convert half to floating point
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
for (int it = 0; it < WARP_ITERATIONS; ++it) {
grad_reg[i][it] = grad_reg_input[i][it];
output_reg[i][it] = output_reg_input[i][it];
}
}
// compute thread local sum
acc_t sum[WARP_BATCH] = {0};
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += grad_reg[i][it] * output_reg[i][it];
}
}
// reduction sum
constexpr uint32_t FULL_MASK = 0xffffffff;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride +
ELEMENTS_PER_LDG_STG * local_idx;
const uint8_t *curr_mask = pad_mask + pad_thread_offset;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_reg[i][it + element] *
(grad_reg[i][it + element] - sum[i]));
}
// store them in global memory
int itr_jmp = it * WARP_SIZE;
int itr_idx = i * element_count + itr_jmp;
// It is kind of unfortunate this has to be here to zero something out
// that is close to zero in the first place
apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&out[0], 0.0,
curr_mask + itr_jmp);
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + itr_idx, out);
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t>
using masked_softmax_backward_func =
void (*)(output_t *gradInput, const input_t *grad, const input_t *output,
const uint8_t *pad_mask, int batch_size, int stride,
int element_count, int pad_batch_stride);
template <typename input_t, typename output_t, typename acc_t>
bool warp_masked_softmax_backward_kernel(
int log2_elements, int &warp_size, int &batches_per_warp,
masked_softmax_backward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
switch (log2_elements) {
case 0: // 1
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 1, 1>;
break;
case 1: // 2
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 2, 1>;
break;
case 2: // 4
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 4, 1>;
break;
case 3: // 8
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 8, 1>;
break;
case 4: // 16
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 16, 1>;
break;
case 5: // 32
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 32, 1>;
break;
case 6: // 64
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 2, 32, 1>;
break;
case 7: // 128
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 4, 32, 1>;
break;
case 8: // 256
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 1, 8, 32, 1>;
break;
case 9: // 512
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 1, 16, 32, 1>;
break;
case 10: // 1024
kernel =
&masked_softmax_warp_backward<input_t, output_t, acc_t, 1, 32, 32, 1>;
break;
default:
return false;
}
return true;
}
template <typename input_t, typename output_t, typename acc_t>
bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad,
const input_t *output,
const uint8_t *pad_mask,
int softmax_elements,
int softmax_elements_stride,
int batch_count, int pad_batch_stride) {
if (softmax_elements == 0) {
return true;
} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up
// to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements)
++log2_elements;
masked_softmax_backward_func<input_t, output_t> kernel;
int warp_size, batches_per_warp;
if (!warp_masked_softmax_backward_kernel<input_t, output_t, acc_t>(
log2_elements, warp_size, batches_per_warp, kernel)) {
return false;
}
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
// compute warps per block.
int warps_per_block = (threads_per_block / warp_size);
// compute launch size
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch
kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, pad_mask, batch_count,
softmax_elements_stride, softmax_elements, pad_batch_stride);
return true;
}
return false;
}
} // namespace
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment