"vscode:/vscode.git/clone" did not exist on "f85b3ea8ebdb4292d3490f7d18b8019ba93b6787"
Commit 6738d947 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Implement RMS Norm

parent a1f49a2b
...@@ -2,6 +2,7 @@ This CUDA extension implements fused dropout + residual + LayerNorm, building on ...@@ -2,6 +2,7 @@ This CUDA extension implements fused dropout + residual + LayerNorm, building on
Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
We add dropout and residual, and make it work for both pre-norm and post-norm architecture. We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144). We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144).
We also implement RMSNorm as an option.
If you want to use it for dimensions larger than 6k, please file an issue. If you want to use it for dimensions larger than 6k, please file an issue.
......
...@@ -44,6 +44,7 @@ struct ParamsBase { ...@@ -44,6 +44,7 @@ struct ParamsBase {
, colscale(nullptr) , colscale(nullptr)
, dropout_keep_p(1.f) , dropout_keep_p(1.f)
, dropout_scale(1.f) , dropout_scale(1.f)
, is_rms_norm(false)
, workspace(nullptr) , workspace(nullptr)
, barrier(nullptr) , barrier(nullptr)
{ {
...@@ -75,6 +76,8 @@ struct ParamsBase { ...@@ -75,6 +76,8 @@ struct ParamsBase {
float dropout_scale; float dropout_scale;
float rowscale_const; float rowscale_const;
bool is_rms_norm;
// Multi-CTA workspace in gmem. // Multi-CTA workspace in gmem.
void *workspace; void *workspace;
......
...@@ -83,7 +83,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp ...@@ -83,7 +83,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
c10::optional<const at::Tensor> &x1_, // Residual: BxSxhidden_size c10::optional<const at::Tensor> &x1_, // Residual: BxSxhidden_size
const at::Tensor &gamma, // hidden_size const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size c10::optional<const at::Tensor> &beta_, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS c10::optional<const at::Tensor> &rowscale_, // BxS
c10::optional<const at::Tensor> &colscale_, // hidden_size c10::optional<const at::Tensor> &colscale_, // hidden_size
c10::optional<const at::Tensor> &x0_subset_, // BxS c10::optional<const at::Tensor> &x0_subset_, // BxS
...@@ -93,7 +93,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -93,7 +93,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const float rowscale_const, const float rowscale_const,
const int64_t z_numrows, const int64_t z_numrows,
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
bool residual_in_fp32 bool residual_in_fp32=false,
bool is_rms_norm=false
) { ) {
auto itype = x0.scalar_type(); auto itype = x0.scalar_type();
auto rtype = x1_.has_value() auto rtype = x1_.has_value()
...@@ -104,11 +105,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -104,11 +105,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto ctype = torch::kFloat32; auto ctype = torch::kFloat32;
auto mtype = torch::kUInt8; auto mtype = torch::kUInt8;
TORCH_CHECK(beta.dtype() == wtype);
TORCH_CHECK(x0.is_cuda()) TORCH_CHECK(x0.is_cuda())
TORCH_CHECK(gamma.is_cuda()) TORCH_CHECK(gamma.is_cuda())
TORCH_CHECK(beta.is_cuda())
TORCH_CHECK(x0.is_contiguous()); TORCH_CHECK(x0.is_contiguous());
// c10::IntArrayRef does not own the storage, so we need to construct a vector. // c10::IntArrayRef does not own the storage, so we need to construct a vector.
...@@ -123,6 +121,14 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -123,6 +121,14 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const int cols = sizes[1]; const int cols = sizes[1];
auto hidden_size = gamma.numel(); auto hidden_size = gamma.numel();
if (beta_.has_value()) {
auto beta = beta_.value();
TORCH_CHECK(beta.dtype() == wtype);
TORCH_CHECK(beta.is_cuda())
TORCH_CHECK(beta.is_contiguous());
TORCH_CHECK(gamma.sizes() == beta.sizes());
}
if (x1_.has_value()) { if (x1_.has_value()) {
auto x1 = x1_.value(); auto x1 = x1_.value();
TORCH_CHECK(x1.is_cuda()) TORCH_CHECK(x1.is_cuda())
...@@ -161,7 +167,6 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -161,7 +167,6 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK(z_subset.dtype() == torch::kInt32); TORCH_CHECK(z_subset.dtype() == torch::kInt32);
} }
TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(hidden_size == cols); TORCH_CHECK(hidden_size == cols);
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144)); TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
...@@ -218,12 +223,13 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -218,12 +223,13 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
params.mu = mu.data_ptr(); params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr(); params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr(); params.gamma = gamma.data_ptr();
params.beta = beta.data_ptr(); params.beta = beta_.has_value() ? beta_.value().data_ptr() : nullptr;
params.z = z.data_ptr(); params.z = z.data_ptr();
params.epsilon = epsilon; params.epsilon = epsilon;
params.dropout_scale = 1.f / (1.f - dropout_p); params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols); params.inverse_cols = 1.f / float(params.cols);
params.rowscale_const = rowscale_const; params.rowscale_const = rowscale_const;
params.is_rms_norm = is_rms_norm;
if (dropout_p > 0.f) { if (dropout_p > 0.f) {
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
...@@ -268,7 +274,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -268,7 +274,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
const float dropout_p, const float dropout_p,
const float rowscale_const, const float rowscale_const,
const int64_t x0_numrows, const int64_t x0_numrows,
const bool has_residual const bool has_residual,
bool is_rms_norm=false
) { ) {
auto itype = dz.scalar_type(); auto itype = dz.scalar_type();
...@@ -431,6 +438,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -431,6 +438,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
params.dropout_scale = 1.f / (1.f - dropout_p); params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols); params.inverse_cols = 1.f / float(params.cols);
params.rowscale_const = rowscale_const; params.rowscale_const = rowscale_const;
params.is_rms_norm = is_rms_norm;
if( launch_params.barrier_size > 0 ) { if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this? // TODO Any way to avoid this?
...@@ -453,6 +461,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -453,6 +461,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA DropoutAddLayerNorm"; m.doc() = "CUDA DropoutAddLayerNorm";
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel"); m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel"); py::arg("x0"), py::arg("x1"), py::arg("gamma"), py::arg("beta"),
py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel",
py::arg("dz"), py::arg("dx_"), py::arg("x"), py::arg("x0_"), py::arg("dmask_"), py::arg("mu"),
py::arg("rsigma"), py::arg("gamma"), py::arg("rowscale_"), py::arg("colscale_"),
py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("rowscale_const"),
py::arg("x0_numrows"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
} }
...@@ -125,7 +125,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -125,7 +125,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
#pragma unroll #pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) { for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x.data.elt[jt]; compute_t x_tmp = x.data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r); compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f));
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]); compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]);
compute_t dz_tmp = dz.data.elt[jt]; compute_t dz_tmp = dz.data.elt[jt];
...@@ -173,7 +173,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -173,7 +173,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if (load_dz) { if (load_dz) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt]; compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt]; compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f)));
dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
} else { } else {
dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f; dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
......
...@@ -89,7 +89,11 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -89,7 +89,11 @@ void ln_fwd_kernel(FwdParams params) {
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) { if (Is_even_cols || (it < num_valid_ldgs)) {
gamma[it].load_from(params.gamma, idx); gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx); if (params.beta != nullptr) {
beta[it].load_from(params.beta, idx);
} else {
beta[it].zero_();
}
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); } if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
idx += VEC_COLS_PER_LDG; idx += VEC_COLS_PER_LDG;
} }
...@@ -159,7 +163,7 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -159,7 +163,7 @@ void ln_fwd_kernel(FwdParams params) {
mu_ptr[row] = mu; mu_ptr[row] = mu;
} }
compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon); compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu));
if( bidn == 0 && warp_n == 0 && lane == 0 ) { if( bidn == 0 && warp_n == 0 && lane == 0 ) {
rs_ptr[row] = rs; rs_ptr[row] = rs;
...@@ -174,7 +178,7 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -174,7 +178,7 @@ void ln_fwd_kernel(FwdParams params) {
Ovec z; Ovec z;
#pragma unroll #pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) { for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - mu)); compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f)));
compute_t g_ij = gamma[it].data.elt[jt]; compute_t g_ij = gamma[it].data.elt[jt];
compute_t b_ij = beta[it].data.elt[jt]; compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = output_t(g_ij * y_ij + b_ij); z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
......
...@@ -308,6 +308,13 @@ struct Vec { ...@@ -308,6 +308,13 @@ struct Vec {
} }
} }
inline __device__ void zero_() {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = Elt_type(0.f);
}
}
inline __device__ void load_from(const void *base_ptr, const size_t idx) { inline __device__ void load_from(const void *base_ptr, const size_t idx) {
this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx]; this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
} }
......
...@@ -8,7 +8,7 @@ import dropout_layer_norm ...@@ -8,7 +8,7 @@ import dropout_layer_norm
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32): residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous """ Assume that arguments are contiguous
""" """
hidden_size = gamma.numel() hidden_size = gamma.numel()
...@@ -17,7 +17,7 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro ...@@ -17,7 +17,7 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
rowscale = rowscale.view(-1) if rowscale is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon, x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
1.0, 0, None, residual_in_fp32 1.0, 0, None, residual_in_fp32, is_rms_norm
) )
# dmask is None if dropout_p == 0.0 # dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
...@@ -25,7 +25,7 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro ...@@ -25,7 +25,7 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
dropout_p, has_residual): dropout_p, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous """ Assume that arguments are contiguous
dx == None means that it was a post-norm architecture dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd). (x = drop(x0) + x1 was not returned in the fwd).
...@@ -41,7 +41,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro ...@@ -41,7 +41,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
assert x0 is not None, 'x0 is required to compute the gradient of colscale' assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None, dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
dropout_p, 1.0, 0, has_residual dropout_p, 1.0, 0, has_residual, is_rms_norm
) )
# dx1mat is None if not has_residual # dx1mat is None if not has_residual
if colscale is None: if colscale is None:
...@@ -53,7 +53,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro ...@@ -53,7 +53,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset, def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset,
dropout_p, epsilon, rowscale_const, out_numrows, dropout_p, epsilon, rowscale_const, out_numrows,
residual_in_fp32): residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous """ Assume that arguments are contiguous
""" """
hidden_size = gamma.numel() hidden_size = gamma.numel()
...@@ -63,7 +63,7 @@ def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_sub ...@@ -63,7 +63,7 @@ def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_sub
out_subset = out_subset.view(-1) if out_subset is not None else None out_subset = out_subset.view(-1) if out_subset is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon, x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, None, residual_in_fp32 rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
) )
# dmask is None if dropout_p == 0.0 # dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
...@@ -72,7 +72,7 @@ def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_sub ...@@ -72,7 +72,7 @@ def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_sub
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
x0_subset, out_subset, dropout_p, rowscale_const, x0_subset, out_subset, dropout_p, rowscale_const,
x0_numrows, has_residual): x0_numrows, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous """ Assume that arguments are contiguous
dx == None means that it was a post-norm architecture dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd). (x = drop(x0) + x1 was not returned in the fwd).
...@@ -89,7 +89,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga ...@@ -89,7 +89,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
assert x0 is not None, 'x0 is required to compute the gradient of colscale' assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset, dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
dropout_p, rowscale_const, x0_numrows, has_residual dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
) )
# dx1mat is None if not has_residual # dx1mat is None if not has_residual
if colscale is None: if colscale is None:
...@@ -101,16 +101,17 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga ...@@ -101,16 +101,17 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
class DropoutAddLayerNormFn(torch.autograd.Function): class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32, def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
prenorm=False, return_dmask=False): residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous() x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None x1 = x1.contiguous() if x1 is not None else None
gamma = gamma.contiguous() gamma = gamma.contiguous()
beta = beta.contiguous() beta = beta.contiguous() if beta is not None else None
rowscale = rowscale.contiguous() if rowscale is not None else None rowscale = rowscale.contiguous() if rowscale is not None else None
colscale = colscale.contiguous() if colscale is not None else None colscale = colscale.contiguous() if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32 x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32, is_rms_norm
) )
# Only need to save x0 if we need to compute gradient wrt colscale # Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None x0_saved = x0 if colscale is not None else None
...@@ -118,6 +119,8 @@ class DropoutAddLayerNormFn(torch.autograd.Function): ...@@ -118,6 +119,8 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
ctx.prenorm = prenorm ctx.prenorm = prenorm
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.has_residual = x1 is not None ctx.has_residual = x1 is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None
if not return_dmask: if not return_dmask:
return (zmat.view(x0.shape) if not prenorm return (zmat.view(x0.shape) if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape))) else (zmat.view(x0.shape), xmat.view(x0.shape)))
...@@ -138,26 +141,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function): ...@@ -138,26 +141,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
dropout_p = ctx.dropout_p dropout_p = ctx.dropout_p
has_residual = ctx.has_residual has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
ctx.is_rms_norm
) )
dx0 = dx0mat.view(x.shape) dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dcolscale = rest[0] if colscale is not None else None dcolscale = rest[0] if colscale is not None else None
return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, None,
None, None, None, None)
class DropoutAddLayerNormSubsetFn(torch.autograd.Function): class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm=False, return_dmask=False): rowscale_const, out_numrows, residual_in_fp32=False,
prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous() x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None x1 = x1.contiguous() if x1 is not None else None
gamma = gamma.contiguous() gamma = gamma.contiguous()
beta = beta.contiguous() beta = beta.contiguous() if beta is not None else None
colscale = colscale.contiguous() if colscale is not None else None colscale = colscale.contiguous() if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32 rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
) )
# Only need to save x0 if we need to compute gradient wrt colscale # Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None x0_saved = x0 if colscale is not None else None
...@@ -169,6 +175,8 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function): ...@@ -169,6 +175,8 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
ctx.rowscale_const = rowscale_const ctx.rowscale_const = rowscale_const
ctx.x0_numrows = x0.shape[:-1].numel() ctx.x0_numrows = x0.shape[:-1].numel()
ctx.has_residual = x1 is not None ctx.has_residual = x1 is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None
z_shape = (-1, *x0.shape[1:]) z_shape = (-1, *x0.shape[1:])
if not return_dmask: if not return_dmask:
return (zmat.view(z_shape) if not prenorm return (zmat.view(z_shape) if not prenorm
...@@ -191,13 +199,13 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function): ...@@ -191,13 +199,13 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
has_residual = ctx.has_residual has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p, dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
ctx.rowscale_const, ctx.x0_numrows, has_residual ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
) )
dx0 = dx0mat.view(-1, *x.shape[1:]) dx0 = dx0mat.view(-1, *x.shape[1:])
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dcolscale = rest[0] if colscale is not None else None dcolscale = rest[0] if colscale is not None else None
return (dx0, dx1, dgamma, dbeta, dcolscale, None, None, None, None, None, None, None, return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, None,
None, None) None, None, None, None, None, None, None)
def layer_norm(x, weight, bias, epsilon): def layer_norm(x, weight, bias, epsilon):
...@@ -212,7 +220,7 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No ...@@ -212,7 +220,7 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
""" """
return DropoutAddLayerNormFn.apply( return DropoutAddLayerNormFn.apply(
x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
return_dropout_mask False, return_dropout_mask
) )
...@@ -225,7 +233,7 @@ def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, laye ...@@ -225,7 +233,7 @@ def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, laye
""" """
return DropoutAddLayerNormSubsetFn.apply( return DropoutAddLayerNormSubsetFn.apply(
x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, return_dropout_mask rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
) )
......
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import torch
from torch.nn import init
from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
def rms_norm(x, weight, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False,
False, True)
def dropout_add_rms_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
prenorm=False, residual_in_fp32=False, return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
return DropoutAddLayerNormFn.apply(
x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
True, return_dropout_mask
)
def dropout_add_rms_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None,
x0_subset=None, out_subset=None, rowscale_const=1.0,
out_numrows=0, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
return DropoutAddLayerNormSubsetFn.apply(
x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
)
class DropoutAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.prenorm = prenorm
self.p = p
self.epsilon = eps
self.residual_in_fp32 = residual_in_fp32
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
def forward(self, x0, x1=None):
return dropout_add_rms_norm(x0, x1, self.weight, None,
self.p if self.training else 0.0, self.epsilon,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
...@@ -8,11 +8,20 @@ from einops import rearrange, repeat ...@@ -8,11 +8,20 @@ from einops import rearrange, repeat
from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm
from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset
from flash_attn.ops.rms_norm import DropoutAddRMSNorm, dropout_add_rms_norm
from flash_attn.ops.rms_norm import dropout_add_rms_norm_subset
try:
from apex.normalization import FusedRMSNorm
except:
FusedRMSNorm = None
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('is_rms_norm', [False, True])
@pytest.mark.parametrize('has_colscale', [True, False]) @pytest.mark.parametrize('has_colscale', [True, False])
# @pytest.mark.parametrize('has_colscale', [False])
@pytest.mark.parametrize('has_rowscale', [True, False]) @pytest.mark.parametrize('has_rowscale', [True, False])
# @pytest.mark.parametrize('has_rowscale', [True]) # @pytest.mark.parametrize('has_rowscale', [True])
@pytest.mark.parametrize('has_residual', [True, False]) @pytest.mark.parametrize('has_residual', [True, False])
...@@ -26,11 +35,17 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 ...@@ -26,11 +35,17 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
(torch.float32, torch.float32)] (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) # @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale, has_colscale): dropout_p, has_residual, has_rowscale, has_colscale, is_rms_norm):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported pytest.skip() # Not supported
if is_rms_norm and FusedRMSNorm is None:
pytest.skip() # We need Apex's FusedRMSNorm to test
layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
device = 'cuda' device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4)
...@@ -67,20 +82,22 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w ...@@ -67,20 +82,22 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
if has_colscale: if has_colscale:
x0_scaled_pt = x0_scaled_pt * colscale_pt x0_scaled_pt = x0_scaled_pt * colscale_pt
x0_scaled_ref = x0_scaled_ref * colscale_ref x0_scaled_ref = x0_scaled_ref * colscale_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight) torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias) if not is_rms_norm:
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) torch.nn.init.normal_(model_pt.bias)
model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype) model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
model = our_layer_norm_cls(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
with torch.no_grad(): with torch.no_grad():
model.weight.copy_(model_pt.weight) model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight) model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias) if not is_rms_norm:
model.bias.copy_(model_pt.bias)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p, out, dmask = our_layer_norm_func(x0, x1, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale, layerscale=colscale, model.epsilon, rowscale=rowscale, layerscale=colscale,
residual_in_fp32=residual_in_fp32, return_dropout_mask=True) residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
assert out.dtype == input_dtype assert out.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
if has_residual: if has_residual:
...@@ -101,7 +118,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w ...@@ -101,7 +118,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
if has_residual: if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4 assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5 assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5 if not is_rms_norm:
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
if has_colscale: if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4 assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
...@@ -151,27 +169,34 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh ...@@ -151,27 +169,34 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
# @pytest.mark.parametrize('has_colscale', [True, False]) @pytest.mark.parametrize('is_rms_norm', [False, True])
# @pytest.mark.parametrize('has_rowscale', [True, False]) @pytest.mark.parametrize('has_colscale', [True, False])
# @pytest.mark.parametrize('has_residual', [True, False]) @pytest.mark.parametrize('has_rowscale', [True, False])
# @pytest.mark.parametrize('dropout_p', [0.37, 0.0]) @pytest.mark.parametrize('has_residual', [True, False])
# @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
# @pytest.mark.parametrize('input_dtype,residual_dtype', @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
# [(torch.float16, torch.float16), (torch.float16, torch.float32), @pytest.mark.parametrize('input_dtype,residual_dtype',
# (torch.float32, torch.float32)] [(torch.float16, torch.float16), (torch.float16, torch.float32),
# + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) (torch.float32, torch.float32)]
@pytest.mark.parametrize('has_colscale', [True]) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('has_rowscale', [False]) # @pytest.mark.parametrize('has_colscale', [True])
@pytest.mark.parametrize('has_residual', [True]) # @pytest.mark.parametrize('has_rowscale', [False])
@pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('has_residual', [True])
@pytest.mark.parametrize('weight_dtype', [torch.float32]) # @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) # @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [256]) @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale, has_colscale): dropout_p, has_residual, has_rowscale, has_colscale,
is_rms_norm):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported pytest.skip() # Not supported
if is_rms_norm and FusedRMSNorm is None:
pytest.skip() # We need Apex's FusedRMSNorm to test
layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
device = 'cuda' device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4) rtol, atol = (1e-3, 2e-4)
...@@ -208,23 +233,25 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ ...@@ -208,23 +233,25 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
if has_colscale: if has_colscale:
x0_scaled_pt = x0_scaled_pt * colscale_pt x0_scaled_pt = x0_scaled_pt * colscale_pt
x0_scaled_ref = x0_scaled_ref * colscale_ref x0_scaled_ref = x0_scaled_ref * colscale_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight) torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias) if not is_rms_norm:
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) torch.nn.init.normal_(model_pt.bias)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
dtype=weight_dtype) model = our_layer_norm_cls(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype)
with torch.no_grad(): with torch.no_grad():
model.weight.copy_(model_pt.weight) model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight) model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias) if not is_rms_norm:
model.bias.copy_(model_pt.bias)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p, out, residual, dmask = our_layer_norm_func(x0, x1, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale, model.epsilon, rowscale=rowscale,
layerscale=colscale, prenorm=True, layerscale=colscale, prenorm=True,
residual_in_fp32=residual_in_fp32, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True) return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
if has_residual: if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype) residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
...@@ -247,7 +274,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ ...@@ -247,7 +274,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
if has_residual: if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4 assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4 assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4 if not is_rms_norm:
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
if has_colscale: if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4 assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment