Unverified Commit db92ee13 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #64 from ROCmSoftwarePlatform/IFU-master-2021-12-08

IFU-master-2021-12-08
parents d150afdc 68364b49
import logging
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch
import warnings
if torch.distributed.is_available():
from . import parallel
......@@ -22,3 +24,19 @@ from . import pyprof
#common utilties to run tests on ROCm.
from . import testing
from . import transformer
# Logging utilities mainly for apex.transformer module
class RankInfoFormatter(logging.Formatter):
def format(self, record):
from apex.transformer.parallel_state import get_rank_info
record.rank_info = get_rank_info()
return super().format(record)
_library_root_logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(RankInfoFormatter("%(asctime)s - %(name)s - %(levelname)s - %(rank_info)s - %(message)s"))
_library_root_logger.addHandler(handler)
_library_root_logger.propagate = False
from typing import Optional
import torch
def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
if not torch.is_autocast_enabled():
return torch.float or dtype
else:
return torch.get_autocast_gpu_dtype()
def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
......
#pragma once
#include <unordered_map>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Params>
struct LaunchParams{
size_t workspace_bytes;
size_t barrier_size;
cudaDeviceProp * props;
cudaStream_t stream;
Params params;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ParamsBase {
ParamsBase()
: ctas_per_col(0)
, rows(0)
, cols(0)
, x(nullptr)
, mu(nullptr)
, rs(nullptr)
, gamma(nullptr)
, workspace(nullptr)
, barrier(nullptr)
{
}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int ctas_per_col;
// Input is interpreted as matrix. We normalize across columns.
int rows;
int cols;
// Common data pointers.
void *x;
void *mu;
void *rs;
void *gamma;
// Multi-CTA workspace in gmem.
void *workspace;
// Multi-CTA sync barriers in gmem.
int *barrier;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FwdParams : public ParamsBase {
FwdParams()
: ParamsBase()
, z(nullptr)
, beta(nullptr)
, epsilon(0.f)
{
}
// Output of LN FWD.
void *z;
void *beta;
float epsilon;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct BwdParams : public ParamsBase {
BwdParams()
: ParamsBase()
, dz(nullptr)
, dbeta_part(nullptr)
, dgamma_part(nullptr)
, dx(nullptr)
, dbeta(nullptr)
, dgamma(nullptr)
{
}
// Input: gradient wrt. LN FWD output.
void *dz;
// Workspace for Wgrad pre-reduction.
void *dbeta_part;
void *dgamma_part;
// Output: Dgrad.
void *dx;
// Output: Wgrad.
void *dbeta;
void *dgamma;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
using FunctionKey = uint64_t;
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
extern FwdRegistry FWD_FUNCS;
extern BwdRegistry BWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeId{};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template<typename T>
struct OutputType2Key : public Type2Key<T, 4>{};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 6>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdRegistrar{
FwdRegistrar(FwdFunction f){
uint64_t key = Types2Key<W,I,O,C>::get(HIDDEN_SIZE);
FWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdRegistrar{
BwdRegistrar(BwdFunction f){
uint64_t key = Types2Key<W,I,O,C>::get(HIDDEN_SIZE);
BWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
void ln_fwd_cuda(at::Tensor &y, at::Tensor &mu, at::Tensor &rsigma,
const at::Tensor &x, const at::Tensor &gamma,
const at::Tensor &beta, const float epsilon, const int rows, const int cols,
cudaStream_t stream);
#include "ln.h"
void ln_bwd_cuda(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,
const at::Tensor &dw, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int rows, const int cols, cudaStream_t stream);
/*
Supported Type combinations:
input compute weights output
=======================================
fp32 fp32 fp32 fp32
fp16 fp32 fp16 fp16
bf16 fp32 bf16 bf16
fp32 fp32 fp16 fp16
fp32 fp32 bf16 bf16
Remarks:
Output type = Weight type
Compute always in FP32
*/
namespace layer_norm {
// Create registries and provide runtime versions of config hash functions.
FwdRegistry FWD_FUNCS;
BwdRegistry BWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
uint32_t get_type_id(torch::Dtype dtype){
if( dtype == torch::kFloat16 ) {
return TypeId<fp16>::Value;
} else if( dtype == torch::kBFloat16 ) {
return TypeId<bf16>::Value;
} else if( dtype == torch::kFloat32 ) {
return TypeId<fp32>::Value;
} else {
TORCH_CHECK(false, "Type not supported: ", dtype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
using namespace layer_norm;
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | (get_type_id(ctype) << 6);
uint64_t launcher_key = (type_key << 32) | hidden_size;
return launcher_key;
}
} // namespace layer_norm
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size));
if( iter != layer_norm::FWD_FUNCS.end() ) {
return iter->second;
} else {
TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size));
if( iter != layer_norm::BWD_FUNCS.end() ) {
return iter->second;
} else {
TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size
const float epsilon
) {
auto itype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = wtype;
auto ctype = torch::kFloat32;
TORCH_CHECK(beta.scalar_type() == wtype);
TORCH_CHECK(x.is_cuda())
TORCH_CHECK(gamma.is_cuda())
......@@ -28,79 +99,148 @@ std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
const int rows = sizes[0];
const int cols = sizes[1];
auto dtype = x.scalar_type();
TORCH_CHECK(gamma.dtype() == dtype);
TORCH_CHECK(beta.dtype() == dtype);
auto hidden_size = gamma.numel();
TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(gamma.numel() == cols);
TORCH_CHECK(hidden_size == cols);
TORCH_CHECK(epsilon >= 0.f);
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto opts = x.options();
auto y = torch::empty_like(x);
auto z = torch::empty(sizes, opts.dtype(otype));
auto opts = x.options();
auto mu = torch::empty({ rows }, opts.dtype(ctype));
auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
auto mu = torch::empty({rows}, opts.dtype(torch::kFloat32));
auto rsigma = torch::empty({rows}, opts.dtype(torch::kFloat32));
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
ln_fwd_cuda(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, stream);
launch_params.props = at::cuda::getCurrentDeviceProperties();
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
return {y, mu, rsigma};
}
// Request the kernel launcher.
auto launcher = get_fwd_launcher(wtype, itype, otype, ctype, hidden_size);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
at::Tensor workspace, barrier;
// Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.beta = beta.data_ptr();
params.z = z.data_ptr();
params.epsilon = epsilon;
if( launch_params.barrier_size > 0 ) {
auto options = x.options();
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
// Launch the kernel.
launcher(launch_params, false);
return { z, mu, rsigma };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> ln_bwd(const at::Tensor &dw, // BxSxhidden_size
std::vector<at::Tensor> ln_bwd(const at::Tensor &dz, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size
const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma // hidden_size
) {
auto itype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = wtype;
auto ctype = torch::kFloat32;
TORCH_CHECK(dz.dtype() == otype);
TORCH_CHECK(mu.dtype() == ctype);
TORCH_CHECK(rsigma.dtype() == ctype);
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(dw.is_cuda());
TORCH_CHECK(dz.is_cuda());
TORCH_CHECK(mu.is_cuda());
TORCH_CHECK(rsigma.is_cuda());
TORCH_CHECK(gamma.is_cuda());
TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(dw.is_contiguous());
TORCH_CHECK(dz.is_contiguous());
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dw.sizes() == sizes);
TORCH_CHECK(dz.sizes() == sizes);
auto rows = sizes[0];
auto cols = sizes[1];
auto dtype = x.scalar_type();
TORCH_CHECK(dw.dtype() == dtype);
TORCH_CHECK(gamma.dtype() == dtype);
TORCH_CHECK(mu.dtype() == torch::kFloat32);
TORCH_CHECK(rsigma.dtype() == torch::kFloat32);
TORCH_CHECK(mu.sizes() == rsigma.sizes());
auto hidden_size = gamma.numel();
TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(mu.sizes() == rsigma.sizes());
TORCH_CHECK(gamma.numel() == cols);
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto options = x.options();
auto dx = torch::empty_like(x);
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
ln_bwd_cuda(dx, dgamma, dbeta, dw, x, mu, rsigma, gamma, rows, cols, stream);
return {dx, dgamma, dbeta};
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
launch_params.props = at::cuda::getCurrentDeviceProperties();
auto launcher = get_bwd_launcher(wtype, itype, otype, ctype, hidden_size);
launcher(launch_params, true);
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype));
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype));
at::Tensor workspace, barrier;
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.dz = dz.data_ptr();
params.dx = dx.data_ptr();
params.dbeta = dbeta.data_ptr();
params.dgamma = dgamma.data_ptr();
params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr();
if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this?
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
launcher(launch_params, false);
return { dx, dgamma, dbeta, dgamma_part, dbeta_part };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA LayerNorm"; // optional module docstring
m.doc() = "CUDA LayerNorm";
m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel");
m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel");
}
#pragma once
namespace layer_norm {
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_kernel(layer_norm::BwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum;
constexpr float rn = 1.f / float(COLS);
Wvec gamma[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
dz_sum[it].data.elt[jt] += dz_tmp;
}
}
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
dx[it].data.elt[jt] = dx_tmp;
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
if( WARPS_M == 1 ) {
idx = r * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dz_sum[NUM_RES];
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
for( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
__syncthreads();
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dzy_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
for( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for( int jt = 0; jt < NUM_RES; jt++ ) {
*dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
for( int jt = 0; jt < NUM_RES; jt++ ) {
*dbeta_part = cta_dz_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA;
}
}
}
template<typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
void ln_bwd_finalize_kernel(BwdParams params)
{
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type;
Sum<reduce_t> sum;
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local));
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
dbeta_part.load_from(params.dbeta_part, idx);
dgamma_part.load_from(params.dgamma_part, idx);
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
}
}
void * smem_gamma = smem_;
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
const int write_row = warp;
const int write_col = lane ^ write_row;
const int write_idx = write_row * THREADS_PER_WARP + write_col;
dgamma_local.store_to(smem_gamma, write_idx);
dbeta_local.store_to(smem_beta, write_idx);
__syncthreads();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
// More than one iter iff ROWS_PER_CTA < 32.
for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
const int read_row = lane;
const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col;
memset(&dbeta_local, 0, sizeof(dbeta_local));
memset(&dgamma_local, 0, sizeof(dgamma_local));
// Load beta and gamma transposed
if(read_row < Kernel_traits::ROWS_PER_CTA){
dbeta_local.load_from(smem_beta, read_idx);
dgamma_local.load_from(smem_gamma, read_idx);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
compute_t b_i = dbeta_local.data.elt[it];
compute_t g_i = dgamma_local.data.elt[it];
b_i = reducer.allreduce(b_i, sum);
g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i;
dbeta_local.data.elt[it] = b_i;
}
// Leader stores the result at the current column.
if(lane == 0){
dgamma_local.store_to(smem_gamma_out, w);
dbeta_local.store_to(smem_beta_out, w);
}
}
// All writes done.
__syncthreads();
// Pack and store: 2-wide stores with half the threads.
if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
dgamma_vec2.load_from(smem_gamma_out, lane);
dbeta_vec2.load_from(smem_beta_out, lane);
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
}
dgamma_out2.store_to(params.dgamma, col_out);
dbeta_out2.store_to(params.dbeta, col_out);
}
}
}
} // namespace layer_norm
#include "utils.cuh"
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ATen/cuda/CUDAContext.h"
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(void * __restrict__ dx_,
void * __restrict__ dg_,
void * __restrict__ db_,
const void * __restrict__ dw_,
const void * __restrict__ x_,
const void * __restrict__ mu_,
const void * __restrict__ rs_,
const void * __restrict__ g_,
const int rows
){
using Vec = typename Ktraits::Vec;
enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = BYTES_PER_ROW / Ktraits::BYTES_PER_ROW_PER_CTA };
static_assert(LDGS * Ktraits::BYTES_PER_ROW_PER_CTA == BYTES_PER_ROW, "");
enum { NUM_ELTS = Vec::NUM_ELTS };
using vec_t = typename Ktraits::vec_t;
using base_t = typename Ktraits::base_t;
using compute_t = typename Ktraits::compute_t;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP;
const int warp = tidx / THREADS_PER_WARP;
const int warp_m = warp / Ktraits::WARPS_N;
const int warp_n = warp % Ktraits::WARPS_N;
const int tid_r = warp_n * THREADS_PER_WARP + lane;
const int r = bidx * Ktraits::ROWS_PER_CTA + warp_m;
const int c = warp_n * THREADS_PER_WARP + lane;
const char *dw_ptr = static_cast<const char *>(dw_);
const char *x_ptr = static_cast<const char *>(x_);
const char *g_ptr = static_cast<const char *>(g_);
char *dx_ptr = static_cast<char *>(dx_);
const compute_t *mu_ptr = static_cast<const compute_t *>(mu_);
const compute_t *rs_ptr = static_cast<const compute_t *>(rs_);
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS, "");
// smem for final reduction
//__shared__ compute_t smem_[ROWS_PER_CTA * COLS];
extern __shared__ compute_t smem_[];
// static_assert(sizeof(smem_dw_sum) == 32*1024,"");
// Using the grid stride loop we can assign multiple rows to each thread
// by using a number of CTAs smaller than rows / ROWS_PER_CTA
// We accumulate them here, one in smem, one in registers, because the smem
// capacity is limited compute_t * dw_sum = &smem_dw_sum[warp_m * COLS + tid_r
// * LDGS * NUM_ELTS];
compute_t dwy_sum[LDGS * NUM_ELTS];
compute_t dw_sum[LDGS * NUM_ELTS];
memset(dwy_sum, 0, sizeof(compute_t) * LDGS * NUM_ELTS);
memset(dw_sum, 0, sizeof(compute_t) * LDGS * NUM_ELTS);
// Debug 8 rows, 4B, 1024 cols
__shared__ compute_t smem_mdy[ROWS_PER_CTA * WARPS_N];
__shared__ compute_t smem_mdyy[ROWS_PER_CTA * WARPS_N];
compute_t *mdy_shared = &smem_mdy[warp_m * WARPS_N];
compute_t *mdyy_shared = &smem_mdyy[warp_m * WARPS_N];
constexpr float rn = 1.f / float(COLS);
Vec gamma[LDGS];
int col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(g_ptr + col * BYTES_PER_LDG);
col += Ktraits::THREADS_PER_ROW;
}
// TODO if ROWS_PER_CTA does not divice rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
const compute_t mu_r = mu_ptr[row];
const compute_t rs_r = rs_ptr[row];
Vec dw[LDGS], x[LDGS], dx[LDGS];
int col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dw[it].load_from(dw_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
x[it].load_from(x_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
// local reductions
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < Vec::NUM_ELTS; jt++) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = gamma[it].data.elt[jt] * dw[it].data.elt[jt];
compute_t dw_tmp = dw[it].data.elt[jt];
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dwy_sum[it * NUM_ELTS + jt] += dw_tmp * y_tmp;
dw_sum[it * NUM_ELTS + jt] += dw_tmp;
}
}
// reduction across row for mdy, mdyy
if (WARPS_N == 1) { // no need to go through smem!
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mdy_local += __shfl_xor_sync(uint32_t(-1), mdy_local, it);
mdyy_local += __shfl_xor_sync(uint32_t(-1), mdyy_local, it);
}
mdy_local *= rn;
mdyy_local *= rn;
#include "ln_bwd_kernels.cuh"
using namespace layer_norm;
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL
>
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG_MAIN
>;
auto kernel = &ln_bwd_kernel<Kernel_traits>;
if( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
} else {
#pragma unroll
for (int it = 16; it > 0; it /= 2) {
mdy_local += __shfl_down_sync(uint32_t(-1), mdy_local, it);
mdyy_local += __shfl_down_sync(uint32_t(-1), mdyy_local, it);
} // lane 0 holds the result!
if (lane == 0) {
mdy_shared[warp_n] = mdy_local;
mdyy_shared[warp_n] = mdyy_local;
}
__syncthreads();
if (warp_n == 0 && lane == 0) {
mdy_local = 0.f;
mdyy_local = 0.f;
for (int it = 0; it < WARPS_N; it++) {
mdy_local += mdy_shared[it];
mdyy_local += mdyy_shared[it];
}
mdy_shared[0] = mdy_local;
mdyy_shared[0] = mdyy_local;
}
__syncthreads();
mdy_local = mdy_shared[0] * rn;
mdyy_local = mdyy_shared[0] * rn;
}
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp =
compute_t(rs_r) * (dy_tmp - mdyy_local * y_tmp - mdy_local);
dx[it].data.elt[jt] = dx_tmp;
}
}
col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dx[it].store_to(dx_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
col += Ktraits::THREADS_PER_ROW;
}
} // end: grid stride loop
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
compute_t *smem_write;
smem_write = &smem_[warp_m * COLS + tid_r * NUM_ELTS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
smem_write[jt] = dw_sum[it * NUM_ELTS + jt];
}
smem_write += THREADS_PER_ROW * NUM_ELTS;
}
__syncthreads();
compute_t cta_dw_sum[NUM_RES];
memset(cta_dw_sum, 0, sizeof(compute_t) * NUM_RES);
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dw_sum[jt] += smem_[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
__syncthreads();
smem_write = &smem_[warp_m * COLS + tid_r * NUM_ELTS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
smem_write[jt] = dwy_sum[it * NUM_ELTS + jt];
}
smem_write += THREADS_PER_ROW * NUM_ELTS;
}
__syncthreads();
compute_t cta_dwy_sum[NUM_RES];
memset(cta_dwy_sum, 0, sizeof(compute_t) * NUM_RES);
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dwy_sum[jt] +=
smem_[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
compute_t *dgamma_part = static_cast<compute_t *>(dg_) + bidx * COLS + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dgamma_part = cta_dwy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
compute_t *dbeta_part = static_cast<compute_t *>(db_) + bidx * COLS + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dbeta_part = cta_dw_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA;
}
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
}
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
output_t,
compute_t,
index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
}
template<typename Ktraits, typename out_t>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_finalize_kernel(void * __restrict__ dg_,
void * __restrict__ db_,
const void * __restrict__ dg_part_,
const void * __restrict__ db_part_,
const int rows
){
using Vec = typename Ktraits::Vec;
enum { NUM_ELTS = Vec::NUM_ELTS };
using vec_t = typename Ktraits::vec_t;
using base_t = typename Ktraits::base_t;
using compute_t = typename Ktraits::compute_t;
enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum {VEC_COLS = BYTES_PER_ROW / BYTES_PER_LDG};
//dbg
static_assert(VEC_COLS == COLS / NUM_ELTS, "");
//static_assert(VEC_COLS == 1024,"");
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP;
const int warp = tidx / THREADS_PER_WARP;
const int warp_m = warp / Ktraits::WARPS_N;
const int warp_n = warp % Ktraits::WARPS_N;
const int tid_c = warp_n * THREADS_PER_WARP + lane;
const int c =bidx * THREADS_PER_ROW + tid_c;
const int r = warp_m;
__shared__ compute_t smem_[(WARPS_M - 1) * THREADS_PER_ROW * NUM_ELTS];
//Will probably run this with WARPS_N = 1 and grid = 1024 / (32*4) = 8, or NUM_ELTS=1 and grid = 32
// and WARPS_M = 4 (or 1??)
for(int col = c; col < VEC_COLS; col += gridDim.x * THREADS_PER_ROW){
const char* dg_part_ptr = static_cast<const char*>(dg_part_) + r * BYTES_PER_ROW + col * BYTES_PER_LDG;
const char* db_part_ptr = static_cast<const char*>(db_part_) + r * BYTES_PER_ROW + col * BYTES_PER_LDG;
compute_t dg_sum[NUM_ELTS];
compute_t db_sum[NUM_ELTS];
memset(dg_sum, 0, sizeof(compute_t) * NUM_ELTS);
memset(db_sum, 0, sizeof(compute_t) * NUM_ELTS);
#pragma unroll
for(int row = r; row < rows;row += ROWS_PER_CTA){
Vec dg;
Vec db;
dg.load_from(dg_part_ptr);
db.load_from(db_part_ptr);
dg_part_ptr += ROWS_PER_CTA * BYTES_PER_ROW;
db_part_ptr += ROWS_PER_CTA * BYTES_PER_ROW;
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dg_sum[jt] += dg.data.elt[jt];
db_sum[jt] += db.data.elt[jt];
}
}
// Finalize the reduction across rows of the CTA
compute_t * smem_write;
smem_write = smem_ + (warp_m -1) *THREADS_PER_ROW * NUM_ELTS + tid_c;
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4);
REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4);
REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4);
REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);
REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);
REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
if (warp_m > 0) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
*smem_write = dg_sum[jt];
smem_write+=THREADS_PER_ROW;
}
}
__syncthreads();
compute_t *smem_read ;
smem_read = smem_ + tid_c ;
if (warp_m == 0) {
#pragma unroll
for (int it = 0; it < WARPS_M - 1; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dg_sum[jt] += *smem_read;
smem_read += THREADS_PER_ROW;
}
}
}
__syncthreads();
smem_write = smem_ + (warp_m -1) *THREADS_PER_ROW * NUM_ELTS + tid_c;
if (warp_m > 0) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
*smem_write = db_sum[jt];
smem_write+=THREADS_PER_ROW;
}
}
__syncthreads();
smem_read = smem_ + tid_c;
if (warp_m == 0) {
#pragma unroll
for (int it = 0; it < WARPS_M - 1; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
db_sum[jt] += *smem_read;
smem_read += THREADS_PER_ROW;
}
}
using vout_t = typename Vec_type<sizeof(out_t) * NUM_ELTS>::Type;
union {
vout_t raw;
out_t elt[NUM_ELTS];
} dg_out, db_out;
// out_t dg_out[NUM_ELTS], db_out[NUM_ELTS];
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dg_out.elt[jt] = dg_sum[jt];
db_out.elt[jt] = db_sum[jt];
}
vout_t *dg_ptr = reinterpret_cast<vout_t *>(dg_) + col ;
vout_t *db_ptr = reinterpret_cast<vout_t *>(db_) + col ;
*dg_ptr = dg_out.raw;
*db_ptr = db_out.raw;
}
}
}
template<typename scalar_t>
void launch(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,
at::Tensor &dgamma_part, at::Tensor &dbeta_part,
const at::Tensor &dw, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int rows, const int cols, const int gridx, cudaStream_t stream){
if (cols == 1024) {
using Ktraits = Kernel_traits<scalar_t, 1024, 4, 1>;
if (Ktraits::SMEM_BYTES >= 48 * 1024) {
AT_CUDA_CHECK(cudaFuncSetAttribute(
ln_bwd_kernel<Ktraits>, cudaFuncAttributeMaxDynamicSharedMemorySize,
Ktraits::SMEM_BYTES));
}
ln_bwd_kernel<Ktraits>
<<<gridx, Ktraits::THREADS_PER_CTA, Ktraits::SMEM_BYTES, stream>>>(
dx.data_ptr(), dgamma_part.data_ptr(), dbeta_part.data_ptr(),
dw.data_ptr(), x.data_ptr(), mu.data_ptr(), rsigma.data_ptr(),
gamma.data_ptr(), rows);
using Ktraits2 = Kernel_traits<float, 1024, 16, 1, 4>;
constexpr int grid2 =
DIVUP(1024, Ktraits2::THREADS_PER_ROW * Ktraits2::Vec::NUM_ELTS);
ln_bwd_finalize_kernel<Ktraits2, scalar_t>
<<<grid2, Ktraits2::THREADS_PER_CTA, 0, stream>>>(
dgamma.data_ptr(), dbeta.data_ptr(), dgamma_part.data_ptr(),
dbeta_part.data_ptr(), gridx);
} else {
assert(false && "Not implemented");
}
AT_CUDA_CHECK(cudaPeekAtLastError());
}
void ln_bwd_cuda(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,
const at::Tensor &dw, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int rows, const int cols, cudaStream_t stream) {
const auto dtype = x.scalar_type();
const auto props = at::cuda::getCurrentDeviceProperties();
const int smCount = props->multiProcessorCount;
// Launch 2 CTAs per SM
const int grid = 2 * smCount;
//request workspace for two-step reduction. We always reduce in FP32.
auto opts = x.options();
auto dbeta_part = torch::empty({grid, cols}, opts.dtype(torch::kFloat32));
auto dgamma_part = torch::empty({grid, cols}, opts.dtype(torch::kFloat32));
if (dtype == torch::kFloat16) {
launch<half>(dx, dgamma, dbeta, dgamma_part, dbeta_part, dw, x, mu, rsigma, gamma, rows, cols, grid, stream);
} else if (dtype == torch::kFloat32) {
launch<float>(dx, dgamma, dbeta, dgamma_part, dbeta_part, dw, x, mu, rsigma, gamma, rows, cols, grid, stream);
} else {
assert(false && "Not implemented");
}
}
\ No newline at end of file
#include "utils.cuh"
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ATen/cuda/CUDAContext.h"
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(
void *__restrict__ y_, void *__restrict__ mu_, void *__restrict__ rsigma_,
const void *__restrict__ x_, const void *__restrict__ gamma_,
const void *__restrict__ beta_, const float epsilon, int rows) {
using Vec = typename Ktraits::Vec;
using base_t = typename Ktraits::base_t;
using compute_t = typename Ktraits::compute_t;
enum { NUM_ELTS = Vec::NUM_ELTS };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };
static_assert(BYTES_PER_LDG == 16, "");
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = BYTES_PER_ROW / Ktraits::BYTES_PER_ROW_PER_CTA };
static_assert(LDGS * Ktraits::BYTES_PER_ROW_PER_CTA == BYTES_PER_ROW, "");
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP;
const int warp = tidx / THREADS_PER_WARP;
const int warp_n = warp % WARPS_N;
const int warp_m = warp / WARPS_N;
const int c = warp_n * THREADS_PER_WARP + lane;
const int r = bidx * ROWS_PER_CTA + warp_m;
const char *x_ptr = static_cast<const char *>(x_);
const char *g_ptr = static_cast<const char *>(gamma_);
const char *b_ptr = static_cast<const char *>(beta_);
char *y_ptr = static_cast<char *>(y_);
compute_t *mu_ptr = static_cast<compute_t *>(mu_);
compute_t *rs_ptr = static_cast<compute_t *>(rsigma_);
Vec gamma[LDGS];
Vec beta[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
gamma[it].load_from(g_ptr + col * BYTES_PER_LDG);
beta[it].load_from(b_ptr + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
x[it].load_from(x_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
xf[it * NUM_ELTS + jt] = compute_t(x[it].data.elt[jt]);
}
}
compute_t mu_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
mu_local += xf[it * NUM_ELTS + jt];
}
}
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
}
mu_local *= rn;
if(lane == 0){
mu_ptr[row] = mu_local;
}
compute_t var_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t diff = xf[it * NUM_ELTS + jt] - mu_local;
var_local += diff * diff;
}
}
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
}
compute_t rsigma = rsqrtf(var_local * rn + epsilon);
if(lane == 0){
rs_ptr[row] = rsigma;
}
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
base_t tmp = (rsigma * (xf[it * NUM_ELTS + jt] - mu_local));
x[it].data.elt[jt] = gamma[it].data.elt[jt] * tmp + beta[it].data.elt[jt];
}
}
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
x[it].store_to(y_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
}
}
template<typename scalar_t>
void launch(
at::Tensor & y, // BxSxhidden_size
at::Tensor & mu,
at::Tensor & rsigma,
const at::Tensor & x, // BxSxhidden_size
const at::Tensor & gamma,
const at::Tensor & beta,
const float epsilon,
const int rows,
const int cols,
const int max_gridx,
cudaStream_t stream
){
if (cols == 1024) {
using Ktraits = Kernel_traits<scalar_t, 1024, 4, 1>;
const int grid =
std::min<int>(DIVUP(rows, Ktraits::ROWS_PER_CTA), max_gridx);
ln_fwd_kernel<Ktraits><<<grid, Ktraits::THREADS_PER_CTA, 0, stream>>>(
y.data_ptr(), mu.data_ptr(), rsigma.data_ptr(), x.data_ptr(),
gamma.data_ptr(), beta.data_ptr(), epsilon, rows);
#include "ln_fwd_kernels.cuh"
using namespace layer_norm;
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG
>
void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
auto kernel = &ln_fwd_kernel<Kernel_traits>;
if( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
assert(false && "Not implemented");
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
}
AT_CUDA_CHECK(cudaPeekAtLastError());
}
void ln_fwd_cuda(
at::Tensor & y, // BxSxhidden_size
at::Tensor & mu,
at::Tensor & rsigma,
const at::Tensor & x, // BxSxhidden_size
const at::Tensor & gamma,
const at::Tensor & beta,
const float epsilon,
const int rows, const int cols,
cudaStream_t stream
){
const auto dtype = x.scalar_type();
const auto props = at::cuda::getCurrentDeviceProperties();
const int max_gridx = props->maxGridSize[0];
//TODO
// - Using dispatch macro costs 1% perf wtf?!?!
// - Tune FP32 warps
// - Add more sizes
if (dtype == torch::kFloat16) {
launch<half>(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, max_gridx, stream);
} else if (dtype == torch::kFloat32) {
launch<float>(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, max_gridx, stream);
} else {
assert(false && "Not implemented");
}
REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);
}
#pragma once
#include "ln.h"
namespace layer_norm {
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_kernel(FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
Wvec gamma[LDGS];
Wvec beta[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx);
idx += VEC_COLS_PER_LDG;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
x[it].load_from(params.x, idx);
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
}
stats_t s = stats.compute(xf, rn);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
mu_ptr[row] = mu;
}
compute_t rs = rsqrtf(rn * m2 + params.epsilon);
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
rs_ptr[row] = rs;
}
Ovec z[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu));
output_t g_ij = gamma[it].data.elt[jt];
output_t b_ij = beta[it].data.elt[jt];
z[it].data.elt[jt] = (g_ij * y_ij + b_ij);
}
z[it].store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
}
}
} // namespace layer_norm
#pragma once
constexpr uint32_t THREADS_PER_WARP = 32;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename dtype, int COLS_, int WARPS_M_, int WARPS_N_,
int BYTES_PER_LDG_ = 16>
struct Kernel_traits {
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = COLS_ };
namespace layer_norm {
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_
>
struct Kernel_traits_base {
using weight_t = weight_t_;
using input_t = input_t_;
using output_t = output_t_;
using compute_t = compute_t_;
using index_t = index_t_;
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
enum { THREADS_PER_WARP = 32 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
THREADS_PER_CTA_>
>
struct Kernel_traits_finalize : public Base {
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
// Bytes per global load from the input.
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
// Number of elements fetched by a global load.
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
// Bytes per global store of the weights.
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
// Shared memory size to transpose the CTA result.
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
// Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA.
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
// The type of the reducer.
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
// Condition for the whole CTA to participate in syncthreads.
static_assert(COLS % Base::THREADS_PER_WARP == 0);
enum { CTAS = COLS / Base::THREADS_PER_WARP };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using Vec = Vec<dtype, BYTES_PER_LDG>;
using vec_t = typename Vec::vec_t;
using base_t = typename Vec::base_t;
using packed_t = typename Vec::packed_t;
using compute_t = typename Vec::compute_t;
using packed_compute_t = typename Vec::packed_compute_t;
template<
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t HIDDEN_SIZE_,
uint32_t CTAS_PER_ROW_,
uint32_t WARPS_M_,
uint32_t WARPS_N_,
uint32_t BYTES_PER_LDG_ = 16,
typename Base = Kernel_traits_base<
HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
WARPS_M_*WARPS_N_*THREADS_PER_WARP
>
>
struct Kernel_traits : public Base {
using input_t = typename Base::input_t;
using weight_t = typename Base::weight_t;
using compute_t = typename Base::compute_t;
using output_t = typename Base::output_t;
using index_t = typename Base::index_t;
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = HIDDEN_SIZE_ };
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(base_t) };
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
enum {SMEM_BYTES = ROWS_PER_CTA * COLS * sizeof(compute_t)};
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
// Assume that each thread can handle the same number of elements in the output and weights as in the input.
static_assert(sizeof(input_t) >= sizeof(output_t));
static_assert(sizeof(input_t) >= sizeof(weight_t));
// The number of columns fetched per load from input: one per thread.
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
// The total number of vectorized loads/stores per hidden vector.
enum { VEC_COLS = COLS / ELTS_PER_LDG };
// The number of loads per thread for the input.
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
//static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
#pragma once
#include <cassert>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "ln.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr uint32_t THREADS_PER_WARP = 32;
////////////////////////////////////////////////////////////////////////////////////////////////////
inline void check_cuda_(cudaError_t status, const char *file, int line) {
if( status != cudaSuccess ) {
fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line);
exit(status);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(ans) \
{ check_cuda_((ans), __FILE__, __LINE__); }
////////////////////////////////////////////////////////////////////////////////////////////////////
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
const bool configure_params) { \
launch_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 operator+(const float2 & a, const float2 & b){
return {a.x + b.x, a.y + b.y};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void operator+=(float2 & a, const float2 & b){
a.x += b.x;
a.y += b.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct Sum {
inline __device__ Sum(){}
inline __device__ T operator()(const T &a, const T &b){
return a + b;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){
return __shfl_xor_sync(uint32_t(-1), x, idx);
}
template<>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx){
return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) };
}
template<typename T>
inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){
return __shfl_down_sync(uint32_t(-1), x, idx);
}
template<>
inline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx){
return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint16 {
uint4 u;
uint4 v;
uint4 s;
uint4 t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint8 {
uint4 u;
uint4 v;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES>
struct BytesToType {};
template<>
struct BytesToType<64> {
using Type = uint16;
static_assert(sizeof(Type) == 64);
};
template<>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<>
struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<>
struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<>
struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<>
struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<>
struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeToVec2 {};
template<>
struct TypeToVec2<float> {
using Type = float2;
};
template<>
struct TypeToVec2<half> {
using Type = half2;
};
template<>
struct TypeToVec2<nv_bfloat16> {
using Type = nv_bfloat162;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int INDEX>
struct Get {
template<typename T, typename R>
static inline __device__ R of(const T &vec);
};
template<>
template<typename T, typename R>
inline __device__ R Get<0>::of(const T &vec) {
return vec.x;
}
template<>
template<typename T, typename R>
inline __device__ R Get<1>::of(const T &vec) {
return vec.y;
}
template<>
template<typename T, typename R>
inline __device__ R Get<2>::of(const T &vec) {
return vec.z;
}
template<>
template<typename T, typename R>
inline __device__ R Get<3>::of(const T &vec) {
return vec.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Src, typename Dst>
struct Converter{
static inline __device__ Dst convert(const Src &from) {
return Dst(from);
}
};
template<>
struct Converter<float2, half2>{
static inline __device__ half2 convert(const float2 &x) {
return __float22half2_rn(x);
}
};
template<>
struct Converter<float2, nv_bfloat162>{
static inline __device__ nv_bfloat162 convert(const float2 &x) {
#if __CUDA_ARCH__ >= 800
return __float22bfloat162_rn(x);
#else
union {
nv_bfloat162 raw;
nv_bfloat16 x;
nv_bfloat16 y;
} tmp;
tmp.x = __float2bfloat16_rn(x.x);
tmp.y = __float2bfloat16_rn(x.y);
return tmp.raw;
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct Zeros{
static inline __device__ T get() {
return T(0.f);
}
};
template<>
struct Zeros<float2>{
static inline __device__ float2 get() {
return make_float2(0.f, 0.f);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Elt_type, uint32_t NUM_ELT>
struct Vec {
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
using Vec_type = typename BytesToType<BYTES>::Type;
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
Alias_type data;
template<typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
other.data.elt[it] = S(this->data.elt[it]);
}
}
template<typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = op(it);
}
}
inline __device__ void load_from(const void *base_ptr, const size_t idx) {
this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
}
inline __device__ void store_to(void *base_ptr, const size_t idx) {
static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<uint32_t CTAS_PER_ROW>
struct InterCTASync {
template<typename Params>
inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn)
: phase_counter_(0)
, b0_(params.barrier + bidm) // The barrier for this group of CTAs.
, b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs.
{
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
}
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
for( int found = -1; found != expected; ) {
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
}
}
inline __device__ void sync(){
// ALL THREADS MUST ENTER!
// We switch barrier every iteration.
int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
// We decrement every other iteration.
bool dec = phase_counter_ & 0x2;
int step = dec ? -1 : 1;
int expected = dec ? 0 : CTAS_PER_ROW;
// There are only 4 phases: up/down for b0/b1.
phase_counter_ = (phase_counter_ + 1) & 0x3;
if( threadIdx.x == 0 ) {
spin_wait_(barrier, step, expected);
}
// CTA waits for thread 0
__syncthreads();
}
int phase_counter_;
int * b0_;
int * b1_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
using InterCTASync = InterCTASync<CTAS_PER_ROW>;
using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
using Type = typename Base::Type;
enum { SMEM_BYTES = Base::SMEM_BYTES };
enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
// size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES };
template<typename Params>
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, inter_cta_(params, bidm, bidn)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
{
}
template<typename Op>
inline __device__ T allreduce(T data, Op &op) {
data = Base::reduce(data, op);
// We switch workspace every iteration.
T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
// Warp leaders 0 hold the CTA-local results.
if( this->warp_n_ == 0 && this->lane_ == 0 ) {
workspace[bidn_] = data;
}
inter_cta_.sync();
static_assert(CTAS_PER_ROW <= 32);
T total = Zeros<T>::get();
if(this->lane_ < CTAS_PER_ROW){
total = workspace[this->lane_];
}
total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
return total;
}
InterCTASync inter_cta_;
T *w0_;
T *w1_;
int bidn_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M>
struct Reducer<T, 1, WARPS_M, 1> {
using Type = T;
enum { SMEM_BYTES = 0 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: warp_n_(warp_n)
, lane_(lane)
{
}
template<typename Op>
static inline __device__ T allreduce_(T data, Op &op) {
#pragma unroll
for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {
data = op(data, warp_shuffle_xor(data, it));
}
return data;
}
template<typename Op>
inline __device__ T allreduce(T data, Op &op) {
return allreduce_(data, op);
}
template<typename Op>
inline __device__ T reduce(T data, Op &op){
// only lane 0 holds the result!
#pragma unroll
for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {
data = op(data, warp_shuffle_down(data, it));
}
return data;
}
int warp_n_;
int lane_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
using Base = Reducer<T, 1, WARPS_M, 1>;
using Type = T;
enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, use0_(true)
{
smem0_ = &static_cast<T *>(smem)[warp_m * WARPS_N];
smem1_ = smem0_ + WARPS_M * WARPS_N;
}
template<typename Op>
inline __device__ T allreduce(T data, Op & op) {
T * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
data = Base::reduce(data, op);
if( this->lane_ == 0 ) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
#pragma unroll
for( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
return out;
}
template<typename Op>
inline __device__ T reduce(T data, Op &op) {
T * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// only intra-CTA group leader holds the result!
data = Base::reduce(data, op);
if( this->lane_ == 0 ) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
if( this->warp_n_ == 0 && this->lane_ == 0 ) {
#pragma unroll
for( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
}
return out;
}
T * smem0_;
T * smem1_;
bool use0_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active){
//Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
#pragma unroll
for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
// Exchange
T n_b = warp_shuffle_down(n_a, step);
T m_b = warp_shuffle_down(m_a, step);
T m2_b = warp_shuffle_down(m2_a, step);
// Update
const T n_ab = n_a + n_b; // We can handle one of them being 0, not both.
const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(
const T delta = m_a - m_b;
const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
n_a = n_ab;
m_a = m_ab;
m2_a = m2_ab;
}
// Intra-warp broadcast (only lane 0 has valid stats).
m_a = __shfl_sync(uint32_t(-1), m_a, 0);
m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats {
// This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.
using InterCTASync = InterCTASync<CTAS_PER_ROW>;
using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
using stats_t = typename BlockStats::stats_t;
enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
template<typename Params>
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: inter_cta_(params, bidm, bidn)
, block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
, warp_n_(warp_n)
, lane_(lane)
{
}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
// TODO rn is not really needed here..
constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
stats_t block_stats = block_stats_.compute(elts, block_rn);
stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
if( warp_n_ == 0 && lane_ == 0 ) {
workspace[bidn_] = block_stats;
}
// Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
inter_cta_.sync();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume CTA group size in N less than 32, such that we can finalize with a single warp.
static_assert(CTAS_PER_ROW <= 32);
// Every warp does the final reduction locally.
if( lane_ < CTAS_PER_ROW ) {
stats_t result = workspace[lane_];
n = ELTS_PER_ROW_PER_CTA;
m = layer_norm::Get<0>::of<stats_t, T>(result);
m2 = layer_norm::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
return { m, m2 };
}
InterCTASync inter_cta_;
BlockStats block_stats_;
stats_t *w0_;
stats_t *w1_;
int bidn_;
int warp_n_;
int lane_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats<T, 1, WARPS_M, WARPS_N> {
using WarpStats = Stats<T, 1, WARPS_M, 1>;
using stats_t = typename WarpStats::stats_t;
enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
template<typename Params>
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, use0_(true)
{
smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;
smem1_ = smem0_ + WARPS_M * WARPS_N;
}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
stats_t * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// Compute warp local for all WARPS_N
constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
stats_t warp_stats = warp_stats_.compute(elts, warp_rn);
//Each warp warp leader stores its stats
const auto warp_n = warp_stats_.reducer_.warp_n_;
const auto lane = warp_stats_.reducer_.lane_;
if( lane == 0 ) {
smem[warp_n] = warp_stats;
}
__syncthreads();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume that there are less than 32 warps, such that we can finalize with a single warp
static_assert(WARPS_N <= 32);
if(lane < WARPS_N){
stats_t result = smem[lane];
n = N * THREADS_PER_WARP;
m = layer_norm::Get<0>::of<stats_t, T>(result);
m2 = layer_norm::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, WARPS_N);
return { m, m2 };
}
WarpStats warp_stats_;
stats_t * smem0_;
stats_t * smem1_;
bool use0_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M>
struct Stats<T, 1, WARPS_M, 1> {
using stats_t = typename TypeToVec2<T>::Type;
// The simple Warp reducer.
using Reducer = Reducer<T, 1, WARPS_M, 1>;
enum { SMEM_BYTES = 0 };
template<typename Params>
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem)
{
}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
auto sum = Sum<T>();
T m = Zeros<T>::get();
#pragma unroll
for( int it = 0; it < N; it++ ) {
m += elts[it];
}
m = reducer_.allreduce(m, sum) * rn;
T m2 = Zeros<T>::get();
#pragma unroll
for( int it = 0; it < N; it++ ) {
T diff = (elts[it] - m);
m2 += diff * diff;
}
m2 = reducer_.allreduce(m2, sum);
return {m, m2};
}
Reducer reducer_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
#pragma once
#include "torch/extension.h"
#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
[&] { \
const auto &the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
template <int Bytes> struct Vec_type {};
template <> struct Vec_type<16> {
using Type = uint4;
static __device__ inline Type zero() { return make_uint4(0, 0, 0, 0); }
};
template <> struct Vec_type<8> {
using Type = uint2;
static __device__ inline Type zero() { return make_uint2(0, 0); }
};
template <> struct Vec_type<4> {
using Type = uint32_t;
static __device__ inline Type zero() { return 0; }
};
template <> struct Vec_type<2> {
using Type = uint16_t;
static __device__ inline Type zero() { return 0; }
};
template <typename T> struct TypeInfo {
using base_t = T;
using packed_t = T;
using compute_t = float;
using packed_compute_t = float;
};
template <> struct TypeInfo<half> {
using base_t = half;
using packed_t = half2;
using compute_t = float;
using packed_compute_t = float2;
};
template <typename dtype, int Bytes> struct Vec {
using base_t = typename TypeInfo<dtype>::base_t;
using packed_t = typename TypeInfo<dtype>::packed_t;
using compute_t = typename TypeInfo<dtype>::compute_t;
using packed_compute_t = typename TypeInfo<dtype>::packed_compute_t;
static_assert(Bytes % sizeof(base_t) == 0, "");
static_assert(Bytes % sizeof(packed_t) == 0, "");
enum { BYTES_PER_THREAD = Bytes };
enum { NUM_ELTS = Bytes / sizeof(base_t) };
enum { NUM_PACKED = Bytes / sizeof(packed_t) };
using vec_t = typename Vec_type<Bytes>::Type;
using store_t = union {
vec_t raw;
base_t elt[NUM_ELTS];
packed_t packed[NUM_PACKED];
};
store_t data;
__device__ Vec() { data.raw = Vec_type<Bytes>::zero(); }
__device__ inline void load_from(const char *ptr) {
data.raw = *reinterpret_cast<const vec_t *>(ptr);
}
__device__ inline void load_or_zero(const char *ptr, const bool is_valid) {
data.raw = is_valid ? *reinterpret_cast<const vec_t *>(ptr)
: Vec_type<Bytes>::zero();
}
__device__ inline void store_to(char *ptr) const {
*reinterpret_cast<vec_t *>(ptr) = data.raw;
}
__device__ inline void store_valid(char *ptr, const bool is_valid) const {
if (is_valid)
*reinterpret_cast<vec_t *>(ptr) = data.raw;
}
};
#include <torch/extension.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const half *pad_mask, float dropout_prob);
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
);
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool is_training,
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only BYTE is supported");
}
return fwd_cuda(
is_training,
heads,
input,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(is_training, heads, input,
use_mask ? static_cast<const half *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
torch::Tensor bwd(
bool use_mask,
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
dropout_prob);
}
} // end namespace mask_softmax_dropout
} // namespace additive_mask_softmax_dropout
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
m.def("forward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "softmax.h"
#include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const half *pad_mask, float dropout_prob) {
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
......@@ -41,63 +35,54 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *input_ptr = static_cast<void *>(input.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
attn_batches * q_seq_len);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
// use at:: function so that C++ version generates the same random mask as
// python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {
dropout_results,
dropout_mask,
softmax_results
};
return {dropout_results, dropout_mask, softmax_results};
}
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
......@@ -109,23 +94,20 @@ torch::Tensor bwd_cuda(
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len, stream);
//backward pass is completely in-place
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(output_grads.data_ptr()),
static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
// backward pass is completely in-place
return output_grads;
}
}
}
}
} // namespace additive_mask_softmax_dropout
} // namespace fused_softmax
} // namespace multihead_attn
......@@ -11,33 +11,22 @@
const int UNROLL = 4;
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
__global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p,
std::pair<uint64_t, uint64_t> seeds
)
{
accscalar_t pinv = accscalar_t(1)/p;
template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void
apex_fused_dropout_kernel(scalar_t const *inputs, scalar_t *outputs,
uint8_t *mask, IndexType totalElements, accscalar_t p,
std::pair<uint64_t, uint64_t> seeds) {
accscalar_t pinv = accscalar_t(1) / p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
idx,
seeds.second,
&state);
curand_init(seeds.first, idx, seeds.second, &state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL) {
IndexType rounded_size =
((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL];
rand.x = rand.x <= p;
......@@ -54,7 +43,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
outputs[li] = src[ii]*(&rand.x)[ii]*pinv;
outputs[li] = src[ii] * (&rand.x)[ii] * pinv;
mask[li] = (uint8_t)(&rand.x)[ii];
}
}
......@@ -62,34 +51,23 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
}
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void apex_dropout_add_kernel(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p,
std::pair<uint64_t, uint64_t> seeds
)
{
accscalar_t pinv = accscalar_t(1)/p;
scalar_t *outputs, uint8_t *mask,
IndexType totalElements, accscalar_t p,
std::pair<uint64_t, uint64_t> seeds) {
accscalar_t pinv = accscalar_t(1) / p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
idx,
seeds.second,
&state);
curand_init(seeds.first, idx, seeds.second, &state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL) {
IndexType rounded_size =
((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL];
scalar_t add_src[UNROLL];
......@@ -108,7 +86,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv;
outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);
outputs[li] =
static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);
mask[li] = (uint8_t)(&rand.x)[ii];
}
}
......@@ -116,22 +95,16 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
}
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
__global__ void apex_add_kernel( scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
IndexType totalElements
)
{
template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void apex_add_kernel(scalar_t const *inputs,
scalar_t const *add_inputs, scalar_t *outputs,
IndexType totalElements) {
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL) {
IndexType rounded_size =
((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t src[UNROLL];
scalar_t add_src[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) {
......@@ -151,23 +124,17 @@ __global__ void apex_add_kernel( scalar_t const *inputs,
}
}
template<typename scalar_t,
typename accscalar_t,
typename IndexType
>
template <typename scalar_t, typename accscalar_t, typename IndexType>
__global__ void apex_masked_scale_kernel(scalar_t const *inputs,
scalar_t *outputs,
uint8_t const *mask,
scalar_t *outputs, uint8_t const *mask,
IndexType totalElements,
accscalar_t scale
)
{
accscalar_t scale) {
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL)
{
IndexType rounded_size =
((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx; linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x * UNROLL) {
scalar_t src[UNROLL];
scalar_t msk[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) {
......@@ -180,33 +147,34 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
outputs[li] = static_cast<accscalar_t>(src[ii]) * scale * static_cast<accscalar_t>(msk[ii]);
outputs[li] = static_cast<accscalar_t>(src[ii]) * scale *
static_cast<accscalar_t>(msk[ii]);
}
}
}
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
void apex_fused_dropout_cuda(scalar_t const *inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p)
{
template <typename scalar_t, typename accscalar_t, typename IndexType>
void apex_fused_dropout_cuda(scalar_t const *inputs, scalar_t *outputs,
uint8_t *mask, IndexType totalElements,
accscalar_t p) {
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
//number of times random will be generated per thread, to offset philox counter in the random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
// number of times random will be generated per thread, to offset philox
// counter in the random state
int64_t counter_offset =
((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
......@@ -215,36 +183,39 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);
rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
counter_offset);
#endif
}
apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, p, rng_engine_inputs);
apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, outputs, mask, totalElements, p, rng_engine_inputs);
C10_CUDA_CHECK(cudaGetLastError());
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
void apex_dropout_add_cuda(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p)
{
template <typename scalar_t, typename accscalar_t, typename IndexType>
void apex_dropout_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs,
scalar_t *outputs, uint8_t *mask,
IndexType totalElements, accscalar_t p) {
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
//number of times random will be generated per thread, to offset philox counter in the random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
// number of times random will be generated per thread, to offset philox
// counter in the random state
int64_t counter_offset =
((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
......@@ -253,54 +224,56 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);
rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
counter_offset);
#endif
}
apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs);
apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, add_inputs, outputs, mask, totalElements, p,
rng_engine_inputs);
C10_CUDA_CHECK(cudaGetLastError());
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
void apex_add_cuda(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
IndexType totalElements
)
{
template <typename scalar_t, typename accscalar_t, typename IndexType>
void apex_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs,
scalar_t *outputs, IndexType totalElements) {
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
apex_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, totalElements);
apex_add_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, add_inputs, outputs, totalElements);
C10_CUDA_CHECK(cudaGetLastError());
}
template<typename scalar_t,
typename accscalar_t,
typename IndexType
>
void apex_masked_scale_cuda(scalar_t const *inputs,
scalar_t *outputs,
uint8_t const *mask,
IndexType totalElements,
accscalar_t scale
)
{
template <typename scalar_t, typename accscalar_t, typename IndexType>
void apex_masked_scale_cuda(scalar_t const *inputs, scalar_t *outputs,
uint8_t const *mask, IndexType totalElements,
accscalar_t scale) {
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
dim3 grid((totalElements + block_size - 1) / block_size);
unsigned int blocks_per_sm =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor /
block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()
->multiProcessorCount *
blocks_per_sm,
grid.x);
apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, scale);
apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
inputs, outputs, mask, totalElements, scale);
C10_CUDA_CHECK(cudaGetLastError());
}
......@@ -5,103 +5,79 @@ namespace multihead_attn {
namespace encdec {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
input_weights_q, input_weights_kv, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
......@@ -115,35 +91,35 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_q_results, input_lin_kv_results,
inputs_q, inputs_kv, input_weights_q, input_weights_kv,
output_weights, dropout_mask, dropout_prob);
}
} // end namespace rocblas_gemm_ex
......
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
#include "strided_batched_gemm.h"
namespace multihead_attn {
namespace encdec {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv,
torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob) {
const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0);
......@@ -48,7 +39,7 @@ std::vector<torch::Tensor> fwd_cuda(
const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim;
const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
......@@ -62,25 +53,34 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs_q.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor input_lin_q_results =
torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor input_lin_kv_results =
torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_q_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(input_lin_kv_results.data_ptr());
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim);
void *q_lin_results_ptr = static_cast<void *>(input_lin_q_results.data_ptr());
void *k_lin_results_ptr =
static_cast<void *>(input_lin_kv_results.data_ptr());
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -140,8 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -166,46 +165,35 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
assert(softmax_success);
if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems,
apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
(1.0f - dropout_prob));
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -253,34 +241,24 @@ std::vector<torch::Tensor> fwd_cuda(
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_q_results,
return {input_lin_q_results,
input_lin_kv_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs
};
outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const &dropout_mask, float dropout_prob) {
const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0);
......@@ -292,7 +270,7 @@ std::vector<torch::Tensor> bwd_cuda(
const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim;
const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
......@@ -316,15 +294,20 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results);
at::Tensor input_lin_kv_output_grads =
torch::empty_like(input_lin_kv_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim;
auto q_lin_results_ptr = static_cast<half *>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half *>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr =
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim;
auto q_lin_grads_ptr =
static_cast<half *>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr()) + head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
......@@ -386,8 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -409,8 +391,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -442,17 +423,14 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad
bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -474,8 +452,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -612,3 +589,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex
} // end namespace encdec
} // end namespace multihead_attn
......@@ -5,66 +5,49 @@ namespace multihead_attn {
namespace encdec_norm_add {
namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q,
torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights,
const uint8_t *pad_mask,
float dropout_prob);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const &dropout_add_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
......@@ -73,58 +56,48 @@ std::vector<torch::Tensor> fwd(
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q,
input_weights_kv, output_weights,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
std::vector<torch::Tensor>
bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const &dropout_add_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
......@@ -144,47 +117,49 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float,
"Only FLOAT is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob
);
return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
softmax_results, input_lin_q_results, input_lin_kv_results,
lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q,
inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights,
input_weights_q, input_weights_kv, output_weights,
dropout_mask, dropout_add_mask, dropout_prob);
}
} // end namespace cublas_gemmex
......@@ -195,4 +170,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.");
}
#include <vector>
#include <math.h>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
#include "softmax.h"
#include "strided_batched_gemm.h"
namespace multihead_attn {
namespace encdec_norm_add {
......@@ -64,7 +61,8 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs_q.options().requires_grad(false);
auto lyr_nrm_options = act_options.dtype(torch::kFloat32);
auto mask_options = act_options.dtype(torch::kUInt8);
......@@ -73,23 +71,31 @@ std::vector<torch::Tensor> fwd_cuda(
torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options);
torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options);
torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor input_lin_q_results =
torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor input_lin_kv_results =
torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor softmax_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options);
torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_q_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(input_lin_kv_results.data_ptr());
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim);
void *q_lin_results_ptr = static_cast<void *>(input_lin_q_results.data_ptr());
void *k_lin_results_ptr =
static_cast<void *>(input_lin_kv_results.data_ptr());
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'};
char a_layout_n{'n'};
......@@ -97,16 +103,15 @@ std::vector<torch::Tensor> fwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
static_cast<float*>(lyr_nrm_mean.data_ptr()),
static_cast<float*>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half*>(inputs_q.data_ptr()),
HostApplyLayerNorm<at::Half, float>(
static_cast<at::Half *>(lyr_nrm_results.data_ptr()),
static_cast<float *>(lyr_nrm_mean.data_ptr()),
static_cast<float *>(lyr_nrm_invvar.data_ptr()),
static_cast<const at::Half *>(inputs_q.data_ptr()),
static_cast<int>(batches_q), // n1
static_cast<int>(embed_dim), // n2
1.0e-5,
static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......@@ -161,8 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -187,46 +191,35 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
} else {
if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
k_seq_len, k_seq_len, attn_batches * q_seq_len,
attn_batches * q_seq_len / sequences);
}
}
assert(softmax_success);
if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems,
apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
(1.0f - dropout_prob));
}
// Matmul2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -276,25 +269,22 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add
if (is_training) {
apex_dropout_add_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs_q.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()),
total_tokens_q,
apex_dropout_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const *>(inputs_q.data_ptr()),
static_cast<at::Half *>(outputs.data_ptr()),
static_cast<uint8_t *>(dropout_add_mask.data_ptr()), total_tokens_q,
(1.0f - dropout_prob));
} else {
apex_add_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<at::Half const*>(inputs_q.data_ptr()),
static_cast<at::Half*>(outputs.data_ptr()),
total_tokens_q);
apex_add_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const *>(output_lin_results.data_ptr()),
static_cast<at::Half const *>(inputs_q.data_ptr()),
static_cast<at::Half *>(outputs.data_ptr()), total_tokens_q);
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
lyr_nrm_results,
return {lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
input_lin_q_results,
......@@ -304,33 +294,22 @@ std::vector<torch::Tensor> fwd_cuda(
dropout_mask,
matmul2_results,
dropout_add_mask,
outputs
};
outputs};
}
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
int heads, torch::Tensor const &output_grads,
torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const &softmax_results,
torch::Tensor const &input_lin_q_results,
torch::Tensor const &input_lin_kv_results,
torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean,
torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q,
torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights,
torch::Tensor const &lyr_nrm_beta_weights,
torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
torch::Tensor const &output_weights, torch::Tensor const &dropout_mask,
torch::Tensor const &dropout_add_mask, float dropout_prob) {
const int embed_dim = inputs_q.size(2);
const int sequences = inputs_q.size(1);
const int q_seq_len = inputs_q.size(0);
......@@ -343,7 +322,7 @@ std::vector<torch::Tensor> bwd_cuda(
const int output_lin_kv_dim = 2 * embed_dim;
const int attn_batches = heads * sequences;
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim;
const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
......@@ -370,16 +349,21 @@ std::vector<torch::Tensor> bwd_cuda(
at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results);
at::Tensor input_lin_kv_output_grads =
torch::empty_like(input_lin_kv_results);
at::Tensor input_lin_q_grads = torch::empty_like(inputs_q);
auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim;
auto q_lin_results_ptr = static_cast<half *>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half *>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr =
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim;
auto q_lin_grads_ptr = static_cast<half*>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim;
auto q_lin_grads_ptr =
static_cast<half *>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr()) + head_dim;
char a_layout_n{'n'};
char a_layout_t{'t'};
......@@ -449,8 +433,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( state,
a_layout_t,
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
......@@ -472,8 +455,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul2 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -505,17 +487,14 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad
bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
// Matmul1 Dgrad1
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
......@@ -537,8 +516,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Matmul1 Dgrad2
gemm_switch_fp32accum( state,
a_layout_n,
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
......@@ -683,17 +661,12 @@ std::vector<torch::Tensor> bwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_q_grads,
input_kv_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads
};
return {input_q_grads, input_kv_grads, lyr_nrm_gamma_grads,
lyr_nrm_beta_grads, input_weight_q_grads, input_weight_kv_grads,
output_weight_grads};
}
} // end namespace rocblas_gemmex
} // end namespace encdec_norm_add
} // end namespace multihead_attn
......@@ -4,14 +4,8 @@
#include <cuda.h>
#include <cuda_runtime.h>
template<typename U> __device__
void cuWelfordOnlineSum(
const U curr,
U& mu,
U& sigma2,
U& count)
{
template <typename U>
__device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) {
count = count + U(1);
U delta = curr - mu;
U lmean = mu + delta / count;
......@@ -20,15 +14,9 @@ void cuWelfordOnlineSum(
sigma2 = sigma2 + delta * delta2;
}
template<typename U> __device__
void cuChanOnlineSum(
const U muB,
const U sigma2B,
const U countB,
U& mu,
U& sigma2,
U& count)
{
template <typename U>
__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB,
U &mu, U &sigma2, U &count) {
U delta = muB - mu;
U nA = count;
U nB = countB;
......@@ -37,7 +25,7 @@ void cuChanOnlineSum(
if (nX > U(0)) {
nA = nA / nX;
nB = nB / nX;
mu = nA*mu + nB*muB;
mu = nA * mu + nB * muB;
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
} else {
mu = U(0);
......@@ -45,16 +33,10 @@ void cuChanOnlineSum(
}
}
template<typename T, typename U> __device__
void cuWelfordMuSigma2(
const T* __restrict__ vals,
const int n1,
const int n2,
const int i1,
U& mu,
U& sigma2,
U* buf)
{
template <typename T, typename U>
__device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
const int n2, const int i1, U &mu, U &sigma2,
U *buf) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
......@@ -62,7 +44,7 @@ void cuWelfordMuSigma2(
//
// compute variance and mean over n2
U count = U(0);
mu= U(0);
mu = U(0);
sigma2 = U(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
......@@ -70,17 +52,17 @@ void cuWelfordMuSigma2(
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const T* lvals = vals + i1*n2;
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
const T *lvals = vals + i1 * n2;
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l+k]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
U curr = static_cast<U>(lvals[l + k]);
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
}
}
for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
......@@ -93,23 +75,24 @@ void cuWelfordMuSigma2(
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
U* ubuf = (U*)buf;
U* ibuf = (U*)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
U *ubuf = (U *)buf;
U *ibuf = (U *)(ubuf + blockDim.y);
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ubuf[2 * wrt_y] = mu;
ubuf[2 * wrt_y + 1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
U muB = ubuf[2*threadIdx.y];
U sigma2B = ubuf[2*threadIdx.y+1];
U muB = ubuf[2 * threadIdx.y];
U sigma2B = ubuf[2 * threadIdx.y + 1];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
}
__syncthreads();
}
......@@ -120,7 +103,7 @@ void cuWelfordMuSigma2(
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1]/U(n2);
sigma2 = ubuf[1] / U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0, 32);
......@@ -129,16 +112,10 @@ void cuWelfordMuSigma2(
}
}
template<> __device__
void cuWelfordMuSigma2(
const at::Half* __restrict__ vals,
const int n1,
const int n2,
const int i1,
float& mu,
float& sigma2,
float* buf)
{
template <>
__device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals,
const int n1, const int n2, const int i1,
float &mu, float &sigma2, float *buf) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
......@@ -146,7 +123,7 @@ void cuWelfordMuSigma2(
//
// compute variance and mean over n2
float count = 0.0f;
mu= float(0);
mu = float(0);
sigma2 = float(0);
if (i1 < n1) {
......@@ -155,28 +132,28 @@ void cuWelfordMuSigma2(
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const at::Half* lvals = vals + i1*n2;
int l = 8*thrx;
if ((((size_t)lvals)&3) != 0) {
const at::Half *lvals = vals + i1 * n2;
int l = 8 * thrx;
if ((((size_t)lvals) & 3) != 0) {
// 16 bit alignment
// first thread consumes first point
if (thrx == 0) {
float curr = static_cast<float>(lvals[0]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
cuWelfordOnlineSum(curr, mu, sigma2, count);
}
++l;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for (; l+7 < n2; l+=8*numx) {
for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
for (; l + 7 < n2; l += 8 * numx) {
for (int k = 0; k < 8; k += 2) {
float2 curr = __half22float2(*((__half2 *)(lvals + l + k)));
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
}
}
for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
cuWelfordOnlineSum(curr, mu, sigma2, count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
......@@ -189,23 +166,24 @@ void cuWelfordMuSigma2(
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
float* ubuf = (float*)buf;
float* ibuf = (float*)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
float *ubuf = (float *)buf;
float *ibuf = (float *)(ubuf + blockDim.y);
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ubuf[2 * wrt_y] = mu;
ubuf[2 * wrt_y + 1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
float muB = ubuf[2*threadIdx.y];
float sigma2B = ubuf[2*threadIdx.y+1];
float muB = ubuf[2 * threadIdx.y];
float sigma2B = ubuf[2 * threadIdx.y + 1];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
}
__syncthreads();
}
......@@ -216,7 +194,7 @@ void cuWelfordMuSigma2(
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1]/float(n2);
sigma2 = ubuf[1] / float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0, 32);
......@@ -246,8 +224,9 @@ template<> double rsqrt(double v) {
}
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// This is the un-specialized struct. Note that we prevent instantiation of
// this struct by putting an undefined symbol in the function body so it won't
// compile.
// template <typename T>
// struct SharedMemory
// {
......@@ -260,64 +239,50 @@ namespace {
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T>
struct SharedMemory;
template <typename T> struct SharedMemory;
template <>
struct SharedMemory <float>
{
__device__ float *getPointer()
{
template <> struct SharedMemory<float> {
__device__ float *getPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
template <>
struct SharedMemory <double>
{
__device__ double *getPointer()
{
template <> struct SharedMemory<double> {
__device__ double *getPointer() {
extern __shared__ double s_double[];
return s_double;
}
};
}
} // namespace
template<typename T, typename U> __global__
void cuApplyLayerNorm(
T* __restrict__ output_vals,
U* __restrict__ mean,
U* __restrict__ invvar,
const T* __restrict__ vals,
const int n1,
const int n2,
const U epsilon,
const T* __restrict__ gamma,
const T* __restrict__ beta
)
{
template <typename T, typename U>
__global__ void
cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean,
U *__restrict__ invvar, const T *__restrict__ vals,
const int n1, const int n2, const U epsilon,
const T *__restrict__ gamma, const T *__restrict__ beta) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
const T* lvals = vals + i1*n2;
T* ovals = output_vals + i1*n2;
U *buf = shared.getPointer();
U mu, sigma2;
cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf);
const T *lvals = vals + i1 * n2;
T *ovals = output_vals + i1 * n2;
U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i+=numx) {
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = gamma[i] * static_cast<T>(c_invvar * (curr - mu)) + beta[i];
}
} else {
for (int i = thrx; i < n2; i+=numx) {
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<T>(c_invvar * (curr - mu));
}
......@@ -329,36 +294,26 @@ void cuApplyLayerNorm(
}
}
template<typename T, typename U> __device__
void cuLoadWriteStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const T* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
template <typename T, typename U>
__device__ void cuLoadWriteStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
const T *input, const T *dout, const int i1_end, const int n2,
const U *__restrict__ mean, const U *__restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*n2+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
warp_buf2[write_idx] =
curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
......@@ -366,78 +321,71 @@ void cuLoadWriteStridedInputs(
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template<typename T, typename U> __device__
void cuLoadAddStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const T* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
template <typename T, typename U>
__device__ void cuLoadAddStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
const T *input, const T *dout, const int i1_end, const int n2,
const U *__restrict__ mean, const U *__restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*n2+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
warp_buf2[write_idx] +=
curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template<typename T, typename U> __global__
void cuComputePartGradGammaBeta(
const T* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
U* part_grad_gamma,
U* part_grad_beta)
{
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
template <typename T, typename U>
__global__ void cuComputePartGradGammaBeta(
const T *__restrict__ dout, const T *__restrict__ input, const int n1,
const int n2, const U *__restrict__ mean, const U *__restrict__ invvar,
U epsilon, U *part_grad_gamma, U *part_grad_beta) {
const int numsegs_n1 =
(n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
const int i1_beg_plus_one =
(blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x+1;
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
const int row_stride = blockDim.x + 1;
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
const int thr_load_row_off =
(threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
U *buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y *
// blockDim.y + (blockDim.y -
// 1)*(blockDim.x/blockDim.y) elements
U *warp_buf1 = (U *)buf;
U *warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off,
row_stride, warp_buf1, warp_buf2, input, dout,
i1_end, n2, mean, invvar);
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
i1_block += blockDim.y * blockDim.y) {
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off,
row_stride, warp_buf1, warp_buf2, input, dout,
i1_end, n2, mean, invvar);
}
__syncthreads();
// inter-warp reductions
......@@ -445,21 +393,21 @@ void cuComputePartGradGammaBeta(
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k*blockDim.y;
int idx1 = row1*row_stride + threadIdx.x;
int row1 = threadIdx.y + k * blockDim.y;
int idx1 = row1 * row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
......@@ -469,53 +417,51 @@ void cuComputePartGradGammaBeta(
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template<typename T, typename U> __global__
void cuComputeGradGammaBeta(
const U* part_grad_gamma,
const U* part_grad_beta,
const int part_size,
const int n1,
const int n2,
T* grad_gamma,
T* grad_beta)
{
template <typename T, typename U>
__global__ void
cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta,
const int part_size, const int n1, const int n2,
T *grad_gamma, T *grad_beta) {
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
U *buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
sum_beta += part_grad_beta_ptr[warp_offset*n2];
const U *part_grad_gamma_ptr =
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U *part_grad_beta_ptr =
part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions;
++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
sum_beta += part_grad_beta_ptr[warp_offset * n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx+nbsize3] = sum_beta;
buf[write_idx + nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx+nbsize3];
sum_beta += buf[read_idx + nbsize3];
}
__syncthreads();
}
......@@ -527,51 +473,46 @@ void cuComputeGradGammaBeta(
}
}
template<typename T, typename U> __global__
void cuComputeGradInput(
const T* __restrict__ dout,
const T* __restrict__ dout_resid,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
const T* gamma,
T* grad_input)
{
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
template <typename T, typename U>
__global__ void
cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid,
const T *__restrict__ input, const int n1, const int n2,
const U *__restrict__ mean, const U *__restrict__ invvar,
U epsilon, const T *gamma, T *grad_input) {
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = invvar[i1];
const T* k_input = input + i1*n2;
const T* k_dout = dout + i1*n2;
const T* k_dout_resid = dout_resid + i1*n2;
const T *k_input = input + i1 * n2;
const T *k_dout = dout + i1 * n2;
const T *k_dout_resid = dout_resid + i1 * n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss * static_cast<U>(gamma[l+k]);
sum_loss2 += c_loss * static_cast<U>(gamma[l+k]) * (c_h - c_mean) * c_invvar;
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * static_cast<U>(gamma[l + k]);
sum_loss2 +=
c_loss * static_cast<U>(gamma[l + k]) * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * static_cast<U>(gamma[l]);
sum_loss2 += c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;
sum_loss2 +=
c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
......@@ -591,161 +532,121 @@ void cuComputeGradInput(
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
U *buf = shared.getPointer();
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2*wrt_i] = sum_loss1;
buf[2*wrt_i+1] = sum_loss2;
buf[2 * wrt_i] = sum_loss1;
buf[2 * wrt_i + 1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2*read_i];
sum_loss2 += buf[2*read_i+1];
sum_loss1 += buf[2 * read_i];
sum_loss2 += buf[2 * read_i + 1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2*threadIdx.x] = sum_loss1;
buf[2*threadIdx.x+1] = sum_loss2;
buf[2 * threadIdx.x] = sum_loss1;
buf[2 * threadIdx.x + 1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y !=0) {
sum_loss1 = buf[2*threadIdx.x];
sum_loss2 = buf[2*threadIdx.x+1];
if (threadIdx.y != 0) {
sum_loss1 = buf[2 * threadIdx.x];
sum_loss2 = buf[2 * threadIdx.x + 1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1*n2;
T *k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l+=numx) {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
const T c_resid= static_cast<T>(k_dout_resid[l]);
const T c_resid = static_cast<T>(k_dout_resid[l]);
U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input)+c_resid;
k_grad_input[l] = static_cast<T>(f_grad_input) + c_resid;
}
} else {
for (int l = thrx; l < n2; l+=numx) {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
const T c_resid= static_cast<T>(k_dout_resid[l]);
const T c_resid = static_cast<T>(k_dout_resid[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input)+c_resid;
k_grad_input[l] = static_cast<T>(f_grad_input) + c_resid;
}
}
}
}
template<typename T, typename U>
void HostApplyLayerNorm(
T* output,
U* mean,
U* invvar,
const T* input,
int n1,
int n2,
double epsilon,
const T* gamma,
const T* beta
)
{
template <typename T, typename U>
void HostApplyLayerNorm(T *output, U *mean, U *invvar, const T *input, int n1,
int n2, double epsilon, const T *gamma, const T *beta) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 threads(32, 4, 1);
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0;
threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output,
mean,
invvar,
input,
n1,n2,
U(epsilon),
gamma,beta);
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
}
template<typename T, typename U>
void HostLayerNormGradient(
const T* dout,
const T* dout_resid,
const U* mean,
const U* invvar,
const at::Tensor& input,
int n1,
int n2,
const T* gamma,
const T* beta,
double epsilon,
T* grad_input,
T* grad_gamma,
T* grad_beta
)
{
template <typename T, typename U>
void HostLayerNormGradient(const T *dout, const T *dout_resid, const U *mean,
const U *invvar, const at::Tensor &input, int n1,
int n2, const T *gamma, const T *beta,
double epsilon, T *grad_input, T *grad_gamma,
T *grad_beta) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const dim3 threads2(32, 4, 1);
const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);
const int nshared2_a =
2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor part_grad_gamma = at::empty(
{part_size, n2},
input.options().dtype(input.scalar_type() == at::ScalarType::Half
? at::ScalarType::Float
: input.scalar_type()));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout,
static_cast<T*>(input.data_ptr()),
n1,n2,
mean,
invvar,
U(epsilon),
static_cast<U*>(part_grad_gamma.data_ptr()),
static_cast<U*>(part_grad_beta.data_ptr()));
dout, static_cast<T *>(input.data_ptr()), n1, n2, mean, invvar,
U(epsilon), static_cast<U *>(part_grad_gamma.data_ptr()),
static_cast<U *>(part_grad_beta.data_ptr()));
const dim3 threads3(32,8,1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const dim3 threads3(32, 8, 1);
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
static_cast<U*>(part_grad_gamma.data_ptr()),
static_cast<U*>(part_grad_beta.data_ptr()),
part_size,
n1,n2,
grad_gamma,
grad_beta);
static_cast<U *>(part_grad_gamma.data_ptr()),
static_cast<U *>(part_grad_beta.data_ptr()), part_size, n1, n2,
grad_gamma, grad_beta);
}
// compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1);
int nshared =
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) :
0;
const dim3 threads1(32, 4, 1);
int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout,
dout_resid,
static_cast<T*>(input.data_ptr()),
n1,n2,
mean,
invvar,
U(epsilon),
gamma,
grad_input);
dout, dout_resid, static_cast<T *>(input.data_ptr()), n1, n2, mean,
invvar, U(epsilon), gamma, grad_input);
}
......@@ -5,81 +5,66 @@ namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
torch::Tensor const &input,
const uint8_t *pad_mask,
float dropout_prob);
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
);
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
const uint8_t *padding_mask, float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool is_training,
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
}
return fwd_cuda(
is_training,
heads,
input,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
return fwd_cuda(is_training, heads, input,
use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
: nullptr,
dropout_prob);
}
torch::Tensor bwd(
bool use_mask,
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
torch::Tensor const& padding_mask,
float dropout_prob
)
{
torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
torch::Tensor const &dropout_mask,
torch::Tensor const &padding_mask, float dropout_prob) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr,
dropout_prob
);
return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
use_mask
? static_cast<const uint8_t *>(padding_mask.data_ptr())
: nullptr,
dropout_prob);
}
} // end namespace mask_softmax_dropout
......@@ -87,7 +72,8 @@ torch::Tensor bwd(
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment