Unverified Commit 6fd62098 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

[PyTorch] Make sure Float8Tensor.contiguous supports autograd (#2533)



* add early return back (removed in 2427)
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make sure Float8Tensor.contiguous supports autograd

Expand quantized tensor tests to check identity ops.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 3e693970
...@@ -20,6 +20,7 @@ from transformer_engine.pytorch import ( ...@@ -20,6 +20,7 @@ from transformer_engine.pytorch import (
Float8Tensor, Float8Tensor,
MXFP8Tensor, MXFP8Tensor,
NVFP4Tensor, NVFP4Tensor,
QuantizedTensor,
) )
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
...@@ -50,14 +51,22 @@ def _to_list(x: Union[Iterable, Any]) -> List: ...@@ -50,14 +51,22 @@ def _to_list(x: Union[Iterable, Any]) -> List:
# Types that can be interpreted as tensor dims # Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int] DimsType = Union[Iterable[int], int]
# Check if FP8 is supported # Supported quantization recipes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True return_reason=True
) )
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
_quantization_list: List[str] = []
if fp8_available:
_quantization_list.append("fp8")
if fp8_block_scaling_available:
_quantization_list.append("fp8_blockwise")
if mxfp8_available:
_quantization_list.append("mxfp8")
if nvfp4_available:
_quantization_list.append("nvfp4")
# delayed scaling # delayed scaling
...@@ -98,6 +107,79 @@ def to_float8_CS( ...@@ -98,6 +107,79 @@ def to_float8_CS(
return quantizer(tensor) return quantizer(tensor)
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
The reference tensor is intended for use in plain PyTorch
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
if quantization is None:
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "fp8_blockwise":
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=True,
force_pow_2_scales=True,
amax_epsilon=0.0,
block_scaling_dim=1,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
elif quantization == "nvfp4":
test = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
# Make sure reference and test tensors match each other
ref.copy_(test)
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor: class TestFloat8Tensor:
...@@ -466,7 +548,7 @@ class TestCurrentScalingFloat8Tensor: ...@@ -466,7 +548,7 @@ class TestCurrentScalingFloat8Tensor:
torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype]) torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype])
class TestAllQuantizedTensors: class TestQuantizedTensor:
@staticmethod @staticmethod
def setup_class(cls) -> None: def setup_class(cls) -> None:
# Configure RNG # Configure RNG
...@@ -474,10 +556,69 @@ class TestAllQuantizedTensors: ...@@ -474,10 +556,69 @@ class TestAllQuantizedTensors:
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("quantization", ["fp8", "mxfp8", "nvfp4", "fp8_blockwise"]) @pytest.mark.parametrize("op", ("clone", "view", "reshape", "contiguous"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_identity_op(
self,
*,
op: str,
quantization: str,
shape: Iterable[int] = (128, 128),
dtype: torch.dtype = torch.bfloat16,
device: torch.device = "cuda",
) -> None:
"""Test operations that do not affect tensor values.
These operations are must produce outputs that are bit-wise
equivalent to the inputs. They must support autograd.
"""
# Create reference and quantized tensor
x_ref, x_test = make_reference_and_test_tensors(
shape=shape,
quantization=quantization,
test_dtype=dtype,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape=shape,
test_dtype=dtype,
requires_grad=False,
)
# Apply identity operation
if op == "clone":
y_ref = x_ref.clone()
y_test = x_test.clone()
elif op == "view":
y_ref = x_ref.view(shape)
y_test = x_test.view(shape)
elif op == "reshape":
y_ref = x_ref.reshape(shape)
y_test = x_test.reshape(shape)
elif op == "contiguous":
y_ref = x_ref.contiguous()
y_test = x_test.contiguous()
# Check autograd
y_test.backward(dy_test)
assert x_test.grad is not None
# Check values
tols = dict(rtol=0, atol=0)
if isinstance(y_test, QuantizedTensor):
y_test = y_test.dequantize()
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dx_ref = dy_ref
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("dim", [0, 1]) @pytest.mark.parametrize("dim", [0, 1])
def test_chunk( def test_chunk(
self, self,
*,
quantization: str, quantization: str,
dim: int, dim: int,
shape: Iterable[int] = (128, 128), shape: Iterable[int] = (128, 128),
...@@ -485,67 +626,33 @@ class TestAllQuantizedTensors: ...@@ -485,67 +626,33 @@ class TestAllQuantizedTensors:
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
device: torch.device = "cuda", device: torch.device = "cuda",
) -> None: ) -> None:
# Skip invalid configs
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_blockwise" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
# Create quantizer
if quantization == "fp8":
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
elif quantization == "mxfp8":
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
elif quantization == "fp8_blockwise":
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=True,
force_pow_2_scales=True,
amax_epsilon=0.0,
block_scaling_dim=1,
)
elif quantization == "nvfp4":
quantizer = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)
else:
raise ValueError(f"Unknown quantizer ({quantizer})")
# Create reference and quantized tensor # Create reference and quantized tensor
ref_tensor = torch.randn(shape, device=device, dtype=dtype) x_ref, x_test = make_reference_and_test_tensors(
quantized_tensor = quantizer(ref_tensor) shape=shape,
ref_tensor.copy_(quantized_tensor) quantization=quantization,
test_dtype=dtype,
)
# Chunk tensors # Chunk tensors
ref_splits = torch.chunk(ref_tensor, chunks, dim=dim) ys_ref = torch.chunk(x_ref, chunks, dim=dim)
quantized_splits = torch.chunk(quantized_tensor, chunks, dim=dim) ys_test = torch.chunk(x_test, chunks, dim=dim)
# Check splits # Check splits
for ref_split, quantized_split in zip(ref_splits, quantized_splits): for y_ref, y_test in zip(ys_ref, ys_test):
# Check split shapes # Check split shapes
assert ref_split.size() == quantized_split.size() assert y_ref.size() == y_test.size()
# Check that splits are quantized when expected # Check that splits are quantized when expected
if quantization == "fp8": if quantization == "fp8":
assert isinstance(quantized_split, Float8Tensor) assert isinstance(y_test, Float8Tensor)
expected_value = quantized_split.dequantize() y_test = y_test.dequantize()
elif quantization == "mxfp8" and dim == 0: elif quantization == "mxfp8" and dim == 0:
assert isinstance(quantized_split, MXFP8Tensor) assert isinstance(y_test, MXFP8Tensor)
expected_value = quantized_split.dequantize() y_test = y_test.dequantize()
else:
# Otherwise torch dispatch would default to base implementation
# dequantize and computing output and hence output from torch chunk
# is already dequantized.
expected_value = quantized_split
# Check values # Check values
torch.testing.assert_close(expected_value, ref_split) tols = dict(rtol=0, atol=0) # Chunking is exact
y_test = y_test.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
...@@ -551,24 +551,31 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -551,24 +551,31 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
) -> Float8Tensor: ) -> Float8Tensor:
"""Returns tensor with data in provided memory format """Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format. Returns ``self`` if data is already in correct memory format.
""" """
# requires_grad remains unaltered when calling contiguous on
# torch tensor and so should be the case for our custom float8 tensor
# as well.
return Float8Tensor.make_like(
tensor=self,
data=self._data.contiguous(memory_format=memory_format),
data_transpose=(
self._transpose.contiguous(memory_format=memory_format)
if self._transpose is not None
else None
),
requires_grad=self.requires_grad,
)
# raise ValueError("Float8Tensor does not support different memory formats!") # Check if tensor already has correct memory format
if self._data is not None and not self._data.is_contiguous(memory_format=memory_format):
pass
elif self._transpose is not None and not self._transpose.is_contiguous(
memory_format=memory_format
):
pass
else:
# Tensor has correct memory format, so return immediately
return self
# Construct tensor with correct data format
data, data_transpose = None, None
if self._data is not None:
data = self._data.contiguous(memory_format=memory_format)
if self._transpose is not None and not self._transpose_invalid:
data_transpose = self._transpose.contiguous(memory_format=memory_format)
return _IdentityFunc.apply(
self,
{"data": data, "data_transpose": data_transpose},
)
def _reset_caches(self) -> None: def _reset_caches(self) -> None:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment