Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
...@@ -9,14 +9,19 @@ import math ...@@ -9,14 +9,19 @@ import math
from typing import Optional, Dict, Any, Tuple from typing import Optional, Dict, Any, Tuple
import torch import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase
from ...constants import TE_DType_To_Torch from ...constants import TE_DType_To_Torch
from ..quantized_tensor import Quantizer from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
class Float8BlockwiseQTensorBase: class Float8BlockwiseQTensorBase(QuantizedTensorBase):
"""Mixin class that holds data attributes of Float8BlockwiseQTensor. """Mixin class that holds data attributes of Float8BlockwiseQTensor.
Float8BlockwiseQTensor inherits from the PyTorch tensor class and this Float8BlockwiseQTensor inherits from the PyTorch tensor class and this
...@@ -56,6 +61,17 @@ class Float8BlockwiseQTensorBase: ...@@ -56,6 +61,17 @@ class Float8BlockwiseQTensorBase:
return instance return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
):
if t is not None:
t.data = _empty_tensor()
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata.""" """Get this tensor's metadata."""
return { return {
...@@ -73,14 +89,17 @@ class Float8BlockwiseQTensorBase: ...@@ -73,14 +89,17 @@ class Float8BlockwiseQTensorBase:
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
""" """
Prepare the tensor base for saving for backward Prepare the tensor base for saving for backward
This does not clear the tensors currently, because with PP config
that clears the weight cache between micro-batches. If the rowwise
data is not required for backward, this is a possible memory
pessimization, but is consistent with the other quantized tensor
classes.
""" """
tensors = [self._rowwise_data, self._columnwise_data] tensors = [
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
]
self._rowwise_data = None
self._columnwise_data = None
self._rowwise_scale_inv = None
self._columnwise_scale_inv = None
return tensors, self return tensors, self
def restore_from_saved( def restore_from_saved(
...@@ -89,7 +108,9 @@ class Float8BlockwiseQTensorBase: ...@@ -89,7 +108,9 @@ class Float8BlockwiseQTensorBase:
"""Restore the tensor base data from the saved tensors list.""" """Restore the tensor base data from the saved tensors list."""
self._rowwise_data = tensors[0] self._rowwise_data = tensors[0]
self._columnwise_data = tensors[1] self._columnwise_data = tensors[1]
return tensors[2:] self._rowwise_scale_inv = tensors[2]
self._columnwise_scale_inv = tensors[3]
return tensors[4:]
def get_data_tensors(self): def get_data_tensors(self):
"""Get this Tensor's data.""" """Get this Tensor's data."""
...@@ -232,6 +253,38 @@ class Float8BlockwiseQTensorBase: ...@@ -232,6 +253,38 @@ class Float8BlockwiseQTensorBase:
reordered.append(dims[0]) reordered.append(dims[0])
return torch.Size(reordered) return torch.Size(reordered)
def _create_columnwise(self):
"""
Update columnwise data and columnwise scale inv. Can only be used when using 2D scaling.
"""
assert self._is_2D_scaled, "Cannot create columnwise data when not using 2D scaling."
rowwise_data = self._rowwise_data
if not rowwise_data.is_contiguous():
rowwise_data = rowwise_data.contiguous()
self._columnwise_data = tex.fp8_transpose(
rowwise_data, self._fp8_dtype, out=self._columnwise_data
)
if self._columnwise_scale_inv is None:
assert self._quantizer is not None, (
"._quantizer of Float8BlockwiseQTensor cannot be None because all the blockwise "
"quantized tensors are supposed to be generated from the quantizer."
)
columnwise_scale_inv_shape = self._quantizer.get_scale_shape(rowwise_data.shape, True)
self._columnwise_scale_inv = torch.empty(
columnwise_scale_inv_shape,
dtype=self._rowwise_scale_inv.dtype,
device=self._rowwise_scale_inv.device,
)
assert len(self._rowwise_scale_inv.shape) == 2
assert len(self._columnwise_scale_inv.shape) == 2
rowwise_scale_inv = self._rowwise_scale_inv
columnwise_scale_inv = rowwise_scale_inv.transpose(-2, -1)
h = min(self._columnwise_scale_inv.shape[0], columnwise_scale_inv.shape[0])
w = min(self._columnwise_scale_inv.shape[1], columnwise_scale_inv.shape[1])
self._columnwise_scale_inv[0:h, 0:w].copy_(columnwise_scale_inv[0:h, 0:w])
def __repr__(self): def __repr__(self):
if self._rowwise_data is not None: if self._rowwise_data is not None:
data = self.dequantize() data = self.dequantize()
...@@ -244,3 +297,55 @@ class Float8BlockwiseQTensorBase: ...@@ -244,3 +297,55 @@ class Float8BlockwiseQTensorBase:
f"fp8_dtype={self._fp8_dtype}, " f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data}" f"{descriptor}_scaled_data={data}"
) )
def update_usage(
self, rowwise_usage: Optional[bool] = None, columnwise_usage: Optional[bool] = None
):
"""
update_usage can be used to clear out one of two possible copies of the data.
"""
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
assert (
columnwise_usage or rowwise_usage
), "Must retain some data either columnwise or rowwise"
if columnwise_usage and rowwise_usage:
if not self._is_2D_scaled:
# For 1D scaling, we cannot create columnwise data/scale_inv from rowwise
# data/scale_inv because their scale values are different.
assert (
self._rowwise_data is not None
and self._rowwise_scale_inv is not None
and self._columnwise_data is not None
and self._columnwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage."
else:
# For 2D scaling, if columnwise data/scale_inv is None, we can create them from
# rowwise data/scale_inv.
assert (
self._rowwise_data is not None and self._rowwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage because rowwise data is None."
if self._columnwise_data is None or self._columnwise_scale_inv is None:
self._create_columnwise()
return
if rowwise_usage:
assert (
self._rowwise_data is not None and self._rowwise_scale_inv is not None
), "Cannot update to rowwise usage."
self._columnwise_data = None
self._columnwise_scale_inv = None
return
if columnwise_usage:
assert (
self._columnwise_data is not None and self._columnwise_scale_inv is not None
), "Cannot update to columnwise usage."
self._rowwise_data = None
self._rowwise_scale_inv = None
return
return
...@@ -12,10 +12,14 @@ import torch ...@@ -12,10 +12,14 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase
from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer from ..quantized_tensor import Quantizer
from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor
class _FromFloat8Func(torch.autograd.Function): class _FromFloat8Func(torch.autograd.Function):
"""Cast from FP8 to other dtype""" """Cast from FP8 to other dtype"""
...@@ -48,7 +52,7 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -48,7 +52,7 @@ class _FromFloat8Func(torch.autograd.Function):
return grad, None return grad, None
class Float8TensorBase: class Float8TensorBase(QuantizedTensorBase):
"""Mixin class that holds data attributes of Float8Tensor. """Mixin class that holds data attributes of Float8Tensor.
Float8Tensor inherits from the PyTorch tensor class and this mixin Float8Tensor inherits from the PyTorch tensor class and this mixin
...@@ -90,6 +94,13 @@ class Float8TensorBase: ...@@ -90,6 +94,13 @@ class Float8TensorBase:
return instance return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (self._data, self._transpose, self._scale_inv):
if t is not None:
t.data = _empty_tensor()
self._transpose_invalid = True
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata.""" """Get this tensor's metadata."""
return { return {
...@@ -100,9 +111,12 @@ class Float8TensorBase: ...@@ -100,9 +111,12 @@ class Float8TensorBase:
"quantizer": self._quantizer, "quantizer": self._quantizer,
} }
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]:
"""Prepare the tensor base for saving for backward""" """Prepare the tensor base for saving for backward"""
tensors = [self._data, self._transpose] tensors = [self._data, self._transpose, self._scale_inv]
self._data = None
self._transpose = None
self._scale_inv = None
return tensors, self return tensors, self
def restore_from_saved( def restore_from_saved(
...@@ -111,7 +125,8 @@ class Float8TensorBase: ...@@ -111,7 +125,8 @@ class Float8TensorBase:
"""Restore the tensor base data from the saved tensors list""" """Restore the tensor base data from the saved tensors list"""
self._data = tensors[0] self._data = tensors[0]
self._transpose = tensors[1] self._transpose = tensors[1]
return tensors[2:] self._scale_inv = tensors[2]
return tensors[3:]
def get_data_tensors(self): def get_data_tensors(self):
"""Get this Tensor's data.""" """Get this Tensor's data."""
...@@ -144,3 +159,43 @@ class Float8TensorBase: ...@@ -144,3 +159,43 @@ class Float8TensorBase:
data = data.contiguous() data = data.contiguous()
self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose) self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose)
self._transpose_invalid = False self._transpose_invalid = False
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""
Generate or remove FP8 data based on provided usage. For
FP8, data cannot be generated even if transpose is available.
"""
has_data = self._data is not None
has_data_transpose = self._transpose is not None and not self._transpose_invalid
needs_data = has_data
needs_data_transpose = has_data_transpose
if is_non_tn_fp8_gemm_supported():
if rowwise_usage is not None and rowwise_usage:
needs_data = True
if columnwise_usage is not None and columnwise_usage:
needs_data = True
needs_data_transpose = False
else:
if rowwise_usage is not None:
needs_data = rowwise_usage
if columnwise_usage is not None:
needs_data_transpose = columnwise_usage
# Generate data that is required
if needs_data and not has_data:
raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose")
if needs_data_transpose and not has_data_transpose:
if not has_data:
raise RuntimeError("FP8 data is required to generate FP8 data transpose")
self._create_transpose()
# Delete data that is not required
if not needs_data:
self._data = None
if not needs_data_transpose:
self._transpose = None
self._transpose_invalid = True
...@@ -11,10 +11,14 @@ import torch ...@@ -11,10 +11,14 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase
from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
class _FromMXFP8Func(torch.autograd.Function): class _FromMXFP8Func(torch.autograd.Function):
"""Cast from MXFP8 to other dtype""" """Cast from MXFP8 to other dtype"""
...@@ -43,7 +47,7 @@ class _FromMXFP8Func(torch.autograd.Function): ...@@ -43,7 +47,7 @@ class _FromMXFP8Func(torch.autograd.Function):
return grad, None return grad, None
class MXFP8TensorBase: class MXFP8TensorBase(QuantizedTensorBase):
"""Mixin class that holds data attributes of MXFP8Tensor. """Mixin class that holds data attributes of MXFP8Tensor.
MXFP8Tensor inherits from the PyTorch tensor class and this mixin MXFP8Tensor inherits from the PyTorch tensor class and this mixin
...@@ -81,6 +85,17 @@ class MXFP8TensorBase: ...@@ -81,6 +85,17 @@ class MXFP8TensorBase:
return instance return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
):
if t is not None:
t.data = _empty_tensor()
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata.""" """Get this tensor's metadata."""
return { return {
...@@ -94,7 +109,16 @@ class MXFP8TensorBase: ...@@ -94,7 +109,16 @@ class MXFP8TensorBase:
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]:
"""Prepare the tensor base for saving for backward""" """Prepare the tensor base for saving for backward"""
tensors = [self._rowwise_data, self._columnwise_data] tensors = [
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
]
self._rowwise_data = None
self._columnwise_data = None
self._rowwise_scale_inv = None
self._columnwise_scale_inv = None
return tensors, self return tensors, self
def restore_from_saved( def restore_from_saved(
...@@ -103,7 +127,9 @@ class MXFP8TensorBase: ...@@ -103,7 +127,9 @@ class MXFP8TensorBase:
"""Restore the tensor base data from the saved tensors list.""" """Restore the tensor base data from the saved tensors list."""
self._rowwise_data = tensors[0] self._rowwise_data = tensors[0]
self._columnwise_data = tensors[1] self._columnwise_data = tensors[1]
return tensors[2:] self._rowwise_scale_inv = tensors[2]
self._columnwise_scale_inv = tensors[3]
return tensors[4:]
def get_data_tensors(self): def get_data_tensors(self):
"""Get this Tensor's data.""" """Get this Tensor's data."""
...@@ -129,3 +155,48 @@ class MXFP8TensorBase: ...@@ -129,3 +155,48 @@ class MXFP8TensorBase:
f"rowwise_scale_inv={self._rowwise_scale_inv}, " f"rowwise_scale_inv={self._rowwise_scale_inv}, "
")" ")"
) )
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""
For MXFP8, columnwise scaled output is only produced by x2
scaling kernels, so this function only disables usages.
"""
# Default usage is based on available data
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
# Update row-scaled data
if rowwise_usage:
if self._rowwise_data is None:
raise RuntimeError(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data"
)
if self._rowwise_scale_inv is None:
raise RuntimeError(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses"
)
else:
self._rowwise_data = None
self._rowwise_scale_inv = None
# Update column-scaled data
if columnwise_usage:
if self._columnwise_data is None:
raise RuntimeError(
"Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data"
)
if self._columnwise_scale_inv is None:
raise RuntimeError(
"Requested column-wise usage, "
"but MXFP8Tensor is missing column-scaled scale-inverses"
)
else:
self._columnwise_data = None
self._columnwise_scale_inv = None
...@@ -309,47 +309,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -309,47 +309,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
return Float8BlockwiseQTensor.make_like(self) return Float8BlockwiseQTensor.make_like(self)
def update_usage(
self, rowwise_usage: Optional[bool] = None, columnwise_usage: Optional[bool] = None
):
"""
update_usage can be used to clear out one of two possible copies of the data.
"""
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
assert (
columnwise_usage or rowwise_usage
), "Must retain some data either columnwise or rowwise"
if columnwise_usage and rowwise_usage:
assert (
self._rowwise_data is not None
and self._rowwise_scale_inv is not None
and self._columnwise_data is not None
and self._columnwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage."
return
if rowwise_usage:
assert (
self._rowwise_data is not None and self._rowwise_scale_inv is not None
), "Cannot update to rowwise usage."
self._columnwise_data = None
self._columnwise_scale_inv = None
return
if columnwise_usage:
assert (
self._columnwise_data is not None and self._columnwise_scale_inv is not None
), "Cannot update to columnwise usage."
self._rowwise_data = None
self._rowwise_scale_inv = None
return
return
def clone(self) -> Float8BlockwiseQTensor: def clone(self) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
rowwise_data = None rowwise_data = None
...@@ -421,11 +380,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -421,11 +380,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
return self return self
raise ValueError("Float8BlockwiseQTensor does not support different memory formats!") raise ValueError("Float8BlockwiseQTensor does not support different memory formats!")
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None
@classmethod @classmethod
def _make_in_reduce_ex( def _make_in_reduce_ex(
cls, cls,
...@@ -544,14 +498,64 @@ class _ViewFunc(torch.autograd.Function): ...@@ -544,14 +498,64 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided # Return input tensor if shape is not provided
if ctx is not None: ctx.shape = tensor.shape
ctx.shape = tensor.shape
if shape is None: if shape is None:
return tensor return tensor
if list(shape) != list(tensor.shape): # Canonicalize shape
raise NotImplementedError("View not implemented.") if not isinstance(shape, Iterable):
return tensor shape = [shape]
elif len(shape) == 1 and isinstance(shape[0], Iterable):
shape = shape[0]
if -1 in shape:
shape = list(shape)
d_inferred = -math.prod(ctx.shape) // math.prod(shape)
for i, d in enumerate(shape):
if d == -1:
shape[i] = d_inferred
break
if tensor._is_2D_scaled:
# For the case of 2D scaled tensor, the last 2 dimensions should not change
if shape[-1] != ctx.shape[-1] or shape[-2] != ctx.shape[-2]:
raise RuntimeError(
"2D scaled Float8BlockwiseQTensor does not support view "
"the last 2 dimensions "
f"(attempted to view dims={tuple(tensor.shape)} to {tuple(shape)})"
)
else:
# For the case of 1D scaled tensor, the last dimension should not change
if shape[-1] != ctx.shape[-1]:
raise RuntimeError(
"1D scaled Float8BlockwiseQTensor does not support view "
"the last dimension "
f"(attempted to view dims={tuple(tensor.shape)} to {tuple(shape)})"
)
if list(shape) == list(tensor.shape):
return tensor
# Construct new tensor if shape is provided
new_rowwise_data = None
new_columnwise_data = None
if tensor._rowwise_data is not None:
new_rowwise_data = tensor._rowwise_data.view(*shape)
if tensor._columnwise_data is not None:
columnwise_shape = [shape[-1]] + list(shape[:-1])
new_columnwise_data = tensor._columnwise_data.view(columnwise_shape)
return Float8BlockwiseQTensor(
shape=shape,
dtype=tensor.dtype,
fp8_dtype=tensor._fp8_dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=tensor._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=tensor._columnwise_scale_inv,
quantizer=tensor._quantizer,
is_2D_scaled=tensor._is_2D_scaled,
requires_grad=tensor.requires_grad,
)
@staticmethod @staticmethod
def backward( def backward(
...@@ -561,7 +565,27 @@ class _ViewFunc(torch.autograd.Function): ...@@ -561,7 +565,27 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor): if isinstance(grad, Float8BlockwiseQTensor):
raise NotImplementedError("View bwd not implemented") new_data = (
grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None
)
if grad._columnwise_data is not None:
columnwise_shape = [ctx.shape[-1]] + list(ctx.shape[:-1])
new_columnwise_data = grad._columnwise_data.view(columnwise_shape)
else:
new_columnwise_data = None
dgrad = Float8BlockwiseQTensor(
shape=ctx.shape,
dtype=grad.dtype,
rowwise_data=new_data,
rowwise_scale_inv=grad._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=grad._columnwise_scale_inv,
fp8_dtype=grad._fp8_dtype,
quantizer=grad._quantizer,
is_2D_scaled=grad._is_2D_scaled,
requires_grad=grad.requires_grad,
)
return dgrad, None
return grad.view(ctx.shape), None return grad.view(ctx.shape), None
...@@ -581,8 +605,7 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -581,8 +605,7 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided # Return input tensor if shape is not provided
if ctx is not None: ctx.shape = tensor.shape
ctx.shape = tensor.shape
if shape is None: if shape is None:
return tensor return tensor
...@@ -598,9 +621,47 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -598,9 +621,47 @@ class _ReshapeFunc(torch.autograd.Function):
if d == -1: if d == -1:
shape[i] = d_inferred shape[i] = d_inferred
break break
if list(shape) != list(tensor.shape):
raise NotImplementedError("Reshape not implemented yet.") if tensor._is_2D_scaled:
return tensor # For the case of 2D scaled tensor, the last 2 dimensions should not change
if shape[-1] != ctx.shape[-1] or shape[-2] != ctx.shape[-2]:
raise RuntimeError(
"2D scaled Float8BlockwiseQTensor does not support reshaping "
"the last 2 dimensions "
f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})"
)
else:
# For the case of 1D scaled tensor, the last dimension should not change
if shape[-1] != ctx.shape[-1]:
raise RuntimeError(
"1D scaled Float8BlockwiseQTensor does not support reshaping "
"the last dimension "
f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})"
)
if list(shape) == list(tensor.shape):
return tensor
# Construct new tensor if shape is provided
new_rowwise_data = None
new_columnwise_data = None
if tensor._rowwise_data is not None:
new_rowwise_data = tensor._rowwise_data.reshape(*shape)
if tensor._columnwise_data is not None:
columnwise_shape = [shape[-1]] + list(shape[:-1])
new_columnwise_data = tensor._columnwise_data.view(columnwise_shape)
return Float8BlockwiseQTensor(
shape=shape,
dtype=tensor.dtype,
fp8_dtype=tensor._fp8_dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=tensor._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=tensor._columnwise_scale_inv,
quantizer=tensor._quantizer,
is_2D_scaled=tensor._is_2D_scaled,
requires_grad=tensor.requires_grad,
)
@staticmethod @staticmethod
def backward( def backward(
...@@ -610,5 +671,24 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -610,5 +671,24 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor): if isinstance(grad, Float8BlockwiseQTensor):
raise NotImplementedError("Reshape bwd not implemented yet.") new_rowwise_data = None
new_columnwise_data = None
if grad._rowwise_data is not None:
new_rowwise_data = grad._rowwise_data.view(*ctx.shape)
if grad._columnwise_data is not None:
columnwise_shape = [ctx.shape[-1]] + list(ctx.shape[:-1])
new_columnwise_data = grad._columnwise_data.view(columnwise_shape)
dgrad = Float8BlockwiseQTensor(
shape=ctx.shape,
dtype=grad.dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=grad._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=grad._columnwise_scale_inv,
fp8_dtype=grad._fp8_dtype,
quantizer=grad._quantizer,
is_2D_scaled=grad._is_2D_scaled,
requires_grad=grad.requires_grad,
)
return dgrad, None
return grad.view(ctx.shape), None return grad.view(ctx.shape), None
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..utils import canonicalize_process_group, devices_match, non_tn_fp8_gemm_supported from ..utils import canonicalize_process_group, devices_match
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..constants import dist_group_type from ..constants import dist_group_type
...@@ -422,43 +422,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -422,43 +422,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
return Float8Tensor.make_like(self) return Float8Tensor.make_like(self)
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
# Figure out what data is available and what is required
has_data = self._data is not None
has_data_transpose = self._transpose is not None and not self._transpose_invalid
needs_data = has_data
needs_data_transpose = has_data_transpose
if non_tn_fp8_gemm_supported():
if rowwise_usage is not None and rowwise_usage:
needs_data = True
if columnwise_usage is not None and columnwise_usage:
needs_data = True
needs_data_transpose = False
else:
if rowwise_usage is not None:
needs_data = rowwise_usage
if columnwise_usage is not None:
needs_data_transpose = columnwise_usage
# Generate data that is required
if needs_data and not has_data:
raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose")
if needs_data_transpose and not has_data_transpose:
if not has_data:
raise RuntimeError("FP8 data is required to generate FP8 data transpose")
self._create_transpose()
# Delete data that is not required
if not needs_data:
self._data = None
if not needs_data_transpose:
self._transpose = None
self._transpose_invalid = True
def clone(self) -> Float8Tensor: def clone(self) -> Float8Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
assert self._data is not None assert self._data is not None
...@@ -516,12 +479,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -516,12 +479,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
del self._transpose # explicitly deletes the data for safety del self._transpose # explicitly deletes the data for safety
self._transpose = None self._transpose = None
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self._data = torch.Tensor() if self._data is not None else None
self._transpose = torch.Tensor() if self._transpose is not None else None
self._transpose_invalid = True
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
......
...@@ -217,51 +217,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -217,51 +217,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
# TODO(ksivamani): Fix the detach bug # TODO(ksivamani): Fix the detach bug
return MXFP8Tensor.make_like(self) return MXFP8Tensor.make_like(self)
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""
For MXFP8, columnwise scaled output is only produced by x2
scaling kernels, so this function only disables usages.
"""
# Default usage is based on available data
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
# Update row-scaled data
if rowwise_usage:
if self._rowwise_data is None:
raise RuntimeError(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data"
)
if self._rowwise_scale_inv is None:
raise RuntimeError(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses"
)
else:
self._rowwise_data = None
self._rowwise_scale_inv = None
# Update column-scaled data
if columnwise_usage:
if self._columnwise_data is None:
raise RuntimeError(
"Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data"
)
if self._columnwise_scale_inv is None:
raise RuntimeError(
"Requested column-wise usage, "
"but MXFP8Tensor is missing column-scaled scale-inverses"
)
else:
self._columnwise_data = None
self._columnwise_scale_inv = None
def clone(self) -> MXFP8Tensor: def clone(self) -> MXFP8Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
assert self._rowwise_data is not None assert self._rowwise_data is not None
...@@ -304,11 +259,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -304,11 +259,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
return self return self
raise ValueError("MXFP8Tensor does not support different memory formats!") raise ValueError("MXFP8Tensor does not support different memory formats!")
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
......
...@@ -15,9 +15,66 @@ from torch.utils._pytree import tree_map ...@@ -15,9 +15,66 @@ from torch.utils._pytree import tree_map
import transformer_engine_torch as tex import transformer_engine_torch as tex
class QuantizedTensorBase:
r"""Base class for all *TensorBase classes.
This class (and its subclasses) are optimization for when
the full QuantizedTensor is not needed (when it is fully
contained inside torch.autograd function and not visible to
PyTorch's autograd).
When creating a new tensor type X one should create both
XTensorBase class inheriting from QuantizedTensorBase and
XTensor inheriting from XTensorBase and QuantizedTensor.
XTensorBase should contain all data members needed to
implement the functionality of the tensor, while
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
r"""
Generate or remove quantized data based on provided usage.
Parameters
----------
rowwise_usage : Optional[bool[, default = `None`
Whether to create or keep the data needed for using the tensor
in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None`
preserves the original value in the tensor.
columnwise_usage : Optional[bool], default = `None`
Whether to create or keep the data needed for using the tensor
in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as
`None` preserves the original value in the tensor.
"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement update_usage function"
)
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]:
"""Prepare the tensor base for saving for backward"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement prepare_for_saving function"
)
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the tensor base data from the saved tensors list"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement restore_from_saved function"
)
def prepare_for_saving( def prepare_for_saving(
*tensors, *tensors: Union[torch.Tensor, QuantizedTensorBase],
) -> Tuple[list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], Optional[Any]]: ) -> Tuple[
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorBase]]
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only """Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal TensorBase types too.""" the internal TensorBase types too."""
...@@ -35,10 +92,13 @@ def prepare_for_saving( ...@@ -35,10 +92,13 @@ def prepare_for_saving(
def restore_from_saved( def restore_from_saved(
tensors: list[Optional[Any]], tensors: list[Optional[Union[torch.Tensor, QuantizedTensorBase]]],
saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
return_saved_tensors: bool = False, return_saved_tensors: bool = False,
) -> list[Optional[Any]] | tuple[list[Optional[Any]], list[Optional[torch.Tensor]]]: ) -> (
list[Optional[torch.Tensor | QuantizedTensorBase]]
| tuple[list[Optional[torch.Tensor | QuantizedTensorBase]], list[Optional[torch.Tensor]]]
):
"""Recombine the tensor data and metadata during backward pass.""" """Recombine the tensor data and metadata during backward pass."""
tensor_objects = [] tensor_objects = []
for tensor in tensors: for tensor in tensors:
...@@ -294,21 +354,6 @@ class QuantizedTensor(torch.Tensor): ...@@ -294,21 +354,6 @@ class QuantizedTensor(torch.Tensor):
f"{self.__class__.__name__} class does not implement detach function" f"{self.__class__.__name__} class does not implement detach function"
) )
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""Indicate to the tensor how it is going to be used
This enables optimizations to memory usage in some cases
where forward and backward passes use the tensor in
different directions.
"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement update_usage function"
)
def clear(self): def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully""" """Deallocate this tensor's memory. Typically not needed and must be used carefully"""
......
...@@ -6,13 +6,13 @@ ...@@ -6,13 +6,13 @@
import torch import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
from .quantized_tensor import QuantizedTensor from .quantized_tensor import QuantizedTensor
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
from ..optimizers.multi_tensor_apply import multi_tensor_applier from ..optimizers.multi_tensor_apply import multi_tensor_applier
...@@ -33,6 +33,12 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): ...@@ -33,6 +33,12 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
new_raw_data.detach().copy_(old_raw_data) new_raw_data.detach().copy_(old_raw_data)
tensor._data = new_raw_data tensor._data = new_raw_data
del old_raw_data del old_raw_data
elif isinstance(tensor, Float8BlockwiseQTensor):
old_raw_data = tensor._rowwise_data
assert old_raw_data.dtype == new_raw_data.dtype, "The data types of raw data don't match"
new_raw_data.detach().copy_(old_raw_data)
tensor._rowwise_data = new_raw_data
del old_raw_data
elif isinstance(tensor, MXFP8Tensor): elif isinstance(tensor, MXFP8Tensor):
raise NotImplementedError("replace_raw_data for MXFP8Tensor is not supported yet") raise NotImplementedError("replace_raw_data for MXFP8Tensor is not supported yet")
else: else:
...@@ -66,6 +72,7 @@ def cast_master_weights_to_fp8( ...@@ -66,6 +72,7 @@ def cast_master_weights_to_fp8(
delayed_scaling_params = [] delayed_scaling_params = []
current_scaling_params = [] current_scaling_params = []
blockwise_scaling_params = []
if fsdp_shard_model_weights is None: if fsdp_shard_model_weights is None:
use_fsdp_shard_model_weights = False use_fsdp_shard_model_weights = False
...@@ -107,6 +114,10 @@ def cast_master_weights_to_fp8( ...@@ -107,6 +114,10 @@ def cast_master_weights_to_fp8(
current_scaling_params.append( current_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight) (model_weight, master_weight, start_offset, fsdp_shard_model_weight)
) )
elif isinstance(quantizer, Float8BlockQuantizer):
blockwise_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, MXFP8Quantizer): elif isinstance(quantizer, MXFP8Quantizer):
raise NotImplementedError( raise NotImplementedError(
"cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet" "cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet"
...@@ -124,6 +135,10 @@ def cast_master_weights_to_fp8( ...@@ -124,6 +135,10 @@ def cast_master_weights_to_fp8(
_cast_master_weights_to_fp8_current_scaling( _cast_master_weights_to_fp8_current_scaling(
current_scaling_params, group, use_fsdp_shard_model_weights current_scaling_params, group, use_fsdp_shard_model_weights
) )
if len(blockwise_scaling_params) > 0:
_cast_master_weights_to_fp8_blockwise_scaling(
blockwise_scaling_params, group, use_fsdp_shard_model_weights
)
def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False): def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False):
...@@ -314,3 +329,125 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo ...@@ -314,3 +329,125 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
model_weight.dtype, model_weight.dtype,
) )
quantizer.update_quantized(master_weight, model_weight_fragment) quantizer.update_quantized(master_weight, model_weight_fragment)
def _cast_master_weights_to_fp8_blockwise_scaling(
params, group, use_fsdp_shard_model_weights=False
):
r"""Helper function to cast master weights to FP8 primary weights for blockwise scaling.
Parameters
----------
params : List of tuple, each tuple contains a model weight, a master weight, and an offset
indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel
group.
use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
"""
# Parameter attributes
device = params[0][0].device
block_len = params[0][0]._get_quantizer().block_len
fp8_dtype = params[0][0]._get_quantizer().dtype
force_pow_2_scales = params[0][0]._get_quantizer().force_pow_2_scales
amax_epsilon = params[0][0]._get_quantizer().amax_epsilon
# Create a dummy overflow buffer, it's needed by multi_tensor_applier.
dummy_overflow_buf = torch.zeros(1, dtype=torch.int, device=device)
# Get the total number of amax elements in all the model weights.
cu_amax_sizes = [0]
for model_weight, _, _, _ in params:
scale_shape = model_weight._get_quantizer().get_scale_shape(model_weight.shape, False)
num_amaxes = scale_shape[0] * scale_shape[1]
cu_amax_sizes.append(cu_amax_sizes[-1] + num_amaxes)
# Create a contiguous buffer to store amaxes temporarily, so we can perform all all-reduce
# NCCL kernels at once.
packed_amaxes = torch.zeros(cu_amax_sizes[-1], dtype=torch.float32, device=device)
# ---------------------------------------------------------------------------------------------
# Step 1: Iterate through all the none empty master weights and compute amax of them. Store the
# amaxes in a contiguous buffer. If a block of a master weight is empty, the
# corresponding amax will be set to 0.
# ---------------------------------------------------------------------------------------------
amaxes, scales, scale_invs = [], [], []
for i, (model_weight, master_weight, start_offset, _) in enumerate(params):
# Make sure all the model weights have the same numerical options.
quantizer = model_weight._get_quantizer()
assert block_len == quantizer.block_len
assert fp8_dtype == quantizer.dtype
assert force_pow_2_scales == quantizer.force_pow_2_scales
assert amax_epsilon == quantizer.amax_epsilon
scale_shape = quantizer.get_scale_shape(model_weight.shape, False)
amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape)
scale = torch.empty(scale_shape, dtype=torch.float32, device=device)
scale_inv = model_weight._rowwise_scale_inv
assert len(scale_shape) == 2
assert len(scale_inv.shape) == 2
assert scale_inv.shape[0] == scale_shape[0]
assert scale_inv.shape[1] == scale_shape[1]
amaxes.append(amax)
scales.append(scale)
scale_invs.append(scale_inv)
# Compute amax of the master weight and store it in packed_amaxes.
if master_weight is not None:
assert len(model_weight.shape) == 2
h, w = model_weight.shape
tex.fp8_block_scaling_compute_partial_amax(
master_weight, amax, h, w, start_offset, block_len
)
# ---------------------------------------------------------------------------------------------
# Step 2: Perform all-reduce on packed_amaxes to get the global amax.
# ---------------------------------------------------------------------------------------------
torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group)
# ---------------------------------------------------------------------------------------------
# Step 3: Update scales and scale_invs.
# ---------------------------------------------------------------------------------------------
if fp8_dtype == tex.DType.kFloat8E4M3:
max_fp8 = 448.0
elif fp8_dtype == tex.DType.kFloat8E5M2:
max_fp8 = 57344.0
else:
raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}")
multi_tensor_applier(
multi_tensor_compute_scale_and_scale_inv,
dummy_overflow_buf,
[amaxes, scales, scale_invs],
max_fp8,
force_pow_2_scales,
amax_epsilon,
)
# ---------------------------------------------------------------------------------------------
# Step 4: Cast master weights to FP8.
# ---------------------------------------------------------------------------------------------
for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
params, scales
):
# Clear columnwise data for all model weights.
# We cannot create columnwise data here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# at this moment.
model_weight.update_usage(rowwise_usage=True, columnwise_usage=False)
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
if master_weight is None:
continue
# Cast master weight to FP8
end_offset = start_offset + master_weight.numel()
if not use_fsdp_shard_model_weights:
model_weight_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset]
assert len(model_weight.shape) == 2
h, w = model_weight.shape
tex.fp8_block_scaling_partial_cast(
master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype
)
...@@ -10,13 +10,11 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -10,13 +10,11 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention import ( from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
MultiheadAttention, from transformer_engine.pytorch.attention.inference import InferenceParams
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import check_set_window_size
from transformer_engine.pytorch.jit import ( from transformer_engine.pytorch.jit import (
set_jit_fusion_options, set_jit_fusion_options,
warmup_jit_bias_dropout_add_all_dtypes, warmup_jit_bias_dropout_add_all_dtypes,
...@@ -27,6 +25,7 @@ from transformer_engine.pytorch.jit import ( ...@@ -27,6 +25,7 @@ from transformer_engine.pytorch.jit import (
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
cast_if_needed, cast_if_needed,
get_default_init_method, get_default_init_method,
torch_get_autocast_gpu_dtype,
) )
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
AttnMaskTypes, AttnMaskTypes,
...@@ -169,6 +168,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -169,6 +168,8 @@ class TransformerLayer(torch.nn.Module):
interpretation is that the individual `q`, `k`, and `v` weights for each interpretation is that the individual `q`, `k`, and `v` weights for each
attention head are interleaved. This parameter is set to `False` when attention head are interleaved. This parameter is set to `False` when
using :attr:`fuse_qkv_params=False`. using :attr:`fuse_qkv_params=False`.
rotary_pos_interleaved : bool, default = `False`
whether to use interleaved rotary position embeddings.
bias : bool, default = `True` bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases. if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu' activation : str, default = 'gelu'
...@@ -268,6 +269,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -268,6 +269,7 @@ class TransformerLayer(torch.nn.Module):
drop_path_rate: float = 0.0, drop_path_rate: float = 0.0,
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
rotary_pos_interleaved: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False, ub_tp_comm_overlap: bool = False,
...@@ -286,11 +288,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -286,11 +288,9 @@ class TransformerLayer(torch.nn.Module):
super().__init__() super().__init__()
self.self_attn_mask_type = self_attn_mask_type self.self_attn_mask_type = self_attn_mask_type
self.window_size = check_set_window_size(self_attn_mask_type, window_size) self.window_size = window_size
self.enc_dec_attn_mask_type = enc_dec_attn_mask_type self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
self.enc_dec_window_size = check_set_window_size( self.enc_dec_window_size = enc_dec_window_size
enc_dec_attn_mask_type, enc_dec_window_size
)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
...@@ -366,6 +366,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -366,6 +366,7 @@ class TransformerLayer(torch.nn.Module):
"fuse_qkv_params": fuse_qkv_params, "fuse_qkv_params": fuse_qkv_params,
"zero_centered_gamma": zero_centered_gamma, "zero_centered_gamma": zero_centered_gamma,
"qkv_weight_interleaved": qkv_weight_interleaved, "qkv_weight_interleaved": qkv_weight_interleaved,
"rotary_pos_interleaved": rotary_pos_interleaved,
"ub_bulk_wgrad": ub_bulk_wgrad, "ub_bulk_wgrad": ub_bulk_wgrad,
"ub_bulk_dgrad": ub_bulk_dgrad, "ub_bulk_dgrad": ub_bulk_dgrad,
"ub_overlap_ag": ub_overlap_ag, "ub_overlap_ag": ub_overlap_ag,
...@@ -440,9 +441,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -440,9 +441,7 @@ class TransformerLayer(torch.nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Set bias+dropout+add fusion grad_enable execution handler. # Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split(".")[0]) use_nvfuser = torch_version() >= (1, 10, 0) and torch_version() < (2, 2, 0)
TORCH_MINOR = int(torch.__version__.split(".")[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
if self.bias_dropout_fusion: if self.bias_dropout_fusion:
...@@ -657,12 +656,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -657,12 +656,10 @@ class TransformerLayer(torch.nn.Module):
self_attn_mask_type = self.self_attn_mask_type self_attn_mask_type = self.self_attn_mask_type
if window_size is None: if window_size is None:
window_size = self.window_size window_size = self.window_size
window_size = check_set_window_size(self_attn_mask_type, window_size)
if enc_dec_attn_mask_type is None: if enc_dec_attn_mask_type is None:
enc_dec_attn_mask_type = self.enc_dec_attn_mask_type enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
if enc_dec_window_size is None: if enc_dec_window_size is None:
enc_dec_window_size = self.enc_dec_window_size enc_dec_window_size = self.enc_dec_window_size
enc_dec_window_size = check_set_window_size(enc_dec_attn_mask_type, enc_dec_window_size)
assert ( assert (
self_attn_mask_type in AttnMaskTypes self_attn_mask_type in AttnMaskTypes
...@@ -694,7 +691,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -694,7 +691,7 @@ class TransformerLayer(torch.nn.Module):
# For AMP # For AMP
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype()) hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())
# Self attention. # Self attention.
self_attention_outputs = self.self_attention( self_attention_outputs = self.self_attention(
......
...@@ -95,6 +95,7 @@ def cross_entropy_kernel( ...@@ -95,6 +95,7 @@ def cross_entropy_kernel(
m_d_X_y_stride, m_d_X_y_stride,
rank, rank,
world_size, world_size,
ignore_idx,
n_cols, n_cols,
n_non_ignore, n_non_ignore,
label_smoothing: tl.constexpr, label_smoothing: tl.constexpr,
...@@ -114,6 +115,7 @@ def cross_entropy_kernel( ...@@ -114,6 +115,7 @@ def cross_entropy_kernel(
m_d_X_y_stride: The stride of m/d/X_y tensor. m_d_X_y_stride: The stride of m/d/X_y tensor.
rank (int): The rank of this device in the TP group. rank (int): The rank of this device in the TP group.
world_size (int): The size of world involved in this distributed loss calculation. world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
n_cols (int): The number of columns in the input tensor. n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch. n_non_ignore (int): The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
...@@ -129,6 +131,13 @@ def cross_entropy_kernel( ...@@ -129,6 +131,13 @@ def cross_entropy_kernel(
Y_ptr += program_id * Y_stride Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr) y = tl.load(Y_ptr)
if y == ignore_idx:
# set all X_ptr as 0
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
return
loss_ptr += program_id * loss_stride loss_ptr += program_id * loss_stride
m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride
...@@ -248,6 +257,7 @@ def cross_entropy_forward( ...@@ -248,6 +257,7 @@ def cross_entropy_forward(
label_smoothing: float, label_smoothing: float,
reduce_loss: bool, reduce_loss: bool,
dist_process_group: Union[dist.ProcessGroup, None], dist_process_group: Union[dist.ProcessGroup, None],
ignore_idx: int,
): ):
"""Forward implementation of Cross Entropy kernel""" """Forward implementation of Cross Entropy kernel"""
...@@ -306,6 +316,7 @@ def cross_entropy_forward( ...@@ -306,6 +316,7 @@ def cross_entropy_forward(
m_d_X_y_stride=m_d_X_y_gathered.stride(-1), m_d_X_y_stride=m_d_X_y_gathered.stride(-1),
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
ignore_idx=ignore_idx,
n_cols=V, n_cols=V,
n_non_ignore=n_rows, n_non_ignore=n_rows,
label_smoothing=label_smoothing, label_smoothing=label_smoothing,
......
...@@ -7,13 +7,13 @@ from __future__ import annotations ...@@ -7,13 +7,13 @@ from __future__ import annotations
import functools import functools
import math import math
import os import os
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch import torch
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from .tensor.quantized_tensor import QuantizedTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
...@@ -24,6 +24,12 @@ def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: ...@@ -24,6 +24,12 @@ def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
return False return False
@functools.lru_cache(maxsize=None)
def _empty_tensor() -> torch.Tensor:
"""Get tensor with no entries and no data"""
return torch.Tensor().cuda()
def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
""" """
Trick to deallocate tensor memory when delete operation does not Trick to deallocate tensor memory when delete operation does not
...@@ -33,17 +39,22 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: ...@@ -33,17 +39,22 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
""" """
for t in tensors: for t in tensors:
if t is not None: if t is not None:
if isinstance(t, QuantizedTensor): if hasattr(t, "clear"):
t.clear() t.clear()
else: else:
t.data = torch.Tensor() t.data = _empty_tensor()
del t del t
@functools.lru_cache
def _get_device_compute_capability(device: torch.device) -> Tuple[int, int]:
props = torch.cuda.get_device_properties(device)
return (props.major, props.minor)
def get_device_compute_capability() -> Tuple[int, int]: def get_device_compute_capability() -> Tuple[int, int]:
"""CUDA compute capability of current GPU""" """CUDA compute capability of current GPU"""
props = torch.cuda.get_device_properties(torch.cuda.current_device()) return _get_device_compute_capability(torch.cuda.current_device())
return (props.major, props.minor)
def attention_mask_func( def attention_mask_func(
...@@ -155,6 +166,184 @@ def split_tensor_along_dim( ...@@ -155,6 +166,184 @@ def split_tensor_along_dim(
return tensor_list return tensor_list
# @klakhani TODO: Consider combining with split_tensor_along_dim() and no_op_cat() and SplitAlongDim
def combine_tensors(
tensors: List[torch.Tensor],
dim: int,
) -> torch.Tensor:
"""Combine tensors along a particular dimension"""
num_tensors = len(tensors)
new_shape = list(tensors[0].shape)
new_shape.insert(dim, num_tensors)
from transformer_engine.pytorch.float8_tensor import Float8Tensor
if isinstance(tensors[0], Float8Tensor):
new_stride = list(tensors[0]._data.stride())
new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
combined_tensor.set_(
tensors[0]._data.untyped_storage(),
tensors[0]._data.storage_offset(),
new_shape,
new_stride,
)
combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor, shape=new_shape)
else:
new_stride = list(tensors[0].stride())
new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
combined_tensor.set_(
tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
)
return combined_tensor
class SplitAlongDim(torch.autograd.Function):
"""
Split tensor along given dimension
"""
@staticmethod
def forward(
ctx,
mixed_x_layer: torch.Tensor,
split_dim: int,
split_size_or_sections: Union[int, List[int], Tuple[int]],
squeeze=False,
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance(
mixed_x_layer, Float8Tensor
):
return tuple(
Float8TensorBase(
fp8_scale_inv=mixed_x_layer._scale_inv,
fp8_dtype=mixed_x_layer._fp8_dtype,
data=x.squeeze(split_dim) if squeeze else x,
shape=x.squeeze(split_dim).shape if squeeze else x.shape,
quantizer=mixed_x_layer._quantizer,
)
for x in torch.split(
mixed_x_layer._data,
split_size_or_sections=split_size_or_sections,
dim=split_dim,
)
)
if isinstance(mixed_x_layer, Float8Tensor):
return tuple(
Float8Tensor.make_like(
mixed_x_layer,
data=x.squeeze(split_dim) if squeeze else x,
shape=x.squeeze(split_dim).shape if squeeze else x.shape,
)
for x in torch.split(
mixed_x_layer._data,
split_size_or_sections=split_size_or_sections,
dim=split_dim,
)
)
out_list = torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
if squeeze:
out_list = [x.squeeze(split_dim) for x in out_list]
return out_list
@staticmethod
def backward(ctx, *grad_outputs):
# pylint: disable=missing-function-docstring
assert len(grad_outputs) > 0, "No gradients received for backprop!"
if isinstance(ctx.split_size_or_sections, (list, tuple)):
split_sizes = ctx.split_size_or_sections
assert len(grad_outputs) == len(
split_sizes
), "Unequal number of gradients vs split sections for backprop!"
if isinstance(ctx.split_size_or_sections, int):
split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
dims = len(grad_outputs[0].shape)
split_dim = (ctx.split_dim + dims) % dims
from transformer_engine.pytorch.float8_tensor import Float8Tensor
if isinstance(grad_outputs[0], Float8Tensor):
noop_ok = True
strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
shape = list(grad_outputs[0].shape)
for i, tensor in enumerate(grad_outputs):
shape_i = shape
shape_i[split_dim] = split_sizes[i]
offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
if (
tensor.stride() != strides
or list(tensor.shape) != shape_i
or tensor._data.untyped_storage().data_ptr() != data_ptr
or tensor.storage_offset() != offset_size
):
noop_ok = False
break
if noop_ok:
ret = torch.Tensor().to(
device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
)
new_shape = list(shape)
new_shape[split_dim] = sum(split_sizes)
ret.set_(
grad_outputs[0]._data.untyped_storage(),
grad_outputs[0]._data.storage_offset(),
new_shape,
strides,
)
return (
Float8Tensor.make_like(grad_outputs[0], data=ret, shape=ret.shape),
None,
None,
)
grad_outputs_data = [x._data for x in grad_outputs]
data = torch.cat(grad_outputs_data, dim=split_dim)
return (
Float8Tensor.make_like(grad_outputs[0], data=data, shape=data.shape),
None,
None,
None,
)
noop_ok = True
strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0].untyped_storage().data_ptr()
shape = list(grad_outputs[0].shape)
for i, tensor in enumerate(grad_outputs):
shape_i = shape
shape_i[split_dim] = split_sizes[i]
offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
if (
tensor.stride() != strides
or list(tensor.shape) != shape_i
or tensor.untyped_storage().data_ptr() != data_ptr
or tensor.storage_offset() != offset_size
):
noop_ok = False
break
if noop_ok:
ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
new_shape = list(shape)
new_shape[split_dim] = sum(split_sizes)
ret.set_(
grad_outputs[0].untyped_storage(),
grad_outputs[0].storage_offset(),
new_shape,
strides,
)
return ret, None, None
return torch.cat(grad_outputs, dim=split_dim), None, None
def validate_ctx_manager(ctx: Callable) -> None: def validate_ctx_manager(ctx: Callable) -> None:
"""Checks if passed in object can be used as a context manager.""" """Checks if passed in object can be used as a context manager."""
try: try:
...@@ -237,10 +426,10 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: ...@@ -237,10 +426,10 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
"""Assert that tensor or tensors dimensions are supported for FP8 TN GEMM.""" """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM."""
for tensor in tensors: for tensor in tensors:
assert tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0, ( assert math.prod(tensor.shape[:-1]) % 8 == 0 and tensor.shape[-1] % 16 == 0, (
"FP8 execution requires 2D input matrices with " "FP8 execution requires the product of all dimensions except the last to be divisible"
"height divisible by 8 and width divisible by 16, " " by 8 and the last dimension to be divisible by 16, but got tensor with"
f"but got tensor with dims={list(tensor.size())}" f" dims={list(tensor.size())}"
) )
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
...@@ -273,11 +462,12 @@ def is_bf16_compatible() -> None: ...@@ -273,11 +462,12 @@ def is_bf16_compatible() -> None:
return torch.cuda.get_device_capability()[0] >= 8 return torch.cuda.get_device_capability()[0] >= 8
def non_tn_fp8_gemm_supported() -> bool: def is_non_tn_fp8_gemm_supported() -> bool:
"""Checks whether the device supports """Checks whether the device supports
non-TN layouts for FP8 GEMMs. non-TN layouts for FP8 GEMMs.
""" """
return torch.cuda.get_device_capability() >= (10, 0) device_capability = torch.cuda.get_device_capability()
return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
...@@ -438,3 +628,16 @@ def canonicalize_process_group( ...@@ -438,3 +628,16 @@ def canonicalize_process_group(
if group is None: if group is None:
return torch.distributed.distributed_c10d._get_default_group() return torch.distributed.distributed_c10d._get_default_group()
return group return group
def torch_get_autocast_gpu_dtype() -> torch.dtype:
"""Get PyTorch autocast GPU dtype."""
if torch_version() >= (2, 4, 0):
return torch.get_autocast_dtype("cuda")
return torch.get_autocast_gpu_dtype()
if torch_version() >= (2, 4, 0):
gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda")
else:
gpu_autocast_ctx = torch.cuda.amp.autocast
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment