Unverified Commit c09411d8 authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

[Pytorch][Bug]MXFP8 Split tensor Bug fix (#2427)



* bug fixed, test added
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix contigous
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* revert unecessary change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* revert another change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* missed adding renamed file
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix minor issue
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix ci issue
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the test for bfloat16
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent fd91bae3
...@@ -32,7 +32,7 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED ...@@ -32,7 +32,7 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
......
...@@ -13,9 +13,15 @@ import transformer_engine.common.recipe ...@@ -13,9 +13,15 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
Float8Quantizer, Float8Quantizer,
Float8Tensor,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
Float8BlockQuantizer,
MXFP8Quantizer,
NVFP4Quantizer,
Float8Tensor,
MXFP8Tensor,
NVFP4Tensor,
) )
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -47,6 +53,12 @@ DimsType = Union[Iterable[int], int] ...@@ -47,6 +53,12 @@ DimsType = Union[Iterable[int], int]
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
# delayed scaling # delayed scaling
def to_float8( def to_float8(
...@@ -452,3 +464,88 @@ class TestCurrentScalingFloat8Tensor: ...@@ -452,3 +464,88 @@ class TestCurrentScalingFloat8Tensor:
# Make sure we are not trivially passing the test # Make sure we are not trivially passing the test
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype]) torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype])
class TestAllQuantizedTensors:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("quantization", ["fp8", "mxfp8", "nvfp4", "fp8_blockwise"])
@pytest.mark.parametrize("dim", [0, 1])
def test_chunk(
self,
quantization: str,
dim: int,
shape: Iterable[int] = (128, 128),
chunks: int = 2,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = "cuda",
) -> None:
# Skip invalid configs
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_blockwise" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
# Create quantizer
if quantization == "fp8":
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
elif quantization == "mxfp8":
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
elif quantization == "fp8_blockwise":
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=True,
force_pow_2_scales=True,
amax_epsilon=0.0,
block_scaling_dim=1,
)
elif quantization == "nvfp4":
quantizer = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)
else:
raise ValueError(f"Unknown quantizer ({quantizer})")
# Create reference and quantized tensor
ref_tensor = torch.randn(shape, device=device, dtype=dtype)
quantized_tensor = quantizer(ref_tensor)
ref_tensor.copy_(quantized_tensor)
# Chunk tensors
ref_splits = torch.chunk(ref_tensor, chunks, dim=dim)
quantized_splits = torch.chunk(quantized_tensor, chunks, dim=dim)
# Check splits
for ref_split, quantized_split in zip(ref_splits, quantized_splits):
# Check split shapes
assert ref_split.size() == quantized_split.size()
# Check that splits are quantized when expected
if quantization == "fp8":
assert isinstance(quantized_split, Float8Tensor)
expected_value = quantized_split.dequantize()
elif quantization == "mxfp8" and dim == 0:
assert isinstance(quantized_split, MXFP8Tensor)
expected_value = quantized_split.dequantize()
else:
# Otherwise torch dispatch would default to base implementation
# dequantize and computing output and hence output from torch chunk
# is already dequantized.
expected_value = quantized_split
# Check values
torch.testing.assert_close(expected_value, ref_split)
...@@ -493,10 +493,10 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -493,10 +493,10 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
# Convert PyTorch dtype to TE dtype # Convert PyTorch dtype to TE dtype
if dtype is None: if dtype is None:
dtype = self.dtype dtype = self.dtype
tensor = self.contiguous()
if torch.is_grad_enabled(): if torch.is_grad_enabled():
return _FromFloat8Func.apply(self, dtype) return _FromFloat8Func.apply(tensor, dtype)
return _FromFloat8Func.forward(None, self, dtype) return _FromFloat8Func.forward(None, tensor, dtype)
def quantize_( def quantize_(
self, self,
...@@ -554,13 +554,19 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -554,13 +554,19 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
Returns `self` if data is already in correct memory format. Returns `self` if data is already in correct memory format.
""" """
if self._data is not None and self._data.is_contiguous(memory_format=memory_format): # requires_grad remains unaltered when calling contiguous on
return self # torch tensor and so should be the case for our custom float8 tensor
if self._transpose is not None and self._transpose.is_contiguous( # as well.
memory_format=memory_format return Float8Tensor.make_like(
): tensor=self,
return self data=self._data.contiguous(memory_format=memory_format),
return Float8Tensor.make_like(tensor=self, data=self._data.contiguous()) data_transpose=(
self._transpose.contiguous(memory_format=memory_format)
if self._transpose is not None
else None
),
requires_grad=self.requires_grad,
)
# raise ValueError("Float8Tensor does not support different memory formats!") # raise ValueError("Float8Tensor does not support different memory formats!")
......
...@@ -434,13 +434,16 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -434,13 +434,16 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
if scale_inv is not None if scale_inv is not None
else None else None
) )
scale_inv_out = list(scale_inv_out) if scale_inv_out is not None else None
# Pad scale_inv_out to be a multiple of pad_multiple # Pad scale_inv_out to be a multiple of pad_multiple
if scale_inv_out is not None: if scale_inv_out is not None:
current_shape = scale_inv_out.shape for idx, split_scale_inv_out in enumerate(scale_inv_out):
current_shape = split_scale_inv_out.shape
pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple
if pad_dim0 > 0: if pad_dim0 > 0:
scale_inv_out = torch.nn.functional.pad(scale_inv_out, (0, 0, 0, pad_dim0)) scale_inv_out[idx] = torch.nn.functional.pad(
split_scale_inv_out, (0, 0, 0, pad_dim0)
)
out_data.append(scale_inv_out) out_data.append(scale_inv_out)
return [ return [
MXFP8Tensor( MXFP8Tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment