Commit 53fa872c authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_release_v2.8' into release_v2.8

parents 27ddce40 40c69e75
This diff is collapsed.
This diff is collapsed.
......@@ -10,7 +10,6 @@ import pytest
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8CurrentScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype
......@@ -274,6 +273,14 @@ class TestFP8RecipeLinearBase:
if bgrad_list is not None and bgrad is not None:
bgrad_list.append(bgrad.detach().clone())
# Stack the results
return (
torch.stack(y_q_list),
torch.stack(dgrad_list),
torch.stack(wgrad_list),
torch.stack(bgrad_list) if bgrad_list is not None else None,
)
@classmethod
def run_linear(
cls,
......
This diff is collapsed.
......@@ -20,6 +20,7 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear
from transformer_engine.pytorch.distributed import fp8_autocast
......@@ -500,3 +501,39 @@ class TestFP8Recipe:
y = module(x, [batch_size])
else:
y = module(x)
fp4_available, reason_for_no_fp4 = FP8GlobalStateManager.is_nvfp4_available()
@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(304, 304),
(320, 256),
# # largest tile
(8192, 8192),
],
)
def test_fp4_dequantize(dtype, M, N):
q = NVFP4Quantizer()
a = torch.rand((M, N)).cuda().to(dtype=dtype)
starting_tensor = q(a)
dequantized_tensor = starting_tensor.dequantize()
new_tensor = q(dequantized_tensor)
torch.testing.assert_close(
new_tensor._rowwise_data,
starting_tensor._rowwise_data,
rtol=0,
atol=0,
)
new_dequantized_tensor = new_tensor.dequantize()
torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment