Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
...@@ -4,13 +4,14 @@ ...@@ -4,13 +4,14 @@
"""Tensor class with FP8 data""" """Tensor class with FP8 data"""
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Iterable from typing import Optional, Tuple, Iterable, Union
import warnings import warnings
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from ..utils import canonicalize_process_group, devices_match from ..utils import canonicalize_process_group, devices_match
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
...@@ -166,6 +167,9 @@ class Float8Quantizer(Quantizer): ...@@ -166,6 +167,9 @@ class Float8Quantizer(Quantizer):
quantizer=self, quantizer=self,
) )
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return DelayedScaling
class Float8CurrentScalingQuantizer(Quantizer): class Float8CurrentScalingQuantizer(Quantizer):
"""Builder class for FP8 tensors with per-tensor current scaling """Builder class for FP8 tensors with per-tensor current scaling
...@@ -328,6 +332,9 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -328,6 +332,9 @@ class Float8CurrentScalingQuantizer(Quantizer):
"""Get process group for amax reduction""" """Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group) return canonicalize_process_group(self.amax_reduction_group)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8CurrentScaling
class Float8Tensor(Float8TensorBase, QuantizedTensor): class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
......
...@@ -6,12 +6,13 @@ ...@@ -6,12 +6,13 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe
from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple from ..utils import devices_match, round_up_to_nearest_multiple
...@@ -135,6 +136,9 @@ class MXFP8Quantizer(Quantizer): ...@@ -135,6 +136,9 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8? # TODO(ksivamani): No calibration needed for mxfp8?
pass pass
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return MXFP8BlockScaling
class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
...@@ -380,6 +384,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -380,6 +384,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
# Quantize to FP8 # Quantize to FP8
assert self._quantizer is not None, "Can't quantize without a quantizer" assert self._quantizer is not None, "Can't quantize without a quantizer"
self._quantizer.internal = False
self.data = self._quantizer.quantize(tensor) self.data = self._quantizer.quantize(tensor)
if self.requires_grad != tensor.requires_grad: if self.requires_grad != tensor.requires_grad:
self.requires_grad_(requires_grad=tensor.requires_grad) self.requires_grad_(requires_grad=tensor.requires_grad)
......
...@@ -8,11 +8,13 @@ from __future__ import annotations ...@@ -8,11 +8,13 @@ from __future__ import annotations
from typing import Optional, Tuple, Iterable, Any, Dict, Union from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc import abc
import copy import copy
import warnings
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
class QuantizedTensorBase: class QuantizedTensorBase:
...@@ -31,6 +33,8 @@ class QuantizedTensorBase: ...@@ -31,6 +33,8 @@ class QuantizedTensorBase:
XTensor should only implement the functionality needed XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__).""" to behave like regular torch.Tensor (liek __torch_dispatch__)."""
_quantizer: Optional[Quantizer]
def update_usage( def update_usage(
self, self,
rowwise_usage: Optional[bool] = None, rowwise_usage: Optional[bool] = None,
...@@ -69,6 +73,14 @@ class QuantizedTensorBase: ...@@ -69,6 +73,14 @@ class QuantizedTensorBase:
f"{self.__class__.__name__} class does not implement restore_from_saved function" f"{self.__class__.__name__} class does not implement restore_from_saved function"
) )
def update_quantizer(self, quantizer: Quantizer):
"""Update quantizer for the tensor"""
if self._quantizer is None:
raise RuntimeError("To be updated, quantizer must be set")
if self._quantizer is not quantizer:
warnings.warn("Quantizer is being updated, this may affect model behavior")
self._quantizer = quantizer
def prepare_for_saving( def prepare_for_saving(
*tensors: Union[torch.Tensor, QuantizedTensorBase], *tensors: Union[torch.Tensor, QuantizedTensorBase],
...@@ -238,6 +250,10 @@ class Quantizer(abc.ABC): ...@@ -238,6 +250,10 @@ class Quantizer(abc.ABC):
"""Create shallow copy""" """Create shallow copy"""
return copy.copy(self) return copy.copy(self)
@abc.abstractmethod
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer"""
class _QuantizeFunc(torch.autograd.Function): class _QuantizeFunc(torch.autograd.Function):
"""Cast to FP8 from other dtype""" """Cast to FP8 from other dtype"""
......
...@@ -179,12 +179,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -179,12 +179,12 @@ class TransformerLayer(torch.nn.Module):
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd' attn_input_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
This controls whether the dimensions of the This controls whether the dimensions of the
intermediate hidden states is 'batch first' ('bshd') or intermediate hidden states is 'sequence first' ('sbhd'), 'batch first' ('bshd'),
'sequence first' ('sbhd'). `s` stands for the sequence or 'token first' ('thd'). `s` stands for the sequence length, `b` batch size,
length, `b` batch size, `h` the number of heads, `d` `t` the total number of tokens, `h` the number of heads, `d` head size.
head size. Note that these formats are very closely Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention` related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules. and `DotProductAttention` modules.
name: str, default = `None` name: str, default = `None`
...@@ -235,6 +235,14 @@ class TransformerLayer(torch.nn.Module): ...@@ -235,6 +235,14 @@ class TransformerLayer(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`. `fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False'
if set to `True`, L2 normalization is applied to query and key tensors
after RoPE (if applicable) but before attention computation.
This follows the Llama4 approach for QK normalization to improve
training stability and model performance.
qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors.
Only used when `use_qk_norm` is True.
""" """
def __init__( def __init__(
...@@ -284,6 +292,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -284,6 +292,8 @@ class TransformerLayer(torch.nn.Module):
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd", attn_input_format: str = "sbhd",
name: str = None, name: str = None,
use_qk_norm: bool = False,
qk_norm_eps: float = 1e-6,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -373,6 +383,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -373,6 +383,8 @@ class TransformerLayer(torch.nn.Module):
"ub_overlap_rs": ub_overlap_rs, "ub_overlap_rs": ub_overlap_rs,
"ub_overlap_rs_dgrad": ub_overlap_rs_dgrad, "ub_overlap_rs_dgrad": ub_overlap_rs_dgrad,
"qkv_format": self.attn_input_format, "qkv_format": self.attn_input_format,
"seq_length": seq_length,
"micro_batch_size": micro_batch_size,
} }
self.self_attention = MultiheadAttention( self.self_attention = MultiheadAttention(
...@@ -384,6 +396,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -384,6 +396,8 @@ class TransformerLayer(torch.nn.Module):
return_bias=not self.parallel_attention_mlp, return_bias=not self.parallel_attention_mlp,
normalization=normalization, normalization=normalization,
device=device, device=device,
use_qk_norm=use_qk_norm,
qk_norm_eps=qk_norm_eps,
name=name + ".self_attention" if name is not None else None, name=name + ".self_attention" if name is not None else None,
) )
...@@ -398,6 +412,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -398,6 +412,8 @@ class TransformerLayer(torch.nn.Module):
return_bias=True, return_bias=True,
normalization=normalization, normalization=normalization,
device=device, device=device,
use_qk_norm=use_qk_norm,
qk_norm_eps=qk_norm_eps,
name=name + ".inter_attention" if name is not None else None, name=name + ".inter_attention" if name is not None else None,
) )
...@@ -552,6 +568,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -552,6 +568,8 @@ class TransformerLayer(torch.nn.Module):
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None, max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
...@@ -633,15 +651,25 @@ class TransformerLayer(torch.nn.Module): ...@@ -633,15 +651,25 @@ class TransformerLayer(torch.nn.Module):
cu_seqlens_q: Optional[torch.Tensor], default = `None` cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. with shape [batch_size + 1] and dtype torch.int32.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Used by decoders' cross-attention.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. Set to `cu_seqlens_q` if None.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Set to `cu_seqlens_kv` if None. Used by decoders' cross-attention.
max_seqlen_q: Optional[int], default = `None` max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`. Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided. Calculated from `cu_seqlens_q_padded` if not provided.
max_seqlen_kv: Optional[int], default = `None` max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`. Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided. Calculated from `cu_seqlens_kv_padded` if not provided.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None inference_params: InferenceParams, default = None
...@@ -649,7 +677,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -649,7 +677,8 @@ class TransformerLayer(torch.nn.Module):
to efficiently calculate and store the context during inference. to efficiently calculate and store the context during inference.
pad_between_seqs: Optional[bool], default = `None` pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch. If true, there are padding tokens between individual sequences in a packed batch,
i.e. qkv_format = 'thd'.
""" """
if self_attn_mask_type is None: if self_attn_mask_type is None:
...@@ -678,7 +707,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -678,7 +707,9 @@ class TransformerLayer(torch.nn.Module):
if ( if (
"padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary" "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
) and attention_mask is not None: ) and attention_mask is not None:
assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor" assert all(
attention_mask[i].dtype == torch.bool for i in range(len(attention_mask))
), "Attention mask must be a boolean tensor or a list/tuple of two boolean tensors"
if ( if (
"padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary" "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
) and enc_dec_attn_mask is not None: ) and enc_dec_attn_mask is not None:
...@@ -707,9 +738,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -707,9 +738,11 @@ class TransformerLayer(torch.nn.Module):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_q,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_q_padded,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_q,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
) )
...@@ -733,12 +766,21 @@ class TransformerLayer(torch.nn.Module): ...@@ -733,12 +766,21 @@ class TransformerLayer(torch.nn.Module):
attn_mask_type=enc_dec_attn_mask_type, attn_mask_type=enc_dec_attn_mask_type,
window_size=enc_dec_window_size, window_size=enc_dec_window_size,
encoder_output=encoder_output, encoder_output=encoder_output,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
pad_between_seqs=pad_between_seqs,
) )
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
attention_output, attention_bias, residual = inter_attention_outputs attention_output, attention_bias, residual = inter_attention_outputs
......
...@@ -37,8 +37,16 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: ...@@ -37,8 +37,16 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
Must be used carefully. Must be used carefully.
""" """
for t in tensors: for t in tensors:
if t is not None: if t is not None:
# Workaround for double buffering in cpu offload
if hasattr(t, "do_not_clear"):
continue
if hasattr(t, "get_data_tensors"):
if any(hasattr(tensor, "do_not_clear") for tensor in t.get_data_tensors()):
continue
if hasattr(t, "clear"): if hasattr(t, "clear"):
t.clear() t.clear()
else: else:
...@@ -462,6 +470,7 @@ def is_bf16_compatible() -> None: ...@@ -462,6 +470,7 @@ def is_bf16_compatible() -> None:
return torch.cuda.get_device_capability()[0] >= 8 return torch.cuda.get_device_capability()[0] >= 8
@functools.lru_cache(maxsize=None)
def is_non_tn_fp8_gemm_supported() -> bool: def is_non_tn_fp8_gemm_supported() -> bool:
"""Checks whether the device supports """Checks whether the device supports
non-TN layouts for FP8 GEMMs. non-TN layouts for FP8 GEMMs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment