Commit 79906517 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Merge remote-tracking branch 'upstream/master' into IFU-master-2021-12-08

parents cc92a4b4 aa756cec
import logging
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch
import warnings
if torch.distributed.is_available():
from . import parallel
......@@ -22,3 +24,19 @@ from . import pyprof
#common utilties to run tests on ROCm.
from . import testing
from . import transformer
# Logging utilities mainly for apex.transformer module
class RankInfoFormatter(logging.Formatter):
def format(self, record):
from apex.transformer.parallel_state import get_rank_info
record.rank_info = get_rank_info()
return super().format(record)
_library_root_logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(RankInfoFormatter("%(asctime)s - %(name)s - %(levelname)s - %(rank_info)s - %(message)s"))
_library_root_logger.addHandler(handler)
_library_root_logger.propagate = False
from typing import Optional
import torch
def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
if not torch.is_autocast_enabled():
return torch.float or dtype
else:
return torch.get_autocast_gpu_dtype()
def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
......
#pragma once
#include <unordered_map>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Params>
struct LaunchParams{
size_t workspace_bytes;
size_t barrier_size;
cudaDeviceProp * props;
cudaStream_t stream;
Params params;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ParamsBase {
ParamsBase()
: ctas_per_col(0)
, rows(0)
, cols(0)
, x(nullptr)
, mu(nullptr)
, rs(nullptr)
, gamma(nullptr)
, workspace(nullptr)
, barrier(nullptr)
{
}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int ctas_per_col;
// Input is interpreted as matrix. We normalize across columns.
int rows;
int cols;
// Common data pointers.
void *x;
void *mu;
void *rs;
void *gamma;
// Multi-CTA workspace in gmem.
void *workspace;
// Multi-CTA sync barriers in gmem.
int *barrier;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FwdParams : public ParamsBase {
FwdParams()
: ParamsBase()
, z(nullptr)
, beta(nullptr)
, epsilon(0.f)
{
}
// Output of LN FWD.
void *z;
void *beta;
float epsilon;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct BwdParams : public ParamsBase {
BwdParams()
: ParamsBase()
, dz(nullptr)
, dbeta_part(nullptr)
, dgamma_part(nullptr)
, dx(nullptr)
, dbeta(nullptr)
, dgamma(nullptr)
{
}
// Input: gradient wrt. LN FWD output.
void *dz;
// Workspace for Wgrad pre-reduction.
void *dbeta_part;
void *dgamma_part;
// Output: Dgrad.
void *dx;
// Output: Wgrad.
void *dbeta;
void *dgamma;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
using FunctionKey = uint64_t;
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
extern FwdRegistry FWD_FUNCS;
extern BwdRegistry BWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeId{};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template<typename T>
struct OutputType2Key : public Type2Key<T, 4>{};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 6>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdRegistrar{
FwdRegistrar(FwdFunction f){
uint64_t key = Types2Key<W,I,O,C>::get(HIDDEN_SIZE);
FWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdRegistrar{
BwdRegistrar(BwdFunction f){
uint64_t key = Types2Key<W,I,O,C>::get(HIDDEN_SIZE);
BWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
void ln_fwd_cuda(at::Tensor &y, at::Tensor &mu, at::Tensor &rsigma,
const at::Tensor &x, const at::Tensor &gamma,
const at::Tensor &beta, const float epsilon, const int rows, const int cols,
cudaStream_t stream);
#include "ln.h"
void ln_bwd_cuda(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,
const at::Tensor &dw, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int rows, const int cols, cudaStream_t stream);
/*
Supported Type combinations:
input compute weights output
=======================================
fp32 fp32 fp32 fp32
fp16 fp32 fp16 fp16
bf16 fp32 bf16 bf16
fp32 fp32 fp16 fp16
fp32 fp32 bf16 bf16
Remarks:
Output type = Weight type
Compute always in FP32
*/
namespace layer_norm {
// Create registries and provide runtime versions of config hash functions.
FwdRegistry FWD_FUNCS;
BwdRegistry BWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
uint32_t get_type_id(torch::Dtype dtype){
if( dtype == torch::kFloat16 ) {
return TypeId<fp16>::Value;
} else if( dtype == torch::kBFloat16 ) {
return TypeId<bf16>::Value;
} else if( dtype == torch::kFloat32 ) {
return TypeId<fp32>::Value;
} else {
TORCH_CHECK(false, "Type not supported: ", dtype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
using namespace layer_norm;
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | (get_type_id(ctype) << 6);
uint64_t launcher_key = (type_key << 32) | hidden_size;
return launcher_key;
}
} // namespace layer_norm
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size));
if( iter != layer_norm::FWD_FUNCS.end() ) {
return iter->second;
} else {
TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size));
if( iter != layer_norm::BWD_FUNCS.end() ) {
return iter->second;
} else {
TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size
const float epsilon
) {
auto itype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = wtype;
auto ctype = torch::kFloat32;
TORCH_CHECK(beta.scalar_type() == wtype);
TORCH_CHECK(x.is_cuda())
TORCH_CHECK(gamma.is_cuda())
......@@ -28,79 +99,148 @@ std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
const int rows = sizes[0];
const int cols = sizes[1];
auto dtype = x.scalar_type();
TORCH_CHECK(gamma.dtype() == dtype);
TORCH_CHECK(beta.dtype() == dtype);
auto hidden_size = gamma.numel();
TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(gamma.numel() == cols);
TORCH_CHECK(hidden_size == cols);
TORCH_CHECK(epsilon >= 0.f);
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto opts = x.options();
auto y = torch::empty_like(x);
auto z = torch::empty(sizes, opts.dtype(otype));
auto opts = x.options();
auto mu = torch::empty({ rows }, opts.dtype(ctype));
auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
auto mu = torch::empty({rows}, opts.dtype(torch::kFloat32));
auto rsigma = torch::empty({rows}, opts.dtype(torch::kFloat32));
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
ln_fwd_cuda(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, stream);
launch_params.props = at::cuda::getCurrentDeviceProperties();
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
return {y, mu, rsigma};
}
// Request the kernel launcher.
auto launcher = get_fwd_launcher(wtype, itype, otype, ctype, hidden_size);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
at::Tensor workspace, barrier;
// Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.beta = beta.data_ptr();
params.z = z.data_ptr();
params.epsilon = epsilon;
if( launch_params.barrier_size > 0 ) {
auto options = x.options();
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
std::vector<at::Tensor> ln_bwd(const at::Tensor &dw, // BxSxhidden_size
// Launch the kernel.
launcher(launch_params, false);
return { z, mu, rsigma };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> ln_bwd(const at::Tensor &dz, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size
const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma // hidden_size
) {
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(dw.is_cuda());
TORCH_CHECK(mu.is_cuda());
TORCH_CHECK(rsigma.is_cuda());
TORCH_CHECK(gamma.is_cuda());
TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(dw.is_contiguous());
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dw.sizes() == sizes);
auto rows = sizes[0];
auto cols = sizes[1];
auto dtype = x.scalar_type();
TORCH_CHECK(dw.dtype() == dtype);
TORCH_CHECK(gamma.dtype() == dtype);
TORCH_CHECK(mu.dtype() == torch::kFloat32);
TORCH_CHECK(rsigma.dtype() == torch::kFloat32);
TORCH_CHECK(mu.sizes() == rsigma.sizes());
TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(gamma.numel() == cols);
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto dx = torch::empty_like(x);
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
ln_bwd_cuda(dx, dgamma, dbeta, dw, x, mu, rsigma, gamma, rows, cols, stream);
return {dx, dgamma, dbeta};
auto itype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = wtype;
auto ctype = torch::kFloat32;
TORCH_CHECK(dz.dtype() == otype);
TORCH_CHECK(mu.dtype() == ctype);
TORCH_CHECK(rsigma.dtype() == ctype);
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(dz.is_cuda());
TORCH_CHECK(mu.is_cuda());
TORCH_CHECK(rsigma.is_cuda());
TORCH_CHECK(gamma.is_cuda());
TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(dz.is_contiguous());
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dz.sizes() == sizes);
auto rows = sizes[0];
auto cols = sizes[1];
auto hidden_size = gamma.numel();
TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(mu.sizes() == rsigma.sizes());
TORCH_CHECK(gamma.numel() == cols);
auto options = x.options();
auto dx = torch::empty_like(x);
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
launch_params.props = at::cuda::getCurrentDeviceProperties();
auto launcher = get_bwd_launcher(wtype, itype, otype, ctype, hidden_size);
launcher(launch_params, true);
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype));
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype));
at::Tensor workspace, barrier;
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.dz = dz.data_ptr();
params.dx = dx.data_ptr();
params.dbeta = dbeta.data_ptr();
params.dgamma = dgamma.data_ptr();
params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr();
if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this?
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
launcher(launch_params, false);
return { dx, dgamma, dbeta, dgamma_part, dbeta_part };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA LayerNorm"; // optional module docstring
m.doc() = "CUDA LayerNorm";
m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel");
m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel");
}
#pragma once
namespace layer_norm {
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_kernel(layer_norm::BwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum;
constexpr float rn = 1.f / float(COLS);
Wvec gamma[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
Ivec x[LDGS];
Ovec dz[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz[it].load_from(params.dz, idx);
x[it].load_from(params.x, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
dz_sum[it].data.elt[jt] += dz_tmp;
}
}
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
dx[it].data.elt[jt] = dx_tmp;
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
if( WARPS_M == 1 ) {
idx = r * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dz_sum[NUM_RES];
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
for( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
__syncthreads();
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dzy_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
for( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for( int jt = 0; jt < NUM_RES; jt++ ) {
*dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
for( int jt = 0; jt < NUM_RES; jt++ ) {
*dbeta_part = cta_dz_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA;
}
}
}
template<typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
void ln_bwd_finalize_kernel(BwdParams params)
{
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type;
Sum<reduce_t> sum;
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local));
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
dbeta_part.load_from(params.dbeta_part, idx);
dgamma_part.load_from(params.dgamma_part, idx);
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
}
}
void * smem_gamma = smem_;
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
const int write_row = warp;
const int write_col = lane ^ write_row;
const int write_idx = write_row * THREADS_PER_WARP + write_col;
dgamma_local.store_to(smem_gamma, write_idx);
dbeta_local.store_to(smem_beta, write_idx);
__syncthreads();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
// More than one iter iff ROWS_PER_CTA < 32.
for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
const int read_row = lane;
const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col;
memset(&dbeta_local, 0, sizeof(dbeta_local));
memset(&dgamma_local, 0, sizeof(dgamma_local));
// Load beta and gamma transposed
if(read_row < Kernel_traits::ROWS_PER_CTA){
dbeta_local.load_from(smem_beta, read_idx);
dgamma_local.load_from(smem_gamma, read_idx);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
compute_t b_i = dbeta_local.data.elt[it];
compute_t g_i = dgamma_local.data.elt[it];
b_i = reducer.allreduce(b_i, sum);
g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i;
dbeta_local.data.elt[it] = b_i;
}
// Leader stores the result at the current column.
if(lane == 0){
dgamma_local.store_to(smem_gamma_out, w);
dbeta_local.store_to(smem_beta_out, w);
}
}
// All writes done.
__syncthreads();
// Pack and store: 2-wide stores with half the threads.
if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
dgamma_vec2.load_from(smem_gamma_out, lane);
dbeta_vec2.load_from(smem_beta_out, lane);
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
}
dgamma_out2.store_to(params.dgamma, col_out);
dbeta_out2.store_to(params.dbeta, col_out);
}
}
}
} // namespace layer_norm
#include "utils.cuh"
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ATen/cuda/CUDAContext.h"
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(
void *__restrict__ y_, void *__restrict__ mu_, void *__restrict__ rsigma_,
const void *__restrict__ x_, const void *__restrict__ gamma_,
const void *__restrict__ beta_, const float epsilon, int rows) {
using Vec = typename Ktraits::Vec;
using base_t = typename Ktraits::base_t;
using compute_t = typename Ktraits::compute_t;
enum { NUM_ELTS = Vec::NUM_ELTS };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { BYTES_PER_LDG = Ktraits::BYTES_PER_LDG };
static_assert(BYTES_PER_LDG == 16, "");
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = BYTES_PER_ROW / Ktraits::BYTES_PER_ROW_PER_CTA };
static_assert(LDGS * Ktraits::BYTES_PER_ROW_PER_CTA == BYTES_PER_ROW, "");
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP;
const int warp = tidx / THREADS_PER_WARP;
const int warp_n = warp % WARPS_N;
const int warp_m = warp / WARPS_N;
const int c = warp_n * THREADS_PER_WARP + lane;
const int r = bidx * ROWS_PER_CTA + warp_m;
const char *x_ptr = static_cast<const char *>(x_);
const char *g_ptr = static_cast<const char *>(gamma_);
const char *b_ptr = static_cast<const char *>(beta_);
char *y_ptr = static_cast<char *>(y_);
compute_t *mu_ptr = static_cast<compute_t *>(mu_);
compute_t *rs_ptr = static_cast<compute_t *>(rsigma_);
Vec gamma[LDGS];
Vec beta[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
gamma[it].load_from(g_ptr + col * BYTES_PER_LDG);
beta[it].load_from(b_ptr + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
x[it].load_from(x_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
xf[it * NUM_ELTS + jt] = compute_t(x[it].data.elt[jt]);
}
}
compute_t mu_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
mu_local += xf[it * NUM_ELTS + jt];
}
#include "ln_fwd_kernels.cuh"
using namespace layer_norm;
template<
typename weight_t,
typename input_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG
>
void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
auto kernel = &ln_fwd_kernel<Kernel_traits>;
if( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
}
return;
}
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
}
mu_local *= rn;
if(lane == 0){
mu_ptr[row] = mu_local;
}
compute_t var_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t diff = xf[it * NUM_ELTS + jt] - mu_local;
var_local += diff * diff;
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
}
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
}
compute_t rsigma = rsqrtf(var_local * rn + epsilon);
if(lane == 0){
rs_ptr[row] = rsigma;
}
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
base_t tmp = (rsigma * (xf[it * NUM_ELTS + jt] - mu_local));
x[it].data.elt[jt] = gamma[it].data.elt[jt] * tmp + beta[it].data.elt[jt];
}
}
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
x[it].store_to(y_ptr + row * BYTES_PER_ROW + col * BYTES_PER_LDG);
col += THREADS_PER_ROW;
}
}
}
template<typename scalar_t>
void launch(
at::Tensor & y, // BxSxhidden_size
at::Tensor & mu,
at::Tensor & rsigma,
const at::Tensor & x, // BxSxhidden_size
const at::Tensor & gamma,
const at::Tensor & beta,
const float epsilon,
const int rows,
const int cols,
const int max_gridx,
cudaStream_t stream
){
if (cols == 1024) {
using Ktraits = Kernel_traits<scalar_t, 1024, 4, 1>;
const int grid =
std::min<int>(DIVUP(rows, Ktraits::ROWS_PER_CTA), max_gridx);
ln_fwd_kernel<Ktraits><<<grid, Ktraits::THREADS_PER_CTA, 0, stream>>>(
y.data_ptr(), mu.data_ptr(), rsigma.data_ptr(), x.data_ptr(),
gamma.data_ptr(), beta.data_ptr(), epsilon, rows);
} else {
assert(false && "Not implemented");
}
AT_CUDA_CHECK(cudaPeekAtLastError());
}
void ln_fwd_cuda(
at::Tensor & y, // BxSxhidden_size
at::Tensor & mu,
at::Tensor & rsigma,
const at::Tensor & x, // BxSxhidden_size
const at::Tensor & gamma,
const at::Tensor & beta,
const float epsilon,
const int rows, const int cols,
cudaStream_t stream
){
const auto dtype = x.scalar_type();
const auto props = at::cuda::getCurrentDeviceProperties();
const int max_gridx = props->maxGridSize[0];
//TODO
// - Using dispatch macro costs 1% perf wtf?!?!
// - Tune FP32 warps
// - Add more sizes
if (dtype == torch::kFloat16) {
launch<half>(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, max_gridx, stream);
} else if (dtype == torch::kFloat32) {
launch<float>(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, max_gridx, stream);
} else {
assert(false && "Not implemented");
}
REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);
}
#pragma once
#include "ln.h"
namespace layer_norm {
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_kernel(FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
Wvec gamma[LDGS];
Wvec beta[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx);
idx += VEC_COLS_PER_LDG;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
x[it].load_from(params.x, idx);
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
}
stats_t s = stats.compute(xf, rn);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
mu_ptr[row] = mu;
}
compute_t rs = rsqrtf(rn * m2 + params.epsilon);
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
rs_ptr[row] = rs;
}
Ovec z[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu));
output_t g_ij = gamma[it].data.elt[jt];
output_t b_ij = beta[it].data.elt[jt];
z[it].data.elt[jt] = (g_ij * y_ij + b_ij);
}
z[it].store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
}
}
} // namespace layer_norm
#pragma once
constexpr uint32_t THREADS_PER_WARP = 32;
template <typename dtype, int COLS_, int WARPS_M_, int WARPS_N_,
int BYTES_PER_LDG_ = 16>
struct Kernel_traits {
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = COLS_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
using Vec = Vec<dtype, BYTES_PER_LDG>;
using vec_t = typename Vec::vec_t;
using base_t = typename Vec::base_t;
using packed_t = typename Vec::packed_t;
using compute_t = typename Vec::compute_t;
using packed_compute_t = typename Vec::packed_compute_t;
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(base_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
enum {SMEM_BYTES = ROWS_PER_CTA * COLS * sizeof(compute_t)};
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace layer_norm {
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_
>
struct Kernel_traits_base {
using weight_t = weight_t_;
using input_t = input_t_;
using output_t = output_t_;
using compute_t = compute_t_;
using index_t = index_t_;
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
enum { THREADS_PER_WARP = 32 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
THREADS_PER_CTA_>
>
struct Kernel_traits_finalize : public Base {
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
// Bytes per global load from the input.
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
// Number of elements fetched by a global load.
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
// Bytes per global store of the weights.
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
// Shared memory size to transpose the CTA result.
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
// Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA.
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
// The type of the reducer.
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
// Condition for the whole CTA to participate in syncthreads.
static_assert(COLS % Base::THREADS_PER_WARP == 0);
enum { CTAS = COLS / Base::THREADS_PER_WARP };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename weight_t_,
typename input_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t HIDDEN_SIZE_,
uint32_t CTAS_PER_ROW_,
uint32_t WARPS_M_,
uint32_t WARPS_N_,
uint32_t BYTES_PER_LDG_ = 16,
typename Base = Kernel_traits_base<
HIDDEN_SIZE_,
weight_t_,
input_t_,
output_t_,
compute_t_,
index_t_,
WARPS_M_*WARPS_N_*THREADS_PER_WARP
>
>
struct Kernel_traits : public Base {
using input_t = typename Base::input_t;
using weight_t = typename Base::weight_t;
using compute_t = typename Base::compute_t;
using output_t = typename Base::output_t;
using index_t = typename Base::index_t;
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = HIDDEN_SIZE_ };
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
// Assume that each thread can handle the same number of elements in the output and weights as in the input.
static_assert(sizeof(input_t) >= sizeof(output_t));
static_assert(sizeof(input_t) >= sizeof(weight_t));
// The number of columns fetched per load from input: one per thread.
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
// The total number of vectorized loads/stores per hidden vector.
enum { VEC_COLS = COLS / ELTS_PER_LDG };
// The number of loads per thread for the input.
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
//static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
This diff is collapsed.
#pragma once
#include "torch/extension.h"
#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
[&] { \
const auto &the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
template <int Bytes> struct Vec_type {};
template <> struct Vec_type<16> {
using Type = uint4;
static __device__ inline Type zero() { return make_uint4(0, 0, 0, 0); }
};
template <> struct Vec_type<8> {
using Type = uint2;
static __device__ inline Type zero() { return make_uint2(0, 0); }
};
template <> struct Vec_type<4> {
using Type = uint32_t;
static __device__ inline Type zero() { return 0; }
};
template <> struct Vec_type<2> {
using Type = uint16_t;
static __device__ inline Type zero() { return 0; }
};
template <typename T> struct TypeInfo {
using base_t = T;
using packed_t = T;
using compute_t = float;
using packed_compute_t = float;
};
template <> struct TypeInfo<half> {
using base_t = half;
using packed_t = half2;
using compute_t = float;
using packed_compute_t = float2;
};
template <typename dtype, int Bytes> struct Vec {
using base_t = typename TypeInfo<dtype>::base_t;
using packed_t = typename TypeInfo<dtype>::packed_t;
using compute_t = typename TypeInfo<dtype>::compute_t;
using packed_compute_t = typename TypeInfo<dtype>::packed_compute_t;
static_assert(Bytes % sizeof(base_t) == 0, "");
static_assert(Bytes % sizeof(packed_t) == 0, "");
enum { BYTES_PER_THREAD = Bytes };
enum { NUM_ELTS = Bytes / sizeof(base_t) };
enum { NUM_PACKED = Bytes / sizeof(packed_t) };
using vec_t = typename Vec_type<Bytes>::Type;
using store_t = union {
vec_t raw;
base_t elt[NUM_ELTS];
packed_t packed[NUM_PACKED];
};
store_t data;
__device__ Vec() { data.raw = Vec_type<Bytes>::zero(); }
__device__ inline void load_from(const char *ptr) {
data.raw = *reinterpret_cast<const vec_t *>(ptr);
}
__device__ inline void load_or_zero(const char *ptr, const bool is_valid) {
data.raw = is_valid ? *reinterpret_cast<const vec_t *>(ptr)
: Vec_type<Bytes>::zero();
}
__device__ inline void store_to(char *ptr) const {
*reinterpret_cast<vec_t *>(ptr) = data.raw;
}
__device__ inline void store_valid(char *ptr, const bool is_valid) const {
if (is_valid)
*reinterpret_cast<vec_t *>(ptr) = data.raw;
}
};
......@@ -10,6 +10,10 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
......
import torch
from torch.nn import init
from apex._autocast_utils import _cast_if_autocast_enabled
import fast_layer_norm
class FastLayerNormFN(torch.autograd.Function):
@staticmethod
def forward(ctx, x, gamma, beta, epsilon):
......@@ -14,23 +16,30 @@ class FastLayerNormFN(torch.autograd.Function):
ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon)
ctx.save_for_backward(x, gamma, mu, rsigma)
return ymat.view(x.shape)
@staticmethod
def backward(ctx, dy):
#assert dy.is_contiguous()
dy = dy.contiguous() # this happens!
# assert dy.is_contiguous()
dy = dy.contiguous() # this happens!
x, gamma, mu, rsigma = ctx.saved_tensors
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dymat = dy.view(xmat.shape)
dxmat, dgamma, dbeta = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma)
dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma)
dx = dxmat.view(x.shape)
return dx, dgamma, dbeta, None
def _fast_layer_norm(x, weight, bias, epsilon):
args = _cast_if_autocast_enabled(x, weight, bias, epsilon)
with torch.cuda.amp.autocast(enabled=False):
return FastLayerNormFN.apply(*args)
class FastLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5):
super(FastLayerNorm, self).__init__()
super().__init__()
self.epsilon = eps
self.weight = torch.nn.Parameter(torch.Tensor(hidden_size))
self.bias = torch.nn.Parameter(torch.Tensor(hidden_size))
......@@ -41,4 +50,4 @@ class FastLayerNorm(torch.nn.Module):
init.zeros_(self.bias)
def forward(self, x):
return FastLayerNormFN.apply(x, self.weight, self.bias, self.epsilon)
return _fast_layer_norm(x, self.weight, self.bias, self.epsilon)
import torch
import unittest
import numpy as np
import sys
import os
import torch.nn.functional as F
from apex.contrib.layer_norm import FastLayerNorm
import numpy as np
import torch
import fast_layer_norm as fln
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
class GPUTimer:
......@@ -14,146 +14,262 @@ class GPUTimer:
self.start_ = torch.cuda.Event(enable_timing=True)
self.stop_ = torch.cuda.Event(enable_timing=True)
self.stream_ = stream
def start(self):
self.stream_.record_event(self.start_)
def stop(self):
self.stream_.record_event(self.stop_)
def sync(self):
self.stream_.synchronize()
def millis(self):
return self.start_.elapsed_time(self.stop_)
def size_in_bytes(t):
return torch.numel(t) * t.element_size()
def abs_err(x, y):
xf = x.float()
yf = y.float()
return ((xf-yf).abs().sum() / yf.abs().sum()).item()
def metrics(y_ref, y, epsilon=1e-6):
y_ref = y_ref.float()
y = y.float()
relerr, mse = (
(y_ref - y).abs().sum() / (y_ref.abs().sum() + epsilon),
(y_ref - y).square().mean(),
)
return relerr.item(), mse.item()
device = torch.device("cuda")
fp32 = torch.float32
fp16 = torch.float16
bf16 = torch.bfloat16
def backward_(dz, x, mu, rs, gamma):
wtype = gamma.dtype
itype = x.dtype
otype = dz.dtype
ctype = mu.dtype
mu = mu.unsqueeze(1)
rs = rs.unsqueeze(1)
hidden_size = gamma.numel()
y = rs * (x.to(ctype) - mu)
dbeta = dz.view(-1, hidden_size).sum(0, dtype=ctype)
dgamma = (dz * y).view(-1, hidden_size).sum(0, dtype=ctype)
dy = dz.view(-1, hidden_size).to(ctype) * gamma.unsqueeze(0).to(ctype)
mdy = dy.mean(1, keepdim=True, dtype=ctype)
mdyy = (dy * y).mean(1, keepdim=True, dtype=ctype)
dx = rs * (dy - mdyy * y - mdy)
return dx.to(itype), dgamma.to(wtype), dbeta.to(wtype)
def benchmark_(S, B, hidden_size, itype, wtype, runs=100):
epsilon = 1e-5
x = torch.randn((S * B, hidden_size), dtype=itype, device=device)
beta = torch.randn(hidden_size, dtype=wtype, device=device)
gamma = torch.randn(hidden_size, dtype=wtype, device=device)
dz = torch.randn(x.shape, dtype=wtype, device=device)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
timer = GPUTimer(stream)
# warmup
for r in range(runs):
z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon)
timer.start()
for r in range(runs):
z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon)
timer.stop()
timer.sync()
total_bytes_fwd = sum([size_in_bytes(t) for t in [x, z, gamma, beta, mu, rsigma]])
ms_fwd = timer.millis() / runs
print(
"[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format(
ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd
)
)
timer.start()
for r in range(runs):
dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, x, mu, rsigma, gamma)
timer.stop()
timer.sync()
total_bytes_bwd = sum(
[
size_in_bytes(t)
for t in [dz, x, mu, rsigma, gamma, dx, dgamma, dbeta, dbp, dbp, dgp, dgp]
]
)
ms_bwd = timer.millis() / runs
print(
"[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format(
ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd
)
)
def test_(S, B, hidden_size, itype, wtype, ctype=fp32):
seed = 1243
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
otype = wtype
print("========================================================")
print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype}")
print("--------------------------------------------------------")
x = torch.randn(S * B, hidden_size, dtype=itype, device=device)
gamma = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2
beta = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2
epsilon = 1e-5
x.requires_grad = True
gamma.requires_grad = True
beta.requires_grad = True
mu_ref = x.mean(1, dtype=ctype, keepdim=True)
v = torch.square(x - mu_ref).mean(1, dtype=ctype, keepdim=True)
rs_ref = torch.rsqrt(v + epsilon)
y_ref = rs_ref * (x.to(ctype) - mu_ref)
z_ref = (gamma.unsqueeze(0) * (y_ref).to(otype) + beta.unsqueeze(0)).to(otype)
mu_ref = mu_ref.flatten()
rs_ref = rs_ref.flatten()
dz = torch.randn_like(z_ref)
# z_ref.backward(dz)
# dx_ref = x.grad
# dgamma_ref = gamma.grad
# dbeta_ref = beta.grad
dx_ref, dg_ref, db_ref = backward_(dz, x, mu_ref, rs_ref, gamma)
z, mu, rs = fln.ln_fwd(x, gamma, beta, epsilon)
dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma)
re_z, mse_z = metrics(z_ref, z)
re_mu, mse_mu = metrics(mu_ref, mu)
re_rs, mse_rs = metrics(rs_ref, rs)
re_dx, mse_dx = metrics(dx_ref, dx)
re_dg, mse_dg = metrics(dg_ref, dg)
re_db, mse_db = metrics(db_ref, db)
print(f" z: relerr={re_z :.4e} mse={mse_z :.4e}")
print(f"mu: relerr={re_mu:.4e} mse={mse_mu:.4e}")
print(f"rs: relerr={re_mu:.4e} mse={mse_mu:.4e}")
print(f"dx: relerr={re_dx:.4e} mse={mse_dx:.4e}")
print(f"dg: relerr={re_dg:.4e} mse={mse_dg:.4e}")
print(f"db: relerr={re_db:.4e} mse={mse_db:.4e}")
def check_err(x, relerr):
tol = 1e-3 if x.dtype == torch.float16 else 5e-6
return relerr < tol
return [
check_err(x, re)
for x, re in zip([z, mu, rs, dx, dg, db], [re_z, re_mu, re_rs, re_dx, re_dg, re_db])
]
class TestFastLayerNorm(unittest.TestCase):
def setUp(self, seed=1234):
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def test_ln_fp32(self):
self.run_test_layer_norm(torch.float32, atol=1e-5)
def test_ln_fp16(self):
self.run_test_layer_norm(torch.float16, atol=1e-2, rtol=1e-3)
def run_test_layer_norm(self, dtype, atol, rtol=1e-5):
device = torch.device('cuda')
s = 512
b = 32
hidden_size = 1024
epsilon = 1e-5
x = torch.randn((s,b,hidden_size), dtype=dtype, device=device)
beta = torch.randn(hidden_size, dtype=dtype, device=device)
gamma = torch.randn(hidden_size, dtype=dtype, device=device)
x.requires_grad = True
beta.requires_grad = True
gamma.requires_grad = True
x2 = x.clone().detach()
beta2 = beta.clone().detach()
gamma2 = gamma.clone().detach()
x2.requires_grad = True
beta2.requires_grad = True
gamma2.requires_grad = True
dummy_label = torch.randn_like(x)
y = F.layer_norm(x, [hidden_size], gamma, beta, epsilon)
diff = y-dummy_label
l = (diff * diff).sum() / b
l.backward()
fln = FastLayerNorm(hidden_size).cuda()
fln.load_state_dict({'bias': beta2, 'weight':gamma2})
if dtype == torch.float16:
fln = fln.half()
y2 = fln(x2)
diff2 = (y2 - dummy_label)
l2 = (diff2 * diff2).sum() / b
l2.backward()
self.assertTrue(torch.allclose(y2, y, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(x2.grad, x.grad, atol=atol,rtol=rtol))
self.assertTrue(torch.allclose(fln.bias.grad, beta.grad, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(fln.weight.grad, gamma.grad, atol=atol, rtol=rtol))
def test_performance(self):
print()
runs = 1000
device = torch.device('cuda')
dtype =torch.float16
s = 512
b = 32
hidden_size = 1024
epsilon = 1e-5
x = torch.randn((s*b,hidden_size), dtype=dtype, device=device)
beta = torch.randn(hidden_size, dtype=dtype, device=device)
gamma = torch.randn(hidden_size, dtype=dtype, device=device)
dy = torch.randn_like(x)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
timer = GPUTimer(stream)
#warmup
for r in range(runs):
y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)
timer.start()
for r in range(runs):
y, mu, rsigma = fln.ln_fwd(x, gamma, beta, 1e-5)
timer.stop()
timer.sync()
total_bytes_fwd = (size_in_bytes(x)
+ size_in_bytes(y)
+ size_in_bytes(gamma)
+ size_in_bytes(beta)
+ size_in_bytes(mu)
+ size_in_bytes(rsigma)
)
ms_fwd = timer.millis() / runs
print('[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd ))
timer.start()
for r in range(runs):
dx, dgamma, dbeta = fln.ln_bwd(dy, x, mu, rsigma, gamma)
timer.stop()
timer.sync()
total_bytes_bwd = (size_in_bytes(x)
+ size_in_bytes(dx)
+ size_in_bytes(dy)
+ size_in_bytes(gamma)
+ size_in_bytes(dgamma)
+ size_in_bytes(dbeta)
+ size_in_bytes(mu)
+ size_in_bytes(rsigma)
)
ms_bwd = timer.millis() / runs
print('[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec'.format(ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd ))
if __name__ == '__main__':
def assertAll(self, l):
if not all(l):
print(l)
for x in l:
self.assertTrue(x)
def test_all_configs(self):
hidden_sizes = [
768,
1024,
1536,
2048,
2304,
3072,
3840,
4096,
5120,
6144,
8192,
10240,
12288,
12800,
15360,
16384,
18432,
20480,
24576,
25600,
30720,
32768,
40960,
49152,
65536,
]
for h in hidden_sizes:
with self.subTest(f"hidden_size={h}"):
self.assertAll(test_(256, 2, h, fp32, fp32))
self.assertAll(test_(256, 2, h, fp16, fp16))
self.assertAll(test_(256, 2, h, fp32, fp16))
self.assertAll(test_(256, 2, h, bf16, bf16))
self.assertAll(test_(256, 2, h, fp32, bf16))
def test_run_benchmark(self):
for (S, B, hidden_size, runs) in (
(512, 32, 768, 1000),
(512, 32, 1024, 1000),
(512, 8, 4096, 1000),
(512, 8, 5120, 1000),
(512, 8, 6144, 1000),
(256, 2, 20480, 500),
(256, 2, 25600, 500),
(256, 2, 40960, 250),
(256, 2, 65536, 250),
):
with self.subTest(f"(S, B, hidden_size)=({S}, {B}, {hidden_size})"):
benchmark_(S, B, hidden_size, fp16, fp16, runs)
def test_compat_with_autocast(self):
autocast_dtypes = (
(torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
)
input_shape = (512, 32, 768)
layer_norm = FastLayerNorm(input_shape[-1]).cuda()
input = torch.randn(input_shape).cuda()
for dtype in autocast_dtypes:
layer_norm.zero_grad(set_to_none=True)
with self.subTest(f"autocast_dtype={dtype}"):
with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
out = layer_norm(input)
self.assertEqual(dtype, out.dtype)
grad = torch.randn_like(out)
out.backward(grad)
self.assertEqual(torch.float32, layer_norm.weight.grad.dtype)
if __name__ == "__main__":
unittest.main()
......@@ -2,4 +2,80 @@
`apex.transformer` is a module which enables efficient large Transformer models at scale.
`apex.transformer.tensor_parallel` is based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s `megatron.mpu` module.
`apex.transformer.tensor_parallel` and `apex.transformer.pipeline_parallel` are both based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s module.
The former is based on `megatron.mpu` and the latter is on `megatron.schedules` and `megatron.p2p_communication`.
## Tensor Model Parallel (TP)
APEX's tensor model parallel utilities provides some `torch.nn.Module`'s, custom fused kernels, and PRNG state handling.
See Appendix B.2 of [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) for the details of
PRNG state handling.
## Pipeline Model Parallel (PP)
APEX's pipeline model parallel functions require models to have `.set_input_tensor` because
the input tensor for `.forward` method can be `None`.
The following is a really casual sketch of training script with apex pp.
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func
class Model(nn.Module):
...
def __init__(self, *args, **kwargs):
super().__init__()
pre_process = kwargs.pop("pre_process")
post_process = kwargs.pop("post_process")
def set_input_tensor(self, tensor):
self.input_tensor = tensor
def forward(self, x, ...):
if parallel_state.is_pipeline_first_stage():
input = x
else:
input = self.input_tensor
...
def model_provider_func(*args, **kwargs):
return Model(*args, **kwargs)
def loss_func(pred, label):
loss = ...
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'nice_loss': averaged_loss}
def forward_step_func(batch, model):
input, label = process_batch(batch)
out = model(input)
return out, partial(loss_func, label)
forward_backward_func = get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
)
# The following line basically is equivalent to `build_model(Model, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)`
model = build_model(model_provider_func, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)
optimizer = ...
data_loader = ...
for epoch in range(num_epochs):
for batch in data_loader:
forward_backward_func(forward_step_func, batch, model, forward_only=False, tensor_shape)
optimizer.step()
```
from . import tensor_parallel
from . import functional
from .enums import LayerType
from .enums import AttnType
from .enums import AttnMaskType
from .parallel_state import (
is_unitialized,
destroy_model_parallel,
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_embedding_group,
get_model_parallel_group,
get_tensor_model_parallel_group,
get_pipeline_model_parallel_group,
get_tensor_model_parallel_rank,
set_tensor_model_parallel_rank,
get_pipeline_model_parallel_rank,
set_pipeline_model_parallel_rank,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_tensor_model_parallel_src_rank,
get_pipeline_model_parallel_first_rank,
get_pipeline_model_parallel_last_rank,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_tensor_model_parallel_world_size,
set_tensor_model_parallel_world_size,
get_pipeline_model_parallel_world_size,
set_pipeline_model_parallel_world_size,
get_virtual_pipeline_model_parallel_rank,
set_virtual_pipeline_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from apex.transformer import amp
from apex.transformer import functional
from apex.transformer import parallel_state
from apex.transformer import pipeline_parallel
from apex.transformer import tensor_parallel
from apex.transformer import utils
from apex.transformer.enums import LayerType
from apex.transformer.enums import AttnType
from apex.transformer.enums import AttnMaskType
__all__ = [
"amp",
"functional",
"parallel_state",
"pipeline_parallel",
"tensor_parallel",
"utils",
# enums.py
"LayerType",
"AttnType",
"AttnMaskType",
]
from apex.transformer._data._batchsampler import MegatronPretrainingRandomSampler
from apex.transformer._data._batchsampler import MegatronPretrainingSampler
__all__ = [
"MegatronPretrainingRandomSampler",
"MegatronPretrainingSampler",
]
"""BatchSampler implementations for POC of dynamic batch size or rampup_batch_size support.
Implementations are based on https://github.com/NVIDIA/Megatron-LM/blob/bcd605f8570ebeeb0436c115ebbfafc3c5a40ae5/megatron/data/data_samplers.py.
""" # NOQA
import abc
import torch
__all__ = [
"MegatronPretrainingSampler",
"MegatronPretrainingRandomSampler",
]
class _Base:
"""Base class for Megatron style BatchSampler."""
@abc.abstractmethod
def __len__(self) -> int:
...
@abc.abstractmethod
def __iter__(self):
...
@property
@abc.abstractmethod
def local_minibatch_size(self) -> int:
...
@local_minibatch_size.setter
@abc.abstractclassmethod
def local_minibatch_size(self) -> None:
...
class MegatronPretrainingSampler(_Base):
def __init__(
self,
total_samples: int,
consumed_samples: int,
local_minibatch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
):
# Sanity checks.
if total_samples <= 0:
raise RuntimeError('no sample to consume: {}'.format(self.total_samples))
if consumed_samples >= total_samples:
raise RuntimeError('no samples left to consume: {}, {}'.format(self.consumed_samples, self.total_samples))
if local_minibatch_size <= 0:
raise RuntimeError(f"local minibatch size must be greater than 0: {local_minibatch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0: {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise RuntimeError('data_parallel_rank should be smaller than data size: {}, {}'.format(self.data_parallel_rank, data_parallel_size))
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self._local_minibatch_size = local_minibatch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * data_parallel_size
self.drop_last = drop_last
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.local_minibatch_size
end_idx = start_idx + self.local_minibatch_size
return start_idx, end_idx
@property
def local_minibatch_size(self) -> int:
return self._local_minibatch_size
@local_minibatch_size.setter
def local_minibatch_size(self, new_local_minibatch_size) -> None:
self._local_minibatch_size = new_local_minibatch_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.local_minibatch_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler(_Base):
"""Megatron style Random Batch Sampler.
Major difference is that `__iter__` yields a local minibatch, not a microbatch.
A local minibatch consists of `global_batch_size / data_parallel_size`
Args:
total_samples: The number of data samples, i.e. ``len(dataset)``.
consumed_samples: The number of samples already consumed in pretraining.
local_minibatch_size: The number of data in each batch returned from `__iter__`. Basically
`local_minibatch_size = global_batch_size / data_parallel_size`.
data_parallel_rank:
data_parallel_size:
"""
def __init__(
self,
total_samples: int,
consumed_samples: int,
local_minibatch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
) -> None:
if total_samples <= 0:
raise ValueError(f"no sample to consume: total_samples of {total_samples}")
if local_minibatch_size <= 0:
raise ValueError(f"Invalid local_minibatch_size: {local_minibatch_size}")
if data_parallel_size <= 0:
raise ValueError(f"Invalid data_parallel_size: {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise ValueError(
f"data_parallel_rank should be smaller than data parallel size: {data_parallel_rank} < {data_parallel_size}"
)
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self._local_minibatch_size = local_minibatch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
self.last_batch_size = self.total_samples % self.local_minibatch_times_data_parallel_size
def __len__(self) -> int:
return self.total_samples
@property
def local_minibatch_size(self) -> int:
return self._local_minibatch_size
@local_minibatch_size.setter
def local_minibatch_size(self, new_local_minibatch_size) -> None:
self._local_minibatch_size = new_local_minibatch_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
# note(mkozuki): might be better to uncomment
# assert current_epoch_samples % (self.data_parallel_size * apex.transformer.pipeline_parallel.utils.get_micro_batch_size()) == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.local_minibatch_times_data_parallel_size) * self.local_minibatch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.local_minibatch_size:
self.consumed_samples += self.local_minibatch_times_data_parallel_size
yield batch
batch = []
from apex.transformer.amp.grad_scaler import GradScaler
__all__ = [
"GradScaler",
]
from collections import defaultdict
import torch
from apex.transformer import parallel_state
class GradScaler(torch.cuda.amp.GradScaler):
"""
Gradient scaler for model-parallel inf check. The inf in gradients are checked across tensor-parallel
ranks in (1) executing optimizer step and (2) gradient scaler update.
"""
def __init__(
self, init_scale=2.0 ** 16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True
):
super().__init__(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
enabled=enabled,
)
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
retval = None
found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())])
# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()
)
if found_inf.item() == 0:
retval = optimizer.step(*args, **kwargs)
return retval
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device=_scale.device, non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf_combined, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()
)
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf = found_infs[i]
# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group()
)
found_inf_combined += found_inf
torch._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(torch.cuda.amp.grad_scaler._refresh_per_optimizer_state)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment