Unverified Commit 3774aa37 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add ops for MoE grouped MLP (#2664)



* Add ops for MoE grouped MLP
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move testing utility functions to util submodule
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak docs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Change order of tensor compatibility checks in noop_cat

Review suggestion from @ptrendx.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add support for GLU interleaving in clamped SwiGLU
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 93d51c82
...@@ -5,8 +5,10 @@ ...@@ -5,8 +5,10 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import functools
import io import io
import math import math
import random
from typing import Optional from typing import Optional
import pytest import pytest
...@@ -36,7 +38,14 @@ from transformer_engine.pytorch import ( ...@@ -36,7 +38,14 @@ from transformer_engine.pytorch import (
import transformer_engine_torch as tex import transformer_engine_torch as tex
# Import utility functions # Import utility functions
from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states from utils import (
assert_close,
assert_close_grads,
dtype_tols,
make_recipe,
quantization_tols,
reset_rng_states,
)
# Check for supported quantization schemes # Check for supported quantization schemes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
...@@ -107,6 +116,9 @@ def maybe_skip_quantization( ...@@ -107,6 +116,9 @@ def maybe_skip_quantization(
@torch.no_grad() @torch.no_grad()
def make_reference_and_test_tensors( def make_reference_and_test_tensors(
shape: int | Iterable[int], shape: int | Iterable[int],
*,
min: float = 0.0,
max: float = 1.0,
quantization: Optional[str] = None, quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64, ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu", ref_device: torch.device = "cpu",
...@@ -127,7 +139,8 @@ def make_reference_and_test_tensors( ...@@ -127,7 +139,8 @@ def make_reference_and_test_tensors(
""" """
# Random reference tensor # Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.empty(shape, dtype=ref_dtype, device=ref_device)
ref.uniform_(min, max)
# Construct test tensor from reference tensor # Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype) test = ref.to(device=test_device, dtype=test_dtype)
...@@ -1680,6 +1693,7 @@ class TestBasicOps: ...@@ -1680,6 +1693,7 @@ class TestBasicOps:
quantization: Optional[str], quantization: Optional[str],
quantize_forward: bool, quantize_forward: bool,
quantize_backward: bool, quantize_backward: bool,
glu_interleave_size: Optional[int] = None,
): ):
# Tensor dimensions # Tensor dimensions
...@@ -1706,7 +1720,17 @@ class TestBasicOps: ...@@ -1706,7 +1720,17 @@ class TestBasicOps:
) )
# Plain PyTorch implementation # Plain PyTorch implementation
x1, x2 = x_ref.chunk(2, dim=-1) x = x_ref
if glu_interleave_size is not None:
x = x.reshape(
*in_shape[:-1],
in_shape[-1] // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(-3, -2)
x = x.reshape(in_shape)
x1, x2 = x.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2 y_ref = torch.nn.functional.silu(x1) * x2
y_ref.backward(dy_ref) y_ref.backward(dy_ref)
...@@ -1714,7 +1738,7 @@ class TestBasicOps: ...@@ -1714,7 +1738,7 @@ class TestBasicOps:
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
forward = te_ops.Sequential( forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantize_backward), te_ops.Quantize(forward=False, backward=quantize_backward),
te_ops.SwiGLU(), te_ops.SwiGLU(glu_interleave_size=glu_interleave_size),
te_ops.Quantize(forward=quantize_forward, backward=False), te_ops.Quantize(forward=quantize_forward, backward=False),
) )
with te.autocast(enabled=quantized_compute, recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
...@@ -1727,10 +1751,19 @@ class TestBasicOps: ...@@ -1727,10 +1751,19 @@ class TestBasicOps:
tols = quantization_tols(quantization) tols = quantization_tols(quantization)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") assert_close(y_test, y_ref, **tols)
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") assert_close_grads(x_test, x_ref, **tols)
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols) def test_interleaved_swiglu(self):
"""SwiGLU with block interleaved input format"""
self.test_swiglu(
out_shape=(32, 192),
dtype=torch.float32,
quantization=None,
quantize_forward=False,
quantize_backward=False,
glu_interleave_size=32,
)
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantization", _quantization_list)
...@@ -1740,6 +1773,7 @@ class TestBasicOps: ...@@ -1740,6 +1773,7 @@ class TestBasicOps:
self, self,
*, *,
out_shape: Iterable[int] = (32, 32), out_shape: Iterable[int] = (32, 32),
glu_interleave_size: Optional[int] = None,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device = "cuda", device: torch.device = "cuda",
quantization: Optional[str], quantization: Optional[str],
...@@ -1748,7 +1782,7 @@ class TestBasicOps: ...@@ -1748,7 +1782,7 @@ class TestBasicOps:
limit: float = 0.75, limit: float = 0.75,
alpha: float = 1.702, alpha: float = 1.702,
): ):
# Test SwiGLU variant used in GPT OSS. """SwiGLU variant used in GPT-OSS"""
# Tensor dimensions # Tensor dimensions
in_shape = list(out_shape) in_shape = list(out_shape)
in_shape[-1] *= 2 in_shape[-1] *= 2
...@@ -1773,7 +1807,17 @@ class TestBasicOps: ...@@ -1773,7 +1807,17 @@ class TestBasicOps:
) )
# Plain PyTorch implementation # Plain PyTorch implementation
x_glu, x_linear = x_ref.chunk(2, dim=-1) x = x_ref
if glu_interleave_size is not None:
x = x.reshape(
*in_shape[:-1],
in_shape[-1] // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(-3, -2)
x = x.reshape(in_shape)
x_glu, x_linear = x.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit) x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit) x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu) out_glu = x_glu * torch.sigmoid(alpha * x_glu)
...@@ -1785,7 +1829,11 @@ class TestBasicOps: ...@@ -1785,7 +1829,11 @@ class TestBasicOps:
forward = te_ops.Sequential( forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantize_backward), te_ops.Quantize(forward=False, backward=quantize_backward),
te_ops.ClampedSwiGLU(limit=limit, alpha=alpha), te_ops.ClampedSwiGLU(
limit=limit,
alpha=alpha,
glu_interleave_size=glu_interleave_size,
),
te_ops.Quantize(forward=quantize_forward, backward=False), te_ops.Quantize(forward=quantize_forward, backward=False),
) )
with te.autocast(enabled=quantized_compute, recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
...@@ -1801,10 +1849,19 @@ class TestBasicOps: ...@@ -1801,10 +1849,19 @@ class TestBasicOps:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results # Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") assert_close(y_test, y_ref, **tols)
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") assert_close_grads(x_test, x_ref, **tols)
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols) def test_interleaved_clamped_swiglu(self):
"""GPT-OSS SwiGLU with block interleaved input format"""
self.test_clamped_swiglu(
out_shape=(32, 192),
dtype=torch.float32,
quantization=None,
quantize_forward=False,
quantize_backward=False,
glu_interleave_size=32,
)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2))) @pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2)))
...@@ -1924,6 +1981,231 @@ class TestBasicOps: ...@@ -1924,6 +1981,231 @@ class TestBasicOps:
abs(z_score) < 2.5758 abs(z_score) < 2.5758
), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})" ), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})"
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("weight_requires_grad", (False, True))
def test_grouped_linear(
self,
*,
group_size: int = 4,
bias: bool,
weight_shape: tuple[int, int] = (128, 128),
split_alignment: int = 128,
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_compute: bool,
quantized_weight: bool,
input_requires_grad: bool,
weight_requires_grad: bool,
) -> None:
"""Grouped GEMM"""
# Split sizes
split_sizes = [split_alignment * i for i in range(group_size)]
random.shuffle(split_sizes)
split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device)
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = (split_sizes.sum().item(), in_features)
out_shape = (in_shape[0], out_features)
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not used")
if quantization is not None and dtype not in (torch.bfloat16, torch.float16):
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=input_requires_grad,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
ws_ref, ws_test = [], []
bs_ref, bs_test = [], []
for _ in range(group_size):
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=weight_requires_grad,
)
b_ref, b_test = None, None
if bias:
b_ref, b_test = make_reference_and_test_tensors(
out_features,
test_dtype=dtype,
test_device=device,
requires_grad=weight_requires_grad,
)
ws_ref.append(w_ref)
ws_test.append(w_test)
bs_ref.append(b_ref)
bs_test.append(b_test)
# Plain PyTorch implementation
xs_ref = torch.split(x_ref, split_sizes.tolist())
ys_ref = []
for x, w, b in zip(xs_ref, ws_ref, bs_ref):
ys_ref.append(torch.nn.functional.linear(x, w, bias=b))
y_ref = torch.cat(ys_ref)
if input_requires_grad or weight_requires_grad:
y_ref.backward(dy_ref)
# Construct fusible operation
recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.GroupedLinear(
group_size,
in_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
)
with torch.no_grad():
for group_idx in range(group_size):
getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx])
if bias:
getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx])
del ws_test, bs_test
for param in op.parameters():
param.requires_grad_(requires_grad=weight_requires_grad)
# Forward and backward pass with op
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = op(x_test, split_sizes)
if input_requires_grad or weight_requires_grad:
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
if input_requires_grad:
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
else:
assert x_test.grad is None
for group_idx in range(group_size):
w_test = getattr(op, f"weight{group_idx}")
if weight_requires_grad:
dw_test = w_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dw_test, ws_ref[group_idx].grad, **tols)
else:
assert w_test.grad is None
if bias:
b_test = getattr(op, f"bias{group_idx}")
if weight_requires_grad:
db_test = b_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, bs_ref[group_idx].grad, **tols)
else:
assert b_test.grad is None
@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("scales_requires_grad", (False, True))
def test_scaled_swiglu(
self,
*,
in_shape: Iterable[int],
glu_interleave_size: Optional[int] = None,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
input_requires_grad: bool,
scales_requires_grad: bool,
) -> None:
"""SwiGLU with post-scale"""
# Tensor dims
out_shape = list(in_shape)
out_shape[-1] //= 2
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=input_requires_grad,
)
scales_ref, scales_test = make_reference_and_test_tensors(
in_shape[:-1],
test_dtype=dtype,
test_device=device,
requires_grad=scales_requires_grad,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
x = x_ref
if glu_interleave_size is not None:
x = x.reshape(
-1,
in_shape[-1] // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(1, 2)
x = x.reshape(in_shape)
x1, x2 = x.chunk(2, dim=-1)
y = torch.nn.functional.silu(x1) * x2
y_ref = scales_ref.unsqueeze(-1) * y
if input_requires_grad or scales_requires_grad:
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
y_test = op(x_test, scales_test)
if input_requires_grad or scales_requires_grad:
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(scales_test, scales_ref, **tols)
def test_interleaved_scaled_swiglu(self):
"""SwiGLU with post-scale and block interleaved input format"""
self.test_scaled_swiglu(
in_shape=(32, 192),
glu_interleave_size=32,
input_requires_grad=True,
scales_requires_grad=True,
)
class TestFusedOps: class TestFusedOps:
"""Tests for fused operations""" """Tests for fused operations"""
...@@ -2931,6 +3213,188 @@ class TestSequentialModules: ...@@ -2931,6 +3213,188 @@ class TestSequentialModules:
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols)
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("glu_interleave_size", (None, 32))
def test_grouped_mlp(
self,
*,
group_size: int = 4,
bias: bool,
hidden_size: int = 256,
dtype: torch.dtype,
quantization: Optional[str],
device: torch.device = "cuda",
split_alignment: int = 256,
glu_interleave_size: Optional[int],
) -> None:
"""GroupedLinear + ScaledSwiGLU + GroupedLinear"""
# Split sizes
split_sizes = [split_alignment * i for i in range(group_size)]
random.shuffle(split_sizes)
split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device)
# Make input shape
in_shape = (split_sizes.sum().item(), hidden_size)
out_shape = in_shape
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
if with_quantization and dtype not in (torch.bfloat16, torch.float16):
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
min=-0.25,
max=0.25,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
min=-0.25,
max=0.25,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
probs_ref, probs_test = make_reference_and_test_tensors(
(in_shape[0],),
test_dtype=dtype,
test_device=device,
)
fc1_ws_ref, fc1_ws_test = [], []
fc1_bs_ref, fc1_bs_test = [], []
fc2_ws_ref, fc2_ws_test = [], []
fc2_bs_ref, fc2_bs_test = [], []
for _ in range(group_size):
fc1_w_ref, fc1_w_test = make_reference_and_test_tensors(
(2 * hidden_size, hidden_size),
min=-0.25,
max=0.25,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
fc2_w_ref, fc2_w_test = make_reference_and_test_tensors(
(hidden_size, hidden_size),
min=-0.25,
max=0.25,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
fc1_b_ref, fc1_b_test = None, None
fc2_b_ref, fc2_b_test = None, None
if bias:
fc1_b_ref, fc1_b_test = make_reference_and_test_tensors(
(2 * hidden_size,),
min=-0.5,
max=0.5,
test_dtype=dtype,
test_device=device,
)
fc2_b_ref, fc2_b_test = make_reference_and_test_tensors(
(hidden_size,),
min=-0.5,
max=0.5,
test_dtype=dtype,
test_device=device,
)
fc1_ws_ref.append(fc1_w_ref)
fc1_bs_ref.append(fc1_b_ref)
fc1_ws_test.append(fc1_w_test)
fc1_bs_test.append(fc1_b_test)
fc2_ws_ref.append(fc2_w_ref)
fc2_bs_ref.append(fc2_b_ref)
fc2_ws_test.append(fc2_w_test)
fc2_bs_test.append(fc2_b_test)
# Reference implementation
xs = torch.split(x_ref, split_sizes.tolist())
probs = torch.split(probs_ref, split_sizes.tolist())
ys = []
for group_idx in range(group_size):
x = xs[group_idx]
x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx])
if glu_interleave_size is not None:
x = x.reshape(
-1,
2 * hidden_size // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(1, 2)
x = x.reshape(-1, 2 * hidden_size)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
x = x * probs[group_idx].unsqueeze(-1)
x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx], bias=fc2_bs_ref[group_idx])
ys.append(x)
y_ref = torch.cat(ys)
y_ref.backward(dy_ref)
# Construct operations
recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=with_quantization, recipe=recipe):
fc1 = te_ops.GroupedLinear(
group_size,
hidden_size,
2 * hidden_size,
bias=bias,
device=device,
dtype=dtype,
)
fc2 = te_ops.GroupedLinear(
group_size,
hidden_size,
hidden_size,
bias=bias,
device=device,
dtype=dtype,
)
module = te_ops.Sequential(
fc1,
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size),
fc2,
)
# Copy weights
with torch.no_grad():
for group_idx in range(group_size):
getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx])
getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx])
if bias:
getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx])
getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx])
del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test
# Fuse ops and perform forward and backward pass
with te.autocast(enabled=with_quantization, recipe=recipe):
y_test = module(x_test, split_sizes, probs_test, split_sizes)
y_test.backward(dy_test)
# Loose tols for sanity checking
tols = {"rtol": 0.125, "atol": 0.25}
if quantization == "nvfp4":
tols = {"rtol": 0.25, "atol": 0.5}
# Check values
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(probs_test, probs_ref, **tols)
for group_idx in range(group_size):
assert_close_grads(getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols)
assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols)
assert_close_grads(getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols)
assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols)
class TestCustomOps: class TestCustomOps:
"""Test with ops that are defined externally""" """Test with ops that are defined externally"""
......
...@@ -15,7 +15,7 @@ import torch ...@@ -15,7 +15,7 @@ import torch
import transformer_engine import transformer_engine
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import InferenceParams from transformer_engine.pytorch import InferenceParams, QuantizedTensor
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.attention.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend, get_attention_backend,
...@@ -361,3 +361,48 @@ def get_available_attention_backends( ...@@ -361,3 +361,48 @@ def get_available_attention_backends(
if fused_attention_backend == FusedAttnBackend[backends[i]]: if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend) fused_attn_backends.append(fused_attention_backend)
return available_backends, flash_attention_backend, fused_attn_backends return available_backends, flash_attention_backend, fused_attn_backends
@torch.no_grad
def assert_close(
actual: Optional[torch.Tensor],
expected: Optional[torch.Tensor],
*,
check_device: bool = False,
check_dtype: bool = False,
check_layout: bool = False,
**kwargs,
) -> None:
"""Assert that two tensors are close.
This function is a wrapper around torch.testing.assert_close. It
changes the defaults for device and dtype checks (useful when the
reference implementation is computed in high precision on CPU) and
it can handle quantized tensors.
"""
if isinstance(actual, QuantizedTensor):
actual = actual.dequantize()
if isinstance(expected, QuantizedTensor):
expected = expected.dequantize()
torch.testing.assert_close(
actual,
expected,
check_device=check_device,
check_dtype=check_dtype,
check_layout=check_layout,
**kwargs,
)
def assert_close_grads(
actual: Optional[torch.Tensor],
expected: Optional[torch.Tensor],
**kwargs,
) -> None:
"""Assert that two tensors have close gradients."""
if actual is None and expected is None:
return
assert actual is not None
assert expected is not None
assert_close(actual.grad, expected.grad, **kwargs)
...@@ -77,6 +77,8 @@ class _NoopCatFunc(torch.autograd.Function): ...@@ -77,6 +77,8 @@ class _NoopCatFunc(torch.autograd.Function):
# Check first tensor # Check first tensor
if not tensors: if not tensors:
raise ValueError("Attempted to concatenate 0 tensors") raise ValueError("Attempted to concatenate 0 tensors")
# Check concat dim
num_dims = tensors[0].dim() num_dims = tensors[0].dim()
if not -num_dims <= dim < num_dims: if not -num_dims <= dim < num_dims:
raise ValueError( raise ValueError(
...@@ -109,11 +111,24 @@ class _NoopCatFunc(torch.autograd.Function): ...@@ -109,11 +111,24 @@ class _NoopCatFunc(torch.autograd.Function):
ctx.dim = dim ctx.dim = dim
ctx.split_ranges = split_ranges ctx.split_ranges = split_ranges
# Out-of-place concatenation if needed # Tensor properties from first tensor
dtype = tensors[0].dtype dtype = tensors[0].dtype
device = tensors[0].device device = tensors[0].device
strides = tensors[0].stride() strides = tensors[0].stride()
data_ptr_stride = strides[dim] * tensors[0].element_size() data_ptr_stride = strides[dim] * tensors[0].element_size()
# Out-of-place concatenation when view tensors have different storage
# Note: This works around an edge case with the split_quantize
# function, which might allocate a buffer and construct
# subviews. However, in order to reduce CPU overheads, these
# views are configured manually outside of PyTorch. PyTorch
# doesn't know these views share the same memory, and it
# blocks us from reconstructing the full tensor because it
# thinks we are accessing out-of-bounds memory.
if tensors[0].untyped_storage().nbytes() < out_shape[dim] * data_ptr_stride:
return torch.cat(tensors, dim=dim)
# Out-of-place concatenation if tensor properties do not match
data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * data_ptr_stride data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * data_ptr_stride
for tensor in tensors[1:]: for tensor in tensors[1:]:
if ( if (
...@@ -126,13 +141,7 @@ class _NoopCatFunc(torch.autograd.Function): ...@@ -126,13 +141,7 @@ class _NoopCatFunc(torch.autograd.Function):
data_ptr += tensor.size(dim) * data_ptr_stride data_ptr += tensor.size(dim) * data_ptr_stride
# No-op concatenation # No-op concatenation
out = tensors[0].new() out = tensors[0].as_strided(out_shape, strides)
out.set_(
tensors[0].untyped_storage(),
tensors[0].storage_offset(),
out_shape,
strides,
)
out.requires_grad = any(tensor.requires_grad for tensor in tensors) out.requires_grad = any(tensor.requires_grad for tensor in tensors)
return out return out
......
...@@ -14,8 +14,6 @@ from .activation import ( ...@@ -14,8 +14,6 @@ from .activation import (
SReLU, SReLU,
SReGLU, SReGLU,
SiLU, SiLU,
SwiGLU,
ClampedSwiGLU,
) )
from .add_extra_input import AddExtraInput from .add_extra_input import AddExtraInput
from .all_gather import AllGather from .all_gather import AllGather
...@@ -24,6 +22,7 @@ from .basic_linear import BasicLinear ...@@ -24,6 +22,7 @@ from .basic_linear import BasicLinear
from .bias import Bias from .bias import Bias
from .constant_scale import ConstantScale from .constant_scale import ConstantScale
from .dropout import Dropout from .dropout import Dropout
from .grouped_linear import GroupedLinear
from .identity import Identity from .identity import Identity
from .l2normalization import L2Normalization from .l2normalization import L2Normalization
from .layer_norm import LayerNorm from .layer_norm import LayerNorm
...@@ -32,3 +31,4 @@ from .quantize import Quantize ...@@ -32,3 +31,4 @@ from .quantize import Quantize
from .reduce_scatter import ReduceScatter from .reduce_scatter import ReduceScatter
from .reshape import Reshape from .reshape import Reshape
from .rmsnorm import RMSNorm from .rmsnorm import RMSNorm
from .swiglu import ClampedSwiGLU, ScaledSwiGLU, SwiGLU
...@@ -27,8 +27,6 @@ __all__ = [ ...@@ -27,8 +27,6 @@ __all__ = [
"SReLU", "SReLU",
"SReGLU", "SReGLU",
"SiLU", "SiLU",
"SwiGLU",
"ClampedSwiGLU",
] ]
...@@ -355,76 +353,3 @@ class SiLU(_ActivationOperation): ...@@ -355,76 +353,3 @@ class SiLU(_ActivationOperation):
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dsilu(*args, **kwargs) return tex.dsilu(*args, **kwargs)
class SwiGLU(_ActivationOperation):
r"""Swish gated linear unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{GEGLU}(a,b) = \text{SiLU}(a) * b
where
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
The Sigmoid Linear Unit (SiLU) gating function is also known as
the swish function. See
`GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__
and `Gaussian Error Linear Units (GELUs)<https://arxiv.org/abs/1606.08415>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.swiglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dswiglu(*args, **kwargs)
class ClampedSwiGLU(_ActivationOperation):
r"""GPT-OSS
Implementation based on `GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>`__.
This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.
.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.
Parameters
----------
limit : float
The clamp limit.
alpha : float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input : bool, default = False
Quantize input tensor when caching for use in the backward pass.
"""
def __init__(
self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False
):
super().__init__(cache_quantized_input=cache_quantized_input)
self.limit = limit
self.alpha = alpha
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for grouped linear layer."""
from __future__ import annotations
from collections.abc import Callable, Iterable, Sequence
import contextlib
import math
from typing import Any, Optional
import torch
import transformer_engine_torch as tex
from ...cpp_extensions import general_grouped_gemm
from ...distributed import CudaRNGStatesTracker
from ...module.base import (
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
get_dummy_wgrad,
)
from ...quantization import FP8GlobalStateManager, Recipe
from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer
from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
round_up_to_nearest_multiple,
)
from .._common import is_quantized_tensor, maybe_dequantize
from ..op import BasicOperation, OperationContext
class GroupedLinear(BasicOperation):
r"""Apply multiple linear transformations: :math:``y_i = x_i W_i^T + b_i``
This feature is experimental and subject to change.
This is equivalent to splitting the input tensor along its first
dimension, applying a separate ``torch.nn.Linear`` to each split,
and concatenating along the first dimension.
Parameters
----------
num_groups : int
Number of linear transformations.
in_features : int
Inner dimension of input tensor.
out_features : int
Inner dimension of output tensor.
bias : bool, default = ``True``
Apply additive bias.
device : torch.device, default = default CUDA device
Tensor device.
dtype : torch.dtype, default = default dtype
Tensor datatype.
rng_state_tracker_function : callable
Function that returns ``CudaRNGStatesTracker``, which is used
for model-parallel weight initialization.
accumulate_into_main_grad : bool, default = ``False``
Whether to directly accumulate weight gradients into the
weight's ``main_grad`` attribute instead of relying on PyTorch
autograd. The weight's ``main_grad`` must be set externally
and there is no guarantee that `grad` will be set or be
meaningful. This is primarily intented to integrate with
Megatron-LM. This argument along with weight tensor having
attribute ``overwrite_main_grad`` set to True will overwrite
``main_grad`` instead of accumulating.
"""
# Operation expects input split sizes
num_extra_inputs: int = 1
def __init__(
self,
num_groups: int,
in_features: int,
out_features: int,
*,
bias: bool = True,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None,
accumulate_into_main_grad: bool = False,
) -> None:
super().__init__()
# Weight tensor dimensions
self.num_groups: int = num_groups
self.in_features: int = in_features
self.out_features: int = out_features
if self.num_groups <= 0:
raise ValueError(f"Invalid number of groups ({self.num_groups})")
if self.in_features <= 0:
raise ValueError(f"Invalid input size ({self.in_features})")
if self.out_features <= 0:
raise ValueError(f"Invalid output size ({self.out_features})")
# Weight tensor attributes
device = canonicalize_device(device)
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Initialize recipe state if needed for natively quantized weight
self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters()
if self._with_quantized_weight:
self.reset_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe())
# RNG state tracker
self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]]
self._rng_state_tracker_function = rng_state_tracker_function
# Register weights
self.weight0: torch.nn.Parameter
for group_idx in range(self.num_groups):
weight_tensor = torch.empty(
self.out_features,
self.in_features,
device="meta",
dtype=dtype,
)
self.register_parameter(
f"weight{group_idx}",
torch.nn.Parameter(weight_tensor),
)
# Register biases
self.bias0: Optional[torch.nn.Parameter]
for group_idx in range(self.num_groups):
bias_tensor = None
if bias:
bias_tensor = torch.empty(
self.out_features,
device="meta",
dtype=dtype,
)
bias_tensor = torch.nn.Parameter(bias_tensor)
self.register_parameter(f"bias{group_idx}", bias_tensor)
# Initialize weights if needed
if device.type != "meta":
self.reset_parameters()
# Whether to accumulate weight gradient into main_grad
self._accumulate_into_main_grad: bool = accumulate_into_main_grad
def num_quantizers(self, mode: str) -> int:
if mode == "forward":
return 2 * self.num_groups
if mode == "backward":
return self.num_groups
return 0
@property
def has_bias(self) -> bool:
"""Whether an additive bias is being applied"""
return self.bias0 is not None
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""
# Parameter device
device = self.weight0.device
if device.type == "meta":
device = canonicalize_device(None)
# Initialize weight values
# Note: Allocate a single buffer in order to support grouped
# GEMM kernels that expect a single weight buffer.
packed_weights = torch.empty(
self.num_groups,
self.out_features,
self.in_features,
dtype=self.weight0.dtype,
device=device,
)
weights = [packed_weights[idx] for idx in range(self.num_groups)]
for weight in weights:
init_context = contextlib.nullcontext()
if self._rng_state_tracker_function is not None:
init_context = self._rng_state_tracker_function().fork()
with init_context:
torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
# Quantize weights if needed
if self._with_quantized_weight:
# Configure quantizers
quantizers = [
self.get_quantizer("forward", 2 * idx + 1) for idx in range(self.num_groups)
]
with_rowwise_usage = True
with_columnwise_usage = torch.is_grad_enabled()
for quantizer in quantizers:
if quantizer is None:
raise RuntimeError(
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"This is most likely because the weight was initialized "
"within quantized_model_init, but the forward pass was not "
"performed within autocast."
)
quantizer.set_usage(
rowwise=with_rowwise_usage,
columnwise=with_columnwise_usage,
)
quantizer.internal = False
# Quantize weights
weights = self._quantize_weights(weights, quantizers)
# Register weights
for group_idx, weight in enumerate(weights):
if not isinstance(weight, torch.nn.Parameter):
weight = torch.nn.Parameter(weight)
setattr(self, f"weight{group_idx}", weight)
# Initialize biases if needed
if self.bias0 is not None:
packed_biases = torch.zeros(
self.num_groups,
self.out_features,
dtype=self.bias0.dtype,
device=device,
)
for group_idx in range(self.num_groups):
bias = torch.nn.Parameter(packed_biases[group_idx])
setattr(self, f"bias{group_idx}", bias)
def _quantize_weights(
self,
weights: Sequence[torch.Tensor],
quantizers: Sequence[Quantizer],
) -> Sequence[torch.Tensor]:
"""Construct quantized weight tensors."""
# Manually construct MXFP8 weights
if isinstance(quantizers[0], MXFP8Quantizer):
return self._quantize_weights_mxfp8(weights, quantizers)
# Use quantizers to construct quantized weights
with torch.no_grad():
return [quantizer(weight) for quantizer, weight in zip(quantizers, weights)]
def _quantize_weights_mxfp8(
self,
weights: Sequence[torch.Tensor],
quantizers: Sequence[Quantizer],
) -> Sequence[MXFP8Tensor]:
"""Construct MXFP8 weight tensors.
Instead of allocating separate buffers for each weight tensor,
this function constructs large buffers and assigns subviews to
each tensor. This is intended to support grouped GEMM kernels
that expect packed buffers.
"""
# Tensor dimensions
num_groups = len(weights)
out_features, in_features = weights[0].size()
packed_shape = (num_groups, out_features, in_features)
unpacked_shape = (out_features, in_features)
# Tensor attributes
device = weights[0].device
dtype = weights[0].dtype
requires_grad = torch.is_grad_enabled()
with_rowwise_usage = quantizers[0].rowwise_usage
with_columnwise_usage = quantizers[0].columnwise_usage
# Construct packed buffers
rowwise_data = [None] * num_groups
rowwise_scales = [None] * num_groups
columnwise_data = [None] * num_groups
columnwise_scales = [None] * num_groups
if with_rowwise_usage:
scale_shape = (
num_groups,
round_up_to_nearest_multiple(out_features, 128),
round_up_to_nearest_multiple(in_features // 32, 4),
)
packed_data = torch.empty(packed_shape, dtype=torch.uint8, device=device)
packed_scales = torch.empty(scale_shape, dtype=torch.uint8, device=device)
rowwise_data = [packed_data[idx] for idx in range(num_groups)]
rowwise_scales = [packed_scales[idx] for idx in range(num_groups)]
if with_columnwise_usage:
scale_shape = (
num_groups,
round_up_to_nearest_multiple(out_features // 32, 4),
round_up_to_nearest_multiple(in_features, 128),
)
packed_data = torch.empty(packed_shape, dtype=torch.uint8, device=device)
packed_scales = torch.empty(scale_shape, dtype=torch.uint8, device=device)
columnwise_data = [packed_data[idx] for idx in range(num_groups)]
columnwise_scales = [packed_scales[idx] for idx in range(num_groups)]
# Construct MXFP8 tensors and cast to MXFP8
out = []
with torch.no_grad():
for group_idx in range(num_groups):
weight = MXFP8Tensor(
shape=unpacked_shape,
dtype=dtype,
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise_data=rowwise_data[group_idx],
rowwise_scale_inv=rowwise_scales[group_idx],
columnwise_data=columnwise_data[group_idx],
columnwise_scale_inv=columnwise_scales[group_idx],
quantizer=quantizers[group_idx],
requires_grad=requires_grad,
with_gemm_swizzled_scales=False,
)
weight.copy_(weights[group_idx])
out.append(weight)
return out
def pre_first_fuser_forward(self) -> None:
super().pre_first_fuser_forward()
# Initialize params if needed
if any(param.device.type == "meta" for param in self.parameters()):
self.reset_parameters()
# Check that weights are consistent
dtype = self.weight0.dtype
device = self.weight0.device
weight_requires_grad = self.weight0.requires_grad
weight_tensor_type = type(self.weight0.data)
for group_idx in range(self.num_groups):
weight = getattr(self, f"weight{group_idx}")
if weight.dtype != dtype:
raise RuntimeError(
f"Weight {group_idx} has invalid dtype (expected {dtype}, got {weight.dtype})."
)
if not devices_match(weight.device, device):
raise RuntimeError(
f"Weight {group_idx} has invalid device "
f"(expected {device}, got {weight.device})."
)
if weight.requires_grad != weight_requires_grad:
raise RuntimeError(
f"Weight {group_idx} has requires_grad={weight.requires_grad}, "
f"but expected requires_grad={weight_requires_grad}."
)
if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck
raise RuntimeError(
f"Weight {group_idx} has invalid tensor type "
f"(expected {weight_tensor_type.__name__}, "
f"got {type(weight.data).__name__})."
)
# Check that biases are consistent
for group_idx in range(self.num_groups):
bias = getattr(self, f"bias{group_idx}")
if self.has_bias:
if bias is None:
raise RuntimeError(f"Expected biases, but bias {group_idx} is uninitialized")
if bias.dtype != dtype:
raise RuntimeError(
f"Bias {group_idx} has invalid dtype (expected {dtype}, got {bias.dtype})."
)
if not devices_match(bias.device, device):
raise RuntimeError(
f"Bias {group_idx} has invalid device "
f"(expected {device}, got {bias.device})."
)
if bias.requires_grad != weight_requires_grad:
raise RuntimeError(
f"Bias {group_idx} has requires_grad={bias.requires_grad}, "
f"but expected requires_grad={weight_requires_grad}."
)
else:
if bias is not None:
raise RuntimeError(f"Expected no biases, but bias {group_idx} is initialized")
def pre_fuser_forward(self, *, requires_grad: bool) -> None:
super().pre_fuser_forward(requires_grad=requires_grad)
if FP8GlobalStateManager.is_fp8_enabled():
# Assume weights have consistent grad requirement
weight_requires_grad = requires_grad and self.weight0.requires_grad
# Configure quantizer usages
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
for group_idx in range(self.num_groups):
input_quantizer = self.get_quantizer("forward", 2 * group_idx)
weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1)
grad_output_quantizer = self.get_quantizer("backward", group_idx)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe)
for group_idx in range(self.num_groups):
# Input/grad output quantizers use internal tensors
input_quantizer = self.get_quantizer("forward", 2 * group_idx)
grad_output_quantizer = self.get_quantizer("backward", group_idx)
if input_quantizer is not None:
input_quantizer.internal = True
if grad_output_quantizer is not None:
grad_output_quantizer.internal = True
# Handle weight quantizer
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1)
if weight_quantizer is None:
pass
elif is_quantized_tensor(getattr(self, f"weight{group_idx}", None)):
# Make sure weight param has correct quantizer
weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
weight_quantizer.internal = False
getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy())
else:
# Use internal tensors if quantized weights will not be
# exposed externally
weight_quantizer.internal = (
not FP8GlobalStateManager.with_fp8_parameters()
and not getattr(self, "_with_quantized_weight", False)
)
# Recipe-specific configuration
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
if recipe is not None:
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon
grad_output_quantizer.force_pow_2_scales = (
recipe.fp8_quant_bwd_grad.power_2_scale
)
grad_output_quantizer.amax_epsilon_scales = (
recipe.fp8_quant_bwd_grad.amax_epsilon
)
def op_forward(self, *args, **kwargs):
raise RuntimeError(
f"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_forward` instead of `op_forward`."
)
def op_backward(self, *args, **kwargs):
raise RuntimeError(
f"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_backward` instead of `op_backward`."
)
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
num_groups = self.num_groups
has_bias = self.has_bias
device = self.weight0.device
# Check which grads are required
ctx = basic_op_ctxs[0]
input_requires_grad = ctx.requires_grad
weight_requires_grad = ctx.requires_grad and self.weight0.requires_grad
# Quantizers
input_quantizers = [None] * num_groups
weight_quantizers = [None] * num_groups
grad_output_quantizers = [None] * num_groups
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
for group_idx in range(num_groups):
input_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx)
weight_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx + 1)
grad_output_quantizers[group_idx] = self.get_quantizer("backward", group_idx)
# Get autocast dtype if needed
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = self.weight0.dtype
# Extract split sizes from extra input
split_sizes = basic_op_extra_inputs[0][0]
split_sizes_int = [int(s) for s in split_sizes.tolist()]
if len(split_sizes_int) != num_groups:
raise ValueError(f"Expected {num_groups} splits, but got {len(split_sizes_int)}.")
# Extract params
weights = [getattr(self, f"weight{idx}") for idx in range(num_groups)]
bs = None
if has_bias:
bs = [maybe_dequantize(getattr(self, f"bias{idx}"), dtype) for idx in range(num_groups)]
# Convert weight dtype if needed
ws = []
for w, quantizer in zip(weights, weight_quantizers):
if not with_quantized_compute:
w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not is_quantized_tensor(w):
quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = quantizer(w)
ws.append(w)
# Split input tensor and convert dtypes if needed
x = maybe_dequantize(input_, dtype)
xs = None
if with_quantized_compute:
for quantizer in input_quantizers:
quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
xs = tex.split_quantize(x, split_sizes_int, input_quantizers)
else:
xs = torch.split(x, split_sizes_int)
# Allocate output tensor
in_shape = list(input_.size())
out_shape = in_shape[:-1] + [self.out_features]
out = torch.empty(out_shape, dtype=dtype, device=device)
# Perform GEMMs
general_grouped_gemm(
ws,
xs,
[out],
[None] * num_groups, # quantization_params
dtype,
m_splits=split_sizes_int,
bias=bs,
use_bias=has_bias,
use_split_accumulator=_2X_ACC_FPROP,
single_output=True,
)
# Prepare weight tensors for backward pass
if not input_requires_grad:
ws = [None] * num_groups
elif with_quantized_compute:
for w, weight_param in zip(ws, weights):
if w is not weight_param:
w.update_usage(rowwise_usage=False, columnwise_usage=True)
# Prepare input tensor for backward pass
if not weight_requires_grad:
xs = [None] * num_groups
elif with_quantized_compute:
for x in xs:
x.update_usage(rowwise_usage=False, columnwise_usage=True)
# Save state for backward pass
if ctx.requires_grad:
ctx.save_for_backward(split_sizes, *xs, *ws)
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizers = input_quantizers
ctx.weight_quantizers = weight_quantizers
ctx.grad_output_quantizers = grad_output_quantizers
ctx.grad_input_quantizers = None
ctx.dtype = dtype
ctx.input_requires_grad = input_requires_grad
ctx.weight_requires_grad = weight_requires_grad
return out, [()]
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]],
]:
num_groups = self.num_groups
has_bias = self.has_bias
device = self.weight0.device
# Saved tensors from forward pass
ctx = basic_op_ctxs[0]
saved_tensors = ctx.saved_tensors
split_sizes, saved_tensors = saved_tensors[0], saved_tensors[1:]
xs, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:]
ws, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:]
# Split grad output tensor and convert dtypes if needed
split_sizes_int = [int(s) for s in split_sizes.tolist()]
dy = maybe_dequantize(grad_output, ctx.dtype)
dys = None
grad_biases = [None] * num_groups
if ctx.with_quantized_compute:
for quantizer in ctx.grad_output_quantizers:
quantizer.set_usage(
rowwise=ctx.input_requires_grad,
columnwise=ctx.weight_requires_grad,
)
dys = tex.split_quantize(dy, split_sizes_int, ctx.grad_output_quantizers)
if has_bias:
grad_biases = [
dy.reshape(-1, dy.size(-1)).sum(dim=0)
for dy in torch.split(grad_output, split_sizes_int)
]
else:
dys = torch.split(dy, split_sizes_int)
if has_bias:
grad_biases = [dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys]
# Initialize grad weight buffers
accumulate_into_main_grad = self._accumulate_into_main_grad
grad_weights = [None] * num_groups
if ctx.weight_requires_grad:
if accumulate_into_main_grad:
# Megatron-LM wgrad fusion
# Note: Get grad tensors from params so we can
# accumulate directly into it.
for group_idx in range(num_groups):
weight_param = getattr(self, f"weight{group_idx}")
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
grad_weights[group_idx] = weight_param.main_grad
accumulate_into_main_grad = not getattr(self.weight0, "overwrite_main_grad", False)
else:
weight_shape = ws[0].size()
for group_idx in range(num_groups):
grad_weights[group_idx] = torch.empty(
weight_shape,
dtype=ctx.dtype,
device=device,
)
else:
accumulate_into_main_grad = False
# Perform dgrad GEMMs
grad_input = None
if ctx.input_requires_grad:
out_shape = list(grad_output.size())
in_shape = out_shape[:-1] + [self.in_features]
grad_input = torch.empty(
in_shape,
dtype=ctx.dtype,
device=device,
)
general_grouped_gemm(
ws,
dys,
[grad_input],
[None] * num_groups, # quantization_params
ctx.dtype,
layout="NN",
m_splits=split_sizes_int,
use_split_accumulator=_2X_ACC_DGRAD,
single_output=True,
)
# Perform wgrad GEMMs
if ctx.weight_requires_grad:
general_grouped_gemm(
xs,
dys,
grad_weights,
[None] * num_groups, # quantization_params
ctx.dtype,
layout="NT",
m_splits=split_sizes_int,
use_split_accumulator=_2X_ACC_WGRAD,
accumulate=accumulate_into_main_grad,
)
# Clear input tensors if possible
clear_tensor_data(*xs)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weights = [None] * num_groups
for group_idx in range(num_groups):
weight_param = getattr(self, f"weight{group_idx}")
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weights[group_idx] = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
grad_params = grad_weights + grad_biases if has_bias else grad_weights
return grad_input, [grad_params], [(None,)]
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for SwiGLU and variants."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import torch
import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize
__all__ = ["SwiGLU", "ClampedSwiGLU", "ScaledSwiGLU"]
class SwiGLU(BasicOperation):
r"""Swish gated linear unit
The input tensor is split into chunks :math:``a`` and :math:``b``
along the last dimension and the following is computed:
.. math::
\text{SwiGLU}(a,b) = \text{SiLU}(a) * b
where
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:``a`` and
:math:``b``. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
The Sigmoid Linear Unit (SiLU) gating function is also known as
the swish function. See
``GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>``__.
Parameters
----------
cache_quantized_input : bool, default = False
Quantize input tensor when caching for use in the backward
pass. This will typically reduce memory usage but require
extra compute and increase numerical error. This feature is
highly experimental.
glu_interleave_size : int, optional
When set, the GLU activations will use a block interleaved
format. Instead of interpreting the input tensor as a
concatenation of gates and linear units (e.g.
:math:``[a_1, a_2, a_3, a_4, b_1, b_2, b_3, b_4]``
in the above notation), it will be interpreted
as alternating blocks of gates and linear units (e.g.
:math:``[a_1, a_2, b_1, b_2, a_3, a_4, b_3, b_4]``
when the interleave size is 2). This data format is highly
experiental and is primarily intended to support some advanced
fused kernels.
"""
def __init__(
self,
*,
cache_quantized_input: bool = False,
glu_interleave_size: Optional[int] = None,
):
super().__init__()
self.cache_quantized_input: bool = cache_quantized_input
self.glu_interleave_size: Optional[int] = glu_interleave_size
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Compute dtype
dtype: torch.dtype
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = input_.dtype
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise RuntimeError(f"Unsupported dtype ({dtype})")
# Check input tensor
input_ = maybe_dequantize(input_.contiguous(), dtype)
# Remove interleaving if needed
swiglu_in = input_
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
# Launch kernel
out = tex.swiglu(swiglu_in, next_op_input_quantizer)
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
input_quantizer = Float8CurrentScalingQuantizer(
tex.DType.kFloat8E4M3,
input_.device,
)
input_quantizer.set_usage(rowwise=True, columnwise=False)
input_ = input_quantizer(input_)
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
ctx.save_for_backward(input_)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
return out
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
(input_,) = ctx.saved_tensors
# Make sure tensors have correct dtypes
x = maybe_dequantize(input_.contiguous(), ctx.dtype)
dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype)
# Remove interleaving if needed
swiglu_in = x
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
# Quantizer for grad input
quantizer = ctx.prev_op_grad_output_quantizer
if self.glu_interleave_size is not None:
quantizer = None
# Launch kernel
grad_swiglu_in = tex.dswiglu(dy, swiglu_in, quantizer)
# Apply interleaving if needed
dx = grad_swiglu_in
if self.glu_interleave_size is not None:
shape = dx.size()
dx = dx.reshape(
-1,
2,
shape[-1] // (2 * self.glu_interleave_size),
self.glu_interleave_size,
)
dx = dx.transpose(1, 2).contiguous()
dx = dx.view(shape)
# Clear input tensor if possible
clear_tensor_data(input_)
return dx, ()
class ClampedSwiGLU(BasicOperation):
r"""GPT-OSS
Implementation based on ``GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>``__.
This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.
.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is different
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.
Parameters
----------
limit : float
The clamp limit.
alpha : float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input : bool, default = ``False``
Quantize input tensor when caching for use in the backward pass.
glu_interleave_size : int, optional
When set, the GLU activations will use an experimental block
interleaved format. See the corresponding option in the SwiGLU
operation for more details.
"""
def __init__(
self,
*,
limit: float = 7.0,
alpha: float = 1.702,
cache_quantized_input: bool = False,
glu_interleave_size: Optional[int] = None,
):
super().__init__()
self.limit: float = limit
self.alpha: float = alpha
self.cache_quantized_input: bool = cache_quantized_input
self.glu_interleave_size: Optional[int] = glu_interleave_size
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Compute dtype
dtype: torch.dtype
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = input_.dtype
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise RuntimeError(f"Unsupported dtype ({dtype})")
# Check input tensor
x = maybe_dequantize(input_.contiguous(), dtype)
# Remove interleaving if needed
swiglu_in = input_
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
# Launch kernel
out = tex.clamped_swiglu(
swiglu_in,
next_op_input_quantizer,
limit=self.limit,
alpha=self.alpha,
)
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device)
input_quantizer.set_usage(rowwise=True, columnwise=False)
x = input_quantizer(x)
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
return out
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
(input_,) = ctx.saved_tensors
# Make sure tensors have correct dtypes
x = maybe_dequantize(input_.contiguous(), ctx.dtype)
dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype)
# Remove interleaving if needed
swiglu_in = x
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
# Quantizer for grad input
quantizer = ctx.prev_op_grad_output_quantizer
if self.glu_interleave_size is not None:
quantizer = None
# Launch kernel
grad_swiglu_in = tex.clamped_dswiglu(
dy,
swiglu_in,
quantizer,
limit=self.limit,
alpha=self.alpha,
)
# Apply interleaving if needed
dx = grad_swiglu_in
if self.glu_interleave_size is not None:
shape = dx.size()
dx = dx.reshape(
-1,
2,
shape[-1] // (2 * self.glu_interleave_size),
self.glu_interleave_size,
)
dx = dx.transpose(1, 2).contiguous()
dx = dx.view(shape)
# Clear input tensor if possible
clear_tensor_data(input_)
return dx, ()
class ScaledSwiGLU(BasicOperation):
r"""SwiGLU with post-scaling.
If the SwiGLU output has shape ``(d_1, ..., d_n)``, it is
multiplied with an extra input tensor of shape
``(d_1, ..., d_{n-1})``.
Parameters
----------
glu_interleave_size : int, optional
When set, the GLU activations will use an experimental block
interleaved format. See the corresponding option in the SwiGLU
operation for more details.
"""
# Operation expects scales
num_extra_inputs: int = 1
def __init__(self, glu_interleave_size: Optional[int] = None):
super().__init__()
self.glu_interleave_size: Optional[int] = glu_interleave_size
def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError(
f"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_forward` instead of `op_forward`."
)
def op_backward(self, *args, **kwargs) -> None:
raise RuntimeError(
f"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_backward` instead of `op_backward`."
)
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
extra_input = basic_op_extra_inputs[0][0]
# Determine compute dtype
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
elif isinstance(input_, torch.Tensor):
dtype = input_.dtype
else:
dtype = extra_input.dtype
# Make sure inputs are in correct dtype
input_ = maybe_dequantize(input_, dtype)
scales = maybe_dequantize(extra_input, dtype)
# Remove gate interleaving if needed
swiglu_in = input_
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
# Compute scaled SwiGLU
swiglu_out = tex.swiglu(swiglu_in, None)
out = swiglu_out * scales.unsqueeze(-1)
# Save state for backward pass
ctx = basic_op_ctxs[0]
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
ctx.input_requires_grad = True
ctx.extra_input_requires_grad = extra_input.requires_grad
ctx.dtype = dtype
ctx.save_for_backward(
input_,
scales if ctx.input_requires_grad else None,
)
return out, [()]
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]],
]:
ctx = basic_op_ctxs[0]
input_, scales = ctx.saved_tensors
input_ = maybe_dequantize(input_, ctx.dtype)
if scales is not None:
scales = maybe_dequantize(scales, ctx.dtype)
grad_output = maybe_dequantize(grad_output, ctx.dtype)
# Remove gate interleaving if needed
swiglu_in = input_
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
# Compute input grad
grad_input = None
if ctx.input_requires_grad:
grad_swiglu_out = grad_output * scales.unsqueeze(-1)
grad_swiglu_in = tex.dswiglu(grad_swiglu_out, swiglu_in, None)
grad_input = grad_swiglu_in
if self.glu_interleave_size is not None:
shape = grad_input.size()
grad_input = grad_input.reshape(
-1,
2,
shape[-1] // (2 * self.glu_interleave_size),
self.glu_interleave_size,
)
grad_input = grad_input.transpose(1, 2).contiguous()
grad_input = grad_input.view(shape)
# Compute scales grad by recomputing SwiGLU
grad_extra_input = None
if ctx.extra_input_requires_grad:
swiglu_out = tex.swiglu(swiglu_in, None)
grad_extra_input = torch.linalg.vecdot(swiglu_out, grad_output)
# Clear input tensor if possible
clear_tensor_data(ctx.saved_tensors[0]) # input_
return grad_input, [()], [(grad_extra_input,)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment