Unverified Commit 3b89c36f authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Small fixes to Float8Tensor (#1225)



* Fixes to Float8Tensor
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 85e60e64
...@@ -39,3 +39,4 @@ develop-eggs/ ...@@ -39,3 +39,4 @@ develop-eggs/
dist/ dist/
downloads/ downloads/
.pytest_cache/ .pytest_cache/
compile_commands.json
...@@ -126,12 +126,9 @@ class _ToFloat8Func(torch.autograd.Function): ...@@ -126,12 +126,9 @@ class _ToFloat8Func(torch.autograd.Function):
# Check scale # Check scale
if scale is None and fp8_meta is None: if scale is None and fp8_meta is None:
scale = 1 scale = torch.full([1], 1, dtype=torch.float32, device=device)
if scale is not None: if scale is not None:
if isinstance(scale, torch.Tensor):
scale = scale.to(device=device, dtype=torch.float32) scale = scale.to(device=device, dtype=torch.float32)
else:
scale = torch.full([1], scale, dtype=torch.float32, device=device)
# Check scale-inverse # Check scale-inverse
if scale_inv is None: if scale_inv is None:
...@@ -335,6 +332,18 @@ class Float8Tensor(QuantizedTensor): ...@@ -335,6 +332,18 @@ class Float8Tensor(QuantizedTensor):
""" """
_data: torch.Tensor
_fp8_attrs: Dict[str, Any]
_fp8_meta: Optional[Dict[str, Any]]
_fp8_meta_forward: bool
_fp8_meta_index: Optional[int]
_fp8_dtype: TE_DType
_scale_inv: torch.Tensor
# FP8 transpose cache
_transpose: Optional[torch.Tensor]
_transpose_invalid: bool
def __new__( def __new__(
cls, cls,
*, *,
...@@ -371,13 +380,12 @@ class Float8Tensor(QuantizedTensor): ...@@ -371,13 +380,12 @@ class Float8Tensor(QuantizedTensor):
requires_grad=requires_grad, requires_grad=requires_grad,
device=data.device, device=data.device,
) )
self._data: torch.Tensor = data self._data = data
# Initialize dict of class attributes # Initialize dict of class attributes
# Note: We store FP8 attributes in a dictionary so we can # Note: We store FP8 attributes in a dictionary so we can
# share them between tensors with the same data, e.g. detached # share them between tensors with the same data, e.g. detached
# tensors. # tensors.
self._fp8_attrs: dict
if fp8_attrs is None: if fp8_attrs is None:
self._fp8_attrs = {} self._fp8_attrs = {}
else: else:
...@@ -390,16 +398,16 @@ class Float8Tensor(QuantizedTensor): ...@@ -390,16 +398,16 @@ class Float8Tensor(QuantizedTensor):
"To initialize Float8Tensor with FP8 meta tensors, " "To initialize Float8Tensor with FP8 meta tensors, "
"the FP8 meta tensor index must also be provided" "the FP8 meta tensor index must also be provided"
) )
self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta self._fp8_meta = fp8_meta
self._fp8_meta_forward: bool = fp8_meta_forward self._fp8_meta_forward = fp8_meta_forward
self._fp8_meta_index: Optional[int] = fp8_meta_index self._fp8_meta_index = fp8_meta_index
# FP8 dtype # FP8 dtype
assert fp8_dtype in ( assert fp8_dtype in (
TE_DType.kFloat8E4M3, TE_DType.kFloat8E4M3,
TE_DType.kFloat8E5M2, TE_DType.kFloat8E5M2,
), f"Unsupported fp8_dtype {fp8_dtype}." ), f"Unsupported fp8_dtype {fp8_dtype}."
self._fp8_dtype: TE_DType = fp8_dtype self._fp8_dtype = fp8_dtype
# FP8 scale-inverse # FP8 scale-inverse
if fp8_scale_inv is None and self._fp8_meta is not None: if fp8_scale_inv is None and self._fp8_meta is not None:
...@@ -412,13 +420,6 @@ class Float8Tensor(QuantizedTensor): ...@@ -412,13 +420,6 @@ class Float8Tensor(QuantizedTensor):
raise ValueError( raise ValueError(
"Attempted to initialize Float8Tensor without specifying scale-inverse" "Attempted to initialize Float8Tensor without specifying scale-inverse"
) )
if not isinstance(fp8_scale_inv, torch.Tensor):
fp8_scale_inv = torch.full(
[1],
fp8_scale_inv,
dtype=torch.float32,
device=self._data.device,
)
if fp8_scale_inv.numel() != 1: if fp8_scale_inv.numel() != 1:
raise ValueError( raise ValueError(
"Attempted to initialize Float8Tensor with invalid scale-inverse tensor" "Attempted to initialize Float8Tensor with invalid scale-inverse tensor"
...@@ -433,11 +434,11 @@ class Float8Tensor(QuantizedTensor): ...@@ -433,11 +434,11 @@ class Float8Tensor(QuantizedTensor):
device=self._data.device, device=self._data.device,
dtype=torch.float32, dtype=torch.float32,
) )
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv self._scale_inv = fp8_scale_inv
# FP8 transpose cache # FP8 transpose cache
self._transpose: Optional[Float8Tensor] = data_transpose self._transpose = data_transpose
self._transpose_invalid: bool = self._transpose is None self._transpose_invalid = self._transpose is None
return self return self
...@@ -477,7 +478,7 @@ class Float8Tensor(QuantizedTensor): ...@@ -477,7 +478,7 @@ class Float8Tensor(QuantizedTensor):
")" ")"
) )
def dequantize(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# Convert PyTorch dtype to TE dtype # Convert PyTorch dtype to TE dtype
if dtype is None: if dtype is None:
...@@ -603,11 +604,8 @@ class Float8Tensor(QuantizedTensor): ...@@ -603,11 +604,8 @@ class Float8Tensor(QuantizedTensor):
# Make sure FP8 scaling factors are in expected format # Make sure FP8 scaling factors are in expected format
if scale is not None: if scale is not None:
if isinstance(scale, torch.Tensor):
if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32: if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32:
scale = scale.to(device=dst.device, dtype=torch.float32) scale = scale.to(device=dst.device, dtype=torch.float32)
else:
scale = torch.full([1], scale, dtype=torch.float32, device=dst.device)
if amax is not None: if amax is not None:
while amax.dim() < 2: while amax.dim() < 2:
amax = amax.unsqueeze(0) amax = amax.unsqueeze(0)
...@@ -781,23 +779,21 @@ class Float8Tensor(QuantizedTensor): ...@@ -781,23 +779,21 @@ class Float8Tensor(QuantizedTensor):
fill_cache = False fill_cache = False
# Need to compute transpose if cache is invalid # Need to compute transpose if cache is invalid
need_compute = force_compute need_compute = (
if self._transpose is None: force_compute
need_compute = True or (self._transpose is None)
elif self._transpose_invalid: or self._transpose_invalid
need_compute = True or (noop_flag is not None)
)
# Need to apply transpose kernel if noop flag is applied
if noop_flag is not None:
need_compute = True
# Return cached transpose if possible # Return cached transpose if possible
if not need_compute: if not need_compute:
assert self._transpose is not None
return self._transpose return self._transpose
# Allocate output if needed # Allocate output if needed
data = self._data.contiguous().reshape(-1, self.size(-1)) data = self._data.contiguous().reshape(-1, self.size(-1))
out = self._transpose out: Optional[torch.Tensor] = self._transpose
if out is None: if out is None:
out = torch.empty( out = torch.empty(
(data.size(1), data.size(0)), (data.size(1), data.size(0)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment