Commit ae137ed1 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Fuse LayerScale

parent 8c6609ae
...@@ -40,6 +40,8 @@ struct ParamsBase { ...@@ -40,6 +40,8 @@ struct ParamsBase {
, mu(nullptr) , mu(nullptr)
, rs(nullptr) , rs(nullptr)
, gamma(nullptr) , gamma(nullptr)
, rowscale(nullptr)
, colscale(nullptr)
, dropout_keep_p(1.f) , dropout_keep_p(1.f)
, dropout_scale(1.f) , dropout_scale(1.f)
, workspace(nullptr) , workspace(nullptr)
...@@ -63,6 +65,7 @@ struct ParamsBase { ...@@ -63,6 +65,7 @@ struct ParamsBase {
void *rs; void *rs;
void *gamma; void *gamma;
void *rowscale; void *rowscale;
void *colscale;
float inverse_cols; float inverse_cols;
...@@ -106,10 +109,12 @@ struct BwdParams : public ParamsBase { ...@@ -106,10 +109,12 @@ struct BwdParams : public ParamsBase {
, dx(nullptr) , dx(nullptr)
, dbeta_part(nullptr) , dbeta_part(nullptr)
, dgamma_part(nullptr) , dgamma_part(nullptr)
, dcolscale_part(nullptr)
, dx0(nullptr) , dx0(nullptr)
, dx1(nullptr) , dx1(nullptr)
, dbeta(nullptr) , dbeta(nullptr)
, dgamma(nullptr) , dgamma(nullptr)
, dcolscale(nullptr)
{ {
} }
...@@ -121,6 +126,7 @@ struct BwdParams : public ParamsBase { ...@@ -121,6 +126,7 @@ struct BwdParams : public ParamsBase {
// Workspace for Wgrad pre-reduction. // Workspace for Wgrad pre-reduction.
void *dbeta_part; void *dbeta_part;
void *dgamma_part; void *dgamma_part;
void *dcolscale_part;
// Output: Dgrad. // Output: Dgrad.
void *dx0; void *dx0;
...@@ -128,13 +134,14 @@ struct BwdParams : public ParamsBase { ...@@ -128,13 +134,14 @@ struct BwdParams : public ParamsBase {
// Output: Wgrad. // Output: Wgrad.
void *dbeta; void *dbeta;
void *dgamma; void *dgamma;
void *dcolscale;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>; using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool, const bool)>; using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
using FunctionKey = uint64_t; using FunctionKey = uint64_t;
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>; using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>; using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
......
...@@ -84,6 +84,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -84,6 +84,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const at::Tensor &gamma, // hidden_size const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size const at::Tensor &beta, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS c10::optional<const at::Tensor> &rowscale_, // BxS
c10::optional<const at::Tensor> &colscale_, // BxS
const float dropout_p, const float dropout_p,
const float epsilon, const float epsilon,
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
...@@ -124,7 +125,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -124,7 +125,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK(rowscale.is_cuda()) TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous()); TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows}); TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.scalar_type() == itype); TORCH_CHECK(rowscale.dtype() == itype);
}
if (colscale_.has_value()) {
auto colscale = colscale_.value();
TORCH_CHECK(colscale.is_cuda())
TORCH_CHECK(colscale.is_contiguous());
TORCH_CHECK(colscale.sizes() == std::vector<int64_t>{cols});
TORCH_CHECK(colscale.dtype() == wtype);
} }
TORCH_CHECK(gamma.sizes() == beta.sizes()); TORCH_CHECK(gamma.sizes() == beta.sizes());
...@@ -135,7 +144,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -135,7 +144,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto opts = x0.options(); auto opts = x0.options();
bool save_x = x1_.has_value() || (dropout_p > 0.f) || (itype != rtype); bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || (itype != rtype);
at::Tensor x; at::Tensor x;
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
at::Tensor dmask; at::Tensor dmask;
...@@ -153,6 +162,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -153,6 +162,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr; launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -212,12 +222,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -212,12 +222,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size
c10::optional<const at::Tensor> &dx_, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size const at::Tensor &x, // BxSxhidden_size
c10::optional<const at::Tensor> &x0_, // BxSxhidden_size
c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
const at::Tensor &mu, // BxS, FP32! const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32! const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma, // hidden_size const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS c10::optional<const at::Tensor> &rowscale_, // BxS
c10::optional<const at::Tensor> &colscale_, // BxS
const float dropout_p, const float dropout_p,
const bool has_residual const bool has_residual
) { ) {
...@@ -250,132 +263,13 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -250,132 +263,13 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto rows = sizes[0]; auto rows = sizes[0];
auto cols = sizes[1]; auto cols = sizes[1];
if (dmask_.has_value()) { if (dx_.has_value()) {
auto dmask = dmask_.value(); auto dx = dx_.value();
TORCH_CHECK(dmask.dtype() == mtype);
TORCH_CHECK(dmask.is_cuda());
TORCH_CHECK(dmask.is_contiguous());
TORCH_CHECK(dmask.sizes() == sizes);
}
if (rowscale_.has_value()) {
auto rowscale = rowscale_.value();
TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.scalar_type() == itype);
}
auto hidden_size = gamma.numel();
TORCH_CHECK(hidden_size == cols);
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(mu.sizes() == rsigma.sizes());
TORCH_CHECK(gamma.numel() == cols);
auto opts = x.options();
auto dx0 = torch::empty_like(x, opts.dtype(itype));
at::Tensor dx1;
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
launch_params.props = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
launcher(launch_params, true, /*prenorm=*/false);
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
at::Tensor workspace, barrier;
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.dz = dz.data_ptr();
params.dx0 = dx0.data_ptr();
params.dbeta = dbeta.data_ptr();
params.dgamma = dgamma.data_ptr();
params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr();
params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols);
if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this?
barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
launcher(launch_params, false, /*prenorm=*/false);
return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // BxSxhidden_size
const at::Tensor &dx, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size
c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
const float dropout_p,
const bool has_residual
) {
auto itype = dz.scalar_type();
auto rtype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = itype;
auto ctype = torch::kFloat32;
auto mtype = torch::kUInt8;
if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
TORCH_CHECK(dz.dtype() == otype);
TORCH_CHECK(dx.dtype() == rtype); TORCH_CHECK(dx.dtype() == rtype);
TORCH_CHECK(mu.dtype() == ctype); TORCH_CHECK(dx.is_cuda())
TORCH_CHECK(rsigma.dtype() == ctype);
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(dz.is_cuda());
TORCH_CHECK(dx.is_cuda());
TORCH_CHECK(mu.is_cuda());
TORCH_CHECK(rsigma.is_cuda());
TORCH_CHECK(gamma.is_cuda());
TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(dz.is_contiguous());
TORCH_CHECK(dx.is_contiguous()); TORCH_CHECK(dx.is_contiguous());
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dz.sizes() == sizes);
TORCH_CHECK(dx.sizes() == sizes); TORCH_CHECK(dx.sizes() == sizes);
auto rows = sizes[0]; }
auto cols = sizes[1];
if (dmask_.has_value()) { if (dmask_.has_value()) {
auto dmask = dmask_.value(); auto dmask = dmask_.value();
...@@ -390,7 +284,22 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // ...@@ -390,7 +284,22 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
TORCH_CHECK(rowscale.is_cuda()) TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous()); TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows}); TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.scalar_type() == itype); TORCH_CHECK(rowscale.dtype() == itype);
}
if (colscale_.has_value()) {
auto colscale = colscale_.value();
TORCH_CHECK(colscale.is_cuda())
TORCH_CHECK(colscale.is_contiguous());
TORCH_CHECK(colscale.sizes() == std::vector<int64_t>{cols});
TORCH_CHECK(colscale.dtype() == wtype);
TORCH_CHECK(x0_.has_value());
auto x0 = x0_.value();
TORCH_CHECK(x0.is_cuda())
TORCH_CHECK(x0.is_contiguous());
TORCH_CHECK(x0.sizes() == sizes);
TORCH_CHECK(x0.dtype() == itype);
} }
auto hidden_size = gamma.numel(); auto hidden_size = gamma.numel();
...@@ -409,6 +318,10 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // ...@@ -409,6 +318,10 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); } if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
auto dgamma = torch::empty_like(gamma); auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma); auto dbeta = torch::empty_like(gamma);
at::Tensor dcolscale;
if (colscale_.has_value()) {
dcolscale = torch::empty_like(colscale_.value());
}
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params; layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
...@@ -417,32 +330,40 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // ...@@ -417,32 +330,40 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
launcher(launch_params, true, /*prenorm=*/true); launcher(launch_params, true);
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
at::Tensor dcolscale_part;
if (colscale_.has_value()) {
dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
}
at::Tensor workspace, barrier; at::Tensor workspace, barrier;
layer_norm::BwdParams &params = launch_params.params; layer_norm::BwdParams &params = launch_params.params;
params.rows = rows; params.rows = rows;
params.cols = cols; params.cols = cols;
params.x = x.data_ptr(); params.x = x.data_ptr();
params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr;
params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr; params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
params.mu = mu.data_ptr(); params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr(); params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr(); params.gamma = gamma.data_ptr();
params.dz = dz.data_ptr(); params.dz = dz.data_ptr();
params.dx = dx.data_ptr(); params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
params.dx0 = dx0.data_ptr(); params.dx0 = dx0.data_ptr();
params.dbeta = dbeta.data_ptr(); params.dbeta = dbeta.data_ptr();
params.dgamma = dgamma.data_ptr(); params.dgamma = dgamma.data_ptr();
params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr;
params.dbeta_part = dbeta_part.data_ptr(); params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr(); params.dgamma_part = dgamma_part.data_ptr();
params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;
params.dropout_scale = 1.f / (1.f - dropout_p); params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols); params.inverse_cols = 1.f / float(params.cols);
...@@ -454,9 +375,14 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // ...@@ -454,9 +375,14 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
params.barrier = barrier.data_ptr<int>(); params.barrier = barrier.data_ptr<int>();
} }
launcher(launch_params, false, /*prenorm=*/true); launcher(launch_params, false);
return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part }; std::vector<at::Tensor> result = { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
if (colscale_.has_value()) {
result.push_back(dcolscale);
result.push_back(dcolscale_part);
}
return result;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -464,5 +390,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -464,5 +390,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA DropoutAddLayerNorm"; m.doc() = "CUDA DropoutAddLayerNorm";
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel"); m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel");
m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel"); m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel");
m.def("dropout_add_ln_prenorm_bwd", &dropout_add_ln_prenorm_bwd, "Run Dropout + Add + LayerNorm (PreNorm version) backward kernel");
} }
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace layer_norm { namespace layer_norm {
template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Is_even_cols> template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_kernel(layer_norm::BwdParams params) { void ln_bwd_kernel(layer_norm::BwdParams params) {
...@@ -53,9 +53,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -53,9 +53,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
Cvec dzy_sum[LDGS]; Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS]; Cvec dz_sum[LDGS];
Cvec dcolscale_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum)); memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum)); memset(dz_sum, 0, sizeof(dz_sum));
if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); }
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_); compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
...@@ -68,11 +70,13 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -68,11 +70,13 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG; ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
Wvec gamma[LDGS]; Wvec gamma[LDGS];
Wvec colscale[LDGS];
index_t idx = c; index_t idx = c;
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) { if (Is_even_cols || (it < num_valid_ldgs)) {
gamma[it].load_from(params.gamma, idx); gamma[it].load_from(params.gamma, idx);
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
idx += Ktraits::VEC_COLS_PER_LDG; idx += Ktraits::VEC_COLS_PER_LDG;
} }
} }
...@@ -131,6 +135,8 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -131,6 +135,8 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if (Is_even_cols || (it < num_valid_ldgs)) { if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec dx0; Ivec dx0;
Rvec dx1; Rvec dx1;
Ivec x0;
if (Has_colscale) { x0.load_from(params.x0, idx); }
#pragma unroll #pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) { for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt]; compute_t dy_tmp = dy[it * NUM_ELTS + jt];
...@@ -140,11 +146,22 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -140,11 +146,22 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; } if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val; compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
if (Is_dropout) { if (Is_dropout) {
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f; dx0_tmp_res *= params.dropout_scale;
if (Has_colscale) {
dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;
} else {
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;
}
} else {
if (Has_colscale) {
dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);
dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);
} else { } else {
dx0.data.elt[jt] = dx0_tmp_res; dx0.data.elt[jt] = dx0_tmp_res;
} }
} }
}
if (Has_residual) { dx1.store_to(params.dx1, idx); } if (Has_residual) { dx1.store_to(params.dx1, idx); }
dx0.store_to(params.dx0, idx); dx0.store_to(params.dx0, idx);
idx += Ktraits::VEC_COLS_PER_LDG; idx += Ktraits::VEC_COLS_PER_LDG;
...@@ -160,6 +177,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -160,6 +177,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if (Is_even_cols || (it < num_valid_ldgs)) { if (Is_even_cols || (it < num_valid_ldgs)) {
dz_sum[it].store_to(params.dbeta_part, idx); dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx); dzy_sum[it].store_to(params.dgamma_part, idx);
if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); }
idx += Ktraits::VEC_COLS_PER_LDG; idx += Ktraits::VEC_COLS_PER_LDG;
} }
} }
...@@ -203,23 +221,46 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -203,23 +221,46 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
} }
} }
compute_t cta_dcolscale_sum[NUM_RES];
if (Has_colscale) {
__syncthreads();
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dcolscale_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES);
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
for( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
}
const index_t num_valid_writes const index_t num_valid_writes
= (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA; = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx; compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx; compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
compute_t *dcolscale_part = Has_colscale ? static_cast<compute_t *>(params.dcolscale_part) + bidm * params.cols + tidx : nullptr;
for( int jt = 0; jt < NUM_RES; jt++ ) { for( int jt = 0; jt < NUM_RES; jt++ ) {
if (Is_even_cols || (jt < num_valid_writes)) { if (Is_even_cols || (jt < num_valid_writes)) {
*dgamma_part = cta_dzy_sum[jt]; *dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA; dgamma_part += Ktraits::THREADS_PER_CTA;
*dbeta_part = cta_dz_sum[jt]; *dbeta_part = cta_dz_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA; dbeta_part += Ktraits::THREADS_PER_CTA;
if (Has_colscale) {
*dcolscale_part = cta_dcolscale_sum[jt];
dcolscale_part += Ktraits::THREADS_PER_CTA;
}
} }
} }
} }
} }
template<typename Kernel_traits, bool Is_even_cols> template<typename Kernel_traits, bool Has_colscale, bool Is_even_cols>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
void ln_bwd_finalize_kernel(BwdParams params) void ln_bwd_finalize_kernel(BwdParams params)
{ {
...@@ -250,26 +291,29 @@ void ln_bwd_finalize_kernel(BwdParams params) ...@@ -250,26 +291,29 @@ void ln_bwd_finalize_kernel(BwdParams params)
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
// Each thread sums over NUM_ELT columns. // Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local; Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local, dcolscale_local;
memset(&dgamma_local, 0, sizeof(dgamma_local)); memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local)); memset(&dbeta_local, 0, sizeof(dbeta_local));
if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
if (Is_even_cols || col < params.cols) { if (Is_even_cols || col < params.cols) {
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
// index_t idx = row * Kernel_traits::COLS + col;
index_t idx = row * params.cols + col; index_t idx = row * params.cols + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part; Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part, dcolscale_part;
dbeta_part.load_from(params.dbeta_part, idx); dbeta_part.load_from(params.dbeta_part, idx);
dgamma_part.load_from(params.dgamma_part, idx); dgamma_part.load_from(params.dgamma_part, idx);
if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); }
#pragma unroll #pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) { for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; }
} }
} }
} }
void * smem_gamma = smem_; void * smem_gamma = smem_;
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
const int write_row = warp; const int write_row = warp;
const int write_col = lane ^ write_row; const int write_col = lane ^ write_row;
...@@ -277,12 +321,14 @@ void ln_bwd_finalize_kernel(BwdParams params) ...@@ -277,12 +321,14 @@ void ln_bwd_finalize_kernel(BwdParams params)
dgamma_local.store_to(smem_gamma, write_idx); dgamma_local.store_to(smem_gamma, write_idx);
dbeta_local.store_to(smem_beta, write_idx); dbeta_local.store_to(smem_beta, write_idx);
if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); }
__syncthreads(); __syncthreads();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma // It would be probably safe to reuse the first row of smem_beta and smem_gamma
void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];
// More than one iter iff ROWS_PER_CTA < 32. // More than one iter iff ROWS_PER_CTA < 32.
...@@ -293,11 +339,13 @@ void ln_bwd_finalize_kernel(BwdParams params) ...@@ -293,11 +339,13 @@ void ln_bwd_finalize_kernel(BwdParams params)
memset(&dbeta_local, 0, sizeof(dbeta_local)); memset(&dbeta_local, 0, sizeof(dbeta_local));
memset(&dgamma_local, 0, sizeof(dgamma_local)); memset(&dgamma_local, 0, sizeof(dgamma_local));
if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
// Load beta and gamma transposed // Load beta and gamma transposed
if(read_row < Kernel_traits::ROWS_PER_CTA){ if(read_row < Kernel_traits::ROWS_PER_CTA){
dbeta_local.load_from(smem_beta, read_idx); dbeta_local.load_from(smem_beta, read_idx);
dgamma_local.load_from(smem_gamma, read_idx); dgamma_local.load_from(smem_gamma, read_idx);
if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); }
} }
// Call reducer on the loaded value(s) and convert. // Call reducer on the loaded value(s) and convert.
...@@ -310,12 +358,18 @@ void ln_bwd_finalize_kernel(BwdParams params) ...@@ -310,12 +358,18 @@ void ln_bwd_finalize_kernel(BwdParams params)
dgamma_local.data.elt[it] = g_i; dgamma_local.data.elt[it] = g_i;
dbeta_local.data.elt[it] = b_i; dbeta_local.data.elt[it] = b_i;
if (Has_colscale) {
compute_t cs_i = dcolscale_local.data.elt[it];
cs_i = reducer.allreduce(cs_i, sum);
dcolscale_local.data.elt[it] = cs_i;
}
} }
// Leader stores the result at the current column. // Leader stores the result at the current column.
if(lane == 0){ if(lane == 0){
dgamma_local.store_to(smem_gamma_out, w); dgamma_local.store_to(smem_gamma_out, w);
dbeta_local.store_to(smem_beta_out, w); dbeta_local.store_to(smem_beta_out, w);
if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); }
} }
} }
...@@ -329,19 +383,21 @@ void ln_bwd_finalize_kernel(BwdParams params) ...@@ -329,19 +383,21 @@ void ln_bwd_finalize_kernel(BwdParams params)
using src_t = typename TypeToVec2<compute_t>::Type; using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type; using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2; Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2, dcolscale_vec2;
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2; Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2, dcolscale_out2;
dgamma_vec2.load_from(smem_gamma_out, lane); dgamma_vec2.load_from(smem_gamma_out, lane);
dbeta_vec2.load_from(smem_beta_out, lane); dbeta_vec2.load_from(smem_beta_out, lane);
if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); }
#pragma unroll #pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) { for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]); dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]); dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dcolscale_vec2.data.elt[it]); }
} }
dgamma_out2.store_to(params.dgamma, col_out); dgamma_out2.store_to(params.dgamma, col_out);
dbeta_out2.store_to(params.dbeta, col_out); dbeta_out2.store_to(params.dbeta, col_out);
if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); }
} }
} }
} }
...@@ -364,7 +420,7 @@ template< ...@@ -364,7 +420,7 @@ template<
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL int BYTES_PER_LDG_FINAL
> >
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
using Kernel_traits = Kernel_traits<weight_t, using Kernel_traits = Kernel_traits<weight_t,
input_t, input_t,
...@@ -378,14 +434,17 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params ...@@ -378,14 +434,17 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
WARPS_N, WARPS_N,
BYTES_PER_LDG_MAIN BYTES_PER_LDG_MAIN
>; >;
bool prenorm = launch_params.params.dx != nullptr;
bool is_dropout = launch_params.params.dropout_keep_p < 1.f; bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
bool has_residual = launch_params.params.dx1 != nullptr; bool has_residual = launch_params.params.dx1 != nullptr;
bool has_colscale = launch_params.params.colscale != nullptr;
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
BOOL_SWITCH(prenorm, PrenormConst, [&] { BOOL_SWITCH(prenorm, PrenormConst, [&] {
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] { BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, IsEvenColsConst>; auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>;
if( configure_params ) { if( configure_params ) {
int ctas_per_sm; int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
...@@ -426,13 +485,15 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params ...@@ -426,13 +485,15 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
output_t, output_t,
compute_t, compute_t,
index_t, index_t,
HasColscaleConst,
32 * 32, // THREADS_PER_CTA 32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>; BYTES_PER_LDG_FINAL>;
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, IsEvenColsConst>; auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params); kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
}); });
}); });
}); });
}); });
});
} }
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
namespace layer_norm { namespace layer_norm {
template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Is_even_cols> template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_kernel(FwdParams params) { void ln_fwd_kernel(FwdParams params) {
...@@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) {
using Stats = typename Ktraits::Stats; using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t; using stats_t = typename Stats::stats_t;
constexpr bool save_x = Has_residual || Is_dropout || !(std::is_same<input_t, residual_t>::value); const bool save_x = Has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || !(std::is_same<input_t, residual_t>::value);
extern __shared__ char smem_[]; extern __shared__ char smem_[];
...@@ -80,12 +80,14 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -80,12 +80,14 @@ void ln_fwd_kernel(FwdParams params) {
Wvec gamma[LDGS]; Wvec gamma[LDGS];
Wvec beta[LDGS]; Wvec beta[LDGS];
Wvec colscale[LDGS];
index_t idx = c; index_t idx = c;
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) { if (Is_even_cols || (it < num_valid_ldgs)) {
gamma[it].load_from(params.gamma, idx); gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx); beta[it].load_from(params.beta, idx);
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
idx += VEC_COLS_PER_LDG; idx += VEC_COLS_PER_LDG;
} }
} }
...@@ -109,13 +111,9 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -109,13 +111,9 @@ void ln_fwd_kernel(FwdParams params) {
// the more efficient curand_uniform4. // the more efficient curand_uniform4.
mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
compute_t x_ij; x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
if (Has_residual) { if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
compute_t x1_ij = compute_t(x1.data.elt[jt]); compute_t x_ij = Has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij;
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij;
} else {
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f;
}
if (save_x) { x.data.elt[jt] = x_ij; } if (save_x) { x.data.elt[jt] = x_ij; }
xf[it * NUM_ELTS + jt] = x_ij; xf[it * NUM_ELTS + jt] = x_ij;
if (Is_dropout) { dmask.data.elt[jt] = keep; } if (Is_dropout) { dmask.data.elt[jt] = keep; }
...@@ -130,8 +128,8 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -130,8 +128,8 @@ void ln_fwd_kernel(FwdParams params) {
const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG; const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG; const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG; const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
// Need to convert to int, otherwise the subtraction will wrap around.
auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int { auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
// Need to convert to int, otherwise the subtraction will wrap around.
const index_t valid_partial_vecs_in_warp = const index_t valid_partial_vecs_in_warp =
std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)), std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
int(THREADS_PER_WARP)); int(THREADS_PER_WARP));
...@@ -206,11 +204,13 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params ...@@ -206,11 +204,13 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
BYTES_PER_LDG BYTES_PER_LDG
>; >;
bool has_residual = launch_params.params.x1 != nullptr; bool has_residual = launch_params.params.x1 != nullptr;
bool has_colscale = launch_params.params.colscale != nullptr;
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] { BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, IsEvenColsConst>; auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>;
if( configure_params ) { if( configure_params ) {
int ctas_per_sm; int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
...@@ -248,4 +248,5 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params ...@@ -248,4 +248,5 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
}); });
}); });
}); });
});
} }
...@@ -38,6 +38,7 @@ template< ...@@ -38,6 +38,7 @@ template<
typename output_t_, typename output_t_,
typename compute_t_, typename compute_t_,
typename index_t_, typename index_t_,
bool Has_colscale,
uint32_t THREADS_PER_CTA_, uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_, uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_, typename Base = Kernel_traits_base<HIDDEN_SIZE_,
...@@ -69,7 +70,8 @@ struct Kernel_traits_finalize : public Base { ...@@ -69,7 +70,8 @@ struct Kernel_traits_finalize : public Base {
// Shared memory size to coalsece the CTA result. // Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA. // Shared memory requirement per CTA.
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT }; static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2;
enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT };
// The type of the reducer. // The type of the reducer.
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>; using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
......
...@@ -45,7 +45,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) { ...@@ -45,7 +45,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
#define REGISTER_BWD_LAUNCHER( \ #define REGISTER_BWD_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \ void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
const bool configure_params, const bool prenorm) { \ const bool configure_params) { \
launch_<WTYPE, \ launch_<WTYPE, \
ITYPE, \ ITYPE, \
RTYPE, \ RTYPE, \
...@@ -57,7 +57,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) { ...@@ -57,7 +57,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
WARPS_M, \ WARPS_M, \
WARPS_N, \ WARPS_N, \
BYTES_PER_LDG, \ BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params, prenorm); \ BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \ } \
static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
......
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import torch import torch
from torch.nn import init from torch.nn import init
import dropout_layer_norm import dropout_layer_norm
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, epsilon, def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32): residual_in_fp32):
""" Assume that arguments are contiguous """ Assume that arguments are contiguous
""" """
...@@ -14,133 +16,98 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, ep ...@@ -14,133 +16,98 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, ep
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, rowscale, dropout_p, epsilon, None, residual_in_fp32 x0mat, x1mat, gamma, beta, rowscale, colscale, dropout_p, epsilon, None, residual_in_fp32
) )
# dmask is None if dropout_p == 0.0 # dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
has_residual):
""" Assume that arguments are contiguous
"""
# dmask is None if dropout_p == 0.0
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dzmat = dz.view(xmat.shape)
rowscale = rowscale.view(-1) if rowscale is not None else None
dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
)
# dx1mat is None if not has_residual
return dx0mat, dx1mat, dgamma, dbeta
def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, rowscale,
dropout_p, has_residual): dropout_p, has_residual):
""" Assume that arguments are contiguous """ Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
x0 must not be None if we have colscale.
""" """
hidden_size = gamma.numel() hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size)) xmat = x.view((-1, hidden_size))
dzmat = dz.view(xmat.shape) dzmat = dz.view(xmat.shape)
dxmat = dx.view(xmat.shape) dxmat = dx.view(xmat.shape) if dx is not None else None
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None
dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_prenorm_bwd( colscale = colscale.view(-1) if colscale is not None else None
dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p,
has_residual
) )
# dx1mat is None if not has_residual
if colscale is None:
return dx0mat, dx1mat, dgamma, dbeta return dx0mat, dx1mat, dgamma, dbeta
class DropoutAddLayerNormFN(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous()
rowscale = rowscale.contiguous() if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
)
ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
ctx.dropout_p = dropout_p
ctx.has_residual = x1 is not None
if not return_dmask:
return zmat.view(x0.shape)
else: else:
dmask = (dmask.view(x0.shape) if dropout_p > 0. dcolscale = rest[0]
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) return dx0mat, dx1mat, dgamma, dbeta, dcolscale
ctx.mark_non_differentiable(dmask)
return zmat.view(x0.shape), dmask
@staticmethod
def backward(ctx, dz, *args):
# assert dz.is_contiguous()
dz = dz.contiguous() # this happens!
x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward(
dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
)
dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
return dx0, dx1, dgamma, dbeta, None, None, None, None, None
class DropoutAddLayerNormPrenormFN(torch.autograd.Function): class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32, def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32,
return_dmask=False): prenorm=False, return_dmask=False):
x0 = x0.contiguous() x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None x1 = x1.contiguous() if x1 is not None else None
gamma = gamma.contiguous() gamma = gamma.contiguous()
beta = beta.contiguous() beta = beta.contiguous()
rowscale = rowscale.contiguous() if rowscale is not None else None rowscale = rowscale.contiguous() if rowscale is not None else None
colscale = colscale.contiguous() if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32 x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32
) )
ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale) # Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.has_residual = x1 is not None ctx.has_residual = x1 is not None
if not return_dmask: if not return_dmask:
return zmat.view(x0.shape), xmat.view(x0.shape) return (zmat.view(x0.shape) if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape)))
else: else:
dmask = (dmask.view(x0.shape) if dropout_p > 0. dmask = (dmask.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
ctx.mark_non_differentiable(dmask) ctx.mark_non_differentiable(dmask)
return zmat.view(x0.shape), xmat.view(x0.shape), dmask return ((zmat.view(x0.shape), dmask) if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
@staticmethod @staticmethod
def backward(ctx, dz, dx, *args): def backward(ctx, dz, *args):
# assert dz.is_contiguous() # assert dz.is_contiguous()
dz = dz.contiguous() # this happens! dz = dz.contiguous() # this happens!
dx = dx.contiguous() # this happens! dx = args[0].contiguous() if ctx.prenorm else None
x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
# x0 is None if colscale is None
dropout_p = ctx.dropout_p dropout_p = ctx.dropout_p
has_residual = ctx.has_residual has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward( dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
dz, dx, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual
) )
dx0 = dx0mat.view(x.shape) dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
return dx0, dx1, dgamma, dbeta, None, None, None, None, None dcolscale = rest[0] if colscale is not None else None
return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
prenorm=False, residual_in_fp32=False, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False): return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None. """residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype. Otherwise residual dtype is x1.dtype.
""" """
args = (x0, x1, weight, bias, rowscale, dropout_p, epsilon, residual_in_fp32, return DropoutAddLayerNormFn.apply(
return_dropout_mask) x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
if not prenorm: return_dropout_mask
return DropoutAddLayerNormFN.apply(*args) )
else:
return DropoutAddLayerNormPrenormFN.apply(*args)
class DropoutAddLayerNorm(torch.nn.Module): class DropoutAddLayerNorm(torch.nn.Module):
......
...@@ -11,6 +11,7 @@ from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_nor ...@@ -11,6 +11,7 @@ from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_nor
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_rowscale', [True, False]) @pytest.mark.parametrize('has_rowscale', [True, False])
# @pytest.mark.parametrize('has_rowscale', [True]) # @pytest.mark.parametrize('has_rowscale', [True])
@pytest.mark.parametrize('has_residual', [True, False]) @pytest.mark.parametrize('has_residual', [True, False])
...@@ -26,12 +27,9 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 ...@@ -26,12 +27,9 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale): dropout_p, has_residual, has_rowscale, has_colscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported pytest.skip() # Not supported
# Backward numerical error is high, and this case isn't used
if has_rowscale and not has_residual:
pytest.skip()
device = 'cuda' device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4)
...@@ -43,6 +41,12 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w ...@@ -43,6 +41,12 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
requires_grad=True) requires_grad=True)
x0 = x0_pt.detach().clone().requires_grad_() x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
colscale_pt = colscale.detach().clone().requires_grad_()
colscale_ref = colscale.detach().clone().float().requires_grad_()
else:
colscale = None
if has_residual: if has_residual:
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_() x1 = x1_pt.detach().clone().requires_grad_()
...@@ -59,6 +63,9 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w ...@@ -59,6 +63,9 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
rowscale = None rowscale = None
x0_scaled_pt = x0_pt x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref x0_scaled_ref = x0_ref
if has_colscale:
x0_scaled_pt = x0_scaled_pt * colscale_pt
x0_scaled_ref = x0_scaled_ref * colscale_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight) torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias) torch.nn.init.normal_(model_pt.bias)
...@@ -71,7 +78,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w ...@@ -71,7 +78,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
model_ref.bias.copy_(model_pt.bias) model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p, out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale, model.epsilon, rowscale=rowscale, layerscale=colscale,
residual_in_fp32=residual_in_fp32, return_dropout_mask=True) residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
assert out.dtype == input_dtype assert out.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
...@@ -94,6 +101,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w ...@@ -94,6 +101,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4 assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5 assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5 assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
...@@ -139,6 +148,7 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh ...@@ -139,6 +148,7 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_rowscale', [True, False]) @pytest.mark.parametrize('has_rowscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False]) @pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
...@@ -147,20 +157,17 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh ...@@ -147,20 +157,17 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
[(torch.float16, torch.float16), (torch.float16, torch.float32), [(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)] (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False]) # @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [True]) # @pytest.mark.parametrize('has_residual', [False])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32]) # @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
# @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale): dropout_p, has_residual, has_rowscale, has_colscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported pytest.skip() # Not supported
# Backward numerical error is high, and this case isn't used
if has_rowscale and not has_residual:
pytest.skip()
device = 'cuda' device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4) rtol, atol = (1e-3, 2e-4)
...@@ -172,6 +179,12 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ ...@@ -172,6 +179,12 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
requires_grad=True) requires_grad=True)
x0 = x0_pt.detach().clone().requires_grad_() x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
colscale_pt = colscale.detach().clone().requires_grad_()
colscale_ref = colscale.detach().clone().float().requires_grad_()
else:
colscale = None
if has_residual: if has_residual:
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_() x1 = x1_pt.detach().clone().requires_grad_()
...@@ -188,6 +201,9 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ ...@@ -188,6 +201,9 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
rowscale = None rowscale = None
x0_scaled_pt = x0_pt x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref x0_scaled_ref = x0_ref
if has_colscale:
x0_scaled_pt = x0_scaled_pt * colscale_pt
x0_scaled_ref = x0_scaled_ref * colscale_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
...@@ -199,7 +215,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ ...@@ -199,7 +215,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
model_ref.bias.copy_(model_pt.bias) model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p, out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale, prenorm=True, model.epsilon, rowscale=rowscale,
layerscale=colscale, prenorm=True,
residual_in_fp32=residual_in_fp32, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True) return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
...@@ -225,6 +242,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ ...@@ -225,6 +242,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4 assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4 assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4 assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment