Unverified Commit 5b89f1ad authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug dtype casting in operation-based API (#1202)



* Handle Float8Tensor when casting module dtype

Keep data in Float8Tensor and only change nominal dtype. Monkey-patch PyTorch module casting functions to handle Float8Tensor. Add tests.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Respect autocast dtype in linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Suppress linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Suppress linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak comments

Review suggestion from @ptrendx
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent e762592e
......@@ -89,7 +89,7 @@ class TestFloat8Tensor:
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_fp8 = x_fp8.from_float8().cpu()
x_fp8 = x_fp8.dequantize().cpu()
# Check results
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
......@@ -144,7 +144,7 @@ class TestFloat8Tensor:
fp8_meta=fp8_meta,
fp8_meta_index=fp8_meta_index,
)
x_ref = x_fp8.from_float8()
x_ref = x_fp8.dequantize()
assert list(x_fp8.size()) == dims, "Incorrect dims"
assert x_fp8.dtype == dtype, "Incorrect nominal dtype"
assert x_fp8.is_cuda, "Incorrect device"
......@@ -194,8 +194,8 @@ class TestFloat8Tensor:
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
y_ref = y_fp8.from_float8()
x_ref = x_fp8.dequantize()
y_ref = y_fp8.dequantize()
# Exact operations
torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0)
......@@ -237,23 +237,23 @@ class TestFloat8Tensor:
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
y_ref = y_fp8.from_float8()
x_ref = x_fp8.dequantize()
y_ref = y_fp8.dequantize()
# In-place operations
tols = _tols[fp8_dtype]
x_fp8 += y_ref
x_ref += y_ref
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_ref = x_fp8.dequantize()
x_fp8 -= y_fp8
x_ref -= y_fp8
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_ref = x_fp8.dequantize()
x_fp8 *= 2
x_ref *= 2
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_ref = x_fp8.dequantize()
# Make sure we are not trivially passing tests
x_ref += 123
......@@ -278,7 +278,7 @@ class TestFloat8Tensor:
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x = x_fp8.from_float8()
x = x_fp8.dequantize()
# Perform transpose
x_fp8_t = x_fp8.transpose_2d()
......@@ -296,7 +296,7 @@ class TestFloat8Tensor:
# Caching test
assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching."
x_fp8 += 0.5
x = x_fp8.from_float8()
x = x_fp8.dequantize()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True))
x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols)
......@@ -305,7 +305,7 @@ class TestFloat8Tensor:
# Inplace update test
x_fp8 += 0.5
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
x = x_fp8.from_float8()
x = x_fp8.dequantize()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose)
x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols)
......@@ -326,7 +326,7 @@ class TestFloat8Tensor:
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
x_ref = x_fp8.dequantize()
# Serialize tensor
byte_stream = io.BytesIO()
......@@ -351,3 +351,47 @@ class TestFloat8Tensor:
x_fp8._scale_inv.zero_()
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols)
def test_set_data(self):
"""Test directly setting .data attr"""
# Initialize Float8Tensor
x0 = torch.zeros(4, dtype=torch.float32)
x = Float8Tensor.to_float8(x0)
assert isinstance(x, Float8Tensor)
assert x0.size() == x.size() == x._data.size()
assert x.dtype == torch.float32
assert x.is_cuda and x._data.is_cuda
y = x.dequantize()
assert not isinstance(y, Float8Tensor)
assert x.size() == y.size()
assert x.dtype == y.dtype
assert x.device == y.device
# Set data to plain tensor
x0 = torch.zeros((3, 2), dtype=torch.float16, device=x.device)
x.data = x0
assert isinstance(x, Float8Tensor)
assert x0.size() == x.size() == x._data.size()
assert x0.dtype == x.dtype
assert x0.device == x.device == x._data.device
y = x.dequantize()
assert not isinstance(y, Float8Tensor)
assert x.size() == y.size()
assert x.dtype == y.dtype
assert x.device == y.device
# Set data to Float8Tensor
x0 = Float8Tensor.to_float8(torch.zeros((4, 3, 1), dtype=torch.float32))
x.data = x0
assert isinstance(x, Float8Tensor)
assert x0.size() == x.size() == x._data.size()
assert x0.dtype == x.dtype
assert x0.device == x.device == x._data.device
assert x0._data is x._data
assert x0._scale_inv is x._scale_inv
y = x.dequantize()
assert not isinstance(y, Float8Tensor)
assert x.size() == y.size()
assert x.dtype == y.dtype
assert x.device == y.device
......@@ -307,6 +307,128 @@ class TestFuser:
torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref))
torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref))
@pytest.mark.parametrize("init_dtype", _dtypes)
@pytest.mark.parametrize("final_dtype", _dtypes)
@pytest.mark.parametrize("fp8_weight", (False, True))
def test_dtype_cast(
self,
*,
size: int = 16,
init_dtype: torch.dtype,
final_dtype: torch.dtype,
device: torch.device = "cuda",
fp8_weight: bool,
) -> None:
"""Check dtype cast functions"""
# Skip invalid configurations
if fp8_weight:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data
dtype = torch.float32
if torch.float16 in (init_dtype, final_dtype):
dtype = torch.float16
if torch.bfloat16 in (init_dtype, final_dtype):
dtype = torch.bfloat16
w_ref, w_test = make_reference_and_test_tensors(
(size, size),
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_weight,
)
# Construct operation
with te.fp8_model_init(enabled=fp8_weight):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
# Cast operation dtype
if final_dtype == torch.float32:
op.float()
elif final_dtype == torch.float16:
op.half()
elif final_dtype == torch.bfloat16:
op.bfloat16()
# Check weights
assert isinstance(op.weight, Float8Tensor) == fp8_weight
assert op.weight.dtype == final_dtype
w_test = op.weight.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0)
# Check forward and backward pass
x = torch.zeros(
(size, size),
dtype=init_dtype,
device=device,
requires_grad=True,
)
y = op(x)
y.backward(torch.zeros_like(y))
assert y.dtype == final_dtype
assert x.grad.dtype == init_dtype
assert op.weight.grad.dtype == final_dtype
@pytest.mark.parametrize("model_dtype", _dtypes)
@pytest.mark.parametrize("autocast_dtype", _dtypes)
@pytest.mark.parametrize("fp8_compute", (False, True))
def test_pyt_autocast(
self,
*,
size: int = 16,
model_dtype: torch.dtype,
autocast_dtype: torch.dtype,
device: torch.device = "cuda",
fp8_weight: bool = False,
fp8_compute: bool,
) -> None:
"""Test with PyTorch autocast"""
device = torch.device(device)
# Skip invalid configurations
if fp8_weight or fp8_compute:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Construct operation
with te.fp8_model_init(enabled=fp8_weight):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype)
# Check forward and backward pass
x = torch.zeros(
(size, size),
dtype=model_dtype,
device=device,
requires_grad=True,
)
with te.fp8_autocast(enabled=fp8_compute):
with torch.autocast(device_type=device.type, dtype=autocast_dtype):
y = op(x)
y.backward(torch.zeros_like(y))
assert y.dtype == autocast_dtype
assert x.grad.dtype == model_dtype
assert op.weight.grad.dtype == model_dtype
# Check forward and backward pass (swapped context order)
if fp8_compute:
x.grad = None
op.weight.grad = None
with torch.autocast(device_type=device.type, dtype=autocast_dtype):
with te.fp8_autocast(enabled=fp8_compute):
y = op(x)
y.backward(torch.zeros_like(y))
assert y.dtype == autocast_dtype
assert x.grad.dtype == model_dtype
assert op.weight.grad.dtype == model_dtype
class TestBasicOps:
"""Tests for individual operations"""
......
......@@ -119,7 +119,6 @@ class BasicLinear(BasicOperation):
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
self.dtype: torch.dtype = canonicalize_dtype(dtype)
# Tensor parallel configuration
self.tensor_parallel_mode: Optional[str]
......@@ -278,7 +277,8 @@ class BasicLinear(BasicOperation):
weight = self.weight
if weight.device.type != "cuda" or is_float8_tensor(weight):
weight = torch.empty_like(weight, device=self.device)
weight = weight.to(device=self.device, dtype=self.dtype)
else:
weight = weight.to(device=self.device)
# Initialize values
init_context = contextlib.nullcontext
......@@ -1082,12 +1082,17 @@ class BasicLinear(BasicOperation):
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0:
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output")
# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
# Linear forward
output, x_local, _ = BasicLinear._functional_forward(
input=input_,
weight=self.weight,
device=self.device,
dtype=self.dtype,
dtype=dtype,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
sequence_parallel=self.sequence_parallel,
......@@ -1103,6 +1108,7 @@ class BasicLinear(BasicOperation):
ctx.weight_fp8_meta = weight_fp8_meta
ctx.grad_output_fp8_meta = grad_output_fp8_meta
ctx.grad_input_fp8_meta = grad_input_fp8_meta
ctx.dtype = dtype
ctx.input_dims = input_.size()
ctx.input_requires_grad = input_.requires_grad
ctx.weight_requires_grad = self.weight.requires_grad
......@@ -1143,7 +1149,7 @@ class BasicLinear(BasicOperation):
input_requires_grad=ctx.input_requires_grad,
weight_requires_grad=ctx.weight_requires_grad,
device=self.device,
dtype=self.dtype,
dtype=ctx.dtype,
grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad,
tensor_parallel_mode=self.tensor_parallel_mode,
......
......@@ -62,9 +62,6 @@ class Bias(BasicOperation):
device = canonicalize_device(None)
self.device: torch.device = device
# Bias tensor datatype
self.dtype: torch.dtype = canonicalize_dtype(dtype)
# Tensor parallel configuration
tensor_parallel_size = 1
local_size = size
......@@ -88,7 +85,7 @@ class Bias(BasicOperation):
bias = torch.empty(
local_size,
device="meta",
dtype=dtype,
dtype=canonicalize_dtype(dtype),
)
bias = torch.nn.Parameter(bias)
self.bias: torch.nn.Parameter
......@@ -103,7 +100,8 @@ class Bias(BasicOperation):
bias = self.bias
if bias.device.type != "cuda":
bias = torch.empty_like(bias, device=self.device)
bias = bias.to(device=self.device, dtype=self.dtype)
else:
bias = bias.to(device=self.device)
# Initialize values
bias.zero_()
......
......@@ -78,7 +78,7 @@ class BackwardLinearAdd(FusedOperation):
input_requires_grad=linear_op_ctx.input_requires_grad,
weight_requires_grad=linear_op_ctx.weight_requires_grad,
device=linear_op.device,
dtype=linear_op.dtype,
dtype=grad_input.dtype,
grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad,
grad_input=grad_input,
......
......@@ -104,13 +104,18 @@ class ForwardLinearBiasActivation(FusedOperation):
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0:
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output")
# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
# Linear forward
output, x_local, _ = BasicLinear._functional_forward(
input=input_,
weight=linear_op.weight,
bias=bias,
device=linear_op.device,
dtype=linear_op.dtype,
dtype=dtype,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
tensor_parallel_group=linear_op.tensor_parallel_group,
sequence_parallel=linear_op.sequence_parallel,
......@@ -126,6 +131,7 @@ class ForwardLinearBiasActivation(FusedOperation):
linear_op_ctx.weight_fp8_meta = weight_fp8_meta
linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta
linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta
linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
......@@ -167,7 +173,7 @@ def fuse_forward_linear_bias_activation(
# Row tensor-parallelism requires communication after the
# GEMM
continue
if op1.dtype not in (torch.float16, torch.bfloat16):
if op1.weight.dtype not in (torch.float16, torch.bfloat16):
# cuBLAS only supports fused GEMM+bias+activation with
# FP16 and BF16 output
continue
......
......@@ -95,6 +95,11 @@ class ForwardLinearBiasAdd(FusedOperation):
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0:
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output")
# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
# Linear forward
output = basic_op_extra_inputs[self._op_idxs["add"]][0]
output, x_local, _ = BasicLinear._functional_forward(
......@@ -102,7 +107,6 @@ class ForwardLinearBiasAdd(FusedOperation):
weight=linear_op.weight,
bias=bias,
device=linear_op.device,
dtype=linear_op.dtype,
out=output,
accumulate_into_out=True,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
......@@ -120,6 +124,7 @@ class ForwardLinearBiasAdd(FusedOperation):
linear_op_ctx.weight_fp8_meta = weight_fp8_meta
linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta
linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta
linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
......
......@@ -4,5 +4,44 @@
"""Custom tensor classes"""
import torch
from .float8_tensor import Float8Tensor
from .quantized_tensor import QuantizedTensor
__all__ = ["Float8Tensor", "QuantizedTensor"]
def _make_module_cast_func(dtype):
"""Make module cast function that can handle QuantizedTensor"""
cast_func_name = {
torch.float32: "float",
torch.float16: "half",
torch.bfloat16: "bfloat16",
}[dtype]
def tensor_cast_func(tensor: torch.Tensor) -> torch.Tensor:
"""Cast tensor dtype"""
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data,
fp8_attrs=tensor._fp8_attrs,
dtype=dtype,
requires_grad=tensor.requires_grad,
)
if tensor.is_floating_point():
return getattr(tensor, cast_func_name)()
return tensor
def module_cast_func(self: torch.nn.Module) -> torch.nn.Module:
"""Cast module dtype"""
return self._apply(tensor_cast_func)
return module_cast_func
# Monkey-patch module cast functions to handle QuantizedTensor
torch.nn.Module.float = _make_module_cast_func(torch.float32)
torch.nn.Module.half = _make_module_cast_func(torch.float16)
torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16)
......@@ -346,6 +346,7 @@ class Float8Tensor(QuantizedTensor):
fp8_dtype: TE_DType = TE_DType.kFloat8E4M3,
fp8_scale_inv: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32,
requires_grad: bool = False,
data_transpose: Optional[torch.Tensor] = None,
):
......@@ -367,7 +368,7 @@ class Float8Tensor(QuantizedTensor):
storage_offset=data.storage_offset(),
dtype=dtype,
layout=data.layout,
requires_grad=data.requires_grad,
requires_grad=requires_grad,
device=data.device,
)
self._data: torch.Tensor = data
......@@ -947,14 +948,81 @@ class Float8Tensor(QuantizedTensor):
"""Get tensor data property"""
return super().data
@torch.no_grad()
def _set_data(self, tensor: torch.Tensor) -> None:
"""Set tensor data property
Cast tensor to FP8 and store in FP8 buffer.
Just takes FP8 data if setting from a Float8Tensor. Otherwise
casts to FP8.
"""
with torch.no_grad():
self.copy_(tensor)
# Tensor device
new_device = tensor.device if tensor.is_cuda else self.device
# Check whether grad is required
if self.requires_grad != tensor.requires_grad:
self.requires_grad_(requires_grad=tensor.requires_grad)
# Just copy FP8 data if other tensor is Float8Tensor
if isinstance(tensor, Float8Tensor):
if ( # pylint: disable=too-many-boolean-expressions
self.size() != tensor.size()
or self.stride() != tensor.stride()
or self.storage_offset() != tensor.storage_offset()
or self.dtype != tensor.dtype
or self.layout != tensor.layout
or not devices_match(self.device, new_device)
):
dummy_tensor = torch.Tensor._make_wrapper_subclass(
Float8Tensor,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
device=new_device,
)
super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor)
self._data = tensor._data
self._fp8_attrs = tensor._fp8_attrs
return
# Reallocate FP8 data if needed
if (
self.size() != tensor.size()
or self.stride() != tensor.stride()
or self.dtype != tensor.dtype
or self.layout != tensor.layout
or not devices_match(self.device, new_device)
):
self._data = torch.empty_like(
tensor,
dtype=torch.uint8,
device=new_device,
)
dummy_tensor = torch.Tensor._make_wrapper_subclass(
Float8Tensor,
self._data.size(),
strides=self._data.stride(),
storage_offset=self._data.storage_offset(),
dtype=tensor.dtype,
layout=self._data.layout,
requires_grad=tensor.requires_grad,
device=self._data.device,
)
super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor)
if self._transpose is not None:
self._transpose = torch.empty(
(self._data.size(-1), self._data.numel() // self._data.size(-1)),
dtype=torch.uint8,
device=self.device,
)
self._transpose_invalid = True
# Copy values from other tensor
self.quantize_(tensor)
# Cast to FP8 when setting Float8Tensor.data
data = property(_get_data, _set_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment