Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
......@@ -4,13 +4,14 @@
"""Tensor class with FP8 data"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable
from typing import Optional, Tuple, Iterable, Union
import warnings
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from ..utils import canonicalize_process_group, devices_match
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
......@@ -166,6 +167,9 @@ class Float8Quantizer(Quantizer):
quantizer=self,
)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return DelayedScaling
class Float8CurrentScalingQuantizer(Quantizer):
"""Builder class for FP8 tensors with per-tensor current scaling
......@@ -328,6 +332,9 @@ class Float8CurrentScalingQuantizer(Quantizer):
"""Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8CurrentScaling
class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data
......
......@@ -6,12 +6,13 @@
from __future__ import annotations
from collections.abc import Iterable
import math
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe
from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple
......@@ -135,6 +136,9 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8?
pass
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return MXFP8BlockScaling
class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data
......@@ -380,6 +384,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
# Quantize to FP8
assert self._quantizer is not None, "Can't quantize without a quantizer"
self._quantizer.internal = False
self.data = self._quantizer.quantize(tensor)
if self.requires_grad != tensor.requires_grad:
self.requires_grad_(requires_grad=tensor.requires_grad)
......
......@@ -8,11 +8,13 @@ from __future__ import annotations
from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc
import copy
import warnings
import torch
from torch.utils._pytree import tree_map
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
class QuantizedTensorBase:
......@@ -31,6 +33,8 @@ class QuantizedTensorBase:
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
_quantizer: Optional[Quantizer]
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
......@@ -69,6 +73,14 @@ class QuantizedTensorBase:
f"{self.__class__.__name__} class does not implement restore_from_saved function"
)
def update_quantizer(self, quantizer: Quantizer):
"""Update quantizer for the tensor"""
if self._quantizer is None:
raise RuntimeError("To be updated, quantizer must be set")
if self._quantizer is not quantizer:
warnings.warn("Quantizer is being updated, this may affect model behavior")
self._quantizer = quantizer
def prepare_for_saving(
*tensors: Union[torch.Tensor, QuantizedTensorBase],
......@@ -238,6 +250,10 @@ class Quantizer(abc.ABC):
"""Create shallow copy"""
return copy.copy(self)
@abc.abstractmethod
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer"""
class _QuantizeFunc(torch.autograd.Function):
"""Cast to FP8 from other dtype"""
......
......@@ -179,12 +179,12 @@ class TransformerLayer(torch.nn.Module):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd'
attn_input_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
This controls whether the dimensions of the
intermediate hidden states is 'batch first' ('bshd') or
'sequence first' ('sbhd'). `s` stands for the sequence
length, `b` batch size, `h` the number of heads, `d`
head size. Note that these formats are very closely
intermediate hidden states is 'sequence first' ('sbhd'), 'batch first' ('bshd'),
or 'token first' ('thd'). `s` stands for the sequence length, `b` batch size,
`t` the total number of tokens, `h` the number of heads, `d` head size.
Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules.
name: str, default = `None`
......@@ -235,6 +235,14 @@ class TransformerLayer(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False'
if set to `True`, L2 normalization is applied to query and key tensors
after RoPE (if applicable) but before attention computation.
This follows the Llama4 approach for QK normalization to improve
training stability and model performance.
qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors.
Only used when `use_qk_norm` is True.
"""
def __init__(
......@@ -284,6 +292,8 @@ class TransformerLayer(torch.nn.Module):
device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd",
name: str = None,
use_qk_norm: bool = False,
qk_norm_eps: float = 1e-6,
) -> None:
super().__init__()
......@@ -373,6 +383,8 @@ class TransformerLayer(torch.nn.Module):
"ub_overlap_rs": ub_overlap_rs,
"ub_overlap_rs_dgrad": ub_overlap_rs_dgrad,
"qkv_format": self.attn_input_format,
"seq_length": seq_length,
"micro_batch_size": micro_batch_size,
}
self.self_attention = MultiheadAttention(
......@@ -384,6 +396,8 @@ class TransformerLayer(torch.nn.Module):
return_bias=not self.parallel_attention_mlp,
normalization=normalization,
device=device,
use_qk_norm=use_qk_norm,
qk_norm_eps=qk_norm_eps,
name=name + ".self_attention" if name is not None else None,
)
......@@ -398,6 +412,8 @@ class TransformerLayer(torch.nn.Module):
return_bias=True,
normalization=normalization,
device=device,
use_qk_norm=use_qk_norm,
qk_norm_eps=qk_norm_eps,
name=name + ".inter_attention" if name is not None else None,
)
......@@ -552,6 +568,8 @@ class TransformerLayer(torch.nn.Module):
alibi_slopes: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True,
......@@ -568,88 +586,99 @@ class TransformerLayer(torch.nn.Module):
Parameters
----------
hidden_states : torch.Tensor
Input tensor.
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input. It should be
in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
mask. It should be `None` for causal masks and "`no_mask`" type.
A `True` value means the corresponding position is masked out and
a `False` means that position is allowed to participate in attention.
Boolean tensor used to mask out self-attention softmax input. It should be
in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
mask. It should be `None` for causal masks and "`no_mask`" type.
A `True` value means the corresponding position is masked out and
a `False` means that position is allowed to participate in attention.
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
default = `causal`
Type of attention mask passed into softmax operation for encoder.
By default, causal masks are aligned to the top left corner of
the softmax matrix. When "`bottom_right`" is specified in the mask type,
causal masks are aligned to the bottom right corner.
'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
default = `causal`
Type of attention mask passed into softmax operation for encoder.
By default, causal masks are aligned to the top left corner of
the softmax matrix. When "`bottom_right`" is specified in the mask type,
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None`
Sliding window size for local attention in encoder.
Sliding window size for local attention in encoder.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensors used to mask out inter-attention softmax input if
using `layer_type="decoder"`. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
A `True` value means the corresponding position is masked out and a `False`
means that position is allowed to participate in attention.
default = `None`. Boolean tensors used to mask out inter-attention softmax input if
using `layer_type="decoder"`. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
A `True` value means the corresponding position is masked out and a `False`
means that position is allowed to participate in attention.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `None`
Type of attention mask passed into softmax operation for decoder.
default = `None`
Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
Sliding window size for local attention in decoder.
Sliding window size for local attention in decoder.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
checkpoint_core_attention: bool, default = `False`
If true, forward activations for core attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
If true, forward activations for core attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
Bias tensor for Q * K.T
alibi_slopes: Optional[torch.Tensor], default = `None`
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Used by decoders' cross-attention.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. Set to `cu_seqlens_q` if None.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Set to `cu_seqlens_kv` if None. Used by decoders' cross-attention.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q_padded` if not provided.
max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided.
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv_padded` if not provided.
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None
Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference.
Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference.
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
If true, there are padding tokens between individual sequences in a packed batch,
i.e. qkv_format = 'thd'.
"""
if self_attn_mask_type is None:
......@@ -678,7 +707,9 @@ class TransformerLayer(torch.nn.Module):
if (
"padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
) and attention_mask is not None:
assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor"
assert all(
attention_mask[i].dtype == torch.bool for i in range(len(attention_mask))
), "Attention mask must be a boolean tensor or a list/tuple of two boolean tensors"
if (
"padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
) and enc_dec_attn_mask is not None:
......@@ -707,9 +738,11 @@ class TransformerLayer(torch.nn.Module):
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_kv=cu_seqlens_q,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_q_padded,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
max_seqlen_kv=max_seqlen_q,
fast_zero_fill=fast_zero_fill,
pad_between_seqs=pad_between_seqs,
)
......@@ -733,12 +766,21 @@ class TransformerLayer(torch.nn.Module):
attn_mask_type=enc_dec_attn_mask_type,
window_size=enc_dec_window_size,
encoder_output=encoder_output,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
fast_zero_fill=fast_zero_fill,
pad_between_seqs=pad_between_seqs,
)
if self.apply_residual_connection_post_layernorm:
attention_output, attention_bias, residual = inter_attention_outputs
......
......@@ -37,8 +37,16 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
Must be used carefully.
"""
for t in tensors:
if t is not None:
# Workaround for double buffering in cpu offload
if hasattr(t, "do_not_clear"):
continue
if hasattr(t, "get_data_tensors"):
if any(hasattr(tensor, "do_not_clear") for tensor in t.get_data_tensors()):
continue
if hasattr(t, "clear"):
t.clear()
else:
......@@ -462,6 +470,7 @@ def is_bf16_compatible() -> None:
return torch.cuda.get_device_capability()[0] >= 8
@functools.lru_cache(maxsize=None)
def is_non_tn_fp8_gemm_supported() -> bool:
"""Checks whether the device supports
non-TN layouts for FP8 GEMMs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment