Unverified Commit d26cc3a0 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Add test for `LayerNormMLP` implementation using `te.ops.Sequential` to...


Add test for `LayerNormMLP` implementation using `te.ops.Sequential` to `test_fusible_ops.py` (#1924)

* Add e2e test for LayerNormMLP implemented with te.Sequential
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix bugs uncovered by test
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix reshaping columnwise_data of MXFP8Tensor
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix taking dtype from weight or grad_output in BasicLinear._functional_backward
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 1ae1d228
......@@ -152,7 +152,7 @@ def make_reference_and_test_tensors(
return ref, test
class TestSequential:
class TestSequentialContainer:
"""Tests for sequential container"""
def test_modules(self) -> None:
......@@ -2080,3 +2080,109 @@ class TestCheckpointing:
torch.testing.assert_close(y_load, y_save, **tols)
for x_load, x_save in zip(xs_load, xs_save):
torch.testing.assert_close(x_load.grad, x_save.grad, **tols)
class TestSequentialModules:
"""Test for larger Sequentials with modules commonly used together"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm"))
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_layernorm_mlp(
self,
*,
bias: bool,
normalization: str,
quantized_compute: bool,
quantized_weight: bool,
dtype: torch.dtype,
quantization: Optional[str],
device: torch.device = "cuda",
hidden_size: int = 32,
sequence_length: int = 512,
batch_size: int = 4,
ffn_hidden_size: int = 64,
layernorm_epsilon: float = 1e-5,
) -> None:
"""
LayerNorm/RMSNorm + Linear + GELU + Linear
Note that this test checks only if the module runs
as when chaining multiple modules it is hard to validate
numerical accuracy.
"""
# Make input shape
in_shape = (sequence_length, batch_size, hidden_size)
ffn_shape = in_shape[:-1] + (ffn_hidden_size,)
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=ffn_shape, device=device)
quantization_needed = quantized_compute or quantized_weight
if quantization is None and quantization_needed:
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not quantization_needed:
pytest.skip("Quantization scheme is not used")
# Random data
_, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
_, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
if normalization == "LayerNorm":
norm = te_ops.LayerNorm(
hidden_size,
eps=layernorm_epsilon,
device=device,
dtype=dtype,
)
else:
norm = te_ops.RMSNorm(
hidden_size,
eps=layernorm_epsilon,
device=device,
dtype=dtype,
)
ffn1 = te_ops.Linear(
hidden_size,
ffn_hidden_size,
bias=bias,
device=device,
dtype=dtype,
)
act = te_ops.GELU()
ffn2 = te_ops.Linear(
ffn_hidden_size,
hidden_size,
bias=bias,
device=device,
dtype=dtype,
)
forward = te_ops.Sequential(norm, ffn1, act, ffn2)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
......@@ -663,9 +663,9 @@ class BasicLinear(BasicOperation):
# Check datatype
if dtype is None:
if weight is not None and not is_quantized_tensor(weight):
if isinstance(weight, torch.Tensor):
dtype = weight.dtype
else:
elif isinstance(grad_output, torch.Tensor):
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
......@@ -693,6 +693,11 @@ class BasicLinear(BasicOperation):
else:
if not is_quantized_tensor(dy_local):
dy_local = grad_output_quantizer(dy_local)
else:
dy_local.update_usage(
rowwise_usage=input_requires_grad,
columnwise_usage=weight_requires_grad,
)
dy = dy_local
else:
dy_local = maybe_dequantize(dy_local, dtype)
......
......@@ -53,7 +53,10 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
*args,
**kwargs,
):
instance = super().__new__(cls, *args, **kwargs)
if cls is Float8BlockwiseQTensorBase:
instance = object.__new__(cls)
else:
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer
......
......@@ -145,16 +145,20 @@ class Float8TensorBase(QuantizedTensorBase):
def view(self, shape: torch.Size):
# pylint: disable=missing-function-docstring
data = self._data
if data is not None:
return Float8TensorBase(
data=data.view(shape),
fp8_scale_inv=self._scale_inv,
fp8_dtype=self._fp8_dtype,
data_transpose=None,
quantizer=self._quantizer,
)
raise RuntimeError("No data available to view")
out_data = self._data.view(shape)
out_transpose = None if self._transpose_invalid else self._transpose
if out_transpose is not None:
out_transpose_shape = out_transpose.size()
if out_transpose_shape[0] != shape[-1] or out_transpose_shape[1:] != shape[:-1]:
out_transpose = None
return Float8TensorBase(
data=out_data,
fp8_scale_inv=self._scale_inv,
fp8_dtype=self._fp8_dtype,
data_transpose=out_transpose,
quantizer=self._quantizer,
)
def __repr__(self):
return (
......
......@@ -6,6 +6,8 @@
from __future__ import annotations
from typing import Optional, Dict, Any, Tuple
from collections.abc import Iterable
import math
import torch
import transformer_engine_torch as tex
......@@ -75,7 +77,10 @@ class MXFP8TensorBase(QuantizedTensorBase):
quantizer: Optional[Quantizer] = None,
**kwargs,
):
instance = super().__new__(cls, *args, **kwargs)
if cls is MXFP8TensorBase:
instance = object.__new__(cls)
else:
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer
......@@ -145,6 +150,51 @@ class MXFP8TensorBase(QuantizedTensorBase):
return self._rowwise_data.size(*args, **kwargs)
return self._columnwise_data.size(*args, **kwargs)
def view(self, shape: torch.Size):
# pylint: disable=missing-function-docstring
# Return input tensor if view not needed
cur_shape = self.size()
if shape is None or shape == cur_shape:
return self
# Canonicalize shape
if not isinstance(shape, Iterable):
shape = [shape]
elif len(shape) == 1 and isinstance(shape[0], Iterable):
shape = shape[0]
if -1 in shape:
shape = list(shape)
d_inferred = -math.prod(cur_shape) // math.prod(shape)
for i, d in enumerate(shape):
if d == -1:
shape[i] = d_inferred
break
if shape[-1] != cur_shape[-1]:
raise RuntimeError(
"MXFP8Tensor does not support reshaping inner dimension "
f"(attempted to reshape dims={tuple(cur_shape)} to {tuple(shape)})"
)
# Construct new tensor
cur_rowwise_data = self._rowwise_data
cur_columnwise_data = self._columnwise_data
new_rowwise_data = None
new_columnwise_data = None
if cur_rowwise_data is not None:
new_rowwise_data = cur_rowwise_data.view(*shape)
if cur_columnwise_data is not None:
new_columnwise_data = cur_columnwise_data.view(*shape)
return MXFP8TensorBase(
rowwise_data=new_rowwise_data,
rowwise_scale_inv=self._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=self._columnwise_scale_inv,
fp8_dtype=self._fp8_dtype,
quantizer=self._quantizer,
)
def __repr__(self):
data_rowwise = self.dequantize()
......
......@@ -437,8 +437,7 @@ class _ViewFunc(torch.autograd.Function):
if tensor._rowwise_data is not None:
new_rowwise_data = tensor._rowwise_data.view(*shape)
if tensor._columnwise_data is not None:
columnwise_shape = [shape[-1]] + list(shape[:-1])
new_columnwise_data = tensor._columnwise_data.view(columnwise_shape)
new_columnwise_data = tensor._columnwise_data.view(*shape)
return MXFP8Tensor(
shape,
tensor.dtype,
......@@ -462,7 +461,7 @@ class _ViewFunc(torch.autograd.Function):
grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None
)
if grad._columnwise_data is not None:
new_columnwise_data = grad._columnwise_data.view(ctx.shape[-1], -1)
new_columnwise_data = grad._columnwise_data.view(*ctx.shape)
else:
new_columnwise_data = None
dgrad = MXFP8Tensor(
......@@ -523,8 +522,7 @@ class _ReshapeFunc(torch.autograd.Function):
if tensor._rowwise_data is not None:
new_rowwise_data = tensor._rowwise_data.reshape(*shape)
if tensor._columnwise_data is not None:
columnwise_shape = [shape[-1]] + list(shape[:-1])
new_columnwise_data = tensor._columnwise_data.view(columnwise_shape)
new_columnwise_data = tensor._columnwise_data.view(*shape)
return MXFP8Tensor(
shape,
......@@ -550,8 +548,7 @@ class _ReshapeFunc(torch.autograd.Function):
if grad._rowwise_data is not None:
new_rowwise_data = grad._rowwise_data.view(*ctx.shape)
if grad._columnwise_data is not None:
columnwise_shape = [ctx.shape[-1]] + list(ctx.shape[:-1])
new_columnwise_data = grad._columnwise_data.view(columnwise_shape)
new_columnwise_data = grad._columnwise_data.view(*ctx.shape)
dgrad = MXFP8Tensor(
ctx.shape,
grad.dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment