Unverified Commit dab931a7 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Improve L2Normalization basic op (#1964)



* Increase intermediate precision and reuse tensors from fwd
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* JIT warmup only when required
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Recompute only rsqrt_norm
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2a293456
...@@ -1400,9 +1400,6 @@ class TestBasicOps: ...@@ -1400,9 +1400,6 @@ class TestBasicOps:
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(y_test, y_ref, **tols)
# L2Norm backward pass requires slightly looser atol for bfloat16
if dtype == torch.bfloat16:
tols["atol"] = 2e-3
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
......
...@@ -134,30 +134,43 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: ...@@ -134,30 +134,43 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
@jit_fuser @jit_fuser
def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor: def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor:
"""L2 normalization fused - inference version""" """L2 normalization fused - inference version"""
x_squared = x.pow(2) x_fp32 = x.float()
x_squared = x_fp32.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
return x * rsqrt_norm y_fp32 = x_fp32 * rsqrt_norm
return y_fp32.to(x.dtype)
@jit_fuser @jit_fuser
def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]: def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
"""L2 normalization fused - training version that returns intermediate values""" """L2 normalization fused - training version that returns intermediate values"""
x_squared = x.pow(2) x_fp32 = x.float()
x_squared = x_fp32.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) l2_norm_squared_eps = l2_norm_squared + eps
y = x * rsqrt_norm rsqrt_norm = torch.rsqrt(l2_norm_squared_eps)
y_fp32 = x_fp32 * rsqrt_norm
y = y_fp32.to(x.dtype)
return y, rsqrt_norm return y, rsqrt_norm
@jit_fuser @jit_fuser
def l2normalization_backward_fused_( def l2normalization_backward_fused_(
grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float grad_output: torch.Tensor,
x: torch.Tensor,
rsqrt_norm: torch.Tensor,
eps: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""L2 normalization backward fused""" """L2 normalization backward fused"""
x_dy_sum = (x * grad_output).sum(dim=-1, keepdim=True) x_fp32 = x.float()
x_norm_squared = x.pow(2).sum(dim=-1, keepdim=True) + eps grad_output_fp32 = grad_output.float()
return rsqrt_norm * (grad_output - x * x_dy_sum / x_norm_squared) x_dy_sum = (x_fp32 * grad_output_fp32).sum(dim=-1, keepdim=True)
x_squared = x_fp32.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
x_norm_squared = l2_norm_squared + eps
dx_fp32 = rsqrt_norm * (grad_output_fp32 - x_fp32 * x_dy_sum / x_norm_squared)
return dx_fp32.to(x.dtype)
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
...@@ -191,7 +204,10 @@ def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor ...@@ -191,7 +204,10 @@ def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor
def l2normalization_backward_fused( def l2normalization_backward_fused(
grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float grad_output: torch.Tensor,
x: torch.Tensor,
rsqrt_norm: torch.Tensor,
eps: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Disable native AMP for l2normalization_backward_fused_""" """Disable native AMP for l2normalization_backward_fused_"""
with gpu_autocast_ctx(enabled=False): with gpu_autocast_ctx(enabled=False):
......
...@@ -6,10 +6,12 @@ ...@@ -6,10 +6,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional from typing import Optional
import os
import torch import torch
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ... import torch_version
from .._common import maybe_dequantize from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from ...jit import ( from ...jit import (
...@@ -60,7 +62,11 @@ class L2Normalization(BasicOperation): ...@@ -60,7 +62,11 @@ class L2Normalization(BasicOperation):
# JIT warmup for L2Normalization fused operations # JIT warmup for L2Normalization fused operations
if seq_length and micro_batch_size: if seq_length and micro_batch_size:
if torch.cuda.is_available(): if (
torch.cuda.is_available()
and torch_version() >= (2, 0, 0)
and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1")))
):
set_jit_fusion_options() set_jit_fusion_options()
# For L2Normalization, we don't know the hidden size until forward pass, # For L2Normalization, we don't know the hidden size until forward pass,
# but we can warm up with common sizes. For QK normalization, this will be # but we can warm up with common sizes. For QK normalization, this will be
...@@ -86,7 +92,7 @@ class L2Normalization(BasicOperation): ...@@ -86,7 +92,7 @@ class L2Normalization(BasicOperation):
# Compute L2 normalization using fused implementation # Compute L2 normalization using fused implementation
# L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps) # L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps)
if requires_grad: if requires_grad:
# Training: use version that returns both output and intermediate values # Training: use version that returns output and intermediate values for backward pass
y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps) y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps)
else: else:
# Inference: use lightweight version that only returns output # Inference: use lightweight version that only returns output
...@@ -110,7 +116,7 @@ class L2Normalization(BasicOperation): ...@@ -110,7 +116,7 @@ class L2Normalization(BasicOperation):
dy = maybe_dequantize(grad_output) dy = maybe_dequantize(grad_output)
# Compute L2 norm backward pass using fused implementation # Compute L2 norm backward pass using fused implementation - recalculates l2_norm_squared_eps
dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps) dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps)
# Clear saved tensors if possible # Clear saved tensors if possible
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment