Unverified Commit 5b4d89c3 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Add backward RMSNorm+Add fusion (#2028)



* Add rmsnorm_bwd_add
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add BackwardAddRMSNorm fused operation
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Try to optimize register usage in kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add separate BackwardAdd stage for the fused backward add
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
parent 1d075c06
......@@ -27,10 +27,19 @@ namespace {
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
NormType norm_type, bool use_cudnn, const bool zero_centered_gamma_in_weight_dtype) {
const NormType norm_type, const bool use_cudnn,
const bool zero_centered_gamma_in_weight_dtype, const bool fused_bwd_add) {
if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
return;
}
if (norm_type == LayerNorm && fused_bwd_add) {
GTEST_SKIP() << "Fused LN backward+add not currently supported";
}
if (fused_bwd_add && zero_centered_gamma_in_weight_dtype) {
GTEST_SKIP() << "zero_centered_gamma_in_weight_dtype not currently supported "
<< "in fused norm backward+add";
}
if (getDeviceComputeCapability() < hopperComputeCapability && use_cudnn) {
......@@ -45,7 +54,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
if ((itype == DType::kBFloat16 && otype == DType::kFloat16) ||
(itype == DType::kFloat16 && otype == DType::kBFloat16)) {
GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16";
return;
}
Tensor input("input", std::vector<size_t>{ N, H }, itype);
......@@ -55,6 +63,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
Tensor mu("mu", std::vector<size_t>{ N }, DType::kFloat32);
Tensor rsigma("rsigma", std::vector<size_t>{ N }, DType::kFloat32);
Tensor dz("dz", std::vector<size_t>{ N, H }, wtype);
Tensor bwd_add("bwd_add", std::vector<size_t>{ N, H }, wtype);
Tensor dx("dx", std::vector<size_t>{ N, H }, itype);
Tensor dgamma("dgamma", std::vector<size_t>{ H }, wtype);
Tensor dbeta("dbeta", std::vector<size_t>{ H }, wtype);
......@@ -65,6 +74,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
fillUniform(&beta);
setRandomScale(&z);
fillUniform(&dz);
if (fused_bwd_add) {
fillUniform(&bwd_add);
} else {
fillCase<WeightType>(&bwd_add, zeros);
}
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<float[]> ref_mu = std::make_unique<float[]>(N);
......@@ -85,7 +99,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
nvte_enable_cudnn_norm_fwd(true);
nvte_enable_cudnn_norm_bwd(true);
// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(true);
......@@ -125,15 +138,23 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
z.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(),
workspace_bwd.data(),
if (fused_bwd_add) {
nvte_rmsnorm_bwd_add(dz.data(), input.data(), bwd_add.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(),
workspace_bwd.data(),
nvte_rmsnorm_bwd_add(dz.data(), input.data(), bwd_add.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
} else {
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount,
zero_centered_gamma, 0);
workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount,
zero_centered_gamma, 0);
}
}
if (use_cudnn){
......@@ -167,6 +188,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
use_cudnn,
zero_centered_gamma_in_weight_dtype);
compute_ref_backward(norm_type, dz.rowwise_cpu_dptr<WeightType>(),
bwd_add.rowwise_cpu_dptr<WeightType>(),
input.rowwise_cpu_dptr<InputType>(),
mu.rowwise_cpu_dptr<float>(), rsigma.rowwise_cpu_dptr<float>(),
gamma.rowwise_cpu_dptr<WeightType>(),
......@@ -214,11 +236,12 @@ std::vector<std::pair<size_t, size_t>> test_cases = {
} // namespace
class NormTestSuite : public ::testing::TestWithParam<std::tuple<bool,
NormType,
transformer_engine::DType,
NormType,
transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool,
bool,
bool>> {};
TEST_P(NormTestSuite, TestNorm) {
......@@ -231,11 +254,20 @@ TEST_P(NormTestSuite, TestNorm) {
const DType output_type = std::get<3>(GetParam());
const auto size = std::get<4>(GetParam());
const bool zero_centered_gamma = std::get<5>(GetParam());
const bool cudnn_zero_centered_gamm_in_weight_dtype = std::get<6>(GetParam());
const bool cudnn_zero_centered_gamma_in_weight_dtype = std::get<6>(GetParam());
const bool fused_bwd_add = std::get<7>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn, cudnn_zero_centered_gamm_in_weight_dtype);
performTest<InputType, OutputType>(
size.first,
size.second,
zero_centered_gamma,
norm_type,
use_cudnn,
cudnn_zero_centered_gamma_in_weight_dtype,
fused_bwd_add
);
);
);
}
......@@ -250,6 +282,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(false, true),
::testing::Values(false, true),
::testing::Values(false, true)),
[](const testing::TestParamInfo<NormTestSuite::ParamType>& info) {
auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn";
......@@ -261,6 +294,7 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<4>(info.param).first) + "X" +
std::to_string(std::get<4>(info.param).second) + "X" +
std::to_string(std::get<5>(info.param)) + "X" +
std::to_string(std::get<6>(info.param));
std::to_string(std::get<6>(info.param)) + "X" +
std::to_string(std::get<7>(info.param));
return name;
});
......@@ -126,7 +126,8 @@ void compute_ref_output(NormType norm_type,
template <typename InputType, typename OutputType>
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data,
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad,
const OutputType *add, const InputType *data,
const float *mu, const float *rsigma,
const InputType *gamma,
InputType *data_grad,
......@@ -165,7 +166,8 @@ void compute_ref_backward(const NormType norm_type, const OutputType *output_gra
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
const compute_t a = static_cast<compute_t>(add[i * H + j]);
const compute_t dx = a + rsigma[i] * (dy - mdyy * y - mdy);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
......
......@@ -844,9 +844,18 @@ void fillCase(Tensor *t, const InputsFillCase fill_case) {
}
}
template void fillCase<byte>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int64>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<bf16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e4m3>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e5m2>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);
#if FP4_TYPE_SUPPORTED
template void fillCase<fp4e2m1>(Tensor *t, const InputsFillCase fill_case);
#endif
void setRandomScale(Tensor *t) {
std::uniform_real_distribution<> dis(-2.0, 1.0);
......
......@@ -21,6 +21,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import (
BackwardActivationBias,
BackwardAddRMSNorm,
BackwardLinearAdd,
BackwardLinearScale,
ForwardLinearBiasActivation,
......@@ -2206,6 +2207,94 @@ class TestFusedOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("weight_shape", ((19,), (64,)))
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True))
def test_backward_add_rmsnorm(
self,
*,
weight_shape: Iterable[int],
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
eps: float = 0.3,
zero_centered_gamma: bool,
) -> None:
"""Fused backward RMNorm + add"""
# Make input and weight shapes consistent
in_shape = list(in_shape)[:-1] + list(weight_shape)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
weight_shape,
test_dtype=dtype,
test_device=device,
)
dy1_ref, dy1_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
dy2_ref, dy2_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape)))
var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape)
if zero_centered_gamma:
y1_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref)
else:
y1_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref
y2_ref = x_ref
(y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()
# Implementation with fusible operations
model = te_ops.Sequential(
te_ops.MakeExtraOutput(),
te_ops.RMSNorm(
weight_shape,
eps=eps,
device=device,
dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
),
)
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
y1_test, y2_test = model(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward()
# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], BackwardAddRMSNorm)
# Expected numerical error
tols = dtype_tols(dtype)
# Check results
y1_test = y1_test.to(dtype=torch.float64, device="cpu")
y2_test = y2_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y1_test, y1_ref, **tols)
torch.testing.assert_close(y2_test, y2_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_linear_add(
......
......@@ -24,7 +24,7 @@ extern "C" {
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta
* @f]
*
* Calling this function with workspace set to empty tensor will not perform the operation,
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] x Input tensor of shape [N, H].
......@@ -55,8 +55,8 @@ void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETe
* else
* with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$.
*
* Calling this function with workspace set to empty tensor will not perform the operation,
* but instead set the shape and type of these tensors to the required values.
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
......@@ -90,9 +90,8 @@ void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETenso
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
* and barrier tensors to the required values.
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
......@@ -121,9 +120,8 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep
* @f]
* with respect to \f$x\f$ and \f$gamma\f$.
*
* Calling this function with workspace, barrier, dgamma_part set
* to empty tensor will not perform the operation, but instead set the shape and type
* of these tensors to the required values.
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
......@@ -142,6 +140,29 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
NVTETensor workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream);
/*! \brief Compute backward of RMSNorm and add additional tensor to output gradient
*
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] add Additional tensor to add to output gradient [N, H].
* \param[in] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] workspace Workspace tensor.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_rmsnorm_bwd_add(const NVTETensor dz, const NVTETensor x, const NVTETensor add,
const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx,
NVTETensor dgamma, NVTETensor workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream);
/*! \brief Helper to enable cuDNN backend for normalization
*
* \param[in] bool Enable if True
......
......@@ -156,7 +156,7 @@ void TeNormalizationPlan<KernelParamsType>::_set_workspace() {
template <>
void TeNormalizationPlan<ForwardKernelParams>::execute(void* x_dptr, void* gamma_dptr,
void* mean_dptr, void* rsigma_dptr,
void* dx_dptr, void* dz_dptr,
void* dx_dptr, void* dz_dptr, void* add_dptr,
void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) {
NVTE_ERROR("Forward normalization should not call the backward execute function!");
......@@ -166,8 +166,9 @@ template <>
void TeNormalizationPlan<BackwardKernelParams>::execute(void* x_dptr, void* gamma_dptr,
void* mean_dptr, void* rsigma_dptr,
void* dx_dptr, void* dz_dptr,
void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) {
void* add_dptr, void* dbeta_dptr,
void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) {
_launch_params.stream = stream;
auto& kernel_params = _launch_params.params;
......@@ -177,6 +178,7 @@ void TeNormalizationPlan<BackwardKernelParams>::execute(void* x_dptr, void* gamm
kernel_params.rs = rsigma_dptr;
kernel_params.dx = dx_dptr;
kernel_params.dz = dz_dptr;
kernel_params.add = add_dptr;
kernel_params.dgamma = dgamma_dptr;
if (_is_layernorm) {
......@@ -447,8 +449,11 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr,
void* rsigma_dptr, void* dx_dptr, void* dz_dptr,
void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) {
void* add_dptr, void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) {
// cuDNN does not currently support fused backward+add
NVTE_CHECK(add_dptr == nullptr);
// Binding data pointers to graph tensors
_variant_pack = {
{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}};
......
......@@ -126,6 +126,9 @@ struct BackwardKernelParams : public KernelParamsBase {
// Input: gradient wrt. LN FWD output.
void* dz;
// Input: extra tensor to add for fused backward+add
void* add;
// Workspace for Wgrad pre-reduction.
void* dbeta_part;
void* dgamma_part;
......@@ -137,8 +140,10 @@ struct BackwardKernelParams : public KernelParamsBase {
void* dgamma;
};
using BackwardAddKernelParams = BackwardKernelParams;
enum class NVTE_Norm_Backend { Te, Cudnn };
enum class NVTE_Norm_Stage { Forward, Backward };
enum class NVTE_Norm_Stage { Forward, Backward, BackwardAdd };
using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>;
struct TupleHash {
......@@ -221,8 +226,8 @@ class NormalizationPlanBase {
cudaStream_t stream) = 0;
virtual void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr,
void* dx_dptr, void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) = 0;
void* dx_dptr, void* dz_dptr, void* add_dptr, void* dbeta_dptr,
void* dgamma_dptr, void* workspace_dptr, cudaStream_t stream) = 0;
private:
virtual void _build() = 0;
......@@ -241,8 +246,8 @@ class TeNormalizationPlan : public NormalizationPlanBase {
cudaStream_t stream) override;
void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr,
void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) override;
void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) override;
private:
void _set_workspace();
......@@ -270,8 +275,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
cudaStream_t stream) override;
void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr,
void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) override;
void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) override;
private:
void _build() override;
......
......@@ -185,7 +185,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
} else {
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
plan->execute(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr, workspace->data.dptr, stream);
dz.data.dptr, nullptr /*add*/, dbeta->data.dptr, dgamma->data.dptr,
workspace->data.dptr, stream);
}
return;
}
......
......@@ -162,7 +162,74 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
} else {
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
plan->execute(x.data.dptr, gamma.data.dptr, nullptr /*mu*/, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, nullptr /*dbeta*/, dgamma->data.dptr, workspace->data.dptr, stream);
dz.data.dptr, nullptr /*add*/, nullptr /*dbeta*/, dgamma->data.dptr,
workspace->data.dptr, stream);
}
return;
}
void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const Tensor &rsigma,
const Tensor &gamma, Tensor *dx, Tensor *dgamma, Tensor *workspace,
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(dz.data.dtype == gamma.data.dtype);
NVTE_CHECK(add.data.dtype == gamma.data.dtype);
NVTE_CHECK(rsigma.data.dtype == DType::kFloat32);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
NVTE_CHECK(add.data.shape == x.data.shape);
NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]);
NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(add, "add");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
}
// cuDNN does not currently support fused backward+add
NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te;
// TE backend does not currently support zero_centered_gamma_in_weight_dtype
NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(),
"zero_centered_gamma_in_weight_dtype is currently not supported for rmsnorm_bwd_add");
bool is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, dgamma->data.dptr, add.data.dptr);
bool gamma_in_weight_dtype = false;
auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan(
norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::BackwardAdd,
gamma.data.dtype, // wtype
x.data.dtype, // itype
gamma.data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
plan->execute(x.data.dptr, gamma.data.dptr, nullptr /*mu*/, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, add.data.dptr, nullptr /*dbeta*/, dgamma->data.dptr,
workspace->data.dptr, stream);
}
return;
}
......@@ -195,3 +262,19 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
convertNVTETensor(dx), convertNVTETensor(dgamma), convertNVTETensor(workspace),
multiprocessorCount, zero_centered_gamma, stream);
}
void nvte_rmsnorm_bwd_add(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size
const NVTETensor add, // Nxhidden_size
const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor workspace,
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
NVTE_API_CALL(nvte_rmsnorm_bwd_add);
using namespace transformer_engine;
rmsnorm_bwd_add(*convertNVTETensorCheck(dz), *convertNVTETensorCheck(x),
*convertNVTETensorCheck(add), *convertNVTETensorCheck(rsigma),
*convertNVTETensorCheck(gamma), convertNVTETensor(dx), convertNVTETensor(dgamma),
convertNVTETensor(workspace), multiprocessorCount, zero_centered_gamma, stream);
}
......@@ -7,13 +7,31 @@
#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
#include <type_traits>
#include "../../utils.cuh"
#include "../common.h"
namespace transformer_engine {
namespace normalization {
template <typename Ktraits>
struct maybe_not_t {};
template <typename T, bool Enabled>
using maybe_t = std::conditional_t<Enabled, T, maybe_not_t>;
template <typename Ivec, typename Ovec, bool FusedAdd>
union dx_add_t {
using add_t = maybe_t<Ovec, FusedAdd>;
using dx_t = Ivec;
struct {
char _padding[sizeof(dx_t) > sizeof(add_t) ? sizeof(dx_t) - sizeof(add_t) : 0];
add_t add;
};
dx_t dx;
};
template <typename Ktraits, bool FusedAdd>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel(
BackwardKernelParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
......@@ -111,10 +129,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke
}
}
dx_add_t<Ivec, Ovec, FusedAdd> temp[LDGS];
if constexpr (FusedAdd) {
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
temp[it].add.load_from(params.add, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
}
reduce_t result = reducer.allreduce({0, mdyy_local}, sum);
mdyy_local = Get<1>::of<reduce_t, compute_t>(result) * rn;
Ivec dx[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
......@@ -123,9 +150,13 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp));
dx[it].data.elt[jt] = dx_tmp;
if constexpr (FusedAdd) {
compute_t add_tmp = temp[it].add.data.elt[jt];
dx_tmp += add_tmp;
}
dx[it].store_to(params.dx, idx);
temp[it].dx.data.elt[jt] = dx_tmp;
}
temp[it].dx.store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
......@@ -274,7 +305,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_fi
}
}
template <typename Ktraits>
template <typename Ktraits, bool FusedAdd>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel(
BackwardKernelParams params) {
enum { LDGS = Ktraits::LDGS };
......@@ -379,14 +410,22 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec dx;
dx_add_t<Ivec, Ovec, FusedAdd> temp;
if constexpr (FusedAdd) {
temp.add.load_from_elts(params.add, row * params.cols + col, params.cols - col);
}
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_ij = dy[it].data.elt[jt];
compute_t y_ij = y[it].data.elt[jt];
dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij));
compute_t dx_ij = rs * (dy_ij - (mdyy * y_ij));
if constexpr (FusedAdd) {
compute_t add_ij = temp.add.data.elt[jt];
dx_ij += add_ij;
}
temp.dx.data.elt[jt] = dx_ij;
}
dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
temp.dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
}
}
......
......@@ -12,17 +12,17 @@ using namespace transformer_engine::normalization;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL, bool FUSED_ADD = false>
void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_tuned_kernel<Kernel_traits>;
auto kernel = &rmsnorm_bwd_tuned_kernel<Kernel_traits, FUSED_ADD>;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
int ctas_per_sm = 0;
NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
......@@ -52,9 +52,9 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), Kernel_traits::SMEM_BYTES,
stream);
NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_),
Kernel_traits::SMEM_BYTES, stream));
}
using Kernel_traits_f =
......@@ -69,7 +69,7 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL>
int BYTES_PER_LDG_FINAL, bool FUSED_ADD = false>
void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
......@@ -77,7 +77,7 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
// Instantiate kernel
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_general_kernel<Kernel_traits>;
auto kernel = &rmsnorm_bwd_general_kernel<Kernel_traits, FUSED_ADD>;
// Configure kernel params
const int rows = launch_params.params.rows;
......@@ -85,9 +85,9 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel,
Kernel_traits::THREADS_PER_CTA, 0);
int ctas_per_sm = 0;
NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0));
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
......@@ -112,8 +112,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream));
}
// Launch finalization kernel
......@@ -143,7 +143,7 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \
} // namespace
// Create rmsnorm tuned launch function and register. Macro signature:
// Create rmsnorm bwd tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
......@@ -171,7 +171,7 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// Create rmsnorm general launch function and register. Macro signature:
// Create rmsnorm bwd general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
......@@ -204,3 +204,108 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp16, fp16, fp32,
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
// Create fused rmsnorm bwd + add tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4,
true);
// Create fused rmsnorm bwd + add general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, fp16, fp32, fp16, fp32, 4, 1, 8, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, bf16, fp32, bf16, fp32, 4, 1, 8, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, fp16, fp32, fp16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4,
true);
REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4,
true);
......@@ -208,6 +208,11 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &rsigma, const at::Tensor &gamma,
const int sm_margin, const bool zero_centered_gamma);
std::vector<py::object> rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &add, const at::Tensor &rsigma,
const at::Tensor &gamma, const int sm_margin,
const bool zero_centered_gamma);
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object ln_out, py::handle quantizer, DType otype,
const int sm_margin, const bool zero_centered_gamma);
......
......@@ -199,6 +199,52 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
return {py::cast(dx), py::cast(dgamma)};
}
std::vector<py::object> rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &add, const at::Tensor &rsigma,
const at::Tensor &gamma, const int sm_margin,
const bool zero_centered_gamma) {
const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous();
const auto &add_ = add.contiguous();
const auto &rsigma_ = rsigma.contiguous();
const auto &gamma_ = gamma.contiguous();
auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_);
TensorWrapper workspace;
auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_);
auto add_cu = makeTransformerEngineTensor(add_);
auto rsigma_cu = makeTransformerEngineTensor(rsigma_);
auto gamma_cu = makeTransformerEngineTensor(gamma_);
auto dx_cu = makeTransformerEngineTensor(dx);
auto dgamma_cu = makeTransformerEngineTensor(dgamma);
// This call populates tensors with the required config.
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_bwd_add(dz_cu.data(), x_cu.data(), add_cu.data(), rsigma_cu.data(),
gamma_cu.data(), dx_cu.data(), dgamma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Actual call to bwd kernel.
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_bwd_add(dz_cu.data(), x_cu.data(), add_cu.data(), rsigma_cu.data(),
gamma_cu.data(), dx_cu.data(), dgamma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
return {py::cast(dx), py::cast(dgamma)};
}
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object out, py::handle quantizer, DType out_dtype,
const int sm_margin, const bool zero_centered_gamma) {
......
......@@ -202,6 +202,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm");
m.def("rmsnorm_bwd_add", &transformer_engine::pytorch::rmsnorm_bwd_add,
"Fused backward of RMSNorm + add");
m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize,
"Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"));
m.def("split_quantize", &transformer_engine::pytorch::split_quantize,
......
......@@ -8,6 +8,10 @@ from .backward_activation_bias import (
BackwardActivationBias,
fuse_backward_activation_bias,
)
from .backward_add_rmsnorm import (
BackwardAddRMSNorm,
fuse_backward_add_rmsnorm,
)
from .backward_linear_add import (
BackwardLinearAdd,
fuse_backward_linear_add,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward RMNorm + add."""
from __future__ import annotations
from typing import Optional
import math
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.ops.basic import MakeExtraOutput, RMSNorm
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data
from .._common import maybe_dequantize
class BackwardAddRMSNorm(FusedOperation):
"""Fused backward RMNorm + add"""
def __init__(self, *, add: MakeExtraOutput, rmsnorm: RMSNorm):
super().__init__((add, rmsnorm))
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
list[tuple[Optional[torch.Tensor], ...]],
list[tuple[()]],
]:
# Get basic operations
rmsnorm_op = self.basic_ops[1]
rmsnorm_op_ctx = basic_op_ctxs[0]
# Saved tensors from forward pass
x, rstdevs = rmsnorm_op_ctx.saved_tensors
# Tensor dims
weight_dims = rmsnorm_op.weight.size()
inner_dim = math.prod(weight_dims)
# Check input tensors
dtype = rmsnorm_op_ctx.dtype
extra_grad = basic_op_grad_extra_outputs[1][0]
dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size())
w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,))
add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size())
# Compute RMSNorm backward pass
dx, dw = tex.rmsnorm_bwd_add(
dy,
x,
add,
rstdevs,
w,
rmsnorm_op._sm_margins["backward"],
rmsnorm_op.zero_centered_gamma,
)
# Clear saved tensors if possible
clear_tensor_data(x)
clear_tensor_data(rstdevs)
# Reshape results
grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims)
return grad_input, [(grad_weight,), ()], [(), ()]
def fuse_backward_add_rmsnorm(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward RMNorm + add
Parameters
----------
ops: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, RMSNorm):
continue
# Check if second op is "make extra output"
op, _ = ops[0]
if not isinstance(op, MakeExtraOutput):
continue
if op._in_place:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardAddRMSNorm(
rmsnorm=window[0][0],
add=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
......@@ -19,6 +19,7 @@ from transformer_engine.pytorch.ops.op import (
)
from transformer_engine.pytorch.ops.fused import (
fuse_backward_activation_bias,
fuse_backward_add_rmsnorm,
fuse_backward_linear_add,
fuse_backward_linear_scale,
fuse_forward_linear_bias_activation,
......@@ -371,6 +372,7 @@ class OperationFuser:
ops = fuse_backward_linear_add(ops)
ops = fuse_backward_linear_scale(ops)
ops = fuse_backward_activation_bias(ops, recipe)
ops = fuse_backward_add_rmsnorm(ops)
return ops
def maybe_fuse_ops(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment