"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "3f0e74b4a06a512a9b2abac170d631db474d884e"
Commit 5db33051 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Support taking subset of input or subset of output

parent ae137ed1
...@@ -66,11 +66,14 @@ struct ParamsBase { ...@@ -66,11 +66,14 @@ struct ParamsBase {
void *gamma; void *gamma;
void *rowscale; void *rowscale;
void *colscale; void *colscale;
void *x0_subset;
void *z_subset;
float inverse_cols; float inverse_cols;
float dropout_keep_p; float dropout_keep_p;
float dropout_scale; float dropout_scale;
float rowscale_const;
// Multi-CTA workspace in gmem. // Multi-CTA workspace in gmem.
void *workspace; void *workspace;
......
...@@ -84,9 +84,13 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -84,9 +84,13 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const at::Tensor &gamma, // hidden_size const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size const at::Tensor &beta, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS c10::optional<const at::Tensor> &rowscale_, // BxS
c10::optional<const at::Tensor> &colscale_, // BxS c10::optional<const at::Tensor> &colscale_, // hidden_size
c10::optional<const at::Tensor> &x0_subset_, // BxS
c10::optional<const at::Tensor> &z_subset_, // BxS
const float dropout_p, const float dropout_p,
const float epsilon, const float epsilon,
const float rowscale_const,
const int64_t z_numrows,
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
bool residual_in_fp32 bool residual_in_fp32
) { ) {
...@@ -99,14 +103,19 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -99,14 +103,19 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto ctype = torch::kFloat32; auto ctype = torch::kFloat32;
auto mtype = torch::kUInt8; auto mtype = torch::kUInt8;
TORCH_CHECK(beta.scalar_type() == wtype); TORCH_CHECK(beta.dtype() == wtype);
TORCH_CHECK(x0.is_cuda()) TORCH_CHECK(x0.is_cuda())
TORCH_CHECK(gamma.is_cuda()) TORCH_CHECK(gamma.is_cuda())
TORCH_CHECK(beta.is_cuda()) TORCH_CHECK(beta.is_cuda())
TORCH_CHECK(x0.is_contiguous()); TORCH_CHECK(x0.is_contiguous());
auto sizes = x0.sizes(); // c10::IntArrayRef does not own the storage, so we need to construct a vector.
// Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because
// blah is then deallocated.
std::vector<int64_t> sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)};
auto sizes = c10::IntArrayRef(sizes_vec);
TORCH_CHECK(x0.dim() == 2);
TORCH_CHECK(sizes.size() == 2); TORCH_CHECK(sizes.size() == 2);
const int rows = sizes[0]; const int rows = sizes[0];
...@@ -124,7 +133,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -124,7 +133,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto rowscale = rowscale_.value(); auto rowscale = rowscale_.value();
TORCH_CHECK(rowscale.is_cuda()) TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous()); TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows}); TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
TORCH_CHECK(rowscale.dtype() == itype); TORCH_CHECK(rowscale.dtype() == itype);
} }
...@@ -132,10 +141,25 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -132,10 +141,25 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto colscale = colscale_.value(); auto colscale = colscale_.value();
TORCH_CHECK(colscale.is_cuda()) TORCH_CHECK(colscale.is_cuda())
TORCH_CHECK(colscale.is_contiguous()); TORCH_CHECK(colscale.is_contiguous());
TORCH_CHECK(colscale.sizes() == std::vector<int64_t>{cols}); TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
TORCH_CHECK(colscale.dtype() == wtype); TORCH_CHECK(colscale.dtype() == wtype);
} }
if (x0_subset_.has_value()) {
auto x0_subset = x0_subset_.value();
TORCH_CHECK(x0_subset.is_cuda())
TORCH_CHECK(x0_subset.is_contiguous());
TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
TORCH_CHECK(z_subset_.has_value());
auto z_subset = z_subset_.value();
TORCH_CHECK(z_subset.is_cuda());
TORCH_CHECK(z_subset.is_contiguous());
TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
TORCH_CHECK(z_subset.dtype() == torch::kInt32);
}
TORCH_CHECK(gamma.sizes() == beta.sizes()); TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(hidden_size == cols); TORCH_CHECK(hidden_size == cols);
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144)); TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
...@@ -144,12 +168,12 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -144,12 +168,12 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto opts = x0.options(); auto opts = x0.options();
bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || (itype != rtype); bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
at::Tensor x; at::Tensor x;
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
at::Tensor dmask; at::Tensor dmask;
if (dropout_p > 0.f) { dmask = torch::empty(sizes, opts.dtype(mtype)); }; if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); };
auto z = torch::empty(sizes, opts.dtype(otype)); auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype));
auto mu = torch::empty({ rows }, opts.dtype(ctype)); auto mu = torch::empty({ rows }, opts.dtype(ctype));
auto rsigma = torch::empty({ rows }, opts.dtype(ctype)); auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
...@@ -163,6 +187,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -163,6 +187,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr; launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -192,6 +218,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -192,6 +218,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
params.epsilon = epsilon; params.epsilon = epsilon;
params.dropout_scale = 1.f / (1.f - dropout_p); params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols); params.inverse_cols = 1.f / float(params.cols);
params.rowscale_const = rowscale_const;
if (dropout_p > 0.f) { if (dropout_p > 0.f) {
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
...@@ -230,8 +257,12 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -230,8 +257,12 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
const at::Tensor &rsigma, // BxS, FP32! const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma, // hidden_size const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS c10::optional<const at::Tensor> &rowscale_, // BxS
c10::optional<const at::Tensor> &colscale_, // BxS c10::optional<const at::Tensor> &colscale_, // hidden_size
c10::optional<const at::Tensor> &x0_subset_, // BxS
c10::optional<const at::Tensor> &z_subset_, // BxS
const float dropout_p, const float dropout_p,
const float rowscale_const,
const int64_t x0_numrows,
const bool has_residual const bool has_residual
) { ) {
...@@ -259,9 +290,16 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -259,9 +290,16 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto sizes = x.sizes(); auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2); TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dz.sizes() == sizes);
auto rows = sizes[0]; auto rows = sizes[0];
auto cols = sizes[1]; auto cols = sizes[1];
TORCH_CHECK(dz.dim() == 2);
TORCH_CHECK(dz.size(1) == cols);
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
// Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because
// blah is then deallocated.
std::vector<int64_t> x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols};
auto x0_sizes = c10::IntArrayRef(x0_sizes_vec);
if (dx_.has_value()) { if (dx_.has_value()) {
auto dx = dx_.value(); auto dx = dx_.value();
...@@ -276,14 +314,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -276,14 +314,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
TORCH_CHECK(dmask.dtype() == mtype); TORCH_CHECK(dmask.dtype() == mtype);
TORCH_CHECK(dmask.is_cuda()); TORCH_CHECK(dmask.is_cuda());
TORCH_CHECK(dmask.is_contiguous()); TORCH_CHECK(dmask.is_contiguous());
TORCH_CHECK(dmask.sizes() == sizes); TORCH_CHECK(dmask.sizes() == x0_sizes);
} }
if (rowscale_.has_value()) { if (rowscale_.has_value()) {
auto rowscale = rowscale_.value(); auto rowscale = rowscale_.value();
TORCH_CHECK(rowscale.is_cuda()) TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous()); TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows}); TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
TORCH_CHECK(rowscale.dtype() == itype); TORCH_CHECK(rowscale.dtype() == itype);
} }
...@@ -291,17 +329,32 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -291,17 +329,32 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto colscale = colscale_.value(); auto colscale = colscale_.value();
TORCH_CHECK(colscale.is_cuda()) TORCH_CHECK(colscale.is_cuda())
TORCH_CHECK(colscale.is_contiguous()); TORCH_CHECK(colscale.is_contiguous());
TORCH_CHECK(colscale.sizes() == std::vector<int64_t>{cols}); TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
TORCH_CHECK(colscale.dtype() == wtype); TORCH_CHECK(colscale.dtype() == wtype);
TORCH_CHECK(x0_.has_value()); TORCH_CHECK(x0_.has_value());
auto x0 = x0_.value(); auto x0 = x0_.value();
TORCH_CHECK(x0.is_cuda()) TORCH_CHECK(x0.is_cuda())
TORCH_CHECK(x0.is_contiguous()); TORCH_CHECK(x0.is_contiguous());
TORCH_CHECK(x0.sizes() == sizes); TORCH_CHECK(x0.sizes() == x0_sizes);
TORCH_CHECK(x0.dtype() == itype); TORCH_CHECK(x0.dtype() == itype);
} }
if (x0_subset_.has_value()) {
auto x0_subset = x0_subset_.value();
TORCH_CHECK(x0_subset.is_cuda())
TORCH_CHECK(x0_subset.is_contiguous());
TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
TORCH_CHECK(z_subset_.has_value());
auto z_subset = z_subset_.value();
TORCH_CHECK(z_subset.is_cuda());
TORCH_CHECK(z_subset.is_contiguous());
TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
TORCH_CHECK(z_subset.dtype() == torch::kInt32);
}
auto hidden_size = gamma.numel(); auto hidden_size = gamma.numel();
TORCH_CHECK(hidden_size == cols); TORCH_CHECK(hidden_size == cols);
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144)); TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
...@@ -313,7 +366,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -313,7 +366,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto opts = x.options(); auto opts = x.options();
auto dx0 = torch::empty_like(x, opts.dtype(itype)); auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
at::Tensor dx1; at::Tensor dx1;
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); } if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
auto dgamma = torch::empty_like(gamma); auto dgamma = torch::empty_like(gamma);
...@@ -331,6 +384,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -331,6 +384,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
...@@ -366,6 +421,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -366,6 +421,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr; params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;
params.dropout_scale = 1.f / (1.f - dropout_p); params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols); params.inverse_cols = 1.f / float(params.cols);
params.rowscale_const = rowscale_const;
if( launch_params.barrier_size > 0 ) { if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this? // TODO Any way to avoid this?
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace layer_norm { namespace layer_norm {
template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols> template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_kernel(layer_norm::BwdParams params) { void ln_bwd_kernel(layer_norm::BwdParams params) {
...@@ -37,6 +37,9 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -37,6 +37,9 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
extern __shared__ char smem_[]; extern __shared__ char smem_[];
const bool has_residual = params.dx1 != nullptr;
const bool prenorm = params.dx != nullptr;
const index_t tidx = threadIdx.x; const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW; const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW; const index_t bidm = blockIdx.x / CTAS_PER_ROW;
...@@ -51,6 +54,10 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -51,6 +54,10 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
const input_t *rowscale = static_cast<input_t *>(params.rowscale);
const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
const index_t *z_subset = static_cast<index_t *>(params.z_subset);
Cvec dzy_sum[LDGS]; Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS]; Cvec dz_sum[LDGS];
Cvec dcolscale_sum[LDGS]; Cvec dcolscale_sum[LDGS];
...@@ -87,25 +94,34 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -87,25 +94,34 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row]; const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row]; const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
const compute_t rowscale_val = const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
params.rowscale == nullptr ? 1.0f : compute_t(static_cast<const input_t *>(params.rowscale)[row]); const int row_z = !Has_subset ? row + 1 : z_subset[row];
const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
const bool load_dz = !Has_subset || row_z > 0;
const bool save_dx0 = !Has_subset || row_x0 > 0;
Mvec dmask[LDGS]; Mvec dmask[LDGS];
Rvec dx[LDGS]; Rvec dx[LDGS];
compute_t dy[LDGS * NUM_ELTS]; compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS]; compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f; compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f; compute_t mdyy_local = 0.f;
index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; // If dz is not loaded, then dy should be 0 and we don't care about the value of y.
if (load_dz) {
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
index_t idx_z = !Has_subset ? idx_x : (load_dz ? (row_z - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) { if (Is_even_cols || (it < num_valid_ldgs)) {
Rvec x; Rvec x;
Ovec dz; Ovec dz;
dz.load_from(params.dz, idx); dz.load_from(params.dz, !Has_subset ? idx_x : idx_z);
if (Prenorm) { dx[it].load_from(params.dx, idx); } if (prenorm) { dx[it].load_from(params.dx, idx_x); }
x.load_from(params.x, idx); x.load_from(params.x, idx_x);
if (Is_dropout) { dmask[it].load_from(params.dmask, idx); } if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
idx += Ktraits::VEC_COLS_PER_LDG; idx_x += Ktraits::VEC_COLS_PER_LDG;
idx_z += Ktraits::VEC_COLS_PER_LDG;
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
#pragma unroll #pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) { for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x.data.elt[jt]; compute_t x_tmp = x.data.elt[jt];
...@@ -124,26 +140,46 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -124,26 +140,46 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
} }
} }
} }
} else {
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
if (prenorm) { dx[it].load_from(params.dx, idx_x); }
if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
idx_x += Ktraits::VEC_COLS_PER_LDG;
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
}
}
}
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols; mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols; mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;
idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) { if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec dx0; Ivec dx0;
Rvec dx1; Rvec dx1;
Ivec x0; Ivec x0;
if (Has_colscale) { x0.load_from(params.x0, idx); } if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
#pragma unroll #pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) { for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dx_tmp_res;
if (load_dz) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt]; compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt]; compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; } } else {
dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
}
if (has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
if (save_dx0) {
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val; compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
if (Is_dropout) { if (Is_dropout) {
dx0_tmp_res *= params.dropout_scale; dx0_tmp_res *= params.dropout_scale;
...@@ -162,9 +198,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -162,9 +198,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
} }
} }
} }
if (Has_residual) { dx1.store_to(params.dx1, idx); } }
dx0.store_to(params.dx0, idx); if (has_residual) { dx1.store_to(params.dx1, idx_x); }
idx += Ktraits::VEC_COLS_PER_LDG; if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }
idx_x += Ktraits::VEC_COLS_PER_LDG;
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
} }
} }
...@@ -434,17 +472,15 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params ...@@ -434,17 +472,15 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
WARPS_N, WARPS_N,
BYTES_PER_LDG_MAIN BYTES_PER_LDG_MAIN
>; >;
bool prenorm = launch_params.params.dx != nullptr;
bool is_dropout = launch_params.params.dropout_keep_p < 1.f; bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
bool has_residual = launch_params.params.dx1 != nullptr;
bool has_colscale = launch_params.params.colscale != nullptr; bool has_colscale = launch_params.params.colscale != nullptr;
bool has_subset = launch_params.params.x0_subset != nullptr;
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
BOOL_SWITCH(prenorm, PrenormConst, [&] {
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>; auto kernel = &ln_bwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
if( configure_params ) { if( configure_params ) {
int ctas_per_sm; int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
...@@ -495,5 +531,4 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params ...@@ -495,5 +531,4 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
}); });
}); });
}); });
});
} }
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
namespace layer_norm { namespace layer_norm {
template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols> template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_kernel(FwdParams params) { void ln_fwd_kernel(FwdParams params) {
...@@ -46,7 +46,8 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -46,7 +46,8 @@ void ln_fwd_kernel(FwdParams params) {
using Stats = typename Ktraits::Stats; using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t; using stats_t = typename Stats::stats_t;
const bool save_x = Has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || !(std::is_same<input_t, residual_t>::value); const bool has_residual = params.x1 != nullptr;
const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);
extern __shared__ char smem_[]; extern __shared__ char smem_[];
...@@ -67,6 +68,8 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -67,6 +68,8 @@ void ln_fwd_kernel(FwdParams params) {
compute_t *rs_ptr = static_cast<compute_t *>(params.rs); compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
const input_t *rowscale = static_cast<input_t *>(params.rowscale); const input_t *rowscale = static_cast<input_t *>(params.rowscale);
const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
const index_t *z_subset = static_cast<index_t *>(params.z_subset);
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
...@@ -93,8 +96,12 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -93,8 +96,12 @@ void ln_fwd_kernel(FwdParams params) {
} }
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t rowscale_val = params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row]); const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
const int row_z = !Has_subset ? row + 1 : z_subset[row];
const bool load_x0 = !Has_subset || row_x0 > 0;
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
compute_t xf[LDGS * NUM_ELTS]; compute_t xf[LDGS * NUM_ELTS];
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
...@@ -103,24 +110,30 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -103,24 +110,30 @@ void ln_fwd_kernel(FwdParams params) {
Rvec x1; Rvec x1;
Rvec x; Rvec x;
Mvec dmask; Mvec dmask;
x0.load_from(params.x0, idx); if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
if (Has_residual) { x1.load_from(params.x1, idx); } if (has_residual) { x1.load_from(params.x1, idx_x); }
#pragma unroll #pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) { for( int jt = 0; jt < NUM_ELTS; jt++ ) {
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
// the more efficient curand_uniform4. // the more efficient curand_uniform4.
compute_t x_ij;
if (load_x0) {
mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
if (Is_dropout) { dmask.data.elt[jt] = keep; }
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); } if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
compute_t x_ij = Has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij; x_ij = has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij;
} else {
x_ij = has_residual ? compute_t(x1.data.elt[jt]) : 0.f;
}
if (save_x) { x.data.elt[jt] = x_ij; } if (save_x) { x.data.elt[jt] = x_ij; }
xf[it * NUM_ELTS + jt] = x_ij; xf[it * NUM_ELTS + jt] = x_ij;
if (Is_dropout) { dmask.data.elt[jt] = keep; }
} }
if (save_x) { x.store_to(params.x, idx); } if (save_x) { x.store_to(params.x, idx_x); }
if (Is_dropout) { dmask.store_to(params.dmask, idx); } if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); }
idx += VEC_COLS_PER_LDG; idx_x += VEC_COLS_PER_LDG;
idx_x0 += VEC_COLS_PER_LDG;
} }
} }
...@@ -152,7 +165,9 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -152,7 +165,9 @@ void ln_fwd_kernel(FwdParams params) {
rs_ptr[row] = rs; rs_ptr[row] = rs;
} }
idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; const bool save_z = !Has_subset || row_z > 0;
if (save_z) {
index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c;
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) { if (Is_even_cols || (it < num_valid_ldgs)) {
...@@ -164,8 +179,9 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -164,8 +179,9 @@ void ln_fwd_kernel(FwdParams params) {
compute_t b_ij = beta[it].data.elt[jt]; compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = output_t(g_ij * y_ij + b_ij); z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
} }
z.store_to(params.z, idx); z.store_to(params.z, idx_z);
idx += VEC_COLS_PER_LDG; idx_z += VEC_COLS_PER_LDG;
}
} }
} }
...@@ -203,14 +219,14 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params ...@@ -203,14 +219,14 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
WARPS_N, WARPS_N,
BYTES_PER_LDG BYTES_PER_LDG
>; >;
bool has_residual = launch_params.params.x1 != nullptr;
bool has_colscale = launch_params.params.colscale != nullptr; bool has_colscale = launch_params.params.colscale != nullptr;
bool has_subset = launch_params.params.x0_subset != nullptr;
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>; auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
if( configure_params ) { if( configure_params ) {
int ctas_per_sm; int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
......
...@@ -16,7 +16,8 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro ...@@ -16,7 +16,8 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, rowscale, colscale, dropout_p, epsilon, None, residual_in_fp32 x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
1.0, 0, None, residual_in_fp32
) )
# dmask is None if dropout_p == 0.0 # dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
...@@ -36,12 +37,59 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro ...@@ -36,12 +37,59 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
dxmat = dx.view(xmat.shape) if dx is not None else None dxmat = dx.view(xmat.shape) if dx is not None else None
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None
colscale = colscale.view(-1) if colscale is not None else None
if colscale is not None: if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale' assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
has_residual dropout_p, 1.0, 0, has_residual
)
# dx1mat is None if not has_residual
if colscale is None:
return dx0mat, dx1mat, dgamma, dbeta
else:
dcolscale = rest[0]
return dx0mat, dx1mat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset,
dropout_p, epsilon, rowscale_const, out_numrows,
residual_in_fp32):
""" Assume that arguments are contiguous
"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, None, residual_in_fp32
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
x0_subset, out_subset, dropout_p, rowscale_const,
x0_numrows, has_residual):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dzmat = dz.view(-1, hidden_size)
dxmat = dx.view(xmat.shape) if dx is not None else None
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None
if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
dropout_p, rowscale_const, x0_numrows, has_residual
) )
# dx1mat is None if not has_residual # dx1mat is None if not has_residual
if colscale is None: if colscale is None:
...@@ -98,6 +146,60 @@ class DropoutAddLayerNormFn(torch.autograd.Function): ...@@ -98,6 +146,60 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm=False, return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous()
colscale = colscale.contiguous() if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
x_shape = (-1, *x0.shape[1:])
ctx.save_for_backward(xmat.view(x_shape), x0, dmask, gamma, mu, rsigma, colscale,
x0_subset, out_subset)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.rowscale_const = rowscale_const
ctx.x0_numrows = x0.shape[:-1].numel()
ctx.has_residual = x1 is not None
z_shape = (-1, *x0.shape[1:])
if not return_dmask:
return (zmat.view(z_shape) if not prenorm
else (zmat.view(z_shape), xmat.view(x0.shape)))
else:
z = zmat.view(z_shape)
dmask = (dmask.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
ctx.mark_non_differentiable(dmask)
return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask))
@staticmethod
def backward(ctx, dz, *args):
# assert dz.is_contiguous()
dz = dz.contiguous() # this happens!
dx = args[0].contiguous() if ctx.prenorm else None
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
# x0 is None if colscale is None
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
ctx.rowscale_const, ctx.x0_numrows, has_residual
)
dx0 = dx0mat.view(-1, *x.shape[1:])
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dcolscale = rest[0] if colscale is not None else None
return (dx0, dx1, dgamma, dbeta, dcolscale, None, None, None, None, None, None, None,
None, None)
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
prenorm=False, residual_in_fp32=False, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False): return_dropout_mask=False):
...@@ -110,6 +212,19 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No ...@@ -110,6 +212,19 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
) )
def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None,
x0_subset=None, out_subset=None, rowscale_const=1.0,
out_numrows=0, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
return DropoutAddLayerNormSubsetFn.apply(
x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, return_dropout_mask
)
class DropoutAddLayerNorm(torch.nn.Module): class DropoutAddLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None): device=None, dtype=None):
......
...@@ -4,9 +4,10 @@ import torch ...@@ -4,9 +4,10 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest import pytest
from einops import rearrange from einops import rearrange, repeat
from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm
from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
...@@ -130,6 +131,8 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh ...@@ -130,6 +131,8 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
x1 = x1_pt.detach().clone().requires_grad_() x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_() x1_ref = x1_pt.detach().clone().float().requires_grad_()
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype) model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
with torch.no_grad(): with torch.no_grad():
...@@ -148,22 +151,23 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh ...@@ -148,22 +151,23 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
@pytest.mark.parametrize('has_colscale', [True, False]) # @pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_rowscale', [True, False]) # @pytest.mark.parametrize('has_rowscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False]) # @pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) # @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) # @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype', # @pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32), # [(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)] # (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) # + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True]) @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False]) @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [False]) @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0]) @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32]) @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) # @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale, has_colscale): dropout_p, has_residual, has_rowscale, has_colscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
...@@ -205,6 +209,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ ...@@ -205,6 +209,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
x0_scaled_pt = x0_scaled_pt * colscale_pt x0_scaled_pt = x0_scaled_pt * colscale_pt
x0_scaled_ref = x0_scaled_ref * colscale_ref x0_scaled_ref = x0_scaled_ref * colscale_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype) dtype=weight_dtype)
...@@ -271,6 +277,8 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp ...@@ -271,6 +277,8 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
x1 = x1_pt.detach().clone().requires_grad_() x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_() x1_ref = x1_pt.detach().clone().float().requires_grad_()
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype) dtype=weight_dtype)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
...@@ -289,3 +297,245 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp ...@@ -289,3 +297,245 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
out_ref = model_ref(residual_ref) out_ref = model_ref(residual_ref)
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4 assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_training(
hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p,
has_residual, has_colscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
drop_path_rate = 0.4
drop_path_scale = 1 / (1 - drop_path_rate)
def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
numrows = (mask_batch).sum().item() * seqlen
mask_batch = mask_batch.to(device=device, non_blocking=True)
mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen)
subset = torch.cumsum(mask_batch_seqlen, dim=0,
dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0)
return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size)
x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device)
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device)
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
colscale_pt = colscale.detach().clone().requires_grad_()
colscale_ref = colscale.detach().clone().float().requires_grad_()
else:
colscale = None
if has_residual:
x1_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
else:
x1 = None
if has_colscale:
x0_scaled_pt = x0_pt * colscale_pt
x0_scaled_ref = x0_ref * colscale_ref
else:
x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=False, p=dropout_p, device=device,
dtype=weight_dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = dropout_add_layer_norm_subset(
x0, x1, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
x0_scaled_pt = x0_scaled_pt.masked_fill(
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
) * drop_path_scale
x0_scaled_ref = x0_scaled_ref.masked_fill(
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
) * drop_path_scale
dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
dmask_expanded[x0_mask_batch] = dmask
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + x1_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
out_ref = model_ref(residual_ref)[out_mask_batch]
assert out.dtype == input_dtype
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size
out_pt.backward(g)
out.backward(g)
out_ref.backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_prenorm_training(
hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p,
has_residual, has_colscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
drop_path_rate = 0.4
drop_path_scale = 1 / (1 - drop_path_rate)
def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
numrows = (mask_batch).sum().item() * seqlen
mask_batch = mask_batch.to(device=device, non_blocking=True)
mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen)
subset = torch.cumsum(mask_batch_seqlen, dim=0,
dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0)
return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size)
x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device)
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device)
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
colscale_pt = colscale.detach().clone().requires_grad_()
colscale_ref = colscale.detach().clone().float().requires_grad_()
else:
colscale = None
if has_residual:
x1_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
else:
x1 = None
if has_colscale:
x0_scaled_pt = x0_pt * colscale_pt
x0_scaled_ref = x0_ref * colscale_ref
else:
x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = dropout_add_layer_norm_subset(
x0, x1, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
x0_scaled_pt = x0_scaled_pt.masked_fill(
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
) * drop_path_scale
x0_scaled_ref = x0_scaled_ref.masked_fill(
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
) * drop_path_scale
dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
dmask_expanded[x0_mask_batch] = dmask
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + x1_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
out_ref = model_ref(residual_ref)[out_mask_batch]
assert out.dtype == input_dtype
assert residual.dtype == residual_dtype
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size
(out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(g)
(out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g)
(out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype)) + residual_ref.mean(0, keepdim=True)).backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment