Commit 970620a5 authored by wenjh's avatar wenjh
Browse files

merge nv_release_v2.10 to release_v2.10


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents c1a1c04e 769ed778
......@@ -127,12 +127,18 @@ def _run_layer_with_overlap(
os.environ["PYTORCH_JIT"] = "0"
os.environ["NVTE_TORCH_COMPILE"] = "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
if te.get_device_compute_capability() <= (8, 0):
# We've experienced numerical discrepancies in Flash Attention
# backward when running with Userbuffers on A100s. This does
# not show up in more recent GPUs.
os.environ["NVTE_FLASH_ATTN"] = "0"
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
os.unsetenv("PYTORCH_JIT")
os.unsetenv("NVTE_TORCH_COMPILE")
os.unsetenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO")
os.unsetenv("NVTE_FLASH_ATTN")
if (
result.returncode != 0
......
......@@ -7,7 +7,7 @@ import sys
import pytest
import torch
import transformer_engine
from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear
from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear, GroupedLinear
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
......@@ -19,7 +19,9 @@ model_configs = {
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"])
@pytest.mark.parametrize(
"module", ["TransformerLayer", "DotProductAttention", "Linear", "GroupedLinear"]
)
def test_current_device(model, module):
"""Test cases where current device is different from tensor device"""
......@@ -42,7 +44,29 @@ def test_current_device(model, module):
self_attn_mask_type="padding",
device=f"cuda:{tensor_device}",
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
seqlens_q = torch.randint(
1,
config.max_seqlen_q,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_q = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
seqlens_kv = torch.randint(
1,
config.max_seqlen_kv,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_kv = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
num_tokens = cu_seqlens_q[-1]
args = [
torch.randn(
(num_tokens, config.hidden_size),
......@@ -51,37 +75,55 @@ def test_current_device(model, module):
requires_grad=True,
)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
if module == "DotProductAttention":
elif module == "DotProductAttention":
model = DotProductAttention(
config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
seqlens_q = torch.randint(
1,
config.max_seqlen_q,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_q = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
seqlens_kv = torch.randint(
1,
config.max_seqlen_kv,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_kv = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
num_tokens = cu_seqlens_q[-1]
args = [
torch.randn(
num_tokens,
config.num_heads,
config.head_dim_qk,
dtype=dtype,
device=tensor_device,
device=f"cuda:{tensor_device}",
requires_grad=True,
)
for _ in range(3)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)]
bwd_args = [
torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=f"cuda:{tensor_device}")
]
elif module == "Linear":
model = Linear(
config.hidden_size,
......@@ -97,6 +139,24 @@ def test_current_device(model, module):
requires_grad=True,
)
]
elif module == "GroupedLinear":
num_gemms = 4
model = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
params_dtype=dtype,
device=f"cuda:{tensor_device}",
)
args = [
torch.randn(
(config.max_seqlen_q * config.batch_size * (num_gemms - 1), config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
),
[0] + [config.max_seqlen_q * config.batch_size] * (num_gemms - 1), # Empty first split.
]
current_device_before = torch.cuda.current_device()
out = model(*args, **kwargs)
......
......@@ -913,15 +913,15 @@ class TestBasicOps:
dtype=dtype,
accumulate_into_main_grad=accumulate_into_main_grad,
)
forward = te_ops.Sequential(
te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32)
forward = te_ops.Sequential(
te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
)
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
......
......@@ -46,7 +46,6 @@ from transformer_engine.pytorch import (
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states
......@@ -2757,7 +2756,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
general_gemm(
A[i],
B[i],
get_workspace(),
dtype,
grad=grad,
accumulate=accumulate,
......@@ -2772,7 +2770,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
B,
out,
dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
......@@ -2832,7 +2829,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
quantized_out, *_ = general_gemm(
weight_fp8,
inp_fp8,
get_workspace(),
outp_type,
quantization_params=out_quantizer,
bias=None,
......@@ -2842,7 +2838,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
out, *_ = general_gemm(
weight_fp8,
inp_fp8,
get_workspace(),
outp_type,
quantization_params=None,
bias=None,
......@@ -2918,7 +2913,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
general_gemm(
A_fp8[i],
B_fp8[i],
get_workspace(),
dtype,
out=out_ref[i],
accumulate=accumulate,
......@@ -2928,7 +2922,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
B_fp8,
out,
dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits,
accumulate=accumulate,
)
......
......@@ -37,7 +37,6 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from utils import ModelConfig
......@@ -961,7 +960,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
inp = torch.reshape(scratchpad[offset:-offset], (N, N))
weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
_ = general_gemm(A=weight, B=inp, workspace=get_workspace())
_ = general_gemm(A=weight, B=inp)
torch.cuda.synchronize()
......@@ -985,7 +984,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
general_gemm(
weight_fp8,
inp_fp8,
get_workspace(),
outp_type,
bias=None,
use_split_accumulator=False,
......
......@@ -8,6 +8,7 @@ import logging
import os
from contextlib import contextmanager
from typing import Optional, Tuple, Dict, Any, List
from packaging.version import Version as PkgVersion
import torch
......@@ -210,6 +211,7 @@ class ModelConfig:
max_ctx_len: int = None,
num_layers: int = 1,
eps: float = 1e-5,
num_splits=1,
):
self.batch_size = batch_size
self.max_seqlen_q = max_seqlen_q
......@@ -239,6 +241,7 @@ class ModelConfig:
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
self.eps = eps
self.num_splits = num_splits
@contextmanager
......@@ -321,6 +324,9 @@ def get_available_attention_backends(
inference_params=inference_params,
softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
# allow all backends to pass so they can be used for testing;
# check for FA3 availability later
num_splits=1,
)
(
use_flash_attention,
......@@ -330,6 +336,10 @@ def get_available_attention_backends(
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Check if FA3 is an available backend when num_splits != 1
if available_backends[0]:
if config.num_splits != 1 and not flash_attention_backend > PkgVersion("3.0.0b"):
available_backends[0] = False
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
......
......@@ -278,7 +278,7 @@ void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, in
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* 2. cu_new_lens and cu_cached_lens are of shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
......
......@@ -131,7 +131,7 @@ enum NVTE_Mask_Type {
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter in shape [H].
* where alpha is a learnable parameter of shape [H].
*/
enum NVTE_Softmax_Type {
/*! Vanilla softmax */
......
......@@ -50,7 +50,7 @@ class MMParams:
Parameters
----------
use_split_accumulator : bool, default = `True`
use_split_accumulator : bool, default = True
Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
"""
......@@ -159,7 +159,7 @@ class DelayedScaling(Recipe):
recipe: DelayedScaling) -> Tensor
where `Tensor` is a framework tensor type.
reduce_amax: bool, default = `True`
reduce_amax: bool, default = True
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `amax_reduction_group` (specified in the `autocast`
call). This keeps the amaxes and scaling factors synced across the given
......@@ -167,13 +167,13 @@ class DelayedScaling(Recipe):
GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default = `False`
fp8_dpa: bool, default = False
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default = `False`
fp8_mha: bool, default = False
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
......@@ -422,11 +422,11 @@ class NVFP4BlockScaling(Recipe):
----------
fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type.
disable_rht : bool, default = `False`
disable_rht : bool, default = False
If set to `True`, random Hadamard transforms are not applied to any tensor.
disable_stochastic_rounding : bool, default = `False`
disable_stochastic_rounding : bool, default = False
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default = `False`
disable_2d_quantization : bool, default = False
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
"""
......@@ -492,17 +492,19 @@ class CustomRecipe(Recipe):
Parameters
----------
qfactory : Callable
Factory callable that returns a quantizer instance for a
given semantic tensor role.
The callable is typically invoked as:
qfactory(
role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
Factory callable that returns a quantizer instance for a
given semantic tensor role.
The callable is typically invoked as::
qfactory(
role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
"""
qfactory: Callable[..., Any]
......
......@@ -736,12 +736,17 @@ int nvte_is_non_tn_fp8_gemm_supported() {
#if USE_ROCM
return true;
#else
int deviceComputeCapability =
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
std::call_once(flags[device_id], [&]() {
int deviceComputeCapability = transformer_engine::cuda::sm_arch(device_id);
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
cache[device_id] = (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
});
return cache[device_id];
#endif
}
......@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_QKV_Format
from transformer_engine_jax import nvte_get_qkv_format
from transformer_engine_jax import NVTE_Softmax_Type
from . import cpp_extensions as tex
......@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
]
class AttnSoftmaxType(Enum):
"""
VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)),
LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [H].
"""
VANILLA_SOFTMAX = NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX
OFF_BY_ONE_SOFTMAX = NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX
LEARNABLE_SOFTMAX = NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX
@classmethod
def from_str(cls, softmax_type: str) -> "AttnSoftmaxType":
"""Convert string to AttnSoftmaxType: 'vanilla', 'off_by_one', or 'learnable'."""
softmax_type_map = {
"vanilla": cls.VANILLA_SOFTMAX,
"off_by_one": cls.OFF_BY_ONE_SOFTMAX,
"learnable": cls.LEARNABLE_SOFTMAX,
}
result = softmax_type_map.get(softmax_type)
if result is None:
raise ValueError(
f"Unknown softmax_type: {softmax_type}. "
"Valid options: 'vanilla', 'off_by_one', 'learnable'"
)
return result
class QKVFormat(Enum):
"""
SBHD: q,k,v memory layout with [s, b, ..., h, d]
......@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_probability,
q_num_heads,
kv_num_heads,
......@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
"""
To check whether the fused attention kernel is supported
"""
window_size_tuple = (-1, -1) if window_size is None else window_size
def make_helper(attn_mask_type):
return tex.FusedAttnHelper(
......@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_probability,
q_num_heads,
kv_num_heads,
......@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
kv_max_seqlen,
head_dim_qk,
head_dim_v,
(-1, -1) if window_size is None else window_size,
window_size_tuple,
)
return make_helper(attn_mask_type).is_fused_attn_kernel_available()
......@@ -497,6 +530,11 @@ def _segment_ids_pos_to_seqlens_offsets(
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
# For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if (attn_mask_type.is_causal() and window_size is None) or (
window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
......@@ -558,21 +596,6 @@ def _segment_ids_pos_to_seqlens_offsets(
)
attn_mask = jnp.logical_and(segment_mask, causal_mask)
# TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets
swa_mask = (
make_swa_mask(
segment_pos_q,
segment_pos_kv,
window_size,
dtype=jnp.bool,
segment_ids_q=segment_ids_q,
segment_ids_kv=segment_ids_kv,
)
if attn_mask_type.is_bottom_right()
else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
)
attn_mask = jnp.logical_and(attn_mask, swa_mask)
attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(
attn_mask_with_id, max_segments_per_seq
......@@ -786,6 +809,7 @@ def _legacy_fused_attn(
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
......@@ -793,6 +817,7 @@ def _legacy_fused_attn(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
softmax_offset: Optional[jnp.ndarray] = None,
):
"""
Perform non-THD (non-packed) cuDNN fused attention.
......@@ -815,6 +840,7 @@ def _legacy_fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
......@@ -863,10 +889,12 @@ def _legacy_fused_attn(
output = _fused_attn(
qkv,
bias,
softmax_offset,
SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)),
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
......@@ -900,6 +928,7 @@ def fused_attn_thd(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
softmax_offset: Optional[jnp.ndarray] = None,
):
"""
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
......@@ -937,6 +966,7 @@ def fused_attn_thd(
output = _fused_attn(
qkv,
bias,
softmax_offset,
SequenceDescriptor.from_seqlens_and_offsets(
(q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets)
),
......@@ -945,6 +975,7 @@ def fused_attn_thd(
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
......@@ -957,15 +988,17 @@ def fused_attn_thd(
return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
softmax_offset: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
......@@ -979,11 +1012,13 @@ def _fused_attn(
output, _ = _fused_attn_fwd_rule(
qkv,
bias,
softmax_offset,
sequence_descriptor,
seed,
attn_bias_type,
attn_mask_type,
qkv_layout,
softmax_type,
scaling_factor,
dropout_probability,
is_training,
......@@ -1000,11 +1035,13 @@ def _fused_attn(
def _fused_attn_fwd_rule(
qkv,
bias,
softmax_offset,
sequence_descriptor,
seed,
attn_bias_type,
attn_mask_type,
qkv_layout,
softmax_type,
scaling_factor,
dropout_probability,
is_training,
......@@ -1018,10 +1055,12 @@ def _fused_attn_fwd_rule(
output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv,
bias,
softmax_offset,
sequence_descriptor,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
......@@ -1041,6 +1080,7 @@ def _fused_attn_fwd_rule(
sequence_descriptor,
softmax_aux,
rng_state,
softmax_offset,
output,
)
......@@ -1049,6 +1089,7 @@ def _fused_attn_bwd_rule(
attn_bias_type,
attn_mask_type,
qkv_layout,
softmax_type,
scaling_factor,
dropout_probability,
is_training,
......@@ -1068,11 +1109,13 @@ def _fused_attn_bwd_rule(
sequence_descriptor,
softmax_aux,
rng_state,
softmax_offset,
output,
) = ctx
grad_qkv, grad_bias = tex.fused_attn_bwd(
grad_qkv, grad_bias, grad_softmax_offset = tex.fused_attn_bwd(
qkv,
bias,
softmax_offset,
softmax_aux,
rng_state,
output,
......@@ -1080,6 +1123,7 @@ def _fused_attn_bwd_rule(
sequence_descriptor,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
......@@ -1092,9 +1136,12 @@ def _fused_attn_bwd_rule(
)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
if softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX:
grad_softmax_offset = None
return (
grad_qkv,
grad_bias,
grad_softmax_offset,
None,
None,
)
......@@ -1111,6 +1158,7 @@ def fused_attn(
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
......@@ -1120,6 +1168,7 @@ def fused_attn(
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
context_checkpoint_name: str = "context",
softmax_offset: Optional[jnp.ndarray] = None,
):
"""
Perform cuDNN fused attention.
......@@ -1139,6 +1188,7 @@ def fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
......@@ -1153,6 +1203,9 @@ def fused_attn(
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
......@@ -1200,6 +1253,7 @@ def fused_attn(
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
......@@ -1208,15 +1262,18 @@ def fused_attn(
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
softmax_offset=softmax_offset,
)
output = _fused_attn(
qkv,
bias,
softmax_offset,
sequence_descriptor,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
softmax_type=softmax_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
......
......@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
Quantizer,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
QuantizeLayout,
)
......
......@@ -39,12 +39,12 @@ from ..quantize import (
Quantizer,
GroupedQuantizer,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
get_quantize_config_with_recipe,
get_global_quantize_recipe,
QuantizeLayout,
)
from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import (
......
......@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
transpose. Note, transpose_axis should be greater than static_axis_boundary
examples:
X in shape (dim0, dim1, dim2, dim3, dim4)
X of shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1)
......
......@@ -35,9 +35,9 @@ from ..sharding import (
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
Quantizer,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
QuantizeLayout,
)
......
......@@ -40,11 +40,11 @@ from ..quantize import (
GroupedScaledTensor1x,
Quantizer,
GroupedQuantizer,
QuantizeLayout,
ScalingMode,
compute_scale_from_amax,
NoScaleTensor,
get_rht_matrix,
QuantizeLayout,
)
......@@ -497,6 +497,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x_spec = get_padded_spec(arg_infos[0])
amax_spec = get_padded_spec(arg_infos[2])
sr_rng_state_spec = get_padded_spec(arg_infos[3])
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec),
......@@ -551,11 +552,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
)
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings[3] = NamedSharding(
mesh,
PartitionSpec(tuple(x for x in x_spec if x is not None), None),
desc="BaseDBiasQuantizePrimitive.sr_rng_state",
)
if len(sr_rng_state_spec) > 1:
# sr_rng_state shape [n_devices, state_per_device]
sr_rng_state_spec = (*tuple(x for x in x_spec if x is not None), None)
arg_shardings[3] = NamedSharding(
mesh,
PartitionSpec(*sr_rng_state_spec),
desc="BaseDBiasQuantizePrimitive.sr_rng_state",
)
arg_shardings = tuple(arg_shardings)
out_shardings = (
out_sharding,
......@@ -654,10 +658,12 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",)
amax = (BATCHING + prefix + "_amax",)
scale = (BATCHING + prefix + "_scale",)
sr_rng_state = (
BATCHING + prefix + "_sr_rng_state_partition_axis",
BATCHING + prefix + "sr_rng_state_data_axis",
)
sr_rng_state = (BATCHING + prefix + "_sr_rng_state",)
if value_types[3].shape != [0]:
sr_rng_state = (
BATCHING + prefix + "_sr_rng_state_devices",
prefix + "sr_rng_state_data",
)
post_rht_amax = (BATCHING + prefix + "_post_rht_amax",)
rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2")
......@@ -849,7 +855,7 @@ def _quantize_dbias_impl(
if force_1x_quantization:
q_layout = QuantizeLayout.ROWWISE
sr_rng_state = None
sr_rng_state = jnp.empty((0,), jnp.uint32)
if quantizer.scaling_mode.is_nvfp4_scaling:
# Only NVFP4 scaling modes support stochastic rounding
if quantizer.stochastic_rounding_rng_state is not None:
......@@ -866,11 +872,7 @@ def _quantize_dbias_impl(
x.data,
scale,
amax,
(
sr_rng_state
if sr_rng_state is not None
else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32)
),
sr_rng_state,
post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
rht_matrix,
out_dtype=quantizer.q_dtype,
......@@ -880,7 +882,7 @@ def _quantize_dbias_impl(
scale_dtype=quantizer.get_scale_dtype(),
is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
is_outer=True,
stochastic_rounding=sr_rng_state is not None,
stochastic_rounding=sr_rng_state.size != 0,
use_rht=use_rht,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
......
......@@ -108,28 +108,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t qk_head_dim, size_t v_head_dim,
int64_t window_size_left, int64_t window_size_right);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(
bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left,
int64_t window_size_right);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
int64_t window_size_left, int64_t window_size_right);
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment