Commit eb33e587 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Rename x1 -> residual

parent f68d41ec
...@@ -59,7 +59,7 @@ struct ParamsBase { ...@@ -59,7 +59,7 @@ struct ParamsBase {
// Common data pointers. // Common data pointers.
void *x0; void *x0;
void *x1; void *residual;
void *x; void *x;
void *dmask; void *dmask;
void *mu; void *mu;
...@@ -117,7 +117,7 @@ struct BwdParams : public ParamsBase { ...@@ -117,7 +117,7 @@ struct BwdParams : public ParamsBase {
, dgamma_part(nullptr) , dgamma_part(nullptr)
, dcolscale_part(nullptr) , dcolscale_part(nullptr)
, dx0(nullptr) , dx0(nullptr)
, dx1(nullptr) , dresidual(nullptr)
, dbeta(nullptr) , dbeta(nullptr)
, dgamma(nullptr) , dgamma(nullptr)
, dcolscale(nullptr) , dcolscale(nullptr)
...@@ -136,7 +136,7 @@ struct BwdParams : public ParamsBase { ...@@ -136,7 +136,7 @@ struct BwdParams : public ParamsBase {
// Output: Dgrad. // Output: Dgrad.
void *dx0; void *dx0;
void *dx1; void *dresidual;
// Output: Wgrad. // Output: Wgrad.
void *dbeta; void *dbeta;
void *dgamma; void *dgamma;
......
...@@ -81,7 +81,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp ...@@ -81,7 +81,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
c10::optional<const at::Tensor> &x1_, // Residual: BxSxhidden_size c10::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
const at::Tensor &gamma, // hidden_size const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &beta_, // hidden_size c10::optional<const at::Tensor> &beta_, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS c10::optional<const at::Tensor> &rowscale_, // BxS
...@@ -97,8 +97,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -97,8 +97,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
bool is_rms_norm=false bool is_rms_norm=false
) { ) {
auto itype = x0.scalar_type(); auto itype = x0.scalar_type();
auto rtype = x1_.has_value() auto rtype = residual_.has_value()
? x1_.value().scalar_type() ? residual_.value().scalar_type()
: (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type()); : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
auto wtype = gamma.scalar_type(); auto wtype = gamma.scalar_type();
auto otype = itype; auto otype = itype;
...@@ -129,11 +129,11 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -129,11 +129,11 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK(gamma.sizes() == beta.sizes()); TORCH_CHECK(gamma.sizes() == beta.sizes());
} }
if (x1_.has_value()) { if (residual_.has_value()) {
auto x1 = x1_.value(); auto residual = residual_.value();
TORCH_CHECK(x1.is_cuda()) TORCH_CHECK(residual.is_cuda())
TORCH_CHECK(x1.is_contiguous()); TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(x1.sizes() == sizes); TORCH_CHECK(residual.sizes() == sizes);
} }
if (rowscale_.has_value()) { if (rowscale_.has_value()) {
...@@ -178,7 +178,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -178,7 +178,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto opts = x0.options(); auto opts = x0.options();
bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype); bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
at::Tensor x; at::Tensor x;
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
at::Tensor dmask; at::Tensor dmask;
...@@ -194,7 +194,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -194,7 +194,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(dropout_p < 1.f); TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr; launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr; launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
...@@ -383,8 +383,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -383,8 +383,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto opts = x.options(); auto opts = x.options();
auto dx0 = torch::empty(x0_sizes, opts.dtype(itype)); auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
at::Tensor dx1; at::Tensor dresidual;
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); } if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
auto dgamma = torch::empty_like(gamma); auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma); auto dbeta = torch::empty_like(gamma);
at::Tensor dcolscale; at::Tensor dcolscale;
...@@ -397,7 +397,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -397,7 +397,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launch_params.props = at::cuda::getCurrentDeviceProperties(); launch_params.props = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dropout_p < 1.f); TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr; launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
...@@ -450,7 +450,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -450,7 +450,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launcher(launch_params, false); launcher(launch_params, false);
std::vector<at::Tensor> result = { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part }; std::vector<at::Tensor> result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part };
if (colscale_.has_value()) { if (colscale_.has_value()) {
result.push_back(dcolscale); result.push_back(dcolscale);
result.push_back(dcolscale_part); result.push_back(dcolscale_part);
...@@ -462,7 +462,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -462,7 +462,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA DropoutAddLayerNorm"; m.doc() = "CUDA DropoutAddLayerNorm";
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel", m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
py::arg("x0"), py::arg("x1"), py::arg("gamma"), py::arg("beta"), py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta"),
py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"), py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false); py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
......
...@@ -37,7 +37,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -37,7 +37,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
extern __shared__ char smem_[]; extern __shared__ char smem_[];
const bool has_residual = params.dx1 != nullptr; const bool has_residual = params.dresidual != nullptr;
const bool prenorm = params.dx != nullptr; const bool prenorm = params.dx != nullptr;
const index_t tidx = threadIdx.x; const index_t tidx = threadIdx.x;
...@@ -164,7 +164,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -164,7 +164,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) { if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec dx0; Ivec dx0;
Rvec dx1; Rvec dresidual;
Ivec x0; Ivec x0;
if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
#pragma unroll #pragma unroll
...@@ -178,7 +178,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -178,7 +178,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
} else { } else {
dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f; dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
} }
if (has_residual) { dx1.data.elt[jt] = dx_tmp_res; } if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }
if (save_dx0) { if (save_dx0) {
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val; compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
if (Is_dropout) { if (Is_dropout) {
...@@ -199,7 +199,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -199,7 +199,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
} }
} }
} }
if (has_residual) { dx1.store_to(params.dx1, idx_x); } if (has_residual) { dresidual.store_to(params.dresidual, idx_x); }
if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); } if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }
idx_x += Ktraits::VEC_COLS_PER_LDG; idx_x += Ktraits::VEC_COLS_PER_LDG;
idx_x0 += Ktraits::VEC_COLS_PER_LDG; idx_x0 += Ktraits::VEC_COLS_PER_LDG;
......
...@@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) {
using Stats = typename Ktraits::Stats; using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t; using stats_t = typename Stats::stats_t;
const bool has_residual = params.x1 != nullptr; const bool has_residual = params.residual != nullptr;
const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value); const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);
extern __shared__ char smem_[]; extern __shared__ char smem_[];
...@@ -111,11 +111,11 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -111,11 +111,11 @@ void ln_fwd_kernel(FwdParams params) {
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) { if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec x0; Ivec x0;
Rvec x1; Rvec residual;
Rvec x; Rvec x;
Mvec dmask; Mvec dmask;
if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
if (has_residual) { x1.load_from(params.x1, idx_x); } if (has_residual) { residual.load_from(params.residual, idx_x); }
#pragma unroll #pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) { for( int jt = 0; jt < NUM_ELTS; jt++ ) {
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
...@@ -127,9 +127,9 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -127,9 +127,9 @@ void ln_fwd_kernel(FwdParams params) {
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); } if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
x_ij = has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij; x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
} else { } else {
x_ij = has_residual ? compute_t(x1.data.elt[jt]) : 0.f; x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f;
} }
if (save_x) { x.data.elt[jt] = x_ij; } if (save_x) { x.data.elt[jt] = x_ij; }
xf[it * NUM_ELTS + jt] = x_ij; xf[it * NUM_ELTS + jt] = x_ij;
......
...@@ -292,7 +292,7 @@ class GPTModel(GPTPreTrainedModel): ...@@ -292,7 +292,7 @@ class GPTModel(GPTPreTrainedModel):
residual = (dropped + residual) if residual is not None else dropped residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
else: else:
# Set prenorm=False here since we don't need to the residual # Set prenorm=False here since we don't need the residual
hidden_states = dropout_add_layer_norm( hidden_states = dropout_add_layer_norm(
hidden_states, residual, self.ln_f.weight, self.ln_f.bias, hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False, self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
...@@ -359,7 +359,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): ...@@ -359,7 +359,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
# Previous: Attn / MLP -> Dropout -> Add -> LN # Previous: Attn / MLP -> Dropout -> Add -> LN
# Current: Dropout -> Add -> LN -> Attn / MLP # Current: Dropout -> Add -> LN -> Attn / MLP
if 'transformer.ln_0.weight' in state_dict: if 'transformer.ln_0.weight' in state_dict:
n_layers = self.config.num_hidden_layers n_layers = len(self.transformer.layers)
ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight') ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias') ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
state_dict['transformer.ln_f.weight'] = ln_weight state_dict['transformer.ln_f.weight'] = ln_weight
......
...@@ -7,20 +7,20 @@ from torch.nn import init ...@@ -7,20 +7,20 @@ from torch.nn import init
import dropout_layer_norm import dropout_layer_norm
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
residual_in_fp32=False, is_rms_norm=False): epsilon, residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous """ Assume that arguments are contiguous
""" """
hidden_size = gamma.numel() hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size)) x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None residualmat = residual.view((-1, hidden_size)) if residual is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon, x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
1.0, 0, None, residual_in_fp32, is_rms_norm 1.0, 0, None, residual_in_fp32, is_rms_norm
) )
# dmask is None if dropout_p == 0.0 # dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
...@@ -28,7 +28,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro ...@@ -28,7 +28,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
dropout_p, has_residual, is_rms_norm=False): dropout_p, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous """ Assume that arguments are contiguous
dx == None means that it was a post-norm architecture dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd). (x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale. x0 must not be None if we have colscale.
""" """
hidden_size = gamma.numel() hidden_size = gamma.numel()
...@@ -39,34 +39,34 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro ...@@ -39,34 +39,34 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
rowscale = rowscale.view(-1) if rowscale is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None
if colscale is not None: if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale' assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None, dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
dropout_p, 1.0, 0, has_residual, is_rms_norm dropout_p, 1.0, 0, has_residual, is_rms_norm
) )
# dx1mat is None if not has_residual # dresidualmat is None if not has_residual
if colscale is None: if colscale is None:
return dx0mat, dx1mat, dgamma, dbeta return dx0mat, dresidualmat, dgamma, dbeta
else: else:
dcolscale = rest[0] dcolscale = rest[0]
return dx0mat, dx1mat, dgamma, dbeta, dcolscale return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset, def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
dropout_p, epsilon, rowscale_const, out_numrows, out_subset, dropout_p, epsilon, rowscale_const,
residual_in_fp32=False, is_rms_norm=False): out_numrows, residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous """ Assume that arguments are contiguous
""" """
hidden_size = gamma.numel() hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size)) x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None residualmat = residual.view((-1, hidden_size)) if residual is not None else None
x0_subset = x0_subset.view(-1) if x0_subset is not None else None x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None out_subset = out_subset.view(-1) if out_subset is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon, x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
) )
# dmask is None if dropout_p == 0.0 # dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
...@@ -75,7 +75,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga ...@@ -75,7 +75,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
x0_numrows, has_residual, is_rms_norm=False): x0_numrows, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous """ Assume that arguments are contiguous
dx == None means that it was a post-norm architecture dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd). (x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale. x0 must not be None if we have colscale.
""" """
hidden_size = gamma.numel() hidden_size = gamma.numel()
...@@ -87,30 +87,30 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga ...@@ -87,30 +87,30 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
out_subset = out_subset.view(-1) if out_subset is not None else None out_subset = out_subset.view(-1) if out_subset is not None else None
if colscale is not None: if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale' assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset, dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
) )
# dx1mat is None if not has_residual # dresidualmat is None if not has_residual
if colscale is None: if colscale is None:
return dx0mat, dx1mat, dgamma, dbeta return dx0mat, dresidualmat, dgamma, dbeta
else: else:
dcolscale = rest[0] dcolscale = rest[0]
return dx0mat, dx1mat, dgamma, dbeta, dcolscale return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
class DropoutAddLayerNormFn(torch.autograd.Function): class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False): residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous() x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None residual = residual.contiguous() if residual is not None else None
gamma = gamma.contiguous() gamma = gamma.contiguous()
beta = beta.contiguous() if beta is not None else None beta = beta.contiguous() if beta is not None else None
rowscale = rowscale.contiguous() if rowscale is not None else None rowscale = rowscale.contiguous() if rowscale is not None else None
colscale = colscale.contiguous() if colscale is not None else None colscale = colscale.contiguous() if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32, is_rms_norm residual_in_fp32, is_rms_norm
) )
# Only need to save x0 if we need to compute gradient wrt colscale # Only need to save x0 if we need to compute gradient wrt colscale
...@@ -118,7 +118,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function): ...@@ -118,7 +118,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale) ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
ctx.prenorm = prenorm ctx.prenorm = prenorm
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.has_residual = x1 is not None ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None ctx.has_beta = beta is not None
if not return_dmask: if not return_dmask:
...@@ -140,29 +140,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function): ...@@ -140,29 +140,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
# x0 is None if colscale is None # x0 is None if colscale is None
dropout_p = ctx.dropout_p dropout_p = ctx.dropout_p
has_residual = ctx.has_residual has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual, dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
ctx.is_rms_norm ctx.is_rms_norm
) )
dx0 = dx0mat.view(x.shape) dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None dcolscale = rest[0] if colscale is not None else None
return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, None, return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
None, None, None, None) None, None, None, None, None)
class DropoutAddLayerNormSubsetFn(torch.autograd.Function): class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32=False, rowscale_const, out_numrows, residual_in_fp32=False,
prenorm=False, is_rms_norm=False, return_dmask=False): prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous() x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None residual = residual.contiguous() if residual is not None else None
gamma = gamma.contiguous() gamma = gamma.contiguous()
beta = beta.contiguous() if beta is not None else None beta = beta.contiguous() if beta is not None else None
colscale = colscale.contiguous() if colscale is not None else None colscale = colscale.contiguous() if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
) )
# Only need to save x0 if we need to compute gradient wrt colscale # Only need to save x0 if we need to compute gradient wrt colscale
...@@ -174,7 +174,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function): ...@@ -174,7 +174,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.rowscale_const = rowscale_const ctx.rowscale_const = rowscale_const
ctx.x0_numrows = x0.shape[:-1].numel() ctx.x0_numrows = x0.shape[:-1].numel()
ctx.has_residual = x1 is not None ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None ctx.has_beta = beta is not None
z_shape = (-1, *x0.shape[1:]) z_shape = (-1, *x0.shape[1:])
...@@ -197,42 +197,42 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function): ...@@ -197,42 +197,42 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
# x0 is None if colscale is None # x0 is None if colscale is None
dropout_p = ctx.dropout_p dropout_p = ctx.dropout_p
has_residual = ctx.has_residual has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p, dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
) )
dx0 = dx0mat.view(-1, *x.shape[1:]) dx0 = dx0mat.view(-1, *x.shape[1:])
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None dcolscale = rest[0] if colscale is not None else None
return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, None, return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
None, None, None, None, None, None, None) None, None, None, None, None, None, None, None)
def layer_norm(x, weight, bias, epsilon): def layer_norm(x, weight, bias, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
prenorm=False, residual_in_fp32=False, layerscale=None, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False): return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None. """residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is x1.dtype. Otherwise residual dtype is residual.dtype.
""" """
return DropoutAddLayerNormFn.apply( return DropoutAddLayerNormFn.apply(
x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
False, return_dropout_mask False, return_dropout_mask
) )
def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None, def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
x0_subset=None, out_subset=None, rowscale_const=1.0, x0_subset=None, out_subset=None, rowscale_const=1.0,
out_numrows=0, prenorm=False, residual_in_fp32=False, out_numrows=0, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False): return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None. """residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is x1.dtype. Otherwise residual dtype is residual.dtype.
""" """
return DropoutAddLayerNormSubsetFn.apply( return DropoutAddLayerNormSubsetFn.apply(
x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
) )
...@@ -254,7 +254,7 @@ class DropoutAddLayerNorm(torch.nn.Module): ...@@ -254,7 +254,7 @@ class DropoutAddLayerNorm(torch.nn.Module):
init.ones_(self.weight) init.ones_(self.weight)
init.zeros_(self.bias) init.zeros_(self.bias)
def forward(self, x0, x1=None): def forward(self, x0, residual=None):
return dropout_add_layer_norm(x0, x1, self.weight, self.bias, return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
self.p if self.training else 0.0, self.epsilon, self.p if self.training else 0.0, self.epsilon,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32) prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
...@@ -12,26 +12,27 @@ def rms_norm(x, weight, epsilon): ...@@ -12,26 +12,27 @@ def rms_norm(x, weight, epsilon):
False, True) False, True)
def dropout_add_rms_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
prenorm=False, residual_in_fp32=False, return_dropout_mask=False): layerscale=None, prenorm=False, residual_in_fp32=False,
"""residual_in_fp32 only has an effect if x1 is None. return_dropout_mask=False):
Otherwise residual dtype is x1.dtype. """residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
""" """
return DropoutAddLayerNormFn.apply( return DropoutAddLayerNormFn.apply(
x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
True, return_dropout_mask True, return_dropout_mask
) )
def dropout_add_rms_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None, def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
x0_subset=None, out_subset=None, rowscale_const=1.0, x0_subset=None, out_subset=None, rowscale_const=1.0,
out_numrows=0, prenorm=False, residual_in_fp32=False, out_numrows=0, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False): return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None. """residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is x1.dtype. Otherwise residual dtype is residual.dtype.
""" """
return DropoutAddLayerNormSubsetFn.apply( return DropoutAddLayerNormSubsetFn.apply(
x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
) )
...@@ -52,7 +53,7 @@ class DropoutAddRMSNorm(torch.nn.Module): ...@@ -52,7 +53,7 @@ class DropoutAddRMSNorm(torch.nn.Module):
def reset_parameters(self): def reset_parameters(self):
init.ones_(self.weight) init.ones_(self.weight)
def forward(self, x0, x1=None): def forward(self, x0, residual=None):
return dropout_add_rms_norm(x0, x1, self.weight, None, return dropout_add_rms_norm(x0, residual, self.weight, None,
self.p if self.training else 0.0, self.epsilon, self.p if self.training else 0.0, self.epsilon,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32) prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment