Unverified Commit 374849e3 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Enable generic QK norm support (+ RMSNorm/LayerNorm) (#1966)



* Support RMSNorm for QK
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* rms -> RMSNorm, l2 -> L2Normalization (align with current pattern)
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Support LayerNorm + init refactor
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Before/after RoPE
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix pylint
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent e950ceb0
...@@ -8,10 +8,10 @@ import pytest ...@@ -8,10 +8,10 @@ import pytest
import torch import torch
@pytest.mark.parametrize("use_qk_norm", [False, True]) @pytest.mark.parametrize("qk_norm_type", [None, "L2Normalization", "RMSNorm", "LayerNorm"])
@pytest.mark.parametrize("attention_type", ["self", "cross"]) @pytest.mark.parametrize("attention_type", ["self", "cross"])
@pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5]) @pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5])
def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None: def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> None:
"""Test QK normalization functionality, module structure, and numerical behavior.""" """Test QK normalization functionality, module structure, and numerical behavior."""
hidden_size = 256 hidden_size = 256
num_attention_heads = 8 num_attention_heads = 8
...@@ -22,25 +22,59 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None ...@@ -22,25 +22,59 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
attention_type=attention_type, attention_type=attention_type,
use_qk_norm=use_qk_norm, qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps, qk_norm_eps=qk_norm_eps,
bias=False, bias=False,
device="cuda", device="cuda",
).cuda() ).cuda()
# Check module structure based on use_qk_norm parameter # Check module structure based on qk_norm_type parameter
if use_qk_norm: if qk_norm_type is not None:
assert hasattr(mha, "qk_norm"), "Should have qk_norm module when use_qk_norm=True" assert mha.q_norm is not None, "Should have q_norm module when qk_norm_type is not None"
assert not hasattr(mha, "q_l2norm"), "Should not have separate q_l2norm module" assert mha.k_norm is not None, "Should have k_norm module when qk_norm_type is not None"
assert not hasattr(mha, "k_l2norm"), "Should not have separate k_l2norm module"
# Check that the module is L2Norm type # Check that the modules are of the correct type
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization if qk_norm_type == "L2Normalization":
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
assert isinstance(
mha.qk_norm, L2Normalization assert isinstance(
), "qk_norm should be an L2Normalization module" mha.q_norm, L2Normalization
), "q_norm should be an L2Normalization module"
assert isinstance(
mha.k_norm, L2Normalization
), "k_norm should be an L2Normalization module"
# For L2 normalization, q_norm and k_norm should be the same instance (parameter-free)
assert (
mha.q_norm is mha.k_norm
), "q_norm and k_norm should be the same instance for L2 normalization"
elif qk_norm_type == "RMSNorm":
from transformer_engine.pytorch.module.rmsnorm import RMSNorm
assert isinstance(mha.q_norm, RMSNorm), "q_norm should be an RMSNorm module"
assert isinstance(mha.k_norm, RMSNorm), "k_norm should be an RMSNorm module"
# For RMS normalization, q_norm and k_norm should be separate instances
assert (
mha.q_norm is not mha.k_norm
), "q_norm and k_norm should be separate instances for RMS normalization"
elif qk_norm_type == "LayerNorm":
from transformer_engine.pytorch.module.layernorm import LayerNorm
assert isinstance(mha.q_norm, LayerNorm), "q_norm should be a LayerNorm module"
assert isinstance(mha.k_norm, LayerNorm), "k_norm should be a LayerNorm module"
# For LayerNorm, q_norm and k_norm should be separate instances
assert (
mha.q_norm is not mha.k_norm
), "q_norm and k_norm should be separate instances for LayerNorm"
else:
# For extensibility - just ensure they exist
assert mha.q_norm is not None, f"q_norm should exist for qk_norm_type={qk_norm_type}"
assert mha.k_norm is not None, f"k_norm should exist for qk_norm_type={qk_norm_type}"
else: else:
assert not hasattr(mha, "qk_norm"), "Should not have qk_norm module when use_qk_norm=False" assert mha.q_norm is None, "Should not have q_norm module when qk_norm_type is None"
assert mha.k_norm is None, "Should not have k_norm module when qk_norm_type is None"
# Create input tensors # Create input tensors
batch_size = 2 # Use a fixed batch size for testing batch_size = 2 # Use a fixed batch size for testing
...@@ -89,17 +123,14 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None ...@@ -89,17 +123,14 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None
assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf" assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf"
def test_qk_norm_output_difference() -> None: @pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_output_difference(qk_norm_type) -> None:
"""Test that QK normalization actually changes the output compared to no normalization.""" """Test that QK normalization actually changes the output compared to no normalization."""
hidden_size = 256 hidden_size = 256
num_attention_heads = 8 num_attention_heads = 8
seq_len = 128 seq_len = 128
batch_size = 2 batch_size = 2
# Use same random seed to ensure identical weight initialization
current_rng_state = torch.get_rng_state()
current_cuda_rng_state = torch.cuda.get_rng_state()
# Reset to a known seed for reproducible initialization # Reset to a known seed for reproducible initialization
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
...@@ -108,7 +139,7 @@ def test_qk_norm_output_difference() -> None: ...@@ -108,7 +139,7 @@ def test_qk_norm_output_difference() -> None:
mha_with_norm = MultiheadAttention( mha_with_norm = MultiheadAttention(
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
use_qk_norm=True, qk_norm_type=qk_norm_type,
bias=False, bias=False,
device="cuda", device="cuda",
).cuda() ).cuda()
...@@ -121,7 +152,7 @@ def test_qk_norm_output_difference() -> None: ...@@ -121,7 +152,7 @@ def test_qk_norm_output_difference() -> None:
mha_no_norm = MultiheadAttention( mha_no_norm = MultiheadAttention(
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
use_qk_norm=False, qk_norm_type=None,
bias=False, bias=False,
device="cuda", device="cuda",
).cuda() ).cuda()
...@@ -139,10 +170,11 @@ def test_qk_norm_output_difference() -> None: ...@@ -139,10 +170,11 @@ def test_qk_norm_output_difference() -> None:
# Outputs should be different when QK normalization is enabled # Outputs should be different when QK normalization is enabled
assert not torch.allclose( assert not torch.allclose(
output_with_norm, output_no_norm, atol=1e-6 output_with_norm, output_no_norm, atol=1e-6
), "QK normalization should change the output, but outputs are identical" ), f"QK normalization ({qk_norm_type}) should change the output, but outputs are identical"
def test_qk_norm_with_fused_qkv() -> None: @pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_with_fused_qkv(qk_norm_type) -> None:
"""Test QK normalization works with fused QKV parameters.""" """Test QK normalization works with fused QKV parameters."""
hidden_size = 256 hidden_size = 256
num_attention_heads = 8 num_attention_heads = 8
...@@ -152,7 +184,7 @@ def test_qk_norm_with_fused_qkv() -> None: ...@@ -152,7 +184,7 @@ def test_qk_norm_with_fused_qkv() -> None:
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
fuse_qkv_params=True, fuse_qkv_params=True,
use_qk_norm=True, qk_norm_type=qk_norm_type,
bias=False, bias=False,
device="cuda", device="cuda",
).cuda() ).cuda()
...@@ -173,7 +205,8 @@ def test_qk_norm_with_fused_qkv() -> None: ...@@ -173,7 +205,8 @@ def test_qk_norm_with_fused_qkv() -> None:
), f"Output shape mismatch: {output.shape}" ), f"Output shape mismatch: {output.shape}"
def test_qk_norm_transformer_layer_output_difference() -> None: @pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_transformer_layer_output_difference(qk_norm_type) -> None:
"""Test that QK normalization actually changes TransformerLayer output compared to no normalization.""" """Test that QK normalization actually changes TransformerLayer output compared to no normalization."""
from transformer_engine.pytorch import TransformerLayer from transformer_engine.pytorch import TransformerLayer
...@@ -183,10 +216,6 @@ def test_qk_norm_transformer_layer_output_difference() -> None: ...@@ -183,10 +216,6 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
seq_len = 128 seq_len = 128
batch_size = 2 batch_size = 2
# Use same random seed to ensure identical weight initialization
current_rng_state = torch.get_rng_state()
current_cuda_rng_state = torch.cuda.get_rng_state()
# Reset to a known seed for reproducible initialization # Reset to a known seed for reproducible initialization
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
...@@ -196,7 +225,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None: ...@@ -196,7 +225,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
hidden_size=hidden_size, hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size, ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
use_qk_norm=True, qk_norm_type=qk_norm_type,
bias=False, bias=False,
device="cuda", device="cuda",
).cuda() ).cuda()
...@@ -210,7 +239,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None: ...@@ -210,7 +239,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
hidden_size=hidden_size, hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size, ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
use_qk_norm=False, qk_norm_type=None,
bias=False, bias=False,
device="cuda", device="cuda",
).cuda() ).cuda()
...@@ -226,9 +255,10 @@ def test_qk_norm_transformer_layer_output_difference() -> None: ...@@ -226,9 +255,10 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
output_no_norm = transformer_no_norm(hidden_states) output_no_norm = transformer_no_norm(hidden_states)
# Outputs should be different when QK normalization is enabled # Outputs should be different when QK normalization is enabled
assert not torch.allclose( assert not torch.allclose(output_with_norm, output_no_norm, atol=1e-6), (
output_with_norm, output_no_norm, atol=1e-6 f"QK normalization ({qk_norm_type}) should change the TransformerLayer output, but outputs"
), "QK normalization should change the TransformerLayer output, but outputs are identical" " are identical"
)
# Check that outputs have expected shapes and properties # Check that outputs have expected shapes and properties
assert output_with_norm.shape == ( assert output_with_norm.shape == (
...@@ -240,3 +270,120 @@ def test_qk_norm_transformer_layer_output_difference() -> None: ...@@ -240,3 +270,120 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf" assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf"
assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN" assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN"
assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf" assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf"
@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_before_after_rope(qk_norm_type) -> None:
"""Test that QK normalization before and after RoPE works without errors."""
hidden_size = 256
num_attention_heads = 8
seq_len = 64
batch_size = 2
# Create model with QK norm after RoPE (default)
mha_after = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
qk_norm_type=qk_norm_type,
qk_norm_before_rope=False,
bias=False,
device="cuda",
).cuda()
# Create model with QK norm before RoPE
mha_before = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
qk_norm_type=qk_norm_type,
qk_norm_before_rope=True,
bias=False,
device="cuda",
).cuda()
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
# Create RoPE embeddings
head_dim = hidden_size // num_attention_heads
rotary_dim = head_dim // 2
rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32)
with torch.no_grad():
output_after_rope = mha_after(hidden_states, rotary_pos_emb=rotary_pos_emb)
output_before_rope = mha_before(hidden_states, rotary_pos_emb=rotary_pos_emb)
output_after_no_rope = mha_after(hidden_states)
output_before_no_rope = mha_before(hidden_states)
# Check output shapes and properties
expected_shape = (seq_len, batch_size, hidden_size)
for output in [
output_after_rope,
output_before_rope,
output_after_no_rope,
output_before_no_rope,
]:
assert output.shape == expected_shape, f"Output shape mismatch: {output.shape}"
assert not torch.isnan(output).any(), "Output contains NaN"
assert not torch.isinf(output).any(), "Output contains Inf"
assert output_after_rope.shape == output_before_rope.shape, "Outputs should have same shape"
assert mha_after.qk_norm_before_rope == False, "mha_after should have qk_norm_before_rope=False"
assert mha_before.qk_norm_before_rope == True, "mha_before should have qk_norm_before_rope=True"
def test_different_qk_norm_types_produce_different_outputs() -> None:
"""Test that different QK normalization types produce different outputs."""
hidden_size = 256
num_attention_heads = 8
seq_len = 128
batch_size = 2
# Use same random seed to ensure identical weight initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create model with L2 normalization
mha_l2 = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
qk_norm_type="L2Normalization",
bias=False,
device="cuda",
).cuda()
# Reset to same seed for identical initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create model with RMS normalization
mha_rms = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
qk_norm_type="RMSNorm",
bias=False,
device="cuda",
).cuda()
# Create input tensors
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
# Compare outputs with identical weights but different QK norm types
with torch.no_grad():
output_l2 = mha_l2(hidden_states)
output_rms = mha_rms(hidden_states)
# Outputs should be different when using different normalization types
assert not torch.allclose(
output_l2, output_rms, atol=1e-6
), "L2 and RMS normalization should produce different outputs, but outputs are identical"
# Check that outputs have expected shapes and properties
assert output_l2.shape == output_rms.shape, "L2 and RMS outputs should have same shape"
assert not torch.isnan(output_l2).any(), "L2 output contains NaN"
assert not torch.isinf(output_l2).any(), "L2 output contains Inf"
assert not torch.isnan(output_rms).any(), "RMS output contains NaN"
assert not torch.isinf(output_rms).any(), "RMS output contains Inf"
...@@ -11,7 +11,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState ...@@ -11,7 +11,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
SplitAlongDim, SplitAlongDim,
...@@ -175,14 +175,23 @@ class MultiheadAttention(torch.nn.Module): ...@@ -175,14 +175,23 @@ class MultiheadAttention(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`. `fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False' qk_norm_type: Optional[str], default = None
if set to `True`, L2 normalization is applied to query and key tensors type of normalization to apply to query and key tensors.
after RoPE (if applicable) but before attention computation. Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied.
This follows the Llama4 approach for QK normalization to improve When 'L2Normalization', L2 normalization is applied to query and key tensors.
training stability and model performance. When 'RMSNorm', RMS normalization is applied to query and key tensors.
When 'LayerNorm', layer normalization is applied to query and key tensors.
Normalization is applied after RoPE (if applicable) but before attention computation
when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach
for QK normalization to improve training stability and model performance.
qk_norm_eps: float, default = 1e-6 qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors. epsilon value for normalization of query and key tensors.
Only used when `use_qk_norm` is True. Only used when `qk_norm_type` is not None.
qk_norm_before_rope: bool, default = `False`
if set to `True`, query and key normalization is applied before rotary position
embedding. When `False` (default), normalization is applied after RoPE.
This parameter allows supporting different architectural variants that apply
QK normalization at different points.
seq_length: Optional[int], default = `None` seq_length: Optional[int], default = `None`
sequence length of input samples. Needed for JIT Warmup, a technique where jit sequence length of input samples. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are used for fused functions are warmed up before training to ensure same kernels are used for
...@@ -231,8 +240,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -231,8 +240,9 @@ class MultiheadAttention(torch.nn.Module):
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
qkv_format: str = "sbhd", qkv_format: str = "sbhd",
name: str = None, name: str = None,
use_qk_norm: bool = False, qk_norm_type: Optional[str] = None,
qk_norm_eps: float = 1e-6, qk_norm_eps: float = 1e-6,
qk_norm_before_rope: bool = False,
seq_length: Optional[int] = None, seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
) -> None: ) -> None:
...@@ -264,6 +274,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -264,6 +274,7 @@ class MultiheadAttention(torch.nn.Module):
qkv_weight_interleaved = False qkv_weight_interleaved = False
self.qkv_weight_interleaved = qkv_weight_interleaved self.qkv_weight_interleaved = qkv_weight_interleaved
self.rotary_pos_interleaved = rotary_pos_interleaved self.rotary_pos_interleaved = rotary_pos_interleaved
self.qk_norm_before_rope = qk_norm_before_rope
assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
if layer_number is not None: if layer_number is not None:
...@@ -288,7 +299,6 @@ class MultiheadAttention(torch.nn.Module): ...@@ -288,7 +299,6 @@ class MultiheadAttention(torch.nn.Module):
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
self.name = name self.name = name
self.use_qk_norm = use_qk_norm
common_gemm_kwargs = { common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation, "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
...@@ -300,13 +310,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -300,13 +310,9 @@ class MultiheadAttention(torch.nn.Module):
"device": device, "device": device,
} }
# Initialize L2 normalization modules for query and key if enabled self.q_norm, self.k_norm = self._create_qk_norm_modules(
if self.use_qk_norm: qk_norm_type, qk_norm_eps, device, seq_length, micro_batch_size
self.qk_norm = L2Normalization( )
eps=qk_norm_eps,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
)
qkv_parallel_mode = "column" if set_parallel_mode else None qkv_parallel_mode = "column" if set_parallel_mode else None
...@@ -427,6 +433,78 @@ class MultiheadAttention(torch.nn.Module): ...@@ -427,6 +433,78 @@ class MultiheadAttention(torch.nn.Module):
**common_gemm_kwargs, **common_gemm_kwargs,
) )
def _create_qk_norm_modules(
self,
qk_norm_type: Optional[str],
qk_norm_eps: float,
device: Union[torch.device, str],
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
) -> Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]:
"""
Create query and key normalization modules based on the specified normalization type.
Parameters
----------
qk_norm_type : Optional[str]
Type of normalization to apply. Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'
qk_norm_eps : float
Epsilon value for numerical stability
device : Union[torch.device, str]
Device to place the normalization modules on
seq_length : Optional[int], default = None
Sequence length for L2Normalization optimization
micro_batch_size : Optional[int], default = None
Micro batch size for L2Normalization optimization
Returns
-------
Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]
Query and key normalization modules (q_norm, k_norm)
"""
if qk_norm_type is None:
return None, None
if qk_norm_type == "L2Normalization":
l2_norm = L2Normalization(
eps=qk_norm_eps,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
)
# L2Normalization is parameter-free, so we can share the same instance
return l2_norm, l2_norm
if qk_norm_type == "RMSNorm":
q_norm = RMSNorm(
normalized_shape=self.hidden_size_per_attention_head,
eps=qk_norm_eps,
device=device,
)
k_norm = RMSNorm(
normalized_shape=self.hidden_size_per_attention_head,
eps=qk_norm_eps,
device=device,
)
return q_norm, k_norm
if qk_norm_type == "LayerNorm":
q_norm = LayerNorm(
normalized_shape=self.hidden_size_per_attention_head,
eps=qk_norm_eps,
device=device,
)
k_norm = LayerNorm(
normalized_shape=self.hidden_size_per_attention_head,
eps=qk_norm_eps,
device=device,
)
return q_norm, k_norm
raise ValueError(
f"Unsupported QK norm type: {qk_norm_type}. "
"Supported types: ['L2Normalization', 'RMSNorm', 'LayerNorm']"
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
""" """
Set the tensor parallel group for the given Set the tensor parallel group for the given
...@@ -789,6 +867,14 @@ class MultiheadAttention(torch.nn.Module): ...@@ -789,6 +867,14 @@ class MultiheadAttention(torch.nn.Module):
) )
query_layer = query_layer.view(*new_tensor_shape) query_layer = query_layer.view(*new_tensor_shape)
# ===========================
# Apply normalization to query and key tensors (before RoPE if configured)
# ===========================
if self.q_norm is not None and self.qk_norm_before_rope:
query_layer = self.q_norm(query_layer)
key_layer = self.k_norm(key_layer)
# ====================================================== # ======================================================
# Apply relative positional encoding (rotary embedding) # Apply relative positional encoding (rotary embedding)
# ====================================================== # ======================================================
...@@ -843,12 +929,12 @@ class MultiheadAttention(torch.nn.Module): ...@@ -843,12 +929,12 @@ class MultiheadAttention(torch.nn.Module):
) )
# =========================== # ===========================
# Apply L2 normalization to query and key tensors # Apply normalization to query and key tensors (after RoPE if not applied before)
# =========================== # ===========================
if self.use_qk_norm: if self.q_norm is not None and not self.qk_norm_before_rope:
query_layer = self.qk_norm(query_layer) query_layer = self.q_norm(query_layer)
key_layer = self.qk_norm(key_layer) key_layer = self.k_norm(key_layer)
# =========================== # ===========================
# Core attention computation # Core attention computation
......
...@@ -236,14 +236,23 @@ class TransformerLayer(torch.nn.Module): ...@@ -236,14 +236,23 @@ class TransformerLayer(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`. `fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False' qk_norm_type: Optional[str], default = None
if set to `True`, L2 normalization is applied to query and key tensors type of normalization to apply to query and key tensors.
after RoPE (if applicable) but before attention computation. Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied.
This follows the Llama4 approach for QK normalization to improve When 'L2Normalization', L2 normalization is applied to query and key tensors.
training stability and model performance. When 'RMSNorm', RMS normalization is applied to query and key tensors.
When 'LayerNorm', layer normalization is applied to query and key tensors.
Normalization is applied after RoPE (if applicable) but before attention computation
when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach for
QK normalization to improve training stability and model performance.
qk_norm_eps: float, default = 1e-6 qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors. epsilon value for normalization of query and key tensors.
Only used when `use_qk_norm` is True. Only used when `qk_norm_type` is not None.
qk_norm_before_rope: bool, default = `False`
if set to `True`, query and key normalization is applied before rotary position
embedding. When `False` (default), normalization is applied after RoPE.
This parameter allows supporting different architectural variants that apply
QK normalization at different points.
""" """
def __init__( def __init__(
...@@ -293,8 +302,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -293,8 +302,9 @@ class TransformerLayer(torch.nn.Module):
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd", attn_input_format: str = "sbhd",
name: str = None, name: str = None,
use_qk_norm: bool = False, qk_norm_type: Optional[str] = None,
qk_norm_eps: float = 1e-6, qk_norm_eps: float = 1e-6,
qk_norm_before_rope: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -397,8 +407,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -397,8 +407,9 @@ class TransformerLayer(torch.nn.Module):
return_bias=not self.parallel_attention_mlp, return_bias=not self.parallel_attention_mlp,
normalization=normalization, normalization=normalization,
device=device, device=device,
use_qk_norm=use_qk_norm, qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps, qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope,
name=name + ".self_attention" if name is not None else None, name=name + ".self_attention" if name is not None else None,
) )
...@@ -413,8 +424,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -413,8 +424,9 @@ class TransformerLayer(torch.nn.Module):
return_bias=True, return_bias=True,
normalization=normalization, normalization=normalization,
device=device, device=device,
use_qk_norm=use_qk_norm, qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps, qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope,
name=name + ".inter_attention" if name is not None else None, name=name + ".inter_attention" if name is not None else None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment