"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "1bd45b9708ed6877e1f89e1d7b95eb61f3bdd90c"
Unverified Commit 40a30a5f authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Support L2Normalization basic op -> use for qk_norm (#1864)



* Support L2Norm basic op
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Add L2Norm module wrapper
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Expose qk_norm to MHA nd transformer laayer
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Move tests into separate file
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix pass
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Add license
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Remove  module
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Resollve comments
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c9d7f3f2
......@@ -1273,6 +1273,58 @@ class TestBasicOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("in_shape", ((32,), (6, 16, 64), (32, 64)))
@pytest.mark.parametrize("dtype", _dtypes)
def test_l2normalization(
self,
*,
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
eps: float = 1e-6,
) -> None:
"""L2 Normalization"""
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
# L2 norm: x / ||x||_2 = x / sqrt(sum(x^2) + eps)
l2_norm_squared = x_ref.pow(2).sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
y_ref = x_ref * rsqrt_norm
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.L2Normalization(
eps=eps,
)
y_test = op(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
# L2Norm backward pass requires slightly looser atol for bfloat16
if dtype == torch.bfloat16:
tols["atol"] = 2e-3
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
......
......@@ -63,3 +63,62 @@ def test_lazy_compile():
from transformer_engine.pytorch.jit import dgelu_fused_
dgelu_fused_(torch.randn(10, 10), torch.randn(10, 10))
def test_l2normalization_fused():
"""Smoke test for L2Normalization fusion functions."""
from transformer_engine.pytorch.jit import (
l2normalization_fused,
l2normalization_fwd_fused,
l2normalization_backward_fused,
)
# Basic smoke test like other JIT functions
x = torch.randn(10, 128, device="cuda", dtype=torch.float32)
eps = 1e-6
# Test inference version
output_inf = l2normalization_fused(x, eps)
# Test training version with backward
x_train = torch.randn(10, 128, device="cuda", dtype=torch.float32, requires_grad=True)
output_train, rsqrt_norm = l2normalization_fwd_fused(x_train, eps)
grad_output = torch.randn_like(output_train)
grad_input = l2normalization_backward_fused(grad_output, x_train, rsqrt_norm, eps)
def test_l2normalization_fused_correctness():
"""Simple verification that L2Normalization fusion matches reference implementation."""
from transformer_engine.pytorch.jit import (
l2normalization_fwd_fused,
l2normalization_backward_fused,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(16, 64, device=device, dtype=torch.float32, requires_grad=True)
eps = 1e-6
# Test fused forward
output_fused, rsqrt_norm = l2normalization_fwd_fused(x, eps)
# Reference implementation
x_ref = x.clone().detach().requires_grad_(True)
x_squared = x_ref.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm_ref = torch.rsqrt(l2_norm_squared + eps)
output_ref = x_ref * rsqrt_norm_ref
# Check forward pass matches
torch.testing.assert_close(output_fused, output_ref, atol=1e-6, rtol=1e-5)
torch.testing.assert_close(rsqrt_norm, rsqrt_norm_ref, atol=1e-6, rtol=1e-5)
# Test fused backward
grad_output = torch.randn_like(output_fused)
grad_input_fused = l2normalization_backward_fused(grad_output, x, rsqrt_norm, eps)
# Reference backward
output_ref.backward(grad_output)
grad_input_ref = x_ref.grad
# Check backward pass matches
torch.testing.assert_close(grad_input_fused, grad_input_ref, atol=1e-5, rtol=1e-4)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from transformer_engine.pytorch import MultiheadAttention
import pytest
import torch
@pytest.mark.parametrize("use_qk_norm", [False, True])
@pytest.mark.parametrize("attention_type", ["self", "cross"])
@pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5])
def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None:
"""Test QK normalization functionality, module structure, and numerical behavior."""
hidden_size = 256
num_attention_heads = 8
seq_len = 128
# Create MultiheadAttention module
mha = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_type=attention_type,
use_qk_norm=use_qk_norm,
qk_norm_eps=qk_norm_eps,
bias=False,
device="cuda",
).cuda()
# Check module structure based on use_qk_norm parameter
if use_qk_norm:
assert hasattr(mha, "qk_norm"), "Should have qk_norm module when use_qk_norm=True"
assert not hasattr(mha, "q_l2norm"), "Should not have separate q_l2norm module"
assert not hasattr(mha, "k_l2norm"), "Should not have separate k_l2norm module"
# Check that the module is L2Norm type
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
assert isinstance(
mha.qk_norm, L2Normalization
), "qk_norm should be an L2Normalization module"
else:
assert not hasattr(mha, "qk_norm"), "Should not have qk_norm module when use_qk_norm=False"
# Create input tensors
batch_size = 2 # Use a fixed batch size for testing
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
if attention_type == "cross":
encoder_output = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
else:
encoder_output = None
# Test forward pass
with torch.no_grad():
if attention_type == "cross":
output = mha(hidden_states, encoder_output=encoder_output)
else:
output = mha(hidden_states)
# Check output shape and numerical properties
assert output.shape == (
seq_len,
batch_size,
hidden_size,
), f"Output shape mismatch: {output.shape}"
assert not torch.isnan(output).any(), "Output contains NaN"
assert not torch.isinf(output).any(), "Output contains Inf"
# Test with RoPE (if self-attention)
if attention_type == "self":
head_dim = hidden_size // num_attention_heads
rotary_dim = head_dim // 2
rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32)
with torch.no_grad():
output_with_rope = mha(hidden_states, rotary_pos_emb=rotary_pos_emb)
assert output_with_rope.shape == (
seq_len,
batch_size,
hidden_size,
), "Output shape with RoPE mismatch"
assert not torch.isnan(output_with_rope).any(), "RoPE output contains NaN"
assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf"
def test_qk_norm_output_difference() -> None:
"""Test that QK normalization actually changes the output compared to no normalization."""
hidden_size = 256
num_attention_heads = 8
seq_len = 128
batch_size = 2
# Use same random seed to ensure identical weight initialization
current_rng_state = torch.get_rng_state()
current_cuda_rng_state = torch.cuda.get_rng_state()
# Reset to a known seed for reproducible initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create model with QK normalization
mha_with_norm = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
use_qk_norm=True,
bias=False,
device="cuda",
).cuda()
# Reset to same seed for identical initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create identical model without QK normalization
mha_no_norm = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
use_qk_norm=False,
bias=False,
device="cuda",
).cuda()
# Create input tensors
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
# Compare outputs with identical weights but different QK norm settings
with torch.no_grad():
output_with_norm = mha_with_norm(hidden_states)
output_no_norm = mha_no_norm(hidden_states)
# Outputs should be different when QK normalization is enabled
assert not torch.allclose(
output_with_norm, output_no_norm, atol=1e-6
), "QK normalization should change the output, but outputs are identical"
def test_qk_norm_with_fused_qkv() -> None:
"""Test QK normalization works with fused QKV parameters."""
hidden_size = 256
num_attention_heads = 8
seq_len = 64
mha = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
fuse_qkv_params=True,
use_qk_norm=True,
bias=False,
device="cuda",
).cuda()
# Create input and test forward pass
batch_size = 2 # Use a fixed batch size for testing
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
with torch.no_grad():
output = mha(hidden_states)
assert output.shape == (
seq_len,
batch_size,
hidden_size,
), f"Output shape mismatch: {output.shape}"
def test_qk_norm_transformer_layer_output_difference() -> None:
"""Test that QK normalization actually changes TransformerLayer output compared to no normalization."""
from transformer_engine.pytorch import TransformerLayer
hidden_size = 256
ffn_hidden_size = 1024
num_attention_heads = 8
seq_len = 128
batch_size = 2
# Use same random seed to ensure identical weight initialization
current_rng_state = torch.get_rng_state()
current_cuda_rng_state = torch.cuda.get_rng_state()
# Reset to a known seed for reproducible initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create TransformerLayer with QK normalization
transformer_with_norm = TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
use_qk_norm=True,
bias=False,
device="cuda",
).cuda()
# Reset to same seed for identical initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create identical TransformerLayer without QK normalization
transformer_no_norm = TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
use_qk_norm=False,
bias=False,
device="cuda",
).cuda()
# Create input tensors
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
# Compare outputs with identical weights but different QK norm settings
with torch.no_grad():
output_with_norm = transformer_with_norm(hidden_states)
output_no_norm = transformer_no_norm(hidden_states)
# Outputs should be different when QK normalization is enabled
assert not torch.allclose(
output_with_norm, output_no_norm, atol=1e-6
), "QK normalization should change the TransformerLayer output, but outputs are identical"
# Check that outputs have expected shapes and properties
assert output_with_norm.shape == (
seq_len,
batch_size,
hidden_size,
), f"Output shape mismatch: {output_with_norm.shape}"
assert not torch.isnan(output_with_norm).any(), "Output with QK norm contains NaN"
assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf"
assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN"
assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf"
......@@ -12,6 +12,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
from transformer_engine.pytorch.utils import (
SplitAlongDim,
divide,
......@@ -174,6 +175,22 @@ class MultiheadAttention(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False'
if set to `True`, L2 normalization is applied to query and key tensors
after RoPE (if applicable) but before attention computation.
This follows the Llama4 approach for QK normalization to improve
training stability and model performance.
qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors.
Only used when `use_qk_norm` is True.
seq_length: Optional[int], default = `None`
sequence length of input samples. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are used for
forward propagation and activation recompute phase.
micro_batch_size: Optional[int], default = `None`
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase.
"""
def __init__(
......@@ -214,6 +231,10 @@ class MultiheadAttention(torch.nn.Module):
device: Union[torch.device, str] = "cuda",
qkv_format: str = "sbhd",
name: str = None,
use_qk_norm: bool = False,
qk_norm_eps: float = 1e-6,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
) -> None:
super().__init__()
......@@ -267,6 +288,7 @@ class MultiheadAttention(torch.nn.Module):
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
self.name = name
self.use_qk_norm = use_qk_norm
common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
......@@ -278,6 +300,14 @@ class MultiheadAttention(torch.nn.Module):
"device": device,
}
# Initialize L2 normalization modules for query and key if enabled
if self.use_qk_norm:
self.qk_norm = L2Normalization(
eps=qk_norm_eps,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
)
qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self":
......@@ -812,6 +842,14 @@ class MultiheadAttention(torch.nn.Module):
interleaved=self.rotary_pos_interleaved,
)
# ===========================
# Apply L2 normalization to query and key tensors
# ===========================
if self.use_qk_norm:
query_layer = self.qk_norm(query_layer)
key_layer = self.qk_norm(key_layer)
# ===========================
# Core attention computation
# ===========================
......
......@@ -121,6 +121,35 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
return dgelu
@jit_fuser
def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor:
"""L2 normalization fused - inference version"""
x_squared = x.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
return x * rsqrt_norm
@jit_fuser
def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
"""L2 normalization fused - training version that returns intermediate values"""
x_squared = x.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
y = x * rsqrt_norm
return y, rsqrt_norm
@jit_fuser
def l2normalization_backward_fused_(
grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float
) -> torch.Tensor:
"""L2 normalization backward fused"""
x_dy_sum = (x * grad_output).sum(dim=-1, keepdim=True)
x_norm_squared = x.pow(2).sum(dim=-1, keepdim=True) + eps
return rsqrt_norm * (grad_output - x * x_dy_sum / x_norm_squared)
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with gpu_autocast_ctx(enabled=False):
......@@ -139,6 +168,26 @@ def bgrad_dgelu_fused(
return None, dgelu_fused_(grad_output, inp)
def l2normalization_fused(x: torch.Tensor, eps: float) -> torch.Tensor:
"""Disable native AMP for l2normalization_fused_ - inference version"""
with gpu_autocast_ctx(enabled=False):
return l2normalization_fused_(x, eps)
def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
"""Disable native AMP for l2normalization_fwd_fused_ - training version"""
with gpu_autocast_ctx(enabled=False):
return l2normalization_fwd_fused_(x, eps)
def l2normalization_backward_fused(
grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float
) -> torch.Tensor:
"""Disable native AMP for l2normalization_backward_fused_"""
with gpu_autocast_ctx(enabled=False):
return l2normalization_backward_fused_(grad_output, x, rsqrt_norm, eps)
def bias_dropout_add(
x: torch.Tensor,
bias: torch.Tensor,
......@@ -264,3 +313,45 @@ def warmup_jit_bias_gelu_all_dtypes(
"""Call `warmup_jit_bias_gelu` for all training dtypes"""
for dtype in [torch.float32, torch.bfloat16, torch.float16]:
warmup_jit_bias_gelu(ffn_hidden_size, dtype, seq_length, micro_batch_size)
def warmup_jit_l2normalization(
hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int
) -> None:
"""Compile L2Normalization JIT function before the main training steps"""
# Save cuda RNG state to ensure warmup does not affect reproducibility.
rng_state = torch.cuda.get_rng_state()
inp = torch.rand(
(seq_length * micro_batch_size, hidden_size),
dtype=dtype,
device="cuda",
)
eps = 1e-6
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad in [False, True]:
inp.requires_grad = input_grad
for _ in range(5):
if input_grad:
# Test training version that returns intermediate values
output, rsqrt_norm = l2normalization_fwd_fused_(inp, eps)
# Test backward pass as well
grad_out = torch.rand_like(output)
_ = l2normalization_backward_fused_(grad_out, inp, rsqrt_norm, eps)
else:
# Test inference version
output = l2normalization_fused_(inp, eps)
del inp, output
torch.cuda.empty_cache()
torch.cuda.set_rng_state(rng_state)
def warmup_jit_l2normalization_all_dtypes(
hidden_size: int, seq_length: int, micro_batch_size: int
) -> None:
"""Call `warmup_jit_l2normalization` for all training dtypes"""
for dtype in [torch.float32, torch.bfloat16, torch.float16]:
warmup_jit_l2normalization(hidden_size, dtype, seq_length, micro_batch_size)
......@@ -11,6 +11,7 @@ from .all_reduce import AllReduce
from .basic_linear import BasicLinear
from .bias import Bias
from .identity import Identity
from .l2normalization import L2Normalization
from .layer_norm import LayerNorm
from .make_extra_output import MakeExtraOutput
from .quantize import Quantize
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusable operation for L2 Normalization."""
from __future__ import annotations
from typing import Optional
import torch
from ...tensor import QuantizedTensor
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from ...jit import (
l2normalization_fused,
l2normalization_fwd_fused,
l2normalization_backward_fused,
set_jit_fusion_options,
warmup_jit_l2normalization_all_dtypes,
)
class L2Normalization(BasicOperation):
r"""L2 Normalization
Applies L2 normalization over the last dimension of input tensors.
This is a parameter-free normalization that scales each vector to unit L2 norm.
.. math::
y = \frac{x}{\sqrt{\sum_{i} x_i^2 + \varepsilon}}
This operation is used e.g. for query-key normalization in attention mechanisms.
Parameters
----------
eps : float, default = 1e-6
A value added to the denominator for numerical stability
seq_length: int, default = None
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward
propagation and activation recompute phase.
micro_batch_size: int, default = None
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase.
"""
def __init__(
self,
*,
eps: float = 1e-6,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
) -> None:
super().__init__()
self.eps: float = eps
# JIT warmup for L2Normalization fused operations
if seq_length and micro_batch_size:
if torch.cuda.is_available():
set_jit_fusion_options()
# For L2Normalization, we don't know the hidden size until forward pass,
# but we can warm up with common sizes. For QK normalization, this will be
# the attention head dimension (hidden_size_per_attention_head), not the full
# model hidden dimension. Common head dimensions are 32, 64, 80, 96, 128, 256.
common_hidden_sizes = [32, 64, 80, 96, 128, 256]
for hidden_size in common_hidden_sizes:
warmup_jit_l2normalization_all_dtypes(hidden_size, seq_length, micro_batch_size)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Use input directly - torch.compile can handle multi-dimensional tensors
x = input_
if isinstance(x, QuantizedTensor):
x = x.dequantize()
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Compute L2 normalization using fused implementation
# L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps)
if requires_grad:
# Training: use version that returns both output and intermediate values
y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps)
else:
# Inference: use lightweight version that only returns output
y = l2normalization_fused(x, self.eps)
rsqrt_norm = None # Not needed for inference
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, rsqrt_norm)
ctx.has_prev_op = prev_op is not None
return y
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
x, rsqrt_norm = ctx.saved_tensors
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
# Compute L2 norm backward pass using fused implementation
dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps)
# Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x)
clear_tensor_data(rsqrt_norm)
# No parameters, so empty tuple for param grads
return dx, ()
......@@ -235,6 +235,14 @@ class TransformerLayer(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False'
if set to `True`, L2 normalization is applied to query and key tensors
after RoPE (if applicable) but before attention computation.
This follows the Llama4 approach for QK normalization to improve
training stability and model performance.
qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors.
Only used when `use_qk_norm` is True.
"""
def __init__(
......@@ -284,6 +292,8 @@ class TransformerLayer(torch.nn.Module):
device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd",
name: str = None,
use_qk_norm: bool = False,
qk_norm_eps: float = 1e-6,
) -> None:
super().__init__()
......@@ -373,6 +383,8 @@ class TransformerLayer(torch.nn.Module):
"ub_overlap_rs": ub_overlap_rs,
"ub_overlap_rs_dgrad": ub_overlap_rs_dgrad,
"qkv_format": self.attn_input_format,
"seq_length": seq_length,
"micro_batch_size": micro_batch_size,
}
self.self_attention = MultiheadAttention(
......@@ -384,6 +396,8 @@ class TransformerLayer(torch.nn.Module):
return_bias=not self.parallel_attention_mlp,
normalization=normalization,
device=device,
use_qk_norm=use_qk_norm,
qk_norm_eps=qk_norm_eps,
name=name + ".self_attention" if name is not None else None,
)
......@@ -398,6 +412,8 @@ class TransformerLayer(torch.nn.Module):
return_bias=True,
normalization=normalization,
device=device,
use_qk_norm=use_qk_norm,
qk_norm_eps=qk_norm_eps,
name=name + ".inter_attention" if name is not None else None,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment