Unverified Commit 2d57db8b authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Proxy class for low-precision tensor (#1127)



* Add base class for tensor proxies
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move tensor detaching logic to tensor proxy base class
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use Python wrappers to PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Include transpose caching logic in proxy encode function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug dimension mismatch with amax history
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move dequantize logic to proxy_decode func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename to "QuantizedTensor"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename "proxy_detach" to "detach"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Include transpose cache in detach and clone funcs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update FP8 workspaces with QuantizedTensor functions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move logic for FP8 transpose cache in FP8 workspaces to base class
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove cast-transpose logic from linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unnecessary args for Float8Tensor when using FP8 attr dict
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove __torch_function__ to QuantizedTensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update tests/pytorch/test_float8tensor.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Debug FP8 transpose test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug cast functions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 40dda924
...@@ -31,7 +31,7 @@ jobs: ...@@ -31,7 +31,7 @@ jobs:
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install pip -y sudo apt-get install pip -y
pip install torch pip install torch numpy
export PYTHON_ONLY=1 export PYTHON_ONLY=1
export TE_PATH=. export TE_PATH=.
bash ./qa/L0_pytorch_lint/test.sh bash ./qa/L0_pytorch_lint/test.sh
......
...@@ -293,7 +293,7 @@ class TestFloat8Tensor: ...@@ -293,7 +293,7 @@ class TestFloat8Tensor:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8_t, x, **tols) torch.testing.assert_close(x_fp8_t, x, **tols)
# Caching test. # Caching test
assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching." assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching."
x_fp8 += 0.5 x_fp8 += 0.5
x = x_fp8.from_float8() x = x_fp8.from_float8()
...@@ -302,14 +302,13 @@ class TestFloat8Tensor: ...@@ -302,14 +302,13 @@ class TestFloat8Tensor:
torch.testing.assert_close(x_fp8_t, x_t, **tols) torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
# Inplace update test. # Inplace update test
x_fp8 += 0.5 x_fp8 += 0.5
assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly." assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
x = x_fp8.from_float8() x = x_fp8.from_float8()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True)) x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose)
x_t = x.transpose(0, 1) x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols) torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
def test_serialization( def test_serialization(
self, self,
......
...@@ -88,10 +88,7 @@ def make_reference_and_test_tensors( ...@@ -88,10 +88,7 @@ def make_reference_and_test_tensors(
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
test = ref.to(device=test_device, dtype=test_dtype) test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8: if test_is_fp8:
test = Float8Tensor.to_float8(test) test = Float8Tensor.to_float8(test, with_transpose_cache=True)
test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1)
test._transpose = test._transpose.contiguous()
test._transpose_invalid = False
elif test.data_ptr() == ref.data_ptr(): elif test.data_ptr() == ref.data_ptr():
test = test.clone() test = test.clone()
ref.copy_(test) ref.copy_(test)
......
...@@ -68,13 +68,13 @@ def canonicalize_fp8_scales( ...@@ -68,13 +68,13 @@ def canonicalize_fp8_scales(
# Force offsets to be the same if needed # Force offsets to be the same if needed
if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset: if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset:
if scale_offset != 0: if scale_offset != 0:
scale = scale[scale_offset] scale = scale[scale_offset:]
scale_offset = 0 scale_offset = 0
if amax_offset != 0: if amax_offset != 0:
amax = amax[0][amax_offset] amax = amax[:, amax_offset:]
amax_offset = 0 amax_offset = 0
if scale_inv_offset != 0: if scale_inv_offset != 0:
scale_inv = scale_inv[scale_inv_offset] scale_inv = scale_inv[scale_inv_offset:]
scale_inv_offset = 0 scale_inv_offset = 0
# Pack tensors and offsets into dicts # Pack tensors and offsets into dicts
......
...@@ -8,7 +8,7 @@ from typing import Optional, Union ...@@ -8,7 +8,7 @@ from typing import Optional, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ._common import canonicalize_fp8_scales, empty_tensor from ._common import canonicalize_fp8_scales
__all__ = ["cast_to_fp8", "cast_from_fp8"] __all__ = ["cast_to_fp8", "cast_from_fp8"]
...@@ -81,8 +81,7 @@ def cast_from_fp8( ...@@ -81,8 +81,7 @@ def cast_from_fp8(
# Construct empty tensors if needed # Construct empty tensors if needed
if scale_inv is None: if scale_inv is None:
scale_inv = empty_tensor() raise ValueError("Did not provide either `scale_inv` or `fp8_meta_tensor`")
scale_inv_offset = 0
# Launch FP8 cast kernel # Launch FP8 cast kernel
return torch.ops.tex_ts.cast_from_fp8_ts( return torch.ops.tex_ts.cast_from_fp8_ts(
......
This diff is collapsed.
...@@ -865,11 +865,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -865,11 +865,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# If primary weights are in fp8, wrap the parameter as Float8Tensor # If primary weights are in fp8, wrap the parameter as Float8Tensor
fp8_meta_index = self.param_init_meta[name].fp8_meta_index fp8_meta_index = self.param_init_meta[name].fp8_meta_index
if self.primary_weights_in_fp8 and fp8_meta_index is not None: if self.primary_weights_in_fp8 and fp8_meta_index is not None:
dummy_amax = torch.empty(
(1, 1),
dtype=torch.float32,
device=param.device,
) # Dummy buffer to avoid overwriting amax history
param = Float8Tensor.to_float8( param = Float8Tensor.to_float8(
param, param,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
fp8_meta_index=fp8_meta_index, fp8_meta_index=fp8_meta_index,
amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history. amax=dummy_amax,
with_transpose_cache=torch.is_grad_enabled(),
) )
# Redo parameter wrap in case we broke it above # Redo parameter wrap in case we broke it above
...@@ -891,7 +897,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -891,7 +897,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
cache_name: Optional[str] = None, cache_name: Optional[str] = None,
update_workspace: bool = True, update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None, skip_update_flag: Optional[torch.Tensor] = None,
with_transpose: bool = False,
fsdp_group: dist_group_type = None, fsdp_group: dist_group_type = None,
) -> Float8Tensor: ) -> Float8Tensor:
"""Get FP8 workspace buffer and maybe update its values """Get FP8 workspace buffer and maybe update its values
...@@ -917,27 +922,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -917,27 +922,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
skip_update_flag: torch.Tensor, optional skip_update_flag: torch.Tensor, optional
GPU flag to skip updating the workspace. Take precedence GPU flag to skip updating the workspace. Take precedence
over `update_workspace` if provided. over `update_workspace` if provided.
with_transpose: bool, default = `False`
Whether to initialize cached transpose in workspace.
fsdp_group: bool, default = None fsdp_group: bool, default = None
FSDP process group that the weights are distributed over. FSDP process group that the weights are distributed over.
""" """
# Construct workspace if needed # Try getting workspace from cache
out = None out = None
if cache_name is not None: if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None) out = self._fp8_workspaces.get(cache_name, None)
# Gather cached Fp8 workspace if it's distributed # Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights. # for models initialized with Fp8 primary weights.
if ( if (
not isinstance(out, Float8Tensor) out is not None
and not isinstance(out, Float8Tensor)
and fsdp_group is not None and fsdp_group is not None
and out._data.shape != tensor.data.shape and out._data.shape != tensor.data.shape
): ):
_fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)
# Construct workspace if needed
if out is None: if out is None:
# FP8 data
if tensor is None or fp8_meta_forward is None or fp8_meta_index is None: if tensor is None or fp8_meta_forward is None or fp8_meta_index is None:
raise ValueError( raise ValueError(
"tensor, fp8_meta_forward, and fp8_meta_index kwargs " "tensor, fp8_meta_forward, and fp8_meta_index kwargs "
...@@ -947,16 +955,38 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -947,16 +955,38 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["recipe"], self.fp8_meta["recipe"],
fprop_tensor=fp8_meta_forward, fprop_tensor=fp8_meta_forward,
) )
data = torch.empty_like(tensor, dtype=torch.uint8)
scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device) scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device)
# Transpose cache
with_transpose_cache = torch.is_grad_enabled()
if (
not with_transpose_cache
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose_cache = True
data_transpose = None
if with_transpose_cache:
data_transpose = torch.empty(
(tensor.size(-1), tensor.numel() // tensor.size(-1)),
dtype=torch.uint8,
device=tensor.device,
)
# Construct FP8 tensor
out = Float8Tensor( out = Float8Tensor(
data=torch.empty_like(tensor, dtype=torch.uint8), data=data,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
fp8_meta_forward=fp8_meta_forward, fp8_meta_forward=fp8_meta_forward,
fp8_meta_index=fp8_meta_index, fp8_meta_index=fp8_meta_index,
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
fp8_scale_inv=scale_inv, fp8_scale_inv=scale_inv,
dtype=tensor.dtype, dtype=tensor.dtype,
data_transpose=data_transpose,
) )
# Update cache
if cache_name is not None: if cache_name is not None:
self._fp8_workspaces[cache_name] = out self._fp8_workspaces[cache_name] = out
update_workspace = True update_workspace = True
...@@ -968,33 +998,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -968,33 +998,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if update_workspace: if update_workspace:
if tensor is None: if tensor is None:
raise ValueError("tensor kwarg must be provided to update FP8 workspace") raise ValueError("tensor kwarg must be provided to update FP8 workspace")
if with_transpose:
out.cast_transpose_(
tensor,
noop_flag=skip_update_flag,
)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=out._fp8_meta_forward,
)
fp8_meta = out._fp8_meta[fp8_meta_key]
fp8_meta_index = out._fp8_meta_index
cast_to_fp8(
tensor,
fp8_meta,
fp8_meta_index,
out._fp8_dtype,
out=out._data,
)
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
# ONNX export expects FP8 scales can be # ONNX export does not support fused cast-transpose
# represented with constant ops. However, copying # kernel and requires that FP8 scales can be
# into a buffer involves an expand op for array # represented with constant ops.
# broadcasting. We work around this by filling the transpose_cache = out._transpose
# buffer instead. out._transpose = None
out._scale_inv.fill_(fp8_meta.scale_inv[fp8_meta_index].item()) out.quantize_(tensor)
out._scale_inv.fill_(out._scale_inv.item())
out._transpose = transpose_cache
else: else:
out._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index]) out.quantize_(tensor, noop_flag=skip_update_flag)
return out return out
......
...@@ -28,8 +28,6 @@ from ..utils import ( ...@@ -28,8 +28,6 @@ from ..utils import (
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
get_distributed_world_size, get_distributed_world_size,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
) )
from ..cpp_extensions import ( from ..cpp_extensions import (
cast_to_fp8, cast_to_fp8,
...@@ -760,22 +758,12 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -760,22 +758,12 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors_fp8 = [None] * self.num_gemms weight_tensors_fp8 = [None] * self.num_gemms
if self.fp8: if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
for i in range(self.num_gemms): for i in range(self.num_gemms):
if isinstance(weight_tensors[i], Float8Tensor): if isinstance(weight_tensors[i], Float8Tensor):
# Fill transpose cache in FP8 tensor if needed # Make sure transpose cache is valid, if present
update_transpose_cache = with_transpose # Note: Transpose cache may have been invalidated
if update_transpose_cache: # externally, e.g. by optimizer.
update_transpose_cache = ( if weight_tensors[i]._transpose is not None:
is_first_microbatch or skip_fp8_weight_update is not None
)
if update_transpose_cache:
weight_tensors[i].transpose_2d( weight_tensors[i].transpose_2d(
fill_cache=True, fill_cache=True,
noop_flag=skip_fp8_weight_update, noop_flag=skip_fp8_weight_update,
...@@ -790,7 +778,6 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -790,7 +778,6 @@ class GroupedLinear(TransformerEngineBaseModule):
cache_name=(None if is_first_microbatch is None else f"weight{i}"), cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
) )
from ..cpu_offload import CPUOffloadEnabled from ..cpu_offload import CPUOffloadEnabled
......
...@@ -36,8 +36,6 @@ from ..distributed import ( ...@@ -36,8 +36,6 @@ from ..distributed import (
allreduce, allreduce,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors, _fsdp_scatter_tensors,
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
...@@ -47,6 +45,7 @@ from ..graph import is_graph_capturing ...@@ -47,6 +45,7 @@ from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode from ..export import is_in_onnx_export_mode
from ..tensor import QuantizedTensor
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
...@@ -1151,14 +1150,14 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1151,14 +1150,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, Float8Tensor) for w in unfused_weights): if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8: if self.fp8:
if len(unfused_weights) != 1: if len(unfused_weights) != 1:
raise RuntimeError( raise RuntimeError(
"Splitting Float8Tensor into multiple params is not supported" "Splitting QuantizedTensor into multiple params is not supported"
) )
else: else:
unfused_weights = [w.from_float8() for w in unfused_weights] unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights) weight_tensor = _noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = _noop_cat( bias_tensor = _noop_cat(
...@@ -1170,32 +1169,18 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1170,32 +1169,18 @@ class LayerNormLinear(TransformerEngineBaseModule):
# Initialize FP8 weights if needed # Initialize FP8 weights if needed
weight_fp8 = None weight_fp8 = None
if self.fp8: if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
if isinstance(weight_tensor, Float8Tensor): if isinstance(weight_tensor, Float8Tensor):
# Fill transpose cache in FP8 tensor if needed # Make sure transpose cache is valid, if present
update_transpose_cache = with_transpose # Note: Transpose cache may have been invalidated
if update_transpose_cache: # externally, e.g. by optimizer.
update_transpose_cache = ( if weight_tensor._transpose is not None:
is_first_microbatch or skip_fp8_weight_update is not None
)
if update_transpose_cache:
weight_tensor.transpose_2d( weight_tensor.transpose_2d(
fill_cache=True, fill_cache=True,
noop_flag=skip_fp8_weight_update, noop_flag=skip_fp8_weight_update,
) )
else: else:
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
update_workspace = ( update_workspace = is_first_microbatch is None or is_first_microbatch
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
weight_fp8 = self.get_fp8_workspace( weight_fp8 = self.get_fp8_workspace(
tensor=weight_tensor, tensor=weight_tensor,
fp8_meta_forward=True, fp8_meta_forward=True,
...@@ -1203,7 +1188,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1203,7 +1188,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
cache_name=(None if is_first_microbatch is None else "weight"), cache_name=(None if is_first_microbatch is None else "weight"),
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
) )
from ..cpu_offload import CPUOffloadEnabled from ..cpu_offload import CPUOffloadEnabled
......
...@@ -42,8 +42,6 @@ from ..distributed import ( ...@@ -42,8 +42,6 @@ from ..distributed import (
allreduce, allreduce,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
use_reentrant_activation_recompute, use_reentrant_activation_recompute,
_fsdp_scatter_tensors, _fsdp_scatter_tensors,
_fsdp_gather_tensors, _fsdp_gather_tensors,
...@@ -1485,19 +1483,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1485,19 +1483,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_fp8 = None fc2_weight_fp8 = None
if self.fp8: if self.fp8:
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
with_transpose = torch.is_grad_enabled()
if (
is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
update_transpose_cache = with_transpose
if update_transpose_cache:
update_transpose_cache = (
is_first_microbatch or skip_fp8_weight_update is not None
)
if isinstance(fc1_weight, Float8Tensor): if isinstance(fc1_weight, Float8Tensor):
if update_transpose_cache: if fc1_weight._transpose is not None:
fc1_weight.transpose_2d( fc1_weight.transpose_2d(
fill_cache=True, fill_cache=True,
noop_flag=skip_fp8_weight_update, noop_flag=skip_fp8_weight_update,
...@@ -1513,10 +1500,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1513,10 +1500,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
cache_name=cache_name, cache_name=cache_name,
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
) )
if isinstance(fc2_weight, Float8Tensor): if isinstance(fc2_weight, Float8Tensor):
if update_transpose_cache: if fc2_weight._transpose is not None:
fc2_weight.transpose_2d( fc2_weight.transpose_2d(
fill_cache=True, fill_cache=True,
noop_flag=skip_fp8_weight_update, noop_flag=skip_fp8_weight_update,
...@@ -1532,7 +1518,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1532,7 +1518,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
cache_name=cache_name, cache_name=cache_name,
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
) )
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
......
...@@ -33,8 +33,6 @@ from ..distributed import ( ...@@ -33,8 +33,6 @@ from ..distributed import (
allreduce, allreduce,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors, _fsdp_scatter_tensors,
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
...@@ -49,6 +47,7 @@ from ..jit import no_torch_dynamo ...@@ -49,6 +47,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode from ..export import is_in_onnx_export_mode
from ..tensor import QuantizedTensor
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -938,19 +937,19 @@ class Linear(TransformerEngineBaseModule): ...@@ -938,19 +937,19 @@ class Linear(TransformerEngineBaseModule):
with self.prepare_forward( with self.prepare_forward(
inp, inp,
is_first_microbatch, is_first_microbatch,
allow_non_contiguous=isinstance(inp, Float8Tensor), allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp: ) as inp:
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, Float8Tensor) for w in unfused_weights): if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8: if self.fp8:
if len(unfused_weights) != 1: if len(unfused_weights) != 1:
raise RuntimeError( raise RuntimeError(
"Splitting Float8Tensor into multiple params is not supported" "Splitting QuantizedTensor into multiple params is not supported"
) )
else: else:
unfused_weights = [w.from_float8() for w in unfused_weights] unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights) weight_tensor = _noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = _noop_cat( bias_tensor = _noop_cat(
...@@ -962,21 +961,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -962,21 +961,11 @@ class Linear(TransformerEngineBaseModule):
# Initialize FP8 weights if needed # Initialize FP8 weights if needed
weight_fp8 = None weight_fp8 = None
if self.fp8: if self.fp8:
with_transpose = torch.is_grad_enabled()
if (
not with_transpose
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
with_transpose = True
if isinstance(weight_tensor, Float8Tensor): if isinstance(weight_tensor, Float8Tensor):
# Fill transpose cache in FP8 tensor if needed # Make sure transpose cache is valid, if present
update_transpose_cache = with_transpose # Note: Transpose cache may have been invalidated
if update_transpose_cache: # externally, e.g. by optimizer.
update_transpose_cache = ( if weight_tensor._transpose is not None:
is_first_microbatch or skip_fp8_weight_update is not None
)
if update_transpose_cache:
weight_tensor.transpose_2d( weight_tensor.transpose_2d(
fill_cache=True, fill_cache=True,
noop_flag=skip_fp8_weight_update, noop_flag=skip_fp8_weight_update,
...@@ -991,7 +980,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -991,7 +980,6 @@ class Linear(TransformerEngineBaseModule):
cache_name=(None if is_first_microbatch is None else "weight"), cache_name=(None if is_first_microbatch is None else "weight"),
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
fsdp_group=self.fsdp_group, fsdp_group=self.fsdp_group,
) )
......
...@@ -9,54 +9,12 @@ from typing import Any, Iterable, Optional ...@@ -9,54 +9,12 @@ from typing import Any, Iterable, Optional
import torch import torch
from transformer_engine.pytorch.float8_tensor import Float8Tensor from ..tensor import Float8Tensor
from ..utils import (
canonicalize_device, # pylint: disable=unused-import
def canonicalize_device(device: Optional[torch.device | str]) -> torch.device: canonicalize_dtype, # pylint: disable=unused-import
"""Canonicalize PyTorch device devices_match, # pylint: disable=unused-import
)
If `None`, then returns the default CUDA device.
"""
if device is None:
# Use default CUDA device
device = torch.get_default_device()
if device.type != "cuda":
device = torch.device("cuda", torch.cuda.current_device())
elif not isinstance(device, torch.device):
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device("cuda", torch.cuda.current_device())
return device
def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
"""Canonicalize PyTorch datatype
If `None`, then returns the default PyTorch datatype.
"""
if dtype is None:
# Use default dtype
dtype = torch.get_default_dtype()
return dtype
def devices_match(device1: torch.device, device2: torch.device) -> bool:
"""Whether two devices are the same"""
device1 = torch.device(device1)
device2 = torch.device(device2)
if device1.type != device2.type:
return False
if device1.type == "cuda":
index1 = device1.index
index2 = device2.index
if index1 is None:
index1 = torch.cuda.current_device()
if index2 is None:
index2 = torch.cuda.current_device()
return index1 == index2
return device1 == device2
def is_float8_tensor(tensor: Any) -> bool: def is_float8_tensor(tensor: Any) -> bool:
...@@ -92,7 +50,13 @@ def convert_tensor( ...@@ -92,7 +50,13 @@ def convert_tensor(
# Convert FP8 tensor # Convert FP8 tensor
if is_float8_tensor(tensor): if is_float8_tensor(tensor):
data = tensor._data.to(device=device, memory_format=memory_format) data = tensor._data
if not devices_match(device, data.device):
data = data.to(device=device)
if memory_format != torch.preserve_format and not data.is_contiguous(
memory_format=memory_format
):
data = data.contiguous(memory_format=memory_format)
return Float8Tensor.make_like( return Float8Tensor.make_like(
tensor, tensor,
data=data, data=data,
......
...@@ -9,11 +9,8 @@ from typing import Optional ...@@ -9,11 +9,8 @@ from typing import Optional
import torch import torch
from transformer_engine.pytorch.ops.op import ( from ...tensor import QuantizedTensor
BasicOperation, from ..op import BasicOperation, OperationContext
OperationContext,
)
from .._common import is_float8_tensor
class AllReduce(BasicOperation): class AllReduce(BasicOperation):
...@@ -54,8 +51,8 @@ class AllReduce(BasicOperation): ...@@ -54,8 +51,8 @@ class AllReduce(BasicOperation):
# Perform all-reduce # Perform all-reduce
x = input_ x = input_
if is_float8_tensor(x): if isinstance(x, QuantizedTensor):
x = x.from_float8() x = x.dequantize()
x = x.contiguous() x = x.contiguous()
torch.distributed.all_reduce(x, group=self.process_group) torch.distributed.all_reduce(x, group=self.process_group)
return x return x
......
...@@ -289,10 +289,18 @@ class BasicLinear(BasicOperation): ...@@ -289,10 +289,18 @@ class BasicLinear(BasicOperation):
# Cast to FP8 if needed # Cast to FP8 if needed
if self._with_fp8_parameters: if self._with_fp8_parameters:
dummy_amax = torch.empty(
(1, 1),
dtype=torch.float32,
device=self.device,
) # Dummy buffer to avoid overwriting amax history
weight = Float8Tensor.to_float8( weight = Float8Tensor.to_float8(
weight, weight,
fp8_meta=self.get_fp8_meta("param"), fp8_meta=self.get_fp8_meta("param"),
fp8_meta_forward=True,
fp8_meta_index=0, fp8_meta_index=0,
amax=dummy_amax,
with_transpose_cache=torch.is_grad_enabled(),
) )
# Save updated parameter # Save updated parameter
...@@ -467,25 +475,19 @@ class BasicLinear(BasicOperation): ...@@ -467,25 +475,19 @@ class BasicLinear(BasicOperation):
input_fp8_meta["recipe"], input_fp8_meta["recipe"],
fprop_tensor=True, fprop_tensor=True,
) )
x_fp8 = Float8Tensor( with_transpose_cache = weight.requires_grad
data=torch.empty_like(x_local, dtype=torch.uint8), if tensor_parallel_mode == "column" and sequence_parallel:
with_transpose_cache = False
x_local = Float8Tensor.to_float8(
x_local,
fp8_meta=input_fp8_meta, fp8_meta=input_fp8_meta,
fp8_meta_forward=True, fp8_meta_forward=True,
fp8_meta_index=0, fp8_meta_index=0,
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), with_transpose_cache=with_transpose_cache,
dtype=dtype,
) )
with_cast_transpose = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_cast_transpose = False
if with_cast_transpose:
x_fp8.cast_transpose_(x_local)
else:
x_fp8.copy_(x_local)
x_local = x_fp8
elif not with_fp8_compute and is_float8_tensor(x_local): elif not with_fp8_compute and is_float8_tensor(x_local):
x_local = x_local.from_float8() x_local = x_local.dequantize()
x = x_local x = x_local
x_async = None x_async = None
if tensor_parallel_mode == "column" and sequence_parallel: if tensor_parallel_mode == "column" and sequence_parallel:
...@@ -510,11 +512,12 @@ class BasicLinear(BasicOperation): ...@@ -510,11 +512,12 @@ class BasicLinear(BasicOperation):
w = Float8Tensor.to_float8( w = Float8Tensor.to_float8(
w, w,
fp8_meta=weight_fp8_meta, fp8_meta=weight_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0, fp8_meta_index=0,
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
) )
elif not with_fp8_compute and is_float8_tensor(w): elif not with_fp8_compute and is_float8_tensor(w):
w = w.from_float8() w = w.dequantize()
# Check bias tensor # Check bias tensor
b = None b = None
...@@ -815,25 +818,19 @@ class BasicLinear(BasicOperation): ...@@ -815,25 +818,19 @@ class BasicLinear(BasicOperation):
grad_output_fp8_meta["recipe"], grad_output_fp8_meta["recipe"],
fprop_tensor=False, fprop_tensor=False,
) )
dy_fp8 = Float8Tensor( with_transpose_cache = weight_requires_grad
data=torch.empty_like(dy, dtype=torch.uint8), if tensor_parallel_mode == "row" and sequence_parallel:
with_transpose_cache = False
dy = Float8Tensor.to_float8(
dy,
fp8_meta=grad_output_fp8_meta, fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False, fp8_meta_forward=False,
fp8_meta_index=0, fp8_meta_index=0,
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), with_transpose_cache=with_transpose_cache,
dtype=dtype,
) )
with_cast_transpose = weight_requires_grad
if tensor_parallel_mode == "row" and sequence_parallel:
with_cast_transpose = False
if with_cast_transpose:
dy_fp8.cast_transpose_(dy)
else:
dy_fp8.copy_(dy)
dy = dy_fp8
elif not with_fp8_compute and is_float8_tensor(dy): elif not with_fp8_compute and is_float8_tensor(dy):
dy = dy.from_float8() dy = dy.dequantize()
if tensor_parallel_mode == "row" and sequence_parallel: if tensor_parallel_mode == "row" and sequence_parallel:
dy, dy_async = gather_along_first_dim( dy, dy_async = gather_along_first_dim(
dy, dy,
...@@ -853,26 +850,24 @@ class BasicLinear(BasicOperation): ...@@ -853,26 +850,24 @@ class BasicLinear(BasicOperation):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
x_is_sharded = tensor_parallel_mode == "column" and sequence_parallel
if with_fp8_compute and not is_float8_tensor(x_local): if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype( fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"], input_fp8_meta["recipe"],
fprop_tensor=True, fprop_tensor=True,
) )
x_fp8 = Float8Tensor( x_local = Float8Tensor.to_float8(
data=torch.empty_like(x_local, dtype=torch.uint8), x_local,
fp8_meta=input_fp8_meta, fp8_meta=input_fp8_meta,
fp8_meta_forward=True, fp8_meta_forward=True,
fp8_meta_index=0, fp8_meta_index=0,
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), with_transpose_cache=(not x_is_sharded),
dtype=dtype,
) )
x_fp8.cast_transpose_(x_local)
x_local = x_fp8
elif not with_fp8_compute and is_float8_tensor(x_local): elif not with_fp8_compute and is_float8_tensor(x_local):
x_local = x_local.from_float8() x_local = x_local.from_float8()
x = x_local x = x_local
if tensor_parallel_mode == "column" and sequence_parallel: if x_is_sharded:
x, x_async = gather_along_first_dim( x, x_async = gather_along_first_dim(
x_local, x_local,
tensor_parallel_group, tensor_parallel_group,
...@@ -898,19 +893,16 @@ class BasicLinear(BasicOperation): ...@@ -898,19 +893,16 @@ class BasicLinear(BasicOperation):
weight_fp8_meta["recipe"], weight_fp8_meta["recipe"],
fprop_tensor=True, fprop_tensor=True,
) )
w_fp8 = Float8Tensor( w = Float8Tensor.to_float8(
data=torch.empty_like(w, dtype=torch.uint8), w,
fp8_meta=weight_fp8_meta, fp8_meta=weight_fp8_meta,
fp8_meta_forward=True, fp8_meta_forward=True,
fp8_meta_index=0, fp8_meta_index=0,
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), with_transpose_cache=True,
dtype=dtype,
) )
w_fp8.cast_transpose_(w)
w = w_fp8
elif not with_fp8_compute and is_float8_tensor(w): elif not with_fp8_compute and is_float8_tensor(w):
w = w.from_float8() w = w.dequantize()
# Construct grad input tensor # Construct grad input tensor
if grad_input is not None: if grad_input is not None:
......
...@@ -9,12 +9,9 @@ from typing import Optional ...@@ -9,12 +9,9 @@ from typing import Optional
import torch import torch
from transformer_engine.pytorch.float8_tensor import Float8Tensor from ...tensor import Float8Tensor, QuantizedTensor
from transformer_engine.pytorch.ops.op import ( from ..op import BasicOperation, OperationContext
BasicOperation, from .._common import convert_tensor
OperationContext,
)
from .._common import convert_tensor, is_float8_tensor
class ReduceScatter(BasicOperation): class ReduceScatter(BasicOperation):
...@@ -63,8 +60,8 @@ class ReduceScatter(BasicOperation): ...@@ -63,8 +60,8 @@ class ReduceScatter(BasicOperation):
# Check input tensor # Check input tensor
x = input_ x = input_
if is_float8_tensor(x): if isinstance(x, QuantizedTensor):
x = x.from_float8() x = x.dequantize()
x = x.contiguous() x = x.contiguous()
# Perform reduce-scatter # Perform reduce-scatter
...@@ -96,7 +93,7 @@ class ReduceScatter(BasicOperation): ...@@ -96,7 +93,7 @@ class ReduceScatter(BasicOperation):
# Perform all-gather # Perform all-gather
dy = convert_tensor(grad_output, memory_format=torch.contiguous_format) dy = convert_tensor(grad_output, memory_format=torch.contiguous_format)
dx = None dx = None
if is_float8_tensor(dy): if isinstance(dy, Float8Tensor):
dx = Float8Tensor.make_like( dx = Float8Tensor.make_like(
dy, dy,
data=torch.empty( data=torch.empty(
...@@ -111,6 +108,8 @@ class ReduceScatter(BasicOperation): ...@@ -111,6 +108,8 @@ class ReduceScatter(BasicOperation):
group=self.process_group, group=self.process_group,
) )
else: else:
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device) dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device)
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
dx, dx,
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Custom tensor classes"""
from .float8_tensor import Float8Tensor
from .quantized_tensor import QuantizedTensor
This diff is collapsed.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor with quantized data"""
from __future__ import annotations
from typing import Optional, Tuple
import torch
from torch.utils._pytree import tree_map
class _DequantizeFunc(torch.autograd.Function):
"""Autograd function to convert quantized tensor to standard tensor"""
@staticmethod
def forward(
_ctx: torch.autograd.function.FunctionCtx, # unused
tensor: QuantizedTensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return tensor.dequantize(dtype=dtype)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
return grad, None
class _IdentityFunc(torch.autograd.Function):
"""Autograd function to create quantized tensor with same data"""
@staticmethod
def forward(
_ctx: torch.autograd.function.FunctionCtx, # unused
tensor: QuantizedTensor,
) -> QuantizedTensor:
return tensor.detach()
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> torch.Tensor:
return grad
class QuantizedTensor(torch.Tensor):
"""Abstract base class for tensor with quantized data
This is a proxy class with the interface of a standard PyTorch
tensor, but with data that has been encoded with some quantization
scheme. Derived classes should implement the quantization scheme
by overriding the `quantize_` and `dequantize` functions.
"""
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Convert quantized data to standard PyTorch tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement dequantize function"
)
def quantize_(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Update quantized data in-place"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement quantize_ function"
)
def detach(self) -> QuantizedTensor:
"""Create new quantized tensor with same data
Output tensor must be detached from the current autograd
graph.
"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement detach function"
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
def float(self) -> torch.Tensor:
return _DequantizeFunc.apply(self, torch.float32)
def bfloat16(self) -> torch.Tensor:
return _DequantizeFunc.apply(self, torch.bfloat16)
def half(self) -> torch.Tensor:
return _DequantizeFunc.apply(self, torch.float16)
def cpu(self) -> torch.Tensor:
return _DequantizeFunc.apply(self).cpu()
def expand_as(self, other: torch.Tensor) -> torch.Tensor:
if other is self:
# Note: expand_as is hackily used to create dummy autograd nodes
# and access the backward graph (see
# https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026).
# We hackily add a dummy function to handle this case.
return _IdentityFunc.apply(self)
return super().expand_as(other)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# Detach op
if func == torch.ops.aten.detach.default:
return args[0].detach()
# In-place copy op
if func == torch.ops.aten.copy_.default:
dst = args[0]
src = args[1]
if isinstance(dst, QuantizedTensor):
dst.quantize_(src)
else:
if isinstance(src, QuantizedTensor):
src = src.dequantize()
dst.copy_(src)
return None
# View op
if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views")
def maybe_unwrap(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize(dtype=arg.dtype)
return arg
def maybe_update_inplace(arg, new_arg, schema_arg):
if (
isinstance(arg, QuantizedTensor)
and isinstance(new_arg, torch.Tensor)
and hasattr(schema_arg, "alias_info")
and hasattr(schema_arg.alias_info, "is_write")
and schema_arg.alias_info.is_write
):
arg.quantize_(new_arg)
# In-place op: dequantize, perform op, and quantize
if func._schema.is_mutable:
new_args = tree_map(maybe_unwrap, args)
new_kwargs = tree_map(maybe_unwrap, kwargs)
schema_args = func._schema.arguments
args_len = len(args)
super().__torch_dispatch__(func, types, new_args, new_kwargs)
for arg, new_arg, schema_arg in zip(args, new_args, schema_args):
maybe_update_inplace(arg, new_arg, schema_arg)
for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]):
assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match"
maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg)
return None
# Default op: dequantize and perform op
args = tree_map(maybe_unwrap, args)
if kwargs is not None:
kwargs = tree_map(maybe_unwrap, kwargs)
out = super().__torch_dispatch__(func, types, args, kwargs)
return out
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# Do not force the QuantizedTensor type on the returned tensor
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Utility functions for Transformer Engine modules""" """Utility functions for Transformer Engine modules"""
from __future__ import annotations
import functools import functools
import math import math
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
...@@ -251,3 +252,52 @@ def get_cudnn_version() -> Tuple[int, int, int]: ...@@ -251,3 +252,52 @@ def get_cudnn_version() -> Tuple[int, int, int]:
major, encoded_version = divmod(encoded_version, major_version_magnitude) major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100) minor, patch = divmod(encoded_version, 100)
return (major, minor, patch) return (major, minor, patch)
def canonicalize_device(device: Optional[torch.device | str]) -> torch.device:
"""Canonicalize PyTorch device
If `None`, then returns the default CUDA device.
"""
if device is None:
# Use default CUDA device
device = torch.get_default_device()
if device.type != "cuda":
device = torch.device("cuda", torch.cuda.current_device())
elif not isinstance(device, torch.device):
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device("cuda", torch.cuda.current_device())
return device
def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
"""Canonicalize PyTorch datatype
If `None`, then returns the default PyTorch datatype.
"""
if dtype is None:
# Use default dtype
dtype = torch.get_default_dtype()
return dtype
def devices_match(device1: torch.device, device2: torch.device) -> bool:
"""Whether two devices are the same"""
device1 = torch.device(device1)
device2 = torch.device(device2)
if device1.type != device2.type:
return False
if device1.type == "cuda":
index1 = device1.index
index2 = device2.index
if index1 == index2:
return True
if index1 is None:
index1 = torch.cuda.current_device()
if index2 is None:
index2 = torch.cuda.current_device()
return index1 == index2
return device1 == device2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment