Unverified Commit e2083df5 authored by yjk21's avatar yjk21 Committed by GitHub
Browse files

fast layer norm (#1037)

parent a78ccf0b
...@@ -2,4 +2,5 @@ apex.egg-info ...@@ -2,4 +2,5 @@ apex.egg-info
dist dist
build build
docs/build docs/build
*~ *~
\ No newline at end of file __pycache__
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
void ln_fwd_cuda(at::Tensor &y, at::Tensor &mu, at::Tensor &rsigma,
const at::Tensor &x, const at::Tensor &gamma,
const at::Tensor &beta, const float epsilon, const int rows, const int cols,
cudaStream_t stream);
void ln_bwd_cuda(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,
const at::Tensor &dw, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int rows, const int cols, cudaStream_t stream);
std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size
const float epsilon
) {
TORCH_CHECK(x.is_cuda())
TORCH_CHECK(gamma.is_cuda())
TORCH_CHECK(beta.is_cuda())
TORCH_CHECK(x.is_contiguous());
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
const int rows = sizes[0];
const int cols = sizes[1];
auto dtype = x.scalar_type();
TORCH_CHECK(gamma.dtype() == dtype);
TORCH_CHECK(beta.dtype() == dtype);
TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(gamma.numel() == cols);
TORCH_CHECK(epsilon >= 0.f);
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto y = torch::empty_like(x);
auto opts = x.options();
auto mu = torch::empty({rows}, opts.dtype(torch::kFloat32));
auto rsigma = torch::empty({rows}, opts.dtype(torch::kFloat32));
ln_fwd_cuda(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, stream);
return {y, mu, rsigma};
}
std::vector<at::Tensor> ln_bwd(const at::Tensor &dw, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size
const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma // hidden_size
) {
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(dw.is_cuda());
TORCH_CHECK(mu.is_cuda());
TORCH_CHECK(rsigma.is_cuda());
TORCH_CHECK(gamma.is_cuda());
TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(dw.is_contiguous());
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dw.sizes() == sizes);
auto rows = sizes[0];
auto cols = sizes[1];
auto dtype = x.scalar_type();
TORCH_CHECK(dw.dtype() == dtype);
TORCH_CHECK(gamma.dtype() == dtype);
TORCH_CHECK(mu.dtype() == torch::kFloat32);
TORCH_CHECK(rsigma.dtype() == torch::kFloat32);
TORCH_CHECK(mu.sizes() == rsigma.sizes());
TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(gamma.numel() == cols);
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto dx = torch::empty_like(x);
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
ln_bwd_cuda(dx, dgamma, dbeta, dw, x, mu, rsigma, gamma, rows, cols, stream);
return {dx, dgamma, dbeta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA LayerNorm"; // optional module docstring
m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel");
m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel");
}
#include "utils.cuh"
#include "ln_kernel_traits.h"
#include "ATen/cuda/CUDAContext.h"
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(void * __restrict__ dx_,
void * __restrict__ dg_,
void * __restrict__ db_,
const void * __restrict__ dw_,
const void * __restrict__ x_,
const void * __restrict__ mu_,
const void * __restrict__ rs_,
const void * __restrict__ g_,
const int rows
){
using Vec = typename Ktraits::Vec;
enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = BYTES_PER_ROW / Ktraits::BYTES_PER_ROW_PER_CTA };
static_assert(LDGS * Ktraits::BYTES_PER_ROW_PER_CTA == BYTES_PER_ROW, "");
enum { NUM_ELTS = Vec::NUM_ELTS };
using vec_t = typename Ktraits::vec_t;
using base_t = typename Ktraits::base_t;
using compute_t = typename Ktraits::compute_t;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP;
const int warp = tidx / THREADS_PER_WARP;
const int warp_m = warp / Ktraits::WARPS_N;
const int warp_n = warp % Ktraits::WARPS_N;
const int tid_r = warp_n * THREADS_PER_WARP + lane;
const int r = bidx * Ktraits::ROWS_PER_CTA + warp_m;
const int c = warp_n * THREADS_PER_WARP + lane;
const char *dw_ptr = static_cast<const char *>(dw_);
const char *x_ptr = static_cast<const char *>(x_);
const char *g_ptr = static_cast<const char *>(g_);
char *dx_ptr = static_cast<char *>(dx_);
const compute_t *mu_ptr = static_cast<const compute_t *>(mu_);
const compute_t *rs_ptr = static_cast<const compute_t *>(rs_);
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS, "");
// smem for final reduction
//__shared__ compute_t smem_[ROWS_PER_CTA * COLS];
extern __shared__ compute_t smem_[];
// static_assert(sizeof(smem_dw_sum) == 32*1024,"");
// Using the grid stride loop we can assign multiple rows to each thread
// by using a number of CTAs smaller than rows / ROWS_PER_CTA
// We accumulate them here, one in smem, one in registers, because the smem
// capacity is limited compute_t * dw_sum = &smem_dw_sum[warp_m * COLS + tid_r
// * LDGS * NUM_ELTS];
compute_t dwy_sum[LDGS * NUM_ELTS];
compute_t dw_sum[LDGS * NUM_ELTS];
memset(dwy_sum, 0, sizeof(compute_t) * LDGS * NUM_ELTS);
memset(dw_sum, 0, sizeof(compute_t) * LDGS * NUM_ELTS);
// Debug 8 rows, 4B, 1024 cols
__shared__ compute_t smem_mdy[ROWS_PER_CTA * WARPS_N];
__shared__ compute_t smem_mdyy[ROWS_PER_CTA * WARPS_N];
compute_t *mdy_shared = &smem_mdy[warp_m * WARPS_N];
compute_t *mdyy_shared = &smem_mdyy[warp_m * WARPS_N];
constexpr float rn = 1.f / float(COLS);
Vec gamma[LDGS];
int col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(g_ptr + col * BYTES_PER_LDG);
col += Ktraits::THREADS_PER_ROW;
}
// TODO if ROWS_PER_CTA does not divice rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
const compute_t mu_r = mu_ptr[row];
const compute_t rs_r = rs_ptr[row];
Vec dw[LDGS], x[LDGS], dx[LDGS];
int col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dw[it].load_from(dw_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
x[it].load_from(x_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
// local reductions
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < Vec::NUM_ELTS; jt++) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = gamma[it].data.elt[jt] * dw[it].data.elt[jt];
compute_t dw_tmp = dw[it].data.elt[jt];
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dwy_sum[it * NUM_ELTS + jt] += dw_tmp * y_tmp;
dw_sum[it * NUM_ELTS + jt] += dw_tmp;
}
}
// reduction across row for mdy, mdyy
if (WARPS_N == 1) { // no need to go through smem!
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mdy_local += __shfl_xor_sync(uint32_t(-1), mdy_local, it);
mdyy_local += __shfl_xor_sync(uint32_t(-1), mdyy_local, it);
}
mdy_local *= rn;
mdyy_local *= rn;
} else {
#pragma unroll
for (int it = 16; it > 0; it /= 2) {
mdy_local += __shfl_down_sync(uint32_t(-1), mdy_local, it);
mdyy_local += __shfl_down_sync(uint32_t(-1), mdyy_local, it);
} // lane 0 holds the result!
if (lane == 0) {
mdy_shared[warp_n] = mdy_local;
mdyy_shared[warp_n] = mdyy_local;
}
__syncthreads();
if (warp_n == 0 && lane == 0) {
mdy_local = 0.f;
mdyy_local = 0.f;
for (int it = 0; it < WARPS_N; it++) {
mdy_local += mdy_shared[it];
mdyy_local += mdyy_shared[it];
}
mdy_shared[0] = mdy_local;
mdyy_shared[0] = mdyy_local;
}
__syncthreads();
mdy_local = mdy_shared[0] * rn;
mdyy_local = mdyy_shared[0] * rn;
}
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp =
compute_t(rs_r) * (dy_tmp - mdyy_local * y_tmp - mdy_local);
dx[it].data.elt[jt] = dx_tmp;
}
}
col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
dx[it].store_to(dx_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
col += Ktraits::THREADS_PER_ROW;
}
} // end: grid stride loop
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
compute_t *smem_write;
smem_write = &smem_[warp_m * COLS + tid_r * NUM_ELTS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
smem_write[jt] = dw_sum[it * NUM_ELTS + jt];
}
smem_write += THREADS_PER_ROW * NUM_ELTS;
}
__syncthreads();
compute_t cta_dw_sum[NUM_RES];
memset(cta_dw_sum, 0, sizeof(compute_t) * NUM_RES);
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dw_sum[jt] += smem_[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
__syncthreads();
smem_write = &smem_[warp_m * COLS + tid_r * NUM_ELTS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
smem_write[jt] = dwy_sum[it * NUM_ELTS + jt];
}
smem_write += THREADS_PER_ROW * NUM_ELTS;
}
__syncthreads();
compute_t cta_dwy_sum[NUM_RES];
memset(cta_dwy_sum, 0, sizeof(compute_t) * NUM_RES);
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dwy_sum[jt] +=
smem_[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
compute_t *dgamma_part = static_cast<compute_t *>(dg_) + bidx * COLS + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dgamma_part = cta_dwy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
compute_t *dbeta_part = static_cast<compute_t *>(db_) + bidx * COLS + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dbeta_part = cta_dw_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA;
}
}
template<typename Ktraits, typename out_t>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_finalize_kernel(void * __restrict__ dg_,
void * __restrict__ db_,
const void * __restrict__ dg_part_,
const void * __restrict__ db_part_,
const int rows
){
using Vec = typename Ktraits::Vec;
enum { NUM_ELTS = Vec::NUM_ELTS };
using vec_t = typename Ktraits::vec_t;
using base_t = typename Ktraits::base_t;
using compute_t = typename Ktraits::compute_t;
enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum {VEC_COLS = BYTES_PER_ROW / BYTES_PER_LDG};
//dbg
static_assert(VEC_COLS == COLS / NUM_ELTS, "");
//static_assert(VEC_COLS == 1024,"");
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP;
const int warp = tidx / THREADS_PER_WARP;
const int warp_m = warp / Ktraits::WARPS_N;
const int warp_n = warp % Ktraits::WARPS_N;
const int tid_c = warp_n * THREADS_PER_WARP + lane;
const int c =bidx * THREADS_PER_ROW + tid_c;
const int r = warp_m;
__shared__ compute_t smem_[(WARPS_M - 1) * THREADS_PER_ROW * NUM_ELTS];
//Will probably run this with WARPS_N = 1 and grid = 1024 / (32*4) = 8, or NUM_ELTS=1 and grid = 32
// and WARPS_M = 4 (or 1??)
for(int col = c; col < VEC_COLS; col += gridDim.x * THREADS_PER_ROW){
const char* dg_part_ptr = static_cast<const char*>(dg_part_) + r * BYTES_PER_ROW + col * BYTES_PER_LDG;
const char* db_part_ptr = static_cast<const char*>(db_part_) + r * BYTES_PER_ROW + col * BYTES_PER_LDG;
compute_t dg_sum[NUM_ELTS];
compute_t db_sum[NUM_ELTS];
memset(dg_sum, 0, sizeof(compute_t) * NUM_ELTS);
memset(db_sum, 0, sizeof(compute_t) * NUM_ELTS);
#pragma unroll
for(int row = r; row < rows;row += ROWS_PER_CTA){
Vec dg;
Vec db;
dg.load_from(dg_part_ptr);
db.load_from(db_part_ptr);
dg_part_ptr += ROWS_PER_CTA * BYTES_PER_ROW;
db_part_ptr += ROWS_PER_CTA * BYTES_PER_ROW;
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dg_sum[jt] += dg.data.elt[jt];
db_sum[jt] += db.data.elt[jt];
}
}
// Finalize the reduction across rows of the CTA
compute_t * smem_write;
smem_write = smem_ + (warp_m -1) *THREADS_PER_ROW * NUM_ELTS + tid_c;
if (warp_m > 0) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
*smem_write = dg_sum[jt];
smem_write+=THREADS_PER_ROW;
}
}
__syncthreads();
compute_t *smem_read ;
smem_read = smem_ + tid_c ;
if (warp_m == 0) {
#pragma unroll
for (int it = 0; it < WARPS_M - 1; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dg_sum[jt] += *smem_read;
smem_read += THREADS_PER_ROW;
}
}
}
__syncthreads();
smem_write = smem_ + (warp_m -1) *THREADS_PER_ROW * NUM_ELTS + tid_c;
if (warp_m > 0) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
*smem_write = db_sum[jt];
smem_write+=THREADS_PER_ROW;
}
}
__syncthreads();
smem_read = smem_ + tid_c;
if (warp_m == 0) {
#pragma unroll
for (int it = 0; it < WARPS_M - 1; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
db_sum[jt] += *smem_read;
smem_read += THREADS_PER_ROW;
}
}
using vout_t = typename Vec_type<sizeof(out_t) * NUM_ELTS>::Type;
union {
vout_t raw;
out_t elt[NUM_ELTS];
} dg_out, db_out;
// out_t dg_out[NUM_ELTS], db_out[NUM_ELTS];
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
dg_out.elt[jt] = dg_sum[jt];
db_out.elt[jt] = db_sum[jt];
}
vout_t *dg_ptr = reinterpret_cast<vout_t *>(dg_) + col ;
vout_t *db_ptr = reinterpret_cast<vout_t *>(db_) + col ;
*dg_ptr = dg_out.raw;
*db_ptr = db_out.raw;
}
}
}
template<typename scalar_t>
void launch(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,
at::Tensor &dgamma_part, at::Tensor &dbeta_part,
const at::Tensor &dw, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int rows, const int cols, const int gridx, cudaStream_t stream){
if (cols == 1024) {
using Ktraits = Kernel_traits<scalar_t, 1024, 4, 1>;
if (Ktraits::SMEM_BYTES >= 48 * 1024) {
AT_CUDA_CHECK(cudaFuncSetAttribute(
ln_bwd_kernel<Ktraits>, cudaFuncAttributeMaxDynamicSharedMemorySize,
Ktraits::SMEM_BYTES));
}
ln_bwd_kernel<Ktraits>
<<<gridx, Ktraits::THREADS_PER_CTA, Ktraits::SMEM_BYTES, stream>>>(
dx.data_ptr(), dgamma_part.data_ptr(), dbeta_part.data_ptr(),
dw.data_ptr(), x.data_ptr(), mu.data_ptr(), rsigma.data_ptr(),
gamma.data_ptr(), rows);
using Ktraits2 = Kernel_traits<float, 1024, 16, 1, 4>;
constexpr int grid2 =
DIVUP(1024, Ktraits2::THREADS_PER_ROW * Ktraits2::Vec::NUM_ELTS);
ln_bwd_finalize_kernel<Ktraits2, scalar_t>
<<<grid2, Ktraits2::THREADS_PER_CTA, 0, stream>>>(
dgamma.data_ptr(), dbeta.data_ptr(), dgamma_part.data_ptr(),
dbeta_part.data_ptr(), gridx);
} else {
assert(false && "Not implemented");
}
AT_CUDA_CHECK(cudaPeekAtLastError());
}
void ln_bwd_cuda(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,
const at::Tensor &dw, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int rows, const int cols, cudaStream_t stream) {
const auto dtype = x.scalar_type();
const auto props = at::cuda::getCurrentDeviceProperties();
const int smCount = props->multiProcessorCount;
// Launch 2 CTAs per SM
const int grid = 2 * smCount;
//request workspace for two-step reduction. We always reduce in FP32.
auto opts = x.options();
auto dbeta_part = torch::empty({grid, cols}, opts.dtype(torch::kFloat32));
auto dgamma_part = torch::empty({grid, cols}, opts.dtype(torch::kFloat32));
if (dtype == torch::kFloat16) {
launch<half>(dx, dgamma, dbeta, dgamma_part, dbeta_part, dw, x, mu, rsigma, gamma, rows, cols, grid, stream);
} else if (dtype == torch::kFloat32) {
launch<float>(dx, dgamma, dbeta, dgamma_part, dbeta_part, dw, x, mu, rsigma, gamma, rows, cols, grid, stream);
} else {
assert(false && "Not implemented");
}
}
\ No newline at end of file
#include "utils.cuh"
#include "ln_kernel_traits.h"
#include "ATen/cuda/CUDAContext.h"
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(
void *__restrict__ y_, void *__restrict__ mu_, void *__restrict__ rsigma_,
const void *__restrict__ x_, const void *__restrict__ gamma_,
const void *__restrict__ beta_, const float epsilon, int rows) {
using Vec = typename Ktraits::Vec;
using base_t = typename Ktraits::base_t;
using compute_t = typename Ktraits::compute_t;
enum { NUM_ELTS = Vec::NUM_ELTS };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };
static_assert(BYTES_PER_LDG == 16, "");
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = BYTES_PER_ROW / Ktraits::BYTES_PER_ROW_PER_CTA };
static_assert(LDGS * Ktraits::BYTES_PER_ROW_PER_CTA == BYTES_PER_ROW, "");
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP;
const int warp = tidx / THREADS_PER_WARP;
const int warp_n = warp % WARPS_N;
const int warp_m = warp / WARPS_N;
const int c = warp_n * THREADS_PER_WARP + lane;
const int r = bidx * ROWS_PER_CTA + warp_m;
const char *x_ptr = static_cast<const char *>(x_);
const char *g_ptr = static_cast<const char *>(gamma_);
const char *b_ptr = static_cast<const char *>(beta_);
char *y_ptr = static_cast<char *>(y_);
compute_t *mu_ptr = static_cast<compute_t *>(mu_);
compute_t *rs_ptr = static_cast<compute_t *>(rsigma_);
Vec gamma[LDGS];
Vec beta[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
gamma[it].load_from(g_ptr + col * BYTES_PER_LDG);
beta[it].load_from(b_ptr + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
x[it].load_from(x_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
xf[it * NUM_ELTS + jt] = compute_t(x[it].data.elt[jt]);
}
}
compute_t mu_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
mu_local += xf[it * NUM_ELTS + jt];
}
}
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
}
mu_local *= rn;
if(lane == 0){
mu_ptr[row] = mu_local;
}
compute_t var_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t diff = xf[it * NUM_ELTS + jt] - mu_local;
var_local += diff * diff;
}
}
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
}
compute_t rsigma = rsqrtf(var_local * rn + epsilon);
if(lane == 0){
rs_ptr[row] = rsigma;
}
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
base_t tmp = (rsigma * (xf[it * NUM_ELTS + jt] - mu_local));
x[it].data.elt[jt] = gamma[it].data.elt[jt] * tmp + beta[it].data.elt[jt];
}
}
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
x[it].store_to(y_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
}
}
template<typename scalar_t>
void launch(
at::Tensor & y, // BxSxhidden_size
at::Tensor & mu,
at::Tensor & rsigma,
const at::Tensor & x, // BxSxhidden_size
const at::Tensor & gamma,
const at::Tensor & beta,
const float epsilon,
const int rows,
const int cols,
const int max_gridx,
cudaStream_t stream
){
if (cols == 1024) {
using Ktraits = Kernel_traits<scalar_t, 1024, 4, 1>;
const int grid =
std::min<int>(DIVUP(rows, Ktraits::ROWS_PER_CTA), max_gridx);
ln_fwd_kernel<Ktraits><<<grid, Ktraits::THREADS_PER_CTA, 0, stream>>>(
y.data_ptr(), mu.data_ptr(), rsigma.data_ptr(), x.data_ptr(),
gamma.data_ptr(), beta.data_ptr(), epsilon, rows);
} else {
assert(false && "Not implemented");
}
AT_CUDA_CHECK(cudaPeekAtLastError());
}
void ln_fwd_cuda(
at::Tensor & y, // BxSxhidden_size
at::Tensor & mu,
at::Tensor & rsigma,
const at::Tensor & x, // BxSxhidden_size
const at::Tensor & gamma,
const at::Tensor & beta,
const float epsilon,
const int rows, const int cols,
cudaStream_t stream
){
const auto dtype = x.scalar_type();
const auto props = at::cuda::getCurrentDeviceProperties();
const int max_gridx = props->maxGridSize[0];
//TODO
// - Using dispatch macro costs 1% perf wtf?!?!
// - Tune FP32 warps
// - Add more sizes
if (dtype == torch::kFloat16) {
launch<half>(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, max_gridx, stream);
} else if (dtype == torch::kFloat32) {
launch<float>(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, max_gridx, stream);
} else {
assert(false && "Not implemented");
}
}
\ No newline at end of file
#pragma once
constexpr uint32_t THREADS_PER_WARP = 32;
template <typename dtype, int COLS_, int WARPS_M_, int WARPS_N_,
int BYTES_PER_LDG_ = 16>
struct Kernel_traits {
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = COLS_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
using Vec = Vec<dtype, BYTES_PER_LDG>;
using vec_t = typename Vec::vec_t;
using base_t = typename Vec::base_t;
using packed_t = typename Vec::packed_t;
using compute_t = typename Vec::compute_t;
using packed_compute_t = typename Vec::packed_compute_t;
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(base_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
enum {SMEM_BYTES = ROWS_PER_CTA * COLS * sizeof(compute_t)};
};
#pragma once
#include "torch/extension.h"
#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
[&] { \
const auto &the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
template <int Bytes> struct Vec_type {};
template <> struct Vec_type<16> {
using Type = uint4;
static __device__ inline Type zero() { return make_uint4(0, 0, 0, 0); }
};
template <> struct Vec_type<8> {
using Type = uint2;
static __device__ inline Type zero() { return make_uint2(0, 0); }
};
template <> struct Vec_type<4> {
using Type = uint32_t;
static __device__ inline Type zero() { return 0; }
};
template <> struct Vec_type<2> {
using Type = uint16_t;
static __device__ inline Type zero() { return 0; }
};
template <typename T> struct TypeInfo {
using base_t = T;
using packed_t = T;
using compute_t = float;
using packed_compute_t = float;
};
template <> struct TypeInfo<half> {
using base_t = half;
using packed_t = half2;
using compute_t = float;
using packed_compute_t = float2;
};
template <typename dtype, int Bytes> struct Vec {
using base_t = typename TypeInfo<dtype>::base_t;
using packed_t = typename TypeInfo<dtype>::packed_t;
using compute_t = typename TypeInfo<dtype>::compute_t;
using packed_compute_t = typename TypeInfo<dtype>::packed_compute_t;
static_assert(Bytes % sizeof(base_t) == 0, "");
static_assert(Bytes % sizeof(packed_t) == 0, "");
enum { BYTES_PER_THREAD = Bytes };
enum { NUM_ELTS = Bytes / sizeof(base_t) };
enum { NUM_PACKED = Bytes / sizeof(packed_t) };
using vec_t = typename Vec_type<Bytes>::Type;
using store_t = union {
vec_t raw;
base_t elt[NUM_ELTS];
packed_t packed[NUM_PACKED];
};
store_t data;
__device__ Vec() { data.raw = Vec_type<Bytes>::zero(); }
__device__ inline void load_from(const char *ptr) {
data.raw = *reinterpret_cast<const vec_t *>(ptr);
}
__device__ inline void load_or_zero(const char *ptr, const bool is_valid) {
data.raw = is_valid ? *reinterpret_cast<const vec_t *>(ptr)
: Vec_type<Bytes>::zero();
}
__device__ inline void store_to(char *ptr) const {
*reinterpret_cast<vec_t *>(ptr) = data.raw;
}
__device__ inline void store_valid(char *ptr, const bool is_valid) const {
if (is_valid)
*reinterpret_cast<vec_t *>(ptr) = data.raw;
}
};
from .layer_norm import FastLayerNorm
import torch
from torch.nn import init
import fast_layer_norm
class FastLayerNormFN(torch.autograd.Function):
@staticmethod
def forward(ctx, x, gamma, beta, epsilon):
x = x.contiguous()
gamma = gamma.contiguous()
beta = beta.contiguous()
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon)
ctx.save_for_backward(x, gamma, mu, rsigma)
return ymat.view(x.shape)
@staticmethod
def backward(ctx, dy):
#assert dy.is_contiguous()
dy = dy.contiguous() # this happens!
x, gamma, mu, rsigma = ctx.saved_tensors
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dymat = dy.view(xmat.shape)
dxmat, dgamma, dbeta = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma)
dx = dxmat.view(x.shape)
return dx, dgamma, dbeta, None
class FastLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5):
super(FastLayerNorm, self).__init__()
self.epsilon = eps
self.weight = torch.nn.Parameter(torch.Tensor(hidden_size))
self.bias = torch.nn.Parameter(torch.Tensor(hidden_size))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, x):
return FastLayerNormFN.apply(x, self.weight, self.bias, self.epsilon)
import torch
import unittest
import numpy as np
import torch.nn.functional as F
from apex.contrib.layer_norm import FastLayerNorm
import fast_layer_norm as fln
class GPUTimer:
def __init__(self, stream):
self.start_ = torch.cuda.Event(enable_timing=True)
self.stop_ = torch.cuda.Event(enable_timing=True)
self.stream_ = stream
def start(self):
self.stream_.record_event(self.start_)
def stop(self):
self.stream_.record_event(self.stop_)
def sync(self):
self.stream_.synchronize()
def millis(self):
return self.start_.elapsed_time(self.stop_)
def size_in_bytes(t):
return torch.numel(t) * t.element_size()
def abs_err(x, y):
xf = x.float()
yf = y.float()
return ((xf-yf).abs().sum() / yf.abs().sum()).item()
class TestFastLayerNorm(unittest.TestCase):
def setUp(self, seed=1234):
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def test_ln_fp32(self):
self.run_test_layer_norm(torch.float32, atol=1e-5)
def test_ln_fp16(self):
self.run_test_layer_norm(torch.float16, atol=1e-2, rtol=1e-3)
def run_test_layer_norm(self, dtype, atol, rtol=1e-5):
device = torch.device('cuda')
s = 512
b = 32
hidden_size = 1024
epsilon = 1e-5
x = torch.randn((s,b,hidden_size), dtype=dtype, device=device)
beta = torch.randn(hidden_size, dtype=dtype, device=device)
gamma = torch.randn(hidden_size, dtype=dtype, device=device)
x.requires_grad = True
beta.requires_grad = True
gamma.requires_grad = True
x2 = x.clone().detach()
beta2 = beta.clone().detach()
gamma2 = gamma.clone().detach()
x2.requires_grad = True
beta2.requires_grad = True
gamma2.requires_grad = True
dummy_label = torch.randn_like(x)
y = F.layer_norm(x, [hidden_size], gamma, beta, epsilon)
diff = y-dummy_label
l = (diff * diff).sum() / b
l.backward()
fln = FastLayerNorm(hidden_size).cuda()
fln.load_state_dict({'bias': beta2, 'weight':gamma2})
if dtype == torch.float16:
fln = fln.half()
y2 = fln(x2)
diff2 = (y2 - dummy_label)
l2 = (diff2 * diff2).sum() / b
l2.backward()
self.assertTrue(torch.allclose(y2, y, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(x2.grad, x.grad, atol=atol,rtol=rtol))
self.assertTrue(torch.allclose(fln.bias.grad, beta.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(fln.weight.grad, gamma.grad, atol=atol, rtol=rtol))
def test_performance(self):
print()
runs = 1000
device = torch.device('cuda')
dtype =torch.float16
s = 512
b = 32
hidden_size = 1024
epsilon = 1e-5
x = torch.randn((s*b,hidden_size), dtype=dtype, device=device)
beta = torch.randn(hidden_size, dtype=dtype, device=device)
gamma = torch.randn(hidden_size, dtype=dtype, device=device)
dy = torch.randn_like(x)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
timer = GPUTimer(stream)
#warmup
for r in range(runs):
y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)
timer.start()
for r in range(runs):
y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)
timer.stop()
timer.sync()
total_bytes_fwd = (size_in_bytes(x)
+ size_in_bytes(y)
+ size_in_bytes(gamma)
+ size_in_bytes(beta)
+ size_in_bytes(mu)
+ size_in_bytes(rsigma)
)
ms_fwd = timer.millis() / runs
print('[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd ))
timer.start()
for r in range(runs):
dx, dgamma, dbeta = fln.ln_bwd(dy, x, mu, rsigma, gamma)
timer.stop()
timer.sync()
total_bytes_bwd = (size_in_bytes(x)
+ size_in_bytes(dx)
+ size_in_bytes(dy)
+ size_in_bytes(gamma)
+ size_in_bytes(dgamma)
+ size_in_bytes(dbeta)
+ size_in_bytes(mu)
+ size_in_bytes(rsigma)
)
ms_bwd = timer.millis() / runs
print('[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd ))
if __name__ == '__main__':
unittest.main()
...@@ -297,6 +297,38 @@ torch_dir = torch.__path__[0] ...@@ -297,6 +297,38 @@ torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')): if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
generator_flag = ['-DOLD_GENERATOR'] generator_flag = ['-DOLD_GENERATOR']
if "--fast_layer_norm" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_layer_norm")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
ext_modules.append(
CUDAExtension(name='fast_layer_norm',
sources=['apex/contrib/csrc/layer_norm/ln_api.cpp',
'apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu',
'apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu',
],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-I./apex/contrib/csrc/layer_norm/',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
if "--fast_multihead_attn" in sys.argv: if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment