Unverified Commit e1221735 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Cache RHT device tensors properly (#2395)



* Cache device tensors properly
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix annotation and add test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* skip nvfp4 test if not supported
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d677a269
......@@ -7,7 +7,16 @@ import sys
import pytest
import torch
import transformer_engine
from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear, GroupedLinear
from transformer_engine.pytorch import (
DotProductAttention,
TransformerLayer,
Linear,
GroupedLinear,
NVFP4Quantizer,
autocast,
is_nvfp4_available,
)
from transformer_engine.common import recipe
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
......@@ -17,6 +26,8 @@ model_configs = {
"small": ModelConfig(2, 10, 2, 16),
}
nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize(
......@@ -138,3 +149,24 @@ def test_current_device(model, module):
assert (
tensor_device_grad == tensor_device
), "The gradient tensor should be the same as the input tensors!"
@pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4)
def test_nvfp4_rht_cache():
"""Ensure correct RHT cache for NVFP4."""
num_devices = torch.cuda.device_count()
assert num_devices > 1, "This test requires more than one GPU!"
# Populate cache on last device.
with torch.cuda.device(num_devices - 1):
_ = NVFP4Quantizer()
hidden_size = 128
dtype = torch.bfloat16
model = Linear(hidden_size, hidden_size, params_dtype=dtype)
inp = torch.randn(hidden_size, hidden_size, device=torch.cuda.current_device(), dtype=dtype)
fp4_recipe = recipe.NVFP4BlockScaling()
with autocast(recipe=fp4_recipe):
_ = model(inp)
......@@ -28,9 +28,9 @@ from ._quantization_helpers import _IdentityFunc
aten = torch.ops.aten
def get_no_random_sign_vector() -> torch.Tensor:
def get_no_random_sign_vector(device: int) -> torch.Tensor:
"""Non-random sign vector for Hadamard transform."""
return torch.tensor([1], dtype=torch.float32, device="cuda")
return torch.tensor([1], dtype=torch.float32, device=device)
def get_sign_from_vector(vector: torch.Tensor) -> int:
......@@ -45,7 +45,7 @@ def get_sign_from_vector(vector: torch.Tensor) -> int:
return mask.item()
def get_wgrad_sign_vector() -> torch.Tensor:
def get_wgrad_sign_vector(device: int) -> torch.Tensor:
"""Hard-coded random signs for Hadamard transform.
https://xkcd.com/221/
......@@ -54,11 +54,11 @@ def get_wgrad_sign_vector() -> torch.Tensor:
return torch.tensor(
[1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1],
dtype=torch.float32,
device="cuda",
device=device,
)
def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
def get_hadamard_matrix(hadamard_dimension: int, device: int) -> torch.Tensor:
"""Construct a 16x16 Hadamard matrix."""
assert hadamard_dimension == 16, "Only hadamard dimension 16 is supported."
hadamard_scale = 1 / math.sqrt(hadamard_dimension)
......@@ -83,30 +83,30 @@ def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
[1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1],
],
dtype=torch.float32,
device="cuda",
device=device,
)
* hadamard_scale
)
@functools.lru_cache(maxsize=None)
def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor:
def get_rht_matrix(with_random_sign_mask: bool, device: int) -> torch.Tensor:
"""Construct matrix used in random Hadamard transform."""
hadamard_dimension = 16
if with_random_sign_mask:
signs = get_wgrad_sign_vector()
signs = get_wgrad_sign_vector(device=device)
else:
signs = get_no_random_sign_vector()
sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32, device="cuda")
rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension)
signs = get_no_random_sign_vector(device=device)
sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32, device=device)
rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension, device=device)
return rht_matrix.to(dtype=torch.bfloat16)
@functools.lru_cache(maxsize=None)
def get_random_sign_mask_for_rht(with_random_sign_mask: bool) -> int:
def get_random_sign_mask_for_rht(with_random_sign_mask: bool, device: int) -> int:
"""Sign mask for random Hadamard transform."""
if with_random_sign_mask:
return get_sign_from_vector(get_wgrad_sign_vector())
return get_sign_from_vector(get_wgrad_sign_vector(device=device))
return 0
......@@ -152,8 +152,10 @@ class NVFP4Quantizer(Quantizer):
self.amax_reduction_group = amax_reduction_group
self.with_2d_quantization = with_2d_quantization
self.stochastic_rounding = stochastic_rounding
self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht(with_random_sign_mask)
self.rht_matrix = get_rht_matrix(with_random_sign_mask)
self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht(
with_random_sign_mask, torch.cuda.current_device()
)
self.rht_matrix = get_rht_matrix(with_random_sign_mask, torch.cuda.current_device())
def update_quantized(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment