Unverified Commit c654e4fe authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Fuse linear+scale+add (#2042)



* Add `nvte_cublas_gemm_scaled`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Support use of `alpha` and `beta` in `tex.generic_gemm`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Support use of `alpha` and `beta` in `general_gemm`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Support use of `alpha` and `beta` in `BasicLinear._functional_forward` and `BasicLinear._functional_backward`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add `ForwardLinearScaleAdd` fusion
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add `BackwardLinearScale` fusion
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Remove calls to `validate_gemm_scale` from `BasicLinear`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 92f431bf
......@@ -22,8 +22,10 @@ import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import (
BackwardActivationBias,
BackwardLinearAdd,
BackwardLinearScale,
ForwardLinearBiasActivation,
ForwardLinearBiasAdd,
ForwardLinearScaleAdd,
)
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
......@@ -2008,6 +2010,109 @@ class TestFusedOps:
db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_forward_linear_scale_add(
self,
*,
scale: float,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool = False,
) -> None:
"""Forward GEMM + scale + add"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
x2_ref, x2_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x1_ref, w_ref) * scale + x2_ref
y_ref.backward(dy_ref)
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
out_features,
bias=False,
device=device,
dtype=dtype,
),
te_ops.ConstantScale(scale),
te_ops.AddExtraInput(in_place=True),
te_ops.Quantize(),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x1_test, x2_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 2
assert isinstance(forward_ops[0][0], ForwardLinearScaleAdd)
assert isinstance(forward_ops[1][0], te_ops.Quantize)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx1_test, x1_ref.grad, **tols)
torch.testing.assert_close(dx2_test, x2_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("activation", ("relu", "gelu"))
@pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
......@@ -2202,6 +2307,99 @@ class TestFusedOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_linear_scale(
self,
*,
scale: float,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool = False,
) -> None:
"""Backward dgrad GEMM + scale"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) * scale
y_ref.backward(dy_ref)
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
out_features,
bias=False,
device=device,
dtype=dtype,
),
te_ops.ConstantScale(scale),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test)
(y_test * dy_test).sum().backward()
# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], BackwardLinearScale)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
class TestCheckpointing:
"""Tests for checkpointing"""
......
......@@ -238,8 +238,9 @@ using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublas
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) {
float alpha, float beta, bool use_split_accumulator, int math_sm_count,
int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter,
cudaStream_t stream) {
// Tensor dims in row-major order
const int A0 = inputA->flat_first_dim();
const int A1 = inputA->flat_last_dim();
......@@ -295,13 +296,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"fp8 Aux output for gemm + gelu fusion not supported!");
}
if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
NVTE_CHECK(beta == 0.0f, "Accumulation mode not supported with FP8 GEMM output!");
}
float one = 1.0;
float zero = 0.0;
float beta = (accumulate) ? one : zero;
cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
cublasLtMatmulDesc_t operationDesc = nullptr;
......@@ -586,7 +583,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
static_cast<const void *>(&one), /* alpha */
static_cast<const void *>(&alpha), /* alpha */
param.A, /* A */
Adesc, param.B, /* B */
Bdesc, static_cast<const void *>(&beta), /* beta */
......@@ -629,7 +626,26 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false,
nullptr, stream);
}
void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
bool transb, bool grad, NVTETensor workspace, float alpha, float beta,
bool use_split_accumulator, int math_sm_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm_scaled);
using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace);
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
alpha, beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
}
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
......@@ -671,8 +687,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Atomic GEMM only supports delayed scaling.");
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer,
inputCounter, stream);
1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split,
n_split, gemm_producer, inputCounter, stream);
}
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
......
......@@ -44,6 +44,36 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations,
* allowing for using a scaling factor for the GEMM result and the accumulation input
*
* Computes:
* - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = alpha*AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(alpha*AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The A matrix.
* \param[in] B The B matrix.
* \param[in,out] D Output matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_gelu_out Output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace Workspace tensor.
* \param[in] alpha Scaling factor applied to the result of the GEMM
* \param[in] beta Scaling factor applied to original value of D when
* accumulating into it. beta=0 means no accumulation.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
bool transb, bool grad, NVTETensor workspace, float alpha, float beta,
bool use_split_accumulator, int math_sm_count, cudaStream_t stream);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
* \warning Cublas atomic gemm uses a beta API and is not tested for all use cases.
......
......@@ -21,6 +21,15 @@ __all__ = [
]
def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
"""Validate whether a GEMM scaling factor is consistent with its usage"""
if required:
return scale if scale is not None else 1.0
if scale not in (0.0, None):
raise ValueError("scale must be zero")
return 0.0
def general_gemm(
A: torch.Tensor,
B: torch.Tensor,
......@@ -29,6 +38,8 @@ def general_gemm(
quantization_params: Optional[Quantizer] = None,
gelu: bool = False,
gelu_in: torch.Tensor = None,
alpha: float = 1.0,
beta: Optional[float] = None,
accumulate: bool = False,
layout: str = "TN",
out: Optional[torch.Tensor] = None,
......@@ -47,6 +58,9 @@ def general_gemm(
transb = layout[1] == "T"
# assert quantization_params is None, "FP8 output not supported yet"
alpha = validate_gemm_scale(alpha, True)
beta = validate_gemm_scale(beta, accumulate)
if ub_type is not None:
assert ub is not None, (
f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires"
......@@ -108,6 +122,8 @@ def general_gemm(
"comm_type": ub_type,
"extra_output": extra_output,
"bulk_overlap": bulk_overlap,
"alpha": alpha,
"beta": beta,
}
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
......
......@@ -122,7 +122,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
std::optional<CommOverlapType> comm_type = std::nullopt,
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false,
float alpha = 1.0f, std::optional<float> beta = std::nullopt);
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
......
......@@ -92,7 +92,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore* comm_overlap,
std::optional<CommOverlapType> comm_type, MaybeTensor extra_output,
bool bulk_overlap) {
bool bulk_overlap, float alpha, std::optional<float> beta) {
// Input tensors
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
NVTE_CHECK(!B.is_none(), "Tensor B has not been provided");
......@@ -110,6 +110,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension");
NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension");
// Check scaling factors
if (accumulate) {
if (!beta) {
beta = 1.0f;
}
} else {
if (!beta) {
beta = 0.0f;
}
NVTE_CHECK(beta == 0.0, "Trying to use non-zero beta while not accumulating ",
"into D tensor. Beta has nothing to be applied to.");
}
// Output tensor
TensorWrapper D_tensor;
if (D.is_none()) {
......@@ -238,9 +251,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
} else {
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, num_math_sms, main_stream);
nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(),
bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad,
te_workspace.data(), alpha, *beta, use_split_accumulator,
num_math_sms, main_stream);
});
}
} else {
......
......@@ -111,7 +111,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"),
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false,
py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt);
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
......
......@@ -350,10 +350,12 @@ class BasicLinear(BasicOperation):
input: torch.Tensor, # pylint: disable=redefined-builtin
weight: torch.Tensor,
*,
alpha: float = 1.0,
bias: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None, # pylint: disable=unused-argument
dtype: Optional[torch.dtype] = None,
out: Optional[torch.Tensor] = None,
beta: Optional[float] = None,
accumulate_into_out: bool = False,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
......@@ -373,6 +375,8 @@ class BasicLinear(BasicOperation):
Input tensor
weight: torch.Tensor
Weight tensor
alpha: float, default = 1.0
Scaling factor applied to the result of the GEMM
bias: torch.Tensor, optional
Bias tensor
device: torch.device, default = default CUDA device
......@@ -381,6 +385,8 @@ class BasicLinear(BasicOperation):
Tensor datatype
out: torch.Tensor, optional
Output tensor
beta: float, optional
Scaling factor applied to original value of out when accumulating into it
accumulate_into_out: bool, default = `False`
Add result to output tensor instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
......@@ -530,6 +536,8 @@ class BasicLinear(BasicOperation):
get_workspace(),
out_dtype=dtype,
quantization_params=output_quantizer,
alpha=alpha,
beta=beta,
accumulate=accumulate_into_out,
out=y,
bias=bias,
......@@ -567,13 +575,17 @@ class BasicLinear(BasicOperation):
input: Optional[torch.Tensor], # pylint: disable=redefined-builtin
weight: Optional[torch.Tensor],
*,
grad_input_alpha: Optional[float] = None,
input_requires_grad: bool = True,
grad_weight_alpha: Optional[float] = None,
weight_requires_grad: bool = True,
device: Optional[torch.device] = None, # pylint: disable=unused-argument
dtype: Optional[torch.dtype] = None,
grad_weight: Optional[torch.Tensor] = None,
grad_weight_beta: Optional[float] = None,
accumulate_into_grad_weight: bool = False,
grad_input: Optional[torch.Tensor] = None,
grad_input_beta: Optional[float] = None,
accumulate_into_grad_input: bool = False,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
......@@ -596,8 +608,12 @@ class BasicLinear(BasicOperation):
weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t.
input.
grad_input_alpha: float, optional
Scaling factor applied to the result of the dgrad GEMM
input_requires_grad: bool
Whether to compute loss gradient w.r.t. input tensor
grad_weight_alpha: float, optional
Scaling factor applied to the result of the wgrad GEMM
weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor
device: torch.device, default = default CUDA device
......@@ -606,10 +622,14 @@ class BasicLinear(BasicOperation):
Tensor datatype
grad_weight: torch.Tensor, optional
Loss gradient w.r.t. weight tensor
grad_weight_beta: float, optional
Scaling factor applied to original value of grad_weight when accumulating into it
accumulate_into_grad_weight: bool, default = `False`
Add result to weight grad instead of overwriting
grad_input: torch.Tensor, optional
Loss gradient w.r.t. input tensor
grad_input_beta: float, optional
Scaling factor applied to original value of grad_input when accumulating into it
accumulate_into_grad_input: bool, default = `False`
Add result to input grad instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
......@@ -806,6 +826,8 @@ class BasicLinear(BasicOperation):
get_workspace(),
out_dtype=dtype,
quantization_params=grad_input_quantizer,
alpha=grad_input_alpha,
beta=grad_input_beta,
accumulate=accumulate_into_grad_input,
layout="NN",
out=dx,
......@@ -856,6 +878,8 @@ class BasicLinear(BasicOperation):
dy,
get_workspace(),
out_dtype=dw_dtype,
alpha=grad_weight_alpha,
beta=grad_weight_beta,
accumulate=accumulate_into_grad_weight,
layout="NT",
out=dw,
......
......@@ -12,6 +12,10 @@ from .backward_linear_add import (
BackwardLinearAdd,
fuse_backward_linear_add,
)
from .backward_linear_scale import (
BackwardLinearScale,
fuse_backward_linear_scale,
)
from .forward_linear_bias_activation import (
ForwardLinearBiasActivation,
fuse_forward_linear_bias_activation,
......@@ -20,6 +24,10 @@ from .forward_linear_bias_add import (
ForwardLinearBiasAdd,
fuse_forward_linear_bias_add,
)
from .forward_linear_scale_add import (
ForwardLinearScaleAdd,
fuse_forward_linear_scale_add,
)
from .userbuffers_backward_linear import (
UserbuffersBackwardLinear,
fuse_userbuffers_backward_linear,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward dgrad GEMM + scale."""
from __future__ import annotations
from typing import Optional
import torch
from ..basic import BasicLinear, ConstantScale
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data
class BackwardLinearScale(FusedOperation):
"""Fused backward dgrad GEMM + scale
Column tensor parallelism is not supported since that requires
communication immediately after the dgrad GEMM.
"""
def __init__(
self,
*,
scale: ConstantScale,
linear: BasicLinear,
) -> None:
super().__init__((linear, scale))
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
list[tuple[Optional[torch.Tensor], ...]],
list[tuple[()]],
]:
# Get basic operations
linear_op = self.basic_ops[0]
linear_op_ctx = basic_op_ctxs[1]
scale_op = self.basic_ops[1]
# Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = linear_op.weight.main_grad.detach()
else:
accumulate_into_main_grad = False
# Linear backward pass
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=w,
input_requires_grad=linear_op_ctx.input_requires_grad,
grad_input_alpha=scale_op.scale,
weight_requires_grad=linear_op_ctx.weight_requires_grad,
grad_weight_alpha=scale_op.scale,
dtype=linear_op_ctx.dtype,
grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
tensor_parallel_group=linear_op.tensor_parallel_group,
sequence_parallel=linear_op.sequence_parallel,
with_quantized_compute=linear_op_ctx.with_quantized_compute,
input_quantizer=linear_op_ctx.input_quantizer,
weight_quantizer=linear_op_ctx.weight_quantizer,
grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
)
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible
clear_tensor_data(x_local)
return grad_input, [(), (grad_weight,)], [(), ()]
def fuse_backward_linear_scale(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dgrad GEMM + constant scale
Parameters
----------
ops: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
out.extend(window)
# Check if first op is constant scale
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, ConstantScale):
continue
# Check if second op is linear
op, _ = ops[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "column":
# Column tensor-parallelism requires communication after the dgrad GEMM
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardLinearScale(
scale=window[0][0],
linear=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused operation for forward GEMM + scale + add."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import torch
from ...fp8 import FP8GlobalStateManager
from ..basic import AddExtraInput, BasicLinear, ConstantScale
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...tensor import Quantizer
class ForwardLinearScaleAdd(FusedOperation):
"""Fused forward GEMM + scale + add
Row tensor parallelism is not supported since that requires
communication immediately after the GEMM.
"""
def __init__(
self,
*,
linear: BasicLinear,
scale: ConstantScale,
add: AddExtraInput,
) -> None:
super().__init__((linear, scale, add))
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
# Get basic operations
linear_op = self.basic_ops[0]
linear_op_ctx = basic_op_ctxs[0]
scale_op = self.basic_ops[1]
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
output_quantizer = None
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
# Get extra input tensor for add operation
extra_input = basic_op_extra_inputs[2][0]
# Get autocast dtype if needed
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = linear_op.weight.dtype
# Linear forward
output, x_local, w = BasicLinear._functional_forward(
input=input_,
weight=linear_op.weight,
alpha=scale_op.scale,
dtype=dtype,
out=extra_input,
accumulate_into_out=True,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
tensor_parallel_group=linear_op.tensor_parallel_group,
sequence_parallel=linear_op.sequence_parallel,
with_quantized_compute=with_quantized_compute,
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
)
# Save state for backward pass
if linear_op_ctx.requires_grad:
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
return output, [() for _ in range(len(self.basic_ops))]
def fuse_forward_linear_scale_add(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse forward GEMM + scale + add
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 3:
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear = op
op, _ = ops[0]
# Check if next op is constant scale
if not isinstance(op, ConstantScale):
continue
scale = op
window.extend(ops[:1])
ops = ops[1:]
op, _ = ops[0]
# Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue
add = op
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearScaleAdd(
linear=linear,
scale=scale,
add=add,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
......@@ -20,8 +20,10 @@ from transformer_engine.pytorch.ops.op import (
from transformer_engine.pytorch.ops.fused import (
fuse_backward_activation_bias,
fuse_backward_linear_add,
fuse_backward_linear_scale,
fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add,
fuse_forward_linear_scale_add,
fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear,
)
......@@ -355,6 +357,7 @@ class OperationFuser:
ops = fuse_userbuffers_forward_linear(ops)
ops = fuse_forward_linear_bias_add(ops)
ops = fuse_forward_linear_bias_activation(ops)
ops = fuse_forward_linear_scale_add(ops)
return ops
@classmethod
......@@ -366,6 +369,7 @@ class OperationFuser:
"""Attempt to fuse operations in backward pass"""
ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops)
ops = fuse_backward_linear_scale(ops)
ops = fuse_backward_activation_bias(ops, recipe)
return ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment