Unverified Commit 2d57db8b authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Proxy class for low-precision tensor (#1127)



* Add base class for tensor proxies
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move tensor detaching logic to tensor proxy base class
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use Python wrappers to PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Include transpose caching logic in proxy encode function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug dimension mismatch with amax history
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move dequantize logic to proxy_decode func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename to "QuantizedTensor"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename "proxy_detach" to "detach"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Include transpose cache in detach and clone funcs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update FP8 workspaces with QuantizedTensor functions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move logic for FP8 transpose cache in FP8 workspaces to base class
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove cast-transpose logic from linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unnecessary args for Float8Tensor when using FP8 attr dict
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove __torch_function__ to QuantizedTensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update tests/pytorch/test_float8tensor.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Debug FP8 transpose test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug cast functions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 40dda924
......@@ -31,7 +31,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install pip -y
pip install torch
pip install torch numpy
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_pytorch_lint/test.sh
......
......@@ -293,7 +293,7 @@ class TestFloat8Tensor:
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8_t, x, **tols)
# Caching test.
# Caching test
assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching."
x_fp8 += 0.5
x = x_fp8.from_float8()
......@@ -302,14 +302,13 @@ class TestFloat8Tensor:
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
# Inplace update test.
# Inplace update test
x_fp8 += 0.5
assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly."
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
x = x_fp8.from_float8()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True))
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose)
x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
def test_serialization(
self,
......
......@@ -88,10 +88,7 @@ def make_reference_and_test_tensors(
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
test = Float8Tensor.to_float8(test)
test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1)
test._transpose = test._transpose.contiguous()
test._transpose_invalid = False
test = Float8Tensor.to_float8(test, with_transpose_cache=True)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
ref.copy_(test)
......
......@@ -68,13 +68,13 @@ def canonicalize_fp8_scales(
# Force offsets to be the same if needed
if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset:
if scale_offset != 0:
scale = scale[scale_offset]
scale = scale[scale_offset:]
scale_offset = 0
if amax_offset != 0:
amax = amax[0][amax_offset]
amax = amax[:, amax_offset:]
amax_offset = 0
if scale_inv_offset != 0:
scale_inv = scale_inv[scale_inv_offset]
scale_inv = scale_inv[scale_inv_offset:]
scale_inv_offset = 0
# Pack tensors and offsets into dicts
......
......@@ -8,7 +8,7 @@ from typing import Optional, Union
import torch
import transformer_engine_torch as tex
from ._common import canonicalize_fp8_scales, empty_tensor
from ._common import canonicalize_fp8_scales
__all__ = ["cast_to_fp8", "cast_from_fp8"]
......@@ -81,8 +81,7 @@ def cast_from_fp8(
# Construct empty tensors if needed
if scale_inv is None:
scale_inv = empty_tensor()
scale_inv_offset = 0
raise ValueError("Did not provide either `scale_inv` or `fp8_meta_tensor`")
# Launch FP8 cast kernel
return torch.ops.tex_ts.cast_from_fp8_ts(
......
This diff is collapsed.
......@@ -865,11 +865,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# If primary weights are in fp8, wrap the parameter as Float8Tensor
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
if self.primary_weights_in_fp8 and fp8_meta_index is not None:
dummy_amax = torch.empty(
(1, 1),
dtype=torch.float32,
device=param.device,
) # Dummy buffer to avoid overwriting amax history
param = Float8Tensor.to_float8(
param,
fp8_meta=self.fp8_meta,
fp8_meta_index=fp8_meta_index,
amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history.
amax=dummy_amax,
with_transpose_cache=torch.is_grad_enabled(),
)
# Redo parameter wrap in case we broke it above
......@@ -891,7 +897,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
cache_name: Optional[str] = None,
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
with_transpose: bool = False,
fsdp_group: dist_group_type = None,
) -> Float8Tensor:
"""Get FP8 workspace buffer and maybe update its values
......@@ -917,27 +922,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
skip_update_flag: torch.Tensor, optional
GPU flag to skip updating the workspace. Take precedence
over `update_workspace` if provided.
with_transpose: bool, default = `False`
Whether to initialize cached transpose in workspace.
fsdp_group: bool, default = None
FSDP process group that the weights are distributed over.
"""
# Construct workspace if needed
# Try getting workspace from cache
out = None
if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None)
# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights.
if (
not isinstance(out, Float8Tensor)
out is not None
and not isinstance(out, Float8Tensor)
and fsdp_group is not None
and out._data.shape != tensor.data.shape
):
_fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)
# Construct workspace if needed
if out is None:
# FP8 data
if tensor is None or fp8_meta_forward is None or fp8_meta_index is None:
raise ValueError(
"tensor, fp8_meta_forward, and fp8_meta_index kwargs "
......@@ -947,16 +955,38 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["recipe"],
fprop_tensor=fp8_meta_forward,
)
data = torch.empty_like(tensor, dtype=torch.uint8)
scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device)
# Transpose cache
with_transpose_cache = torch.is_grad_enabled()
if (
not with_transpose_cache
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose_cache = True
data_transpose = None
if with_transpose_cache:
data_transpose = torch.empty(
(tensor.size(-1), tensor.numel() // tensor.size(-1)),
dtype=torch.uint8,
device=tensor.device,
)
# Construct FP8 tensor
out = Float8Tensor(
data=torch.empty_like(tensor, dtype=torch.uint8),
data=data,
fp8_meta=self.fp8_meta,
fp8_meta_forward=fp8_meta_forward,
fp8_meta_index=fp8_meta_index,
fp8_dtype=fp8_dtype,
fp8_scale_inv=scale_inv,
dtype=tensor.dtype,
data_transpose=data_transpose,
)
# Update cache
if cache_name is not None:
self._fp8_workspaces[cache_name] = out
update_workspace = True
......@@ -968,33 +998,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if update_workspace:
if tensor is None:
raise ValueError("tensor kwarg must be provided to update FP8 workspace")
if with_transpose:
out.cast_transpose_(
tensor,
noop_flag=skip_update_flag,
)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=out._fp8_meta_forward,
)
fp8_meta = out._fp8_meta[fp8_meta_key]
fp8_meta_index = out._fp8_meta_index
cast_to_fp8(
tensor,
fp8_meta,
fp8_meta_index,
out._fp8_dtype,
out=out._data,
)
if is_in_onnx_export_mode():
# ONNX export expects FP8 scales can be
# represented with constant ops. However, copying
# into a buffer involves an expand op for array
# broadcasting. We work around this by filling the
# buffer instead.
out._scale_inv.fill_(fp8_meta.scale_inv[fp8_meta_index].item())
# ONNX export does not support fused cast-transpose
# kernel and requires that FP8 scales can be
# represented with constant ops.
transpose_cache = out._transpose
out._transpose = None
out.quantize_(tensor)
out._scale_inv.fill_(out._scale_inv.item())
out._transpose = transpose_cache
else:
out._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index])
out.quantize_(tensor, noop_flag=skip_update_flag)
return out
......
......@@ -28,8 +28,6 @@ from ..utils import (
from ..distributed import (
set_tensor_model_parallel_attributes,
get_distributed_world_size,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
)
from ..cpp_extensions import (
cast_to_fp8,
......@@ -760,22 +758,12 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors_fp8 = [None] * self.num_gemms
if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
for i in range(self.num_gemms):
if isinstance(weight_tensors[i], Float8Tensor):
# Fill transpose cache in FP8 tensor if needed
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch or skip_fp8_weight_update is not None
)
if update_transpose_cache:
# Make sure transpose cache is valid, if present
# Note: Transpose cache may have been invalidated
# externally, e.g. by optimizer.
if weight_tensors[i]._transpose is not None:
weight_tensors[i].transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
......@@ -790,7 +778,6 @@ class GroupedLinear(TransformerEngineBaseModule):
cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
)
from ..cpu_offload import CPUOffloadEnabled
......
......@@ -36,8 +36,6 @@ from ..distributed import (
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
......@@ -47,6 +45,7 @@ from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
from ..tensor import QuantizedTensor
__all__ = ["LayerNormLinear"]
......@@ -1151,14 +1150,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, Float8Tensor) for w in unfused_weights):
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting Float8Tensor into multiple params is not supported"
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
unfused_weights = [w.from_float8() for w in unfused_weights]
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = _noop_cat(
......@@ -1170,32 +1169,18 @@ class LayerNormLinear(TransformerEngineBaseModule):
# Initialize FP8 weights if needed
weight_fp8 = None
if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
if isinstance(weight_tensor, Float8Tensor):
# Fill transpose cache in FP8 tensor if needed
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch or skip_fp8_weight_update is not None
)
if update_transpose_cache:
# Make sure transpose cache is valid, if present
# Note: Transpose cache may have been invalidated
# externally, e.g. by optimizer.
if weight_tensor._transpose is not None:
weight_tensor.transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
)
else:
# FP8 cast to workspace buffer
update_workspace = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
update_workspace = is_first_microbatch is None or is_first_microbatch
weight_fp8 = self.get_fp8_workspace(
tensor=weight_tensor,
fp8_meta_forward=True,
......@@ -1203,7 +1188,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
cache_name=(None if is_first_microbatch is None else "weight"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
)
from ..cpu_offload import CPUOffloadEnabled
......
......@@ -42,8 +42,6 @@ from ..distributed import (
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
use_reentrant_activation_recompute,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
......@@ -1485,19 +1483,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_fp8 = None
if self.fp8:
update_workspace = is_first_microbatch is None or is_first_microbatch
with_transpose = torch.is_grad_enabled()
if (
is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch or skip_fp8_weight_update is not None
)
if isinstance(fc1_weight, Float8Tensor):
if update_transpose_cache:
if fc1_weight._transpose is not None:
fc1_weight.transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
......@@ -1513,10 +1500,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
cache_name=cache_name,
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
)
if isinstance(fc2_weight, Float8Tensor):
if update_transpose_cache:
if fc2_weight._transpose is not None:
fc2_weight.transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
......@@ -1532,7 +1518,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
cache_name=cache_name,
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
)
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
......
......@@ -33,8 +33,6 @@ from ..distributed import (
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
......@@ -49,6 +47,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
from ..tensor import QuantizedTensor
__all__ = ["Linear"]
......@@ -938,19 +937,19 @@ class Linear(TransformerEngineBaseModule):
with self.prepare_forward(
inp,
is_first_microbatch,
allow_non_contiguous=isinstance(inp, Float8Tensor),
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, Float8Tensor) for w in unfused_weights):
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting Float8Tensor into multiple params is not supported"
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
unfused_weights = [w.from_float8() for w in unfused_weights]
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = _noop_cat(
......@@ -962,21 +961,11 @@ class Linear(TransformerEngineBaseModule):
# Initialize FP8 weights if needed
weight_fp8 = None
if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
if isinstance(weight_tensor, Float8Tensor):
# Fill transpose cache in FP8 tensor if needed
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch or skip_fp8_weight_update is not None
)
if update_transpose_cache:
# Make sure transpose cache is valid, if present
# Note: Transpose cache may have been invalidated
# externally, e.g. by optimizer.
if weight_tensor._transpose is not None:
weight_tensor.transpose_2d(
fill_cache=True,
noop_flag=skip_fp8_weight_update,
......@@ -991,7 +980,6 @@ class Linear(TransformerEngineBaseModule):
cache_name=(None if is_first_microbatch is None else "weight"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
fsdp_group=self.fsdp_group,
)
......
......@@ -9,54 +9,12 @@ from typing import Any, Iterable, Optional
import torch
from transformer_engine.pytorch.float8_tensor import Float8Tensor
def canonicalize_device(device: Optional[torch.device | str]) -> torch.device:
"""Canonicalize PyTorch device
If `None`, then returns the default CUDA device.
"""
if device is None:
# Use default CUDA device
device = torch.get_default_device()
if device.type != "cuda":
device = torch.device("cuda", torch.cuda.current_device())
elif not isinstance(device, torch.device):
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device("cuda", torch.cuda.current_device())
return device
def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
"""Canonicalize PyTorch datatype
If `None`, then returns the default PyTorch datatype.
"""
if dtype is None:
# Use default dtype
dtype = torch.get_default_dtype()
return dtype
def devices_match(device1: torch.device, device2: torch.device) -> bool:
"""Whether two devices are the same"""
device1 = torch.device(device1)
device2 = torch.device(device2)
if device1.type != device2.type:
return False
if device1.type == "cuda":
index1 = device1.index
index2 = device2.index
if index1 is None:
index1 = torch.cuda.current_device()
if index2 is None:
index2 = torch.cuda.current_device()
return index1 == index2
return device1 == device2
from ..tensor import Float8Tensor
from ..utils import (
canonicalize_device, # pylint: disable=unused-import
canonicalize_dtype, # pylint: disable=unused-import
devices_match, # pylint: disable=unused-import
)
def is_float8_tensor(tensor: Any) -> bool:
......@@ -92,7 +50,13 @@ def convert_tensor(
# Convert FP8 tensor
if is_float8_tensor(tensor):
data = tensor._data.to(device=device, memory_format=memory_format)
data = tensor._data
if not devices_match(device, data.device):
data = data.to(device=device)
if memory_format != torch.preserve_format and not data.is_contiguous(
memory_format=memory_format
):
data = data.contiguous(memory_format=memory_format)
return Float8Tensor.make_like(
tensor,
data=data,
......
......@@ -9,11 +9,8 @@ from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import is_float8_tensor
from ...tensor import QuantizedTensor
from ..op import BasicOperation, OperationContext
class AllReduce(BasicOperation):
......@@ -54,8 +51,8 @@ class AllReduce(BasicOperation):
# Perform all-reduce
x = input_
if is_float8_tensor(x):
x = x.from_float8()
if isinstance(x, QuantizedTensor):
x = x.dequantize()
x = x.contiguous()
torch.distributed.all_reduce(x, group=self.process_group)
return x
......
......@@ -289,10 +289,18 @@ class BasicLinear(BasicOperation):
# Cast to FP8 if needed
if self._with_fp8_parameters:
dummy_amax = torch.empty(
(1, 1),
dtype=torch.float32,
device=self.device,
) # Dummy buffer to avoid overwriting amax history
weight = Float8Tensor.to_float8(
weight,
fp8_meta=self.get_fp8_meta("param"),
fp8_meta_forward=True,
fp8_meta_index=0,
amax=dummy_amax,
with_transpose_cache=torch.is_grad_enabled(),
)
# Save updated parameter
......@@ -467,25 +475,19 @@ class BasicLinear(BasicOperation):
input_fp8_meta["recipe"],
fprop_tensor=True,
)
x_fp8 = Float8Tensor(
data=torch.empty_like(x_local, dtype=torch.uint8),
with_transpose_cache = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_transpose_cache = False
x_local = Float8Tensor.to_float8(
x_local,
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
with_transpose_cache=with_transpose_cache,
)
with_cast_transpose = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_cast_transpose = False
if with_cast_transpose:
x_fp8.cast_transpose_(x_local)
else:
x_fp8.copy_(x_local)
x_local = x_fp8
elif not with_fp8_compute and is_float8_tensor(x_local):
x_local = x_local.from_float8()
x_local = x_local.dequantize()
x = x_local
x_async = None
if tensor_parallel_mode == "column" and sequence_parallel:
......@@ -510,11 +512,12 @@ class BasicLinear(BasicOperation):
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
elif not with_fp8_compute and is_float8_tensor(w):
w = w.from_float8()
w = w.dequantize()
# Check bias tensor
b = None
......@@ -815,25 +818,19 @@ class BasicLinear(BasicOperation):
grad_output_fp8_meta["recipe"],
fprop_tensor=False,
)
dy_fp8 = Float8Tensor(
data=torch.empty_like(dy, dtype=torch.uint8),
with_transpose_cache = weight_requires_grad
if tensor_parallel_mode == "row" and sequence_parallel:
with_transpose_cache = False
dy = Float8Tensor.to_float8(
dy,
fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
with_transpose_cache=with_transpose_cache,
)
with_cast_transpose = weight_requires_grad
if tensor_parallel_mode == "row" and sequence_parallel:
with_cast_transpose = False
if with_cast_transpose:
dy_fp8.cast_transpose_(dy)
else:
dy_fp8.copy_(dy)
dy = dy_fp8
elif not with_fp8_compute and is_float8_tensor(dy):
dy = dy.from_float8()
dy = dy.dequantize()
if tensor_parallel_mode == "row" and sequence_parallel:
dy, dy_async = gather_along_first_dim(
dy,
......@@ -853,26 +850,24 @@ class BasicLinear(BasicOperation):
device=device,
dtype=dtype,
)
x_is_sharded = tensor_parallel_mode == "column" and sequence_parallel
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
x_fp8 = Float8Tensor(
data=torch.empty_like(x_local, dtype=torch.uint8),
x_local = Float8Tensor.to_float8(
x_local,
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
with_transpose_cache=(not x_is_sharded),
)
x_fp8.cast_transpose_(x_local)
x_local = x_fp8
elif not with_fp8_compute and is_float8_tensor(x_local):
x_local = x_local.from_float8()
x = x_local
if tensor_parallel_mode == "column" and sequence_parallel:
if x_is_sharded:
x, x_async = gather_along_first_dim(
x_local,
tensor_parallel_group,
......@@ -898,19 +893,16 @@ class BasicLinear(BasicOperation):
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w_fp8 = Float8Tensor(
data=torch.empty_like(w, dtype=torch.uint8),
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
with_transpose_cache=True,
)
w_fp8.cast_transpose_(w)
w = w_fp8
elif not with_fp8_compute and is_float8_tensor(w):
w = w.from_float8()
w = w.dequantize()
# Construct grad input tensor
if grad_input is not None:
......
......@@ -9,12 +9,9 @@ from typing import Optional
import torch
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import convert_tensor, is_float8_tensor
from ...tensor import Float8Tensor, QuantizedTensor
from ..op import BasicOperation, OperationContext
from .._common import convert_tensor
class ReduceScatter(BasicOperation):
......@@ -63,8 +60,8 @@ class ReduceScatter(BasicOperation):
# Check input tensor
x = input_
if is_float8_tensor(x):
x = x.from_float8()
if isinstance(x, QuantizedTensor):
x = x.dequantize()
x = x.contiguous()
# Perform reduce-scatter
......@@ -96,7 +93,7 @@ class ReduceScatter(BasicOperation):
# Perform all-gather
dy = convert_tensor(grad_output, memory_format=torch.contiguous_format)
dx = None
if is_float8_tensor(dy):
if isinstance(dy, Float8Tensor):
dx = Float8Tensor.make_like(
dy,
data=torch.empty(
......@@ -111,6 +108,8 @@ class ReduceScatter(BasicOperation):
group=self.process_group,
)
else:
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device)
torch.distributed.all_gather_into_tensor(
dx,
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Custom tensor classes"""
from .float8_tensor import Float8Tensor
from .quantized_tensor import QuantizedTensor
This diff is collapsed.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor with quantized data"""
from __future__ import annotations
from typing import Optional, Tuple
import torch
from torch.utils._pytree import tree_map
class _DequantizeFunc(torch.autograd.Function):
"""Autograd function to convert quantized tensor to standard tensor"""
@staticmethod
def forward(
_ctx: torch.autograd.function.FunctionCtx, # unused
tensor: QuantizedTensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return tensor.dequantize(dtype=dtype)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
return grad, None
class _IdentityFunc(torch.autograd.Function):
"""Autograd function to create quantized tensor with same data"""
@staticmethod
def forward(
_ctx: torch.autograd.function.FunctionCtx, # unused
tensor: QuantizedTensor,
) -> QuantizedTensor:
return tensor.detach()
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> torch.Tensor:
return grad
class QuantizedTensor(torch.Tensor):
"""Abstract base class for tensor with quantized data
This is a proxy class with the interface of a standard PyTorch
tensor, but with data that has been encoded with some quantization
scheme. Derived classes should implement the quantization scheme
by overriding the `quantize_` and `dequantize` functions.
"""
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Convert quantized data to standard PyTorch tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement dequantize function"
)
def quantize_(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Update quantized data in-place"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement quantize_ function"
)
def detach(self) -> QuantizedTensor:
"""Create new quantized tensor with same data
Output tensor must be detached from the current autograd
graph.
"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement detach function"
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
def float(self) -> torch.Tensor:
return _DequantizeFunc.apply(self, torch.float32)
def bfloat16(self) -> torch.Tensor:
return _DequantizeFunc.apply(self, torch.bfloat16)
def half(self) -> torch.Tensor:
return _DequantizeFunc.apply(self, torch.float16)
def cpu(self) -> torch.Tensor:
return _DequantizeFunc.apply(self).cpu()
def expand_as(self, other: torch.Tensor) -> torch.Tensor:
if other is self:
# Note: expand_as is hackily used to create dummy autograd nodes
# and access the backward graph (see
# https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026).
# We hackily add a dummy function to handle this case.
return _IdentityFunc.apply(self)
return super().expand_as(other)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# Detach op
if func == torch.ops.aten.detach.default:
return args[0].detach()
# In-place copy op
if func == torch.ops.aten.copy_.default:
dst = args[0]
src = args[1]
if isinstance(dst, QuantizedTensor):
dst.quantize_(src)
else:
if isinstance(src, QuantizedTensor):
src = src.dequantize()
dst.copy_(src)
return None
# View op
if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views")
def maybe_unwrap(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize(dtype=arg.dtype)
return arg
def maybe_update_inplace(arg, new_arg, schema_arg):
if (
isinstance(arg, QuantizedTensor)
and isinstance(new_arg, torch.Tensor)
and hasattr(schema_arg, "alias_info")
and hasattr(schema_arg.alias_info, "is_write")
and schema_arg.alias_info.is_write
):
arg.quantize_(new_arg)
# In-place op: dequantize, perform op, and quantize
if func._schema.is_mutable:
new_args = tree_map(maybe_unwrap, args)
new_kwargs = tree_map(maybe_unwrap, kwargs)
schema_args = func._schema.arguments
args_len = len(args)
super().__torch_dispatch__(func, types, new_args, new_kwargs)
for arg, new_arg, schema_arg in zip(args, new_args, schema_args):
maybe_update_inplace(arg, new_arg, schema_arg)
for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]):
assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match"
maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg)
return None
# Default op: dequantize and perform op
args = tree_map(maybe_unwrap, args)
if kwargs is not None:
kwargs = tree_map(maybe_unwrap, kwargs)
out = super().__torch_dispatch__(func, types, args, kwargs)
return out
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# Do not force the QuantizedTensor type on the returned tensor
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""Utility functions for Transformer Engine modules"""
from __future__ import annotations
import functools
import math
from typing import Any, Callable, Optional, Tuple
......@@ -251,3 +252,52 @@ def get_cudnn_version() -> Tuple[int, int, int]:
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
def canonicalize_device(device: Optional[torch.device | str]) -> torch.device:
"""Canonicalize PyTorch device
If `None`, then returns the default CUDA device.
"""
if device is None:
# Use default CUDA device
device = torch.get_default_device()
if device.type != "cuda":
device = torch.device("cuda", torch.cuda.current_device())
elif not isinstance(device, torch.device):
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device("cuda", torch.cuda.current_device())
return device
def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
"""Canonicalize PyTorch datatype
If `None`, then returns the default PyTorch datatype.
"""
if dtype is None:
# Use default dtype
dtype = torch.get_default_dtype()
return dtype
def devices_match(device1: torch.device, device2: torch.device) -> bool:
"""Whether two devices are the same"""
device1 = torch.device(device1)
device2 = torch.device(device2)
if device1.type != device2.type:
return False
if device1.type == "cuda":
index1 = device1.index
index2 = device2.index
if index1 == index2:
return True
if index1 is None:
index1 = torch.cuda.current_device()
if index2 is None:
index2 = torch.cuda.current_device()
return index1 == index2
return device1 == device2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment