Unverified Commit db92ee13 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #64 from ROCmSoftwarePlatform/IFU-master-2021-12-08

IFU-master-2021-12-08
parents d150afdc 68364b49
import logging
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten # May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch import torch
import warnings
if torch.distributed.is_available(): if torch.distributed.is_available():
from . import parallel from . import parallel
...@@ -22,3 +24,19 @@ from . import pyprof ...@@ -22,3 +24,19 @@ from . import pyprof
#common utilties to run tests on ROCm. #common utilties to run tests on ROCm.
from . import testing from . import testing
from . import transformer from . import transformer
# Logging utilities mainly for apex.transformer module
class RankInfoFormatter(logging.Formatter):
def format(self, record):
from apex.transformer.parallel_state import get_rank_info
record.rank_info = get_rank_info()
return super().format(record)
_library_root_logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(RankInfoFormatter("%(asctime)s - %(name)s - %(levelname)s - %(rank_info)s - %(message)s"))
_library_root_logger.addHandler(handler)
_library_root_logger.propagate = False
from typing import Optional
import torch import torch
def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
if not torch.is_autocast_enabled():
return torch.float or dtype
else:
return torch.get_autocast_gpu_dtype()
def _cast_if_autocast_enabled(*args): def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled(): if not torch.is_autocast_enabled():
return args return args
......
#pragma once
#include <unordered_map>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Params>
struct LaunchParams{
size_t workspace_bytes;
size_t barrier_size;
cudaDeviceProp * props;
cudaStream_t stream;
Params params;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ParamsBase {
ParamsBase()
: ctas_per_col(0)
, rows(0)
, cols(0)
, x(nullptr)
, mu(nullptr)
, rs(nullptr)
, gamma(nullptr)
, workspace(nullptr)
, barrier(nullptr)
{
}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int ctas_per_col;
// Input is interpreted as matrix. We normalize across columns.
int rows;
int cols;
// Common data pointers.
void *x;
void *mu;
void *rs;
void *gamma;
// Multi-CTA workspace in gmem.
void *workspace;
// Multi-CTA sync barriers in gmem.
int *barrier;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FwdParams : public ParamsBase {
FwdParams()
: ParamsBase()
, z(nullptr)
, beta(nullptr)
, epsilon(0.f)
{
}
// Output of LN FWD.
void *z;
void *beta;
float epsilon;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct BwdParams : public ParamsBase {
BwdParams()
: ParamsBase()
, dz(nullptr)
, dbeta_part(nullptr)
, dgamma_part(nullptr)
, dx(nullptr)
, dbeta(nullptr)
, dgamma(nullptr)
{
}
// Input: gradient wrt. LN FWD output.
void *dz;
// Workspace for Wgrad pre-reduction.
void *dbeta_part;
void *dgamma_part;
// Output: Dgrad.
void *dx;
// Output: Wgrad.
void *dbeta;
void *dgamma;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
using FunctionKey = uint64_t;
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
extern FwdRegistry FWD_FUNCS;
extern BwdRegistry BWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeId{};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template<typename T>
struct OutputType2Key : public Type2Key<T, 4>{};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 6>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdRegistrar{
FwdRegistrar(FwdFunction f){
uint64_t key = Types2Key<W,I,O,C>::get(HIDDEN_SIZE);
FWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdRegistrar{
BwdRegistrar(BwdFunction f){
uint64_t key = Types2Key<W,I,O,C>::get(HIDDEN_SIZE);
BWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
#include <torch/extension.h> #include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/CUDAContext.h"
void ln_fwd_cuda(at::Tensor &y, at::Tensor &mu, at::Tensor &rsigma, #include "ln.h"
const at::Tensor &x, const at::Tensor &gamma,
const at::Tensor &beta, const float epsilon, const int rows, const int cols,
cudaStream_t stream);
void ln_bwd_cuda(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta, /*
const at::Tensor &dw, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int rows, const int cols, cudaStream_t stream);
Supported Type combinations:
input compute weights output
=======================================
fp32 fp32 fp32 fp32
fp16 fp32 fp16 fp16
bf16 fp32 bf16 bf16
fp32 fp32 fp16 fp16
fp32 fp32 bf16 bf16
Remarks:
Output type = Weight type
Compute always in FP32
*/
namespace layer_norm {
// Create registries and provide runtime versions of config hash functions.
FwdRegistry FWD_FUNCS;
BwdRegistry BWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
uint32_t get_type_id(torch::Dtype dtype){
if( dtype == torch::kFloat16 ) {
return TypeId<fp16>::Value;
} else if( dtype == torch::kBFloat16 ) {
return TypeId<bf16>::Value;
} else if( dtype == torch::kFloat32 ) {
return TypeId<fp32>::Value;
} else {
TORCH_CHECK(false, "Type not supported: ", dtype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
using namespace layer_norm;
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | (get_type_id(ctype) << 6);
uint64_t launcher_key = (type_key << 32) | hidden_size;
return launcher_key;
}
} // namespace layer_norm
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size));
if( iter != layer_norm::FWD_FUNCS.end() ) {
return iter->second;
} else {
TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size));
if( iter != layer_norm::BWD_FUNCS.end() ) {
return iter->second;
} else {
TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
const at::Tensor &gamma, // hidden_size const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size const at::Tensor &beta, // hidden_size
const float epsilon const float epsilon
) { ) {
auto itype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = wtype;
auto ctype = torch::kFloat32;
TORCH_CHECK(beta.scalar_type() == wtype);
TORCH_CHECK(x.is_cuda()) TORCH_CHECK(x.is_cuda())
TORCH_CHECK(gamma.is_cuda()) TORCH_CHECK(gamma.is_cuda())
...@@ -28,79 +99,148 @@ std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size ...@@ -28,79 +99,148 @@ std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
const int rows = sizes[0]; const int rows = sizes[0];
const int cols = sizes[1]; const int cols = sizes[1];
auto hidden_size = gamma.numel();
auto dtype = x.scalar_type();
TORCH_CHECK(gamma.dtype() == dtype);
TORCH_CHECK(beta.dtype() == dtype);
TORCH_CHECK(gamma.sizes() == beta.sizes()); TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(gamma.numel() == cols); TORCH_CHECK(hidden_size == cols);
TORCH_CHECK(epsilon >= 0.f); TORCH_CHECK(epsilon >= 0.f);
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto opts = x.options();
auto y = torch::empty_like(x); auto z = torch::empty(sizes, opts.dtype(otype));
auto opts = x.options(); auto mu = torch::empty({ rows }, opts.dtype(ctype));
auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
auto mu = torch::empty({rows}, opts.dtype(torch::kFloat32)); layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
auto rsigma = torch::empty({rows}, opts.dtype(torch::kFloat32));
ln_fwd_cuda(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, stream); launch_params.props = at::cuda::getCurrentDeviceProperties();
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
return {y, mu, rsigma}; // Request the kernel launcher.
} auto launcher = get_fwd_launcher(wtype, itype, otype, ctype, hidden_size);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
at::Tensor workspace, barrier;
// Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.beta = beta.data_ptr();
params.z = z.data_ptr();
params.epsilon = epsilon;
if( launch_params.barrier_size > 0 ) {
auto options = x.options();
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
std::vector<at::Tensor> ln_bwd(const at::Tensor &dw, // BxSxhidden_size // Launch the kernel.
launcher(launch_params, false);
return { z, mu, rsigma };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> ln_bwd(const at::Tensor &dz, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size const at::Tensor &x, // BxSxhidden_size
const at::Tensor &mu, // BxS, FP32! const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32! const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma // hidden_size const at::Tensor &gamma // hidden_size
) { ) {
TORCH_CHECK(x.is_cuda()); auto itype = x.scalar_type();
TORCH_CHECK(dw.is_cuda()); auto wtype = gamma.scalar_type();
TORCH_CHECK(mu.is_cuda()); auto otype = wtype;
TORCH_CHECK(rsigma.is_cuda()); auto ctype = torch::kFloat32;
TORCH_CHECK(gamma.is_cuda());
TORCH_CHECK(dz.dtype() == otype);
TORCH_CHECK(x.is_contiguous()); TORCH_CHECK(mu.dtype() == ctype);
TORCH_CHECK(dw.is_contiguous()); TORCH_CHECK(rsigma.dtype() == ctype);
auto sizes = x.sizes(); TORCH_CHECK(x.is_cuda());
TORCH_CHECK(sizes.size() == 2); TORCH_CHECK(dz.is_cuda());
TORCH_CHECK(dw.sizes() == sizes); TORCH_CHECK(mu.is_cuda());
auto rows = sizes[0]; TORCH_CHECK(rsigma.is_cuda());
auto cols = sizes[1]; TORCH_CHECK(gamma.is_cuda());
auto dtype = x.scalar_type(); TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(dw.dtype() == dtype); TORCH_CHECK(dz.is_contiguous());
TORCH_CHECK(gamma.dtype() == dtype);
TORCH_CHECK(mu.dtype() == torch::kFloat32); auto sizes = x.sizes();
TORCH_CHECK(rsigma.dtype() == torch::kFloat32); TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(mu.sizes() == rsigma.sizes()); TORCH_CHECK(dz.sizes() == sizes);
TORCH_CHECK(mu.numel() == rows); auto rows = sizes[0];
auto cols = sizes[1];
TORCH_CHECK(gamma.numel() == cols);
auto hidden_size = gamma.numel();
auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(mu.sizes() == rsigma.sizes());
auto dx = torch::empty_like(x);
auto dgamma = torch::empty_like(gamma); TORCH_CHECK(gamma.numel() == cols);
auto dbeta = torch::empty_like(gamma);
auto options = x.options();
ln_bwd_cuda(dx, dgamma, dbeta, dw, x, mu, rsigma, gamma, rows, cols, stream);
auto dx = torch::empty_like(x);
return {dx, dgamma, dbeta}; auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
launch_params.props = at::cuda::getCurrentDeviceProperties();
auto launcher = get_bwd_launcher(wtype, itype, otype, ctype, hidden_size);
launcher(launch_params, true);
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype));
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype));
at::Tensor workspace, barrier;
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.dz = dz.data_ptr();
params.dx = dx.data_ptr();
params.dbeta = dbeta.data_ptr();
params.dgamma = dgamma.data_ptr();
params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr();
if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this?
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
launcher(launch_params, false);
return { dx, dgamma, dbeta, dgamma_part, dbeta_part };
} }
////////////////////////////////////////////////////////////////////////////////////////////////////
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA LayerNorm"; // optional module docstring m.doc() = "CUDA LayerNorm";
m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel"); m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel");
m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel"); m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel");
} }
#pragma once
namespace layer_norm {
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_kernel(layer_norm::BwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum;
constexpr float rn = 1.f / float(COLS);
Wvec gamma[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
dz_sum[it].data.elt[jt] += dz_tmp;
}
}
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
dx[it].data.elt[jt] = dx_tmp;
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
if( WARPS_M == 1 ) {
idx = r * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dz_sum[NUM_RES];
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
for( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
__syncthreads();
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dzy_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
for( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for( int jt = 0; jt < NUM_RES; jt++ ) {
*dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
for( int jt = 0; jt < NUM_RES; jt++ ) {
*dbeta_part = cta_dz_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA;
}
}
}
template<typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
void ln_bwd_finalize_kernel(BwdParams params)
{
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type;
Sum<reduce_t> sum;
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local));
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
dbeta_part.load_from(params.dbeta_part, idx);
dgamma_part.load_from(params.dgamma_part, idx);
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
}
}
void * smem_gamma = smem_;
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
const int write_row = warp;
const int write_col = lane ^ write_row;
const int write_idx = write_row * THREADS_PER_WARP + write_col;
dgamma_local.store_to(smem_gamma, write_idx);
dbeta_local.store_to(smem_beta, write_idx);
__syncthreads();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
// More than one iter iff ROWS_PER_CTA < 32.
for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
const int read_row = lane;
const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col;
memset(&dbeta_local, 0, sizeof(dbeta_local));
memset(&dgamma_local, 0, sizeof(dgamma_local));
// Load beta and gamma transposed
if(read_row < Kernel_traits::ROWS_PER_CTA){
dbeta_local.load_from(smem_beta, read_idx);
dgamma_local.load_from(smem_gamma, read_idx);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
compute_t b_i = dbeta_local.data.elt[it];
compute_t g_i = dgamma_local.data.elt[it];
b_i = reducer.allreduce(b_i, sum);
g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i;
dbeta_local.data.elt[it] = b_i;
}
// Leader stores the result at the current column.
if(lane == 0){
dgamma_local.store_to(smem_gamma_out, w);
dbeta_local.store_to(smem_beta_out, w);
}
}
// All writes done.
__syncthreads();
// Pack and store: 2-wide stores with half the threads.
if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
dgamma_vec2.load_from(smem_gamma_out, lane);
dbeta_vec2.load_from(smem_beta_out, lane);
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
}
dgamma_out2.store_to(params.dgamma, col_out);
dbeta_out2.store_to(params.dbeta, col_out);
}
}
}
} // namespace layer_norm
#include "utils.cuh" #include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h" #include "ln_kernel_traits.h"
#include "ATen/cuda/CUDAContext.h" #include "ln_fwd_kernels.cuh"
template <typename Ktraits> using namespace layer_norm;
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(
void *__restrict__ y_, void *__restrict__ mu_, void *__restrict__ rsigma_, template<
const void *__restrict__ x_, const void *__restrict__ gamma_, typename weight_t,
const void *__restrict__ beta_, const float epsilon, int rows) { typename input_t,
typename output_t,
using Vec = typename Ktraits::Vec; typename compute_t,
typename index_t,
using base_t = typename Ktraits::base_t; int HIDDEN_SIZE,
using compute_t = typename Ktraits::compute_t; int CTAS_PER_ROW,
enum { NUM_ELTS = Vec::NUM_ELTS }; int WARPS_M,
enum { WARPS_N = Ktraits::WARPS_N }; int WARPS_N,
enum { WARPS_M = Ktraits::WARPS_M }; int BYTES_PER_LDG
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; >
void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG }; using Kernel_traits = Kernel_traits<weight_t,
static_assert(BYTES_PER_LDG == 16, ""); input_t,
output_t,
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; compute_t,
enum { LDGS = BYTES_PER_ROW / Ktraits::BYTES_PER_ROW_PER_CTA }; index_t,
static_assert(LDGS * Ktraits::BYTES_PER_ROW_PER_CTA == BYTES_PER_ROW, ""); HIDDEN_SIZE,
CTAS_PER_ROW,
const int tidx = threadIdx.x; WARPS_M,
const int bidx = blockIdx.x; WARPS_N,
const int lane = tidx % THREADS_PER_WARP; BYTES_PER_LDG
const int warp = tidx / THREADS_PER_WARP; >;
const int warp_n = warp % WARPS_N; auto kernel = &ln_fwd_kernel<Kernel_traits>;
const int warp_m = warp / WARPS_N;
if( configure_params ) {
const int c = warp_n * THREADS_PER_WARP + lane; int ctas_per_sm;
const int r = bidx * ROWS_PER_CTA + warp_m; cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
const char *x_ptr = static_cast<const char *>(x_); launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
const char *g_ptr = static_cast<const char *>(gamma_); launch_params.workspace_bytes = 0;
const char *b_ptr = static_cast<const char *>(beta_); if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
char *y_ptr = static_cast<char *>(y_); launch_params.workspace_bytes = launch_params.params.ctas_per_col
compute_t *mu_ptr = static_cast<compute_t *>(mu_); * Kernel_traits::WARPS_M
compute_t *rs_ptr = static_cast<compute_t *>(rsigma_); * Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
Vec gamma[LDGS]; * 2;
Vec beta[LDGS]; }
#pragma unroll return;
for (int it = 0, col = c; it < LDGS; it++) {
gamma[it].load_from(g_ptr + col * BYTES_PER_LDG);
beta[it].load_from(b_ptr + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
x[it].load_from(x_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
xf[it * NUM_ELTS + jt] = compute_t(x[it].data.elt[jt]);
}
}
compute_t mu_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
mu_local += xf[it * NUM_ELTS + jt];
}
} }
#pragma unroll if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
for (int it = 1; it < THREADS_PER_WARP; it *= 2) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
} }
mu_local *= rn; auto stream = launch_params.stream;
if(lane == 0){ auto ctas_per_col = launch_params.params.ctas_per_col;
mu_ptr[row] = mu_local;
} if( Kernel_traits::CTAS_PER_ROW == 1 ) {
compute_t var_local = 0.f; kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
#pragma unroll dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
for (int it = 0; it < LDGS; it++) { dim3 block(Kernel_traits::THREADS_PER_CTA);
#pragma unroll void *params_ = (void *)&launch_params.params;
for (int jt = 0; jt < NUM_ELTS; jt++) { cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
compute_t diff = xf[it * NUM_ELTS + jt] - mu_local;
var_local += diff * diff;
}
} }
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
}
compute_t rsigma = rsqrtf(var_local * rn + epsilon);
if(lane == 0){
rs_ptr[row] = rsigma;
}
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
base_t tmp = (rsigma * (xf[it * NUM_ELTS + jt] - mu_local));
x[it].data.elt[jt] = gamma[it].data.elt[jt] * tmp + beta[it].data.elt[jt];
}
}
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
x[it].store_to(y_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
}
}
template<typename scalar_t>
void launch(
at::Tensor & y, // BxSxhidden_size
at::Tensor & mu,
at::Tensor & rsigma,
const at::Tensor & x, // BxSxhidden_size
const at::Tensor & gamma,
const at::Tensor & beta,
const float epsilon,
const int rows,
const int cols,
const int max_gridx,
cudaStream_t stream
){
if (cols == 1024) {
using Ktraits = Kernel_traits<scalar_t, 1024, 4, 1>;
const int grid =
std::min<int>(DIVUP(rows, Ktraits::ROWS_PER_CTA), max_gridx);
ln_fwd_kernel<Ktraits><<<grid, Ktraits::THREADS_PER_CTA, 0, stream>>>(
y.data_ptr(), mu.data_ptr(), rsigma.data_ptr(), x.data_ptr(),
gamma.data_ptr(), beta.data_ptr(), epsilon, rows);
} else {
assert(false && "Not implemented");
}
AT_CUDA_CHECK(cudaPeekAtLastError());
} }
void ln_fwd_cuda( REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
at::Tensor & y, // BxSxhidden_size REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
at::Tensor & mu, REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
at::Tensor & rsigma, REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
const at::Tensor & x, // BxSxhidden_size REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
const at::Tensor & gamma,
const at::Tensor & beta, REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
const float epsilon, REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
const int rows, const int cols, REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
cudaStream_t stream REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
){ REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
const auto dtype = x.scalar_type(); REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
const auto props = at::cuda::getCurrentDeviceProperties(); REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
const int max_gridx = props->maxGridSize[0]; REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
//TODO REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// - Using dispatch macro costs 1% perf wtf?!?!
// - Tune FP32 warps REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// - Add more sizes REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
if (dtype == torch::kFloat16) { REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
launch<half>(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, max_gridx, stream); REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
} else if (dtype == torch::kFloat32) { REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
launch<float>(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, max_gridx, stream);
} else { REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
assert(false && "Not implemented"); REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
} REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);
}
#pragma once
#include "ln.h"
namespace layer_norm {
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_kernel(FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
Wvec gamma[LDGS];
Wvec beta[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx);
idx += VEC_COLS_PER_LDG;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
x[it].load_from(params.x, idx);
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
}
stats_t s = stats.compute(xf, rn);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
mu_ptr[row] = mu;
}
compute_t rs = rsqrtf(rn * m2 + params.epsilon);
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
rs_ptr[row] = rs;
}
Ovec z[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu));
output_t g_ij = gamma[it].data.elt[jt];
output_t b_ij = beta[it].data.elt[jt];
z[it].data.elt[jt] = (g_ij * y_ij + b_ij);
}
z[it].store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
}
}
} // namespace layer_norm
#pragma once #pragma once
constexpr uint32_t THREADS_PER_WARP = 32; ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename dtype, int COLS_, int WARPS_M_, int WARPS_N_, namespace layer_norm {
int BYTES_PER_LDG_ = 16> template<
struct Kernel_traits { uint32_t HIDDEN_SIZE_,
enum { WARPS_M = WARPS_M_ }; typename weight_t_,
enum { WARPS_N = WARPS_N_ }; typename input_t_,
enum { COLS = COLS_ }; typename output_t_,
enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; typename compute_t_,
typename index_t_,
using Vec = Vec<dtype, BYTES_PER_LDG>; uint32_t THREADS_PER_CTA_
>
using vec_t = typename Vec::vec_t; struct Kernel_traits_base {
using base_t = typename Vec::base_t;
using packed_t = typename Vec::packed_t; using weight_t = weight_t_;
using compute_t = typename Vec::compute_t; using input_t = input_t_;
using packed_compute_t = typename Vec::packed_compute_t; using output_t = output_t_;
using compute_t = compute_t_;
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; using index_t = index_t_;
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M }; enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
enum { BYTES_PER_ROW = COLS * sizeof(base_t) }; enum { THREADS_PER_WARP = 32 };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
enum {SMEM_BYTES = ROWS_PER_CTA * COLS * sizeof(compute_t)};
}; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
THREADS_PER_CTA_>
>
struct Kernel_traits_finalize : public Base {
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
// Bytes per global load from the input.
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
// Number of elements fetched by a global load.
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
// Bytes per global store of the weights.
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
// Shared memory size to transpose the CTA result.
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
// Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA.
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
// The type of the reducer.
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
// Condition for the whole CTA to participate in syncthreads.
static_assert(COLS % Base::THREADS_PER_WARP == 0);
enum { CTAS = COLS / Base::THREADS_PER_WARP };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t HIDDEN_SIZE_,
uint32_t CTAS_PER_ROW_,
uint32_t WARPS_M_,
uint32_t WARPS_N_,
uint32_t BYTES_PER_LDG_ = 16,
typename Base = Kernel_traits_base<
HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
WARPS_M_*WARPS_N_*THREADS_PER_WARP
>
>
struct Kernel_traits : public Base {
using input_t = typename Base::input_t;
using weight_t = typename Base::weight_t;
using compute_t = typename Base::compute_t;
using output_t = typename Base::output_t;
using index_t = typename Base::index_t;
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = HIDDEN_SIZE_ };
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
// Assume that each thread can handle the same number of elements in the output and weights as in the input.
static_assert(sizeof(input_t) >= sizeof(output_t));
static_assert(sizeof(input_t) >= sizeof(weight_t));
// The number of columns fetched per load from input: one per thread.
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
// The total number of vectorized loads/stores per hidden vector.
enum { VEC_COLS = COLS / ELTS_PER_LDG };
// The number of loads per thread for the input.
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
//static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
This diff is collapsed.
#pragma once
#include "torch/extension.h"
#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
[&] { \
const auto &the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
template <int Bytes> struct Vec_type {};
template <> struct Vec_type<16> {
using Type = uint4;
static __device__ inline Type zero() { return make_uint4(0, 0, 0, 0); }
};
template <> struct Vec_type<8> {
using Type = uint2;
static __device__ inline Type zero() { return make_uint2(0, 0); }
};
template <> struct Vec_type<4> {
using Type = uint32_t;
static __device__ inline Type zero() { return 0; }
};
template <> struct Vec_type<2> {
using Type = uint16_t;
static __device__ inline Type zero() { return 0; }
};
template <typename T> struct TypeInfo {
using base_t = T;
using packed_t = T;
using compute_t = float;
using packed_compute_t = float;
};
template <> struct TypeInfo<half> {
using base_t = half;
using packed_t = half2;
using compute_t = float;
using packed_compute_t = float2;
};
template <typename dtype, int Bytes> struct Vec {
using base_t = typename TypeInfo<dtype>::base_t;
using packed_t = typename TypeInfo<dtype>::packed_t;
using compute_t = typename TypeInfo<dtype>::compute_t;
using packed_compute_t = typename TypeInfo<dtype>::packed_compute_t;
static_assert(Bytes % sizeof(base_t) == 0, "");
static_assert(Bytes % sizeof(packed_t) == 0, "");
enum { BYTES_PER_THREAD = Bytes };
enum { NUM_ELTS = Bytes / sizeof(base_t) };
enum { NUM_PACKED = Bytes / sizeof(packed_t) };
using vec_t = typename Vec_type<Bytes>::Type;
using store_t = union {
vec_t raw;
base_t elt[NUM_ELTS];
packed_t packed[NUM_PACKED];
};
store_t data;
__device__ Vec() { data.raw = Vec_type<Bytes>::zero(); }
__device__ inline void load_from(const char *ptr) {
data.raw = *reinterpret_cast<const vec_t *>(ptr);
}
__device__ inline void load_or_zero(const char *ptr, const bool is_valid) {
data.raw = is_valid ? *reinterpret_cast<const vec_t *>(ptr)
: Vec_type<Bytes>::zero();
}
__device__ inline void store_to(char *ptr) const {
*reinterpret_cast<vec_t *>(ptr) = data.raw;
}
__device__ inline void store_valid(char *ptr, const bool is_valid) const {
if (is_valid)
*reinterpret_cast<vec_t *>(ptr) = data.raw;
}
};
#include <torch/extension.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector> #include <vector>
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace additive_mask_softmax_dropout { namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const half *pad_mask, float dropout_prob);
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
std::vector<torch::Tensor> fwd( #define CHECK_INPUT(x) \
bool use_mask, CHECK_CUDA(x); \
bool is_training, CHECK_CONTIGUOUS(x)
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
torch::Tensor const &input,
torch::Tensor const &pad_mask,
float dropout_prob) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(is_training, heads, input,
is_training, use_mask ? static_cast<const half *>(pad_mask.data_ptr())
heads, : nullptr,
input, dropout_prob);
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
torch::Tensor bwd( torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
bool use_mask, torch::Tensor const &softmax_results,
int heads, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& output_grads, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& softmax_results, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
float dropout_prob AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
) "Only HALF is supported");
{ AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); "Only HALF is supported");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); // AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); // "Only BYTE is supported");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); dropout_prob);
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
dropout_prob
);
} }
} // end namespace mask_softmax_dropout } // namespace additive_mask_softmax_dropout
} // end namespace fused_softmax } // end namespace fused_softmax
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward."); m.def("forward",
m.def("backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward."); &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd,
"Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
} }
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace additive_mask_softmax_dropout { namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const half *pad_mask, float dropout_prob) {
torch::Tensor const& input, const int attn_batches = input.size(0);
const half* pad_mask, const int sequences = attn_batches / heads;
float dropout_prob const int q_seq_len = input.size(1);
) const int k_seq_len = q_seq_len;
{ const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads; // There is no reason to use more than one stream as every kernel is
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = input.options().requires_grad(false); // by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor softmax_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr()); void *input_ptr = static_cast<void *>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
pad_mask, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
if (is_training) { if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version // use at:: function so that C++ version generates the same random mask as
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); // python version
auto dropout_tuple =
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
dropout_results = std::get<0>(dropout_tuple); dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple);
} }
// Matmul2 // Matmul2
return { return {dropout_results, dropout_mask, softmax_results};
dropout_results,
dropout_mask,
softmax_results
};
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& softmax_results, const int attn_batches = output_grads.size(0);
torch::Tensor const& dropout_mask, const int q_seq_len = output_grads.size(1);
float dropout_prob const int k_seq_len = q_seq_len;
) const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads); // torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
k_seq_len, attn_batches * q_seq_len, stream);
k_seq_len, // backward pass is completely in-place
attn_batches*q_seq_len, stream);
//backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace additive_mask_softmax_dropout
} } // namespace fused_softmax
} } // namespace multihead_attn
This diff is collapsed.
...@@ -5,145 +5,121 @@ namespace multihead_attn { ...@@ -5,145 +5,121 @@ namespace multihead_attn {
namespace encdec { namespace encdec {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs_q,
bool is_training, torch::Tensor const &inputs_kv,
int heads, torch::Tensor const &input_weights_q,
torch::Tensor const& inputs_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_kv, torch::Tensor const &output_weights,
torch::Tensor const& input_weights_q, const uint8_t *pad_mask,
torch::Tensor const& input_weights_kv, float dropout_prob);
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob);
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor>
bool use_mask, fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,
bool use_time_mask, torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv,
bool is_training, torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv,
int heads, torch::Tensor const &output_weights, torch::Tensor const &pad_mask,
torch::Tensor const& inputs_q, float dropout_prob) {
torch::Tensor const& inputs_kv, AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_q, AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_kv, AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv,
use_time_mask, input_weights_q, input_weights_kv, output_weights,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
inputs_q, dropout_prob);
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
std::vector<torch::Tensor> bwd( std::vector<torch::Tensor>
int heads, bwd(int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs_kv, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_q, AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& input_weights_kv, AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& output_weights, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half,
return bwd_cuda( "Only HALF is supported");
heads, AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half,
output_grads, "Only HALF is supported");
matmul2_results, AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half,
dropout_results, "Only HALF is supported");
softmax_results, AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half,
input_lin_q_results, "Only HALF is supported");
input_lin_kv_results, AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half,
inputs_q, "Only HALF is supported");
inputs_kv, AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
input_weights_q, "Only BYTE is supported");
input_weights_kv,
output_weights, return bwd_cuda(heads, output_grads, matmul2_results, dropout_results,
dropout_mask, softmax_results, input_lin_q_results, input_lin_kv_results,
dropout_prob inputs_q, inputs_kv, input_weights_q, input_weights_kv,
); output_weights, dropout_mask, dropout_prob);
} }
} // end namespace rocblas_gemm_ex } // end namespace rocblas_gemm_ex
......
#include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h" #include "dropout.h"
#include "layer_norm.h" #include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs #include "strided_batched_gemm.h"
extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace encdec { namespace encdec {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const &inputs_q,
bool is_training, torch::Tensor const &inputs_kv,
int heads, torch::Tensor const &input_weights_q,
torch::Tensor const& inputs_q, torch::Tensor const &input_weights_kv,
torch::Tensor const& inputs_kv, torch::Tensor const &output_weights,
torch::Tensor const& input_weights_q, const uint8_t *pad_mask,
torch::Tensor const& input_weights_kv, float dropout_prob) {
torch::Tensor const& output_weights, const int embed_dim = inputs_q.size(2);
const uint8_t* pad_mask, const int sequences = inputs_q.size(1);
float dropout_prob const int q_seq_len = inputs_q.size(0);
) const int k_seq_len = inputs_kv.size(0);
{ const int batches_q = sequences * q_seq_len;
const int embed_dim = inputs_q.size(2); const int batches_kv = sequences * k_seq_len;
const int sequences = inputs_q.size(1); const int head_dim = embed_dim / heads;
const int q_seq_len = inputs_q.size(0); const int output_lin_q_dim = embed_dim;
const int k_seq_len = inputs_kv.size(0); const int output_lin_kv_dim = 2 * embed_dim;
const int batches_q = sequences * q_seq_len; const int attn_batches = heads * sequences;
const int batches_kv = sequences * k_seq_len; const int lead_dim_q = attn_batches * head_dim;
const int head_dim = embed_dim / heads; const int lead_dim_kv = attn_batches * 2 * head_dim;
const int output_lin_q_dim = embed_dim; const int batch_stride_q = head_dim;
const int output_lin_kv_dim = 2 * embed_dim; const int batch_stride_kv = 2 * head_dim;
const int attn_batches = heads * sequences; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int lead_dim_q = attn_batches * head_dim; const float alpha = 1.0;
const int lead_dim_kv = attn_batches * 2 *head_dim; const float beta = 0.0;
const int batch_stride_q = head_dim; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // There is no reason to use more than one stream as every kernel is
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) // 3 Intermediate Results + Output (Note: dropout intermediates are generated
auto act_options = inputs_q.options().requires_grad(false); // by ATen library code)
auto act_options = inputs_q.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); torch::Tensor input_lin_q_results =
torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor input_lin_kv_results =
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor softmax_results =
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options); torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs_q, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations // Input Linear Results Pointers to Q, K, and V of interviewed activations
void* q_lin_results_ptr = static_cast<void*>(input_lin_q_results.data_ptr()); void *q_lin_results_ptr = static_cast<void *>(input_lin_q_results.data_ptr());
void* k_lin_results_ptr = static_cast<void*>(input_lin_kv_results.data_ptr()); void *k_lin_results_ptr =
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim); static_cast<void *>(input_lin_kv_results.data_ptr());
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
...@@ -140,8 +140,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -140,8 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -166,46 +165,35 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -166,46 +165,35 @@ std::vector<torch::Tensor> fwd_cuda(
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
} else { } else {
if (use_time_mask) { if (use_time_mask) {
softmax_success = dispatch_time_masked_softmax<half, half, float>( softmax_success = dispatch_time_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
q_seq_len);
} else { } else {
softmax_success = dispatch_masked_softmax<half, half, float>( softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half *>(softmax_results_ptr), pad_mask,
pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,
k_seq_len, attn_batches * q_seq_len / sequences);
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
} }
} }
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<at::Half,float,uint32_t>( apex_fused_dropout_cuda<at::Half, float, uint32_t>(
static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half const *>(softmax_results.data_ptr()),
static_cast<at::Half*>(dropout_results.data_ptr()), static_cast<at::Half *>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t *>(dropout_mask.data_ptr()), dropout_elems,
dropout_elems, (1.0f - dropout_prob));
(1.0f - dropout_prob));
} }
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -253,78 +241,73 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -253,78 +241,73 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {input_lin_q_results,
input_lin_q_results, input_lin_kv_results,
input_lin_kv_results, softmax_results,
softmax_results, dropout_results,
dropout_results, dropout_mask,
dropout_mask, matmul2_results,
matmul2_results, outputs};
outputs
};
} }
std::vector<torch::Tensor> bwd_cuda( std::vector<torch::Tensor> bwd_cuda(
int heads, int heads, torch::Tensor const &output_grads,
torch::Tensor const& output_grads, torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results,
torch::Tensor const& matmul2_results, torch::Tensor const &softmax_results,
torch::Tensor const& dropout_results, torch::Tensor const &input_lin_q_results,
torch::Tensor const& softmax_results, torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q,
torch::Tensor const& input_lin_q_results, torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q,
torch::Tensor const& input_lin_kv_results, torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights,
torch::Tensor const& inputs_q, torch::Tensor const &dropout_mask, float dropout_prob) {
torch::Tensor const& inputs_kv, const int embed_dim = inputs_q.size(2);
torch::Tensor const& input_weights_q, const int sequences = inputs_q.size(1);
torch::Tensor const& input_weights_kv, const int q_seq_len = inputs_q.size(0);
torch::Tensor const& output_weights, const int k_seq_len = inputs_kv.size(0);
torch::Tensor const& dropout_mask, const int batches_q = sequences * q_seq_len;
float dropout_prob const int batches_kv = sequences * k_seq_len;
) const int head_dim = embed_dim / heads;
{ const int output_lin_q_dim = embed_dim;
const int embed_dim = inputs_q.size(2); const int output_lin_kv_dim = 2 * embed_dim;
const int sequences = inputs_q.size(1); const int attn_batches = heads * sequences;
const int q_seq_len = inputs_q.size(0); const int lead_dim_q = attn_batches * head_dim;
const int k_seq_len = inputs_kv.size(0); const int lead_dim_kv = attn_batches * 2 * head_dim;
const int batches_q = sequences * q_seq_len; const int batch_stride_q = head_dim;
const int batches_kv = sequences * k_seq_len; const int batch_stride_kv = 2 * head_dim;
const int head_dim = embed_dim / heads; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const int output_lin_q_dim = embed_dim; const float alpha = 1.0;
const int output_lin_kv_dim = 2 * embed_dim; const float beta = 0.0;
const int attn_batches = heads * sequences; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const int lead_dim_q = attn_batches * head_dim;
const int lead_dim_kv = attn_batches * 2 *head_dim;
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// Output Tensor Allocations // Output Tensor Allocations
torch::Tensor input_q_grads = torch::empty_like(inputs_q); torch::Tensor input_q_grads = torch::empty_like(inputs_q);
torch::Tensor input_kv_grads = torch::empty_like(inputs_kv); torch::Tensor input_kv_grads = torch::empty_like(inputs_kv);
torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q); torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q);
torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); at::Tensor input_lin_kv_output_grads =
torch::empty_like(input_lin_kv_results);
auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());
auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()); auto q_lin_results_ptr = static_cast<half *>(input_lin_q_results.data_ptr());
auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim; auto k_lin_results_ptr = static_cast<half *>(input_lin_kv_results.data_ptr());
auto v_lin_results_ptr =
auto q_lin_grads_ptr = static_cast<half*>(input_lin_q_output_grads.data_ptr()); static_cast<half *>(input_lin_kv_results.data_ptr()) + head_dim;
auto k_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim; auto q_lin_grads_ptr =
static_cast<half *>(input_lin_q_output_grads.data_ptr());
auto k_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr());
auto v_lin_grads_ptr =
static_cast<half *>(input_lin_kv_output_grads.data_ptr()) + head_dim;
char a_layout_n{'n'}; char a_layout_n{'n'};
char a_layout_t{'t'}; char a_layout_t{'t'};
...@@ -386,8 +369,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -386,8 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_t,
a_layout_t,
b_layout_n, b_layout_n,
k_seq_len, k_seq_len,
q_seq_len, q_seq_len,
...@@ -409,8 +391,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -409,8 +391,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -442,17 +423,14 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -442,17 +423,14 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad // Softmax Grad
bool softmax_success = false; bool softmax_success = false;
softmax_success = dispatch_softmax_backward<half, half, float>( softmax_success = dispatch_softmax_backward<half, half, float>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len, attn_batches * q_seq_len);
k_seq_len,
attn_batches*q_seq_len);
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
q_seq_len, q_seq_len,
...@@ -474,8 +452,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -474,8 +452,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( a_layout_n,
a_layout_n,
b_layout_t, b_layout_t,
head_dim, head_dim,
k_seq_len, k_seq_len,
...@@ -612,3 +589,4 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -612,3 +589,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace encdec } // end namespace encdec
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -5,81 +5,66 @@ namespace multihead_attn { ...@@ -5,81 +5,66 @@ namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace mask_softmax_dropout { namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
bool is_training, torch::Tensor const &input,
int heads, const uint8_t *pad_mask,
torch::Tensor const& input, float dropout_prob);
const uint8_t* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
int heads, torch::Tensor const &softmax_results,
torch::Tensor const& output_grads, torch::Tensor const &dropout_mask,
torch::Tensor const& softmax_results, const uint8_t *padding_mask, float dropout_prob);
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
);
// C++ interface // C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd( std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
bool use_mask, torch::Tensor const &input,
bool is_training, torch::Tensor const &pad_mask,
int heads, float dropout_prob) {
torch::Tensor const& input, AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
torch::Tensor const& pad_mask, AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
float dropout_prob "Only HALF is supported");
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte,
"Only BYTE is supported");
} }
return fwd_cuda( return fwd_cuda(is_training, heads, input,
is_training, use_mask ? static_cast<const uint8_t *>(pad_mask.data_ptr())
heads, : nullptr,
input, dropout_prob);
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
torch::Tensor bwd( torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
bool use_mask, torch::Tensor const &softmax_results,
int heads, torch::Tensor const &dropout_mask,
torch::Tensor const& output_grads, torch::Tensor const &padding_mask, float dropout_prob) {
torch::Tensor const& softmax_results, AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
torch::Tensor const& dropout_mask, AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
torch::Tensor const& padding_mask, AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported"); AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
"Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return bwd_cuda( return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
heads, use_mask
output_grads, ? static_cast<const uint8_t *>(padding_mask.data_ptr())
softmax_results, : nullptr,
dropout_mask, dropout_prob);
use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr,
dropout_prob
);
} }
} // end namespace mask_softmax_dropout } // end namespace mask_softmax_dropout
...@@ -87,7 +72,8 @@ torch::Tensor bwd( ...@@ -87,7 +72,8 @@ torch::Tensor bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward."); m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd,
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward."); "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd,
"Self Multihead Attention masked softmax dropout -- Backward.");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment