Commit 970620a5 authored by wenjh's avatar wenjh
Browse files

merge nv_release_v2.10 to release_v2.10


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents c1a1c04e 769ed778
...@@ -127,12 +127,18 @@ def _run_layer_with_overlap( ...@@ -127,12 +127,18 @@ def _run_layer_with_overlap(
os.environ["PYTORCH_JIT"] = "0" os.environ["PYTORCH_JIT"] = "0"
os.environ["NVTE_TORCH_COMPILE"] = "0" os.environ["NVTE_TORCH_COMPILE"] = "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
if te.get_device_compute_capability() <= (8, 0):
# We've experienced numerical discrepancies in Flash Attention
# backward when running with Userbuffers on A100s. This does
# not show up in more recent GPUs.
os.environ["NVTE_FLASH_ATTN"] = "0"
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
os.unsetenv("PYTORCH_JIT") os.unsetenv("PYTORCH_JIT")
os.unsetenv("NVTE_TORCH_COMPILE") os.unsetenv("NVTE_TORCH_COMPILE")
os.unsetenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO") os.unsetenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO")
os.unsetenv("NVTE_FLASH_ATTN")
if ( if (
result.returncode != 0 result.returncode != 0
......
...@@ -7,7 +7,7 @@ import sys ...@@ -7,7 +7,7 @@ import sys
import pytest import pytest
import torch import torch
import transformer_engine import transformer_engine
from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear, GroupedLinear
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
...@@ -19,7 +19,9 @@ model_configs = { ...@@ -19,7 +19,9 @@ model_configs = {
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"]) @pytest.mark.parametrize(
"module", ["TransformerLayer", "DotProductAttention", "Linear", "GroupedLinear"]
)
def test_current_device(model, module): def test_current_device(model, module):
"""Test cases where current device is different from tensor device""" """Test cases where current device is different from tensor device"""
...@@ -42,7 +44,29 @@ def test_current_device(model, module): ...@@ -42,7 +44,29 @@ def test_current_device(model, module):
self_attn_mask_type="padding", self_attn_mask_type="padding",
device=f"cuda:{tensor_device}", device=f"cuda:{tensor_device}",
) )
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item() seqlens_q = torch.randint(
1,
config.max_seqlen_q,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_q = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
seqlens_kv = torch.randint(
1,
config.max_seqlen_kv,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_kv = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
num_tokens = cu_seqlens_q[-1]
args = [ args = [
torch.randn( torch.randn(
(num_tokens, config.hidden_size), (num_tokens, config.hidden_size),
...@@ -51,37 +75,55 @@ def test_current_device(model, module): ...@@ -51,37 +75,55 @@ def test_current_device(model, module):
requires_grad=True, requires_grad=True,
) )
] ]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv kwargs["max_seqlen_kv"] = config.max_seqlen_kv
if module == "DotProductAttention": elif module == "DotProductAttention":
model = DotProductAttention( model = DotProductAttention(
config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding" config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
) )
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item() seqlens_q = torch.randint(
1,
config.max_seqlen_q,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_q = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
seqlens_kv = torch.randint(
1,
config.max_seqlen_kv,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_kv = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
num_tokens = cu_seqlens_q[-1]
args = [ args = [
torch.randn( torch.randn(
num_tokens, num_tokens,
config.num_heads, config.num_heads,
config.head_dim_qk, config.head_dim_qk,
dtype=dtype, dtype=dtype,
device=tensor_device, device=f"cuda:{tensor_device}",
requires_grad=True, requires_grad=True,
) )
for _ in range(3) for _ in range(3)
] ]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv kwargs["max_seqlen_kv"] = config.max_seqlen_kv
bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)] bwd_args = [
torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=f"cuda:{tensor_device}")
]
elif module == "Linear": elif module == "Linear":
model = Linear( model = Linear(
config.hidden_size, config.hidden_size,
...@@ -97,6 +139,24 @@ def test_current_device(model, module): ...@@ -97,6 +139,24 @@ def test_current_device(model, module):
requires_grad=True, requires_grad=True,
) )
] ]
elif module == "GroupedLinear":
num_gemms = 4
model = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
params_dtype=dtype,
device=f"cuda:{tensor_device}",
)
args = [
torch.randn(
(config.max_seqlen_q * config.batch_size * (num_gemms - 1), config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
),
[0] + [config.max_seqlen_q * config.batch_size] * (num_gemms - 1), # Empty first split.
]
current_device_before = torch.cuda.current_device() current_device_before = torch.cuda.current_device()
out = model(*args, **kwargs) out = model(*args, **kwargs)
......
...@@ -913,15 +913,15 @@ class TestBasicOps: ...@@ -913,15 +913,15 @@ class TestBasicOps:
dtype=dtype, dtype=dtype,
accumulate_into_main_grad=accumulate_into_main_grad, accumulate_into_main_grad=accumulate_into_main_grad,
) )
forward = te_ops.Sequential(
te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
)
with torch.no_grad(): with torch.no_grad():
op.weight.copy_(w_test) op.weight.copy_(w_test)
del w_test del w_test
op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32) op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32)
forward = te_ops.Sequential(
te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
)
with te.autocast(enabled=quantized_compute, recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
......
...@@ -46,7 +46,6 @@ from transformer_engine.pytorch import ( ...@@ -46,7 +46,6 @@ from transformer_engine.pytorch import (
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch import checkpoint as te_checkpoint from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states from utils import ModelConfig, reset_rng_states
...@@ -2757,7 +2756,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): ...@@ -2757,7 +2756,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
general_gemm( general_gemm(
A[i], A[i],
B[i], B[i],
get_workspace(),
dtype, dtype,
grad=grad, grad=grad,
accumulate=accumulate, accumulate=accumulate,
...@@ -2772,7 +2770,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): ...@@ -2772,7 +2770,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
B, B,
out, out,
dtype, dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits, m_splits=m_splits,
grad=grad, grad=grad,
accumulate=accumulate, accumulate=accumulate,
...@@ -2832,7 +2829,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua ...@@ -2832,7 +2829,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
quantized_out, *_ = general_gemm( quantized_out, *_ = general_gemm(
weight_fp8, weight_fp8,
inp_fp8, inp_fp8,
get_workspace(),
outp_type, outp_type,
quantization_params=out_quantizer, quantization_params=out_quantizer,
bias=None, bias=None,
...@@ -2842,7 +2838,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua ...@@ -2842,7 +2838,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
out, *_ = general_gemm( out, *_ = general_gemm(
weight_fp8, weight_fp8,
inp_fp8, inp_fp8,
get_workspace(),
outp_type, outp_type,
quantization_params=None, quantization_params=None,
bias=None, bias=None,
...@@ -2918,7 +2913,6 @@ def test_fp8_grouped_gemm(shape, accumulate): ...@@ -2918,7 +2913,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
general_gemm( general_gemm(
A_fp8[i], A_fp8[i],
B_fp8[i], B_fp8[i],
get_workspace(),
dtype, dtype,
out=out_ref[i], out=out_ref[i],
accumulate=accumulate, accumulate=accumulate,
...@@ -2928,7 +2922,6 @@ def test_fp8_grouped_gemm(shape, accumulate): ...@@ -2928,7 +2922,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
B_fp8, B_fp8,
out, out,
dtype, dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits, m_splits=m_splits,
accumulate=accumulate, accumulate=accumulate,
) )
......
...@@ -37,7 +37,6 @@ from transformer_engine.pytorch import ( ...@@ -37,7 +37,6 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from utils import ModelConfig from utils import ModelConfig
...@@ -961,7 +960,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype): ...@@ -961,7 +960,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
inp = torch.reshape(scratchpad[offset:-offset], (N, N)) inp = torch.reshape(scratchpad[offset:-offset], (N, N))
weight = torch.reshape(scratchpad[offset * 2 :], (N, N)) weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
_ = general_gemm(A=weight, B=inp, workspace=get_workspace()) _ = general_gemm(A=weight, B=inp)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -985,7 +984,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): ...@@ -985,7 +984,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
general_gemm( general_gemm(
weight_fp8, weight_fp8,
inp_fp8, inp_fp8,
get_workspace(),
outp_type, outp_type,
bias=None, bias=None,
use_split_accumulator=False, use_split_accumulator=False,
......
...@@ -8,6 +8,7 @@ import logging ...@@ -8,6 +8,7 @@ import logging
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Tuple, Dict, Any, List from typing import Optional, Tuple, Dict, Any, List
from packaging.version import Version as PkgVersion
import torch import torch
...@@ -210,6 +211,7 @@ class ModelConfig: ...@@ -210,6 +211,7 @@ class ModelConfig:
max_ctx_len: int = None, max_ctx_len: int = None,
num_layers: int = 1, num_layers: int = 1,
eps: float = 1e-5, eps: float = 1e-5,
num_splits=1,
): ):
self.batch_size = batch_size self.batch_size = batch_size
self.max_seqlen_q = max_seqlen_q self.max_seqlen_q = max_seqlen_q
...@@ -239,6 +241,7 @@ class ModelConfig: ...@@ -239,6 +241,7 @@ class ModelConfig:
self.max_ctx_len = max_ctx_len self.max_ctx_len = max_ctx_len
self.num_layers = num_layers self.num_layers = num_layers
self.eps = eps self.eps = eps
self.num_splits = num_splits
@contextmanager @contextmanager
...@@ -321,6 +324,9 @@ def get_available_attention_backends( ...@@ -321,6 +324,9 @@ def get_available_attention_backends(
inference_params=inference_params, inference_params=inference_params,
softmax_type=config.softmax_type, softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit, return_max_logit=config.return_max_logit,
# allow all backends to pass so they can be used for testing;
# check for FA3 availability later
num_splits=1,
) )
( (
use_flash_attention, use_flash_attention,
...@@ -330,6 +336,10 @@ def get_available_attention_backends( ...@@ -330,6 +336,10 @@ def get_available_attention_backends(
use_unfused_attention, use_unfused_attention,
available_backends, available_backends,
) = get_attention_backend(attention_params) ) = get_attention_backend(attention_params)
# Check if FA3 is an available backend when num_splits != 1
if available_backends[0]:
if config.num_splits != 1 and not flash_attention_backend > PkgVersion("3.0.0b"):
available_backends[0] = False
# Set attention.py _attention_backends var using return value # Set attention.py _attention_backends var using return value
# from get_attention_backend() # from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention _attention_backends["use_flash_attention"] = use_flash_attention
......
...@@ -278,7 +278,7 @@ void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, in ...@@ -278,7 +278,7 @@ void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, in
/*************************************************************************************************** /***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache * KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format * 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens * 2. cu_new_lens and cu_cached_lens are of shape [b + 1]; cu_cached_lens include the added lens
* in current step * in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and * 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged. * max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
......
...@@ -131,7 +131,7 @@ enum NVTE_Mask_Type { ...@@ -131,7 +131,7 @@ enum NVTE_Mask_Type {
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), * NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and * NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), * NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter in shape [H]. * where alpha is a learnable parameter of shape [H].
*/ */
enum NVTE_Softmax_Type { enum NVTE_Softmax_Type {
/*! Vanilla softmax */ /*! Vanilla softmax */
......
...@@ -50,7 +50,7 @@ class MMParams: ...@@ -50,7 +50,7 @@ class MMParams:
Parameters Parameters
---------- ----------
use_split_accumulator : bool, default = `True` use_split_accumulator : bool, default = True
Use FP8 fast accumulation on Hopper or Ada. For more details, Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul. see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
""" """
...@@ -159,7 +159,7 @@ class DelayedScaling(Recipe): ...@@ -159,7 +159,7 @@ class DelayedScaling(Recipe):
recipe: DelayedScaling) -> Tensor recipe: DelayedScaling) -> Tensor
where `Tensor` is a framework tensor type. where `Tensor` is a framework tensor type.
reduce_amax: bool, default = `True` reduce_amax: bool, default = True
By default, if `torch.distributed` is initialized, the `amax` value for FP8 By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `amax_reduction_group` (specified in the `autocast` tensors is reduced across the `amax_reduction_group` (specified in the `autocast`
call). This keeps the amaxes and scaling factors synced across the given call). This keeps the amaxes and scaling factors synced across the given
...@@ -167,13 +167,13 @@ class DelayedScaling(Recipe): ...@@ -167,13 +167,13 @@ class DelayedScaling(Recipe):
GPU maintains local amaxes and scaling factors. To ensure results are GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors. ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default = `False` fp8_dpa: bool, default = False
Whether to enable FP8 dot product attention (DPA). When the model is placed in an Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the `autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend. `FusedAttention` backend.
fp8_mha: bool, default = `False` fp8_mha: bool, default = False
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
...@@ -422,11 +422,11 @@ class NVFP4BlockScaling(Recipe): ...@@ -422,11 +422,11 @@ class NVFP4BlockScaling(Recipe):
---------- ----------
fp4_format : {Format.E2M1}, default = Format.E2M1 fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type. FP4 data type.
disable_rht : bool, default = `False` disable_rht : bool, default = False
If set to `True`, random Hadamard transforms are not applied to any tensor. If set to `True`, random Hadamard transforms are not applied to any tensor.
disable_stochastic_rounding : bool, default = `False` disable_stochastic_rounding : bool, default = False
If set to `True`, stochastic rounding is disabled during quantization for all tensors. If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default = `False` disable_2d_quantization : bool, default = False
If set to `True`, 1D block scaling with block size 16 is used for all tensors. If set to `True`, 1D block scaling with block size 16 is used for all tensors.
""" """
...@@ -492,17 +492,19 @@ class CustomRecipe(Recipe): ...@@ -492,17 +492,19 @@ class CustomRecipe(Recipe):
Parameters Parameters
---------- ----------
qfactory : Callable qfactory : Callable
Factory callable that returns a quantizer instance for a Factory callable that returns a quantizer instance for a
given semantic tensor role. given semantic tensor role.
The callable is typically invoked as: The callable is typically invoked as::
qfactory(
role: str, qfactory(
) role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract): Where `role` is one of the following strings for e.g. te.Linear
- forward: "linear_input", "linear_weight", "linear_output" (stable public contract):
- backward: "linear_grad_output", "linear_grad_input"
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
""" """
qfactory: Callable[..., Any] qfactory: Callable[..., Any]
......
...@@ -736,12 +736,17 @@ int nvte_is_non_tn_fp8_gemm_supported() { ...@@ -736,12 +736,17 @@ int nvte_is_non_tn_fp8_gemm_supported() {
#if USE_ROCM #if USE_ROCM
return true; return true;
#else #else
int deviceComputeCapability = int num_devices = transformer_engine::cuda::num_devices();
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()); static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
// Note: this is temporary restriction and should be lifted in the future. int device_id = transformer_engine::cuda::current_device();
// (remove the note once it's done.) std::call_once(flags[device_id], [&]() {
return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) || int deviceComputeCapability = transformer_engine::cuda::sm_arch(device_id);
deviceComputeCapability >= 130; // Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
cache[device_id] = (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
});
return cache[device_id];
#endif #endif
} }
...@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type ...@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_QKV_Format from transformer_engine_jax import NVTE_QKV_Format
from transformer_engine_jax import nvte_get_qkv_format from transformer_engine_jax import nvte_get_qkv_format
from transformer_engine_jax import NVTE_Softmax_Type
from . import cpp_extensions as tex from . import cpp_extensions as tex
...@@ -74,6 +75,35 @@ class AttnMaskType(Enum): ...@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
] ]
class AttnSoftmaxType(Enum):
"""
VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)),
LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [H].
"""
VANILLA_SOFTMAX = NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX
OFF_BY_ONE_SOFTMAX = NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX
LEARNABLE_SOFTMAX = NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX
@classmethod
def from_str(cls, softmax_type: str) -> "AttnSoftmaxType":
"""Convert string to AttnSoftmaxType: 'vanilla', 'off_by_one', or 'learnable'."""
softmax_type_map = {
"vanilla": cls.VANILLA_SOFTMAX,
"off_by_one": cls.OFF_BY_ONE_SOFTMAX,
"learnable": cls.LEARNABLE_SOFTMAX,
}
result = softmax_type_map.get(softmax_type)
if result is None:
raise ValueError(
f"Unknown softmax_type: {softmax_type}. "
"Valid options: 'vanilla', 'off_by_one', 'learnable'"
)
return result
class QKVFormat(Enum): class QKVFormat(Enum):
""" """
SBHD: q,k,v memory layout with [s, b, ..., h, d] SBHD: q,k,v memory layout with [s, b, ..., h, d]
...@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available( ...@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_probability, dropout_probability,
q_num_heads, q_num_heads,
kv_num_heads, kv_num_heads,
...@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available( ...@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
""" """
To check whether the fused attention kernel is supported To check whether the fused attention kernel is supported
""" """
window_size_tuple = (-1, -1) if window_size is None else window_size
def make_helper(attn_mask_type): def make_helper(attn_mask_type):
return tex.FusedAttnHelper( return tex.FusedAttnHelper(
...@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available( ...@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_probability, dropout_probability,
q_num_heads, q_num_heads,
kv_num_heads, kv_num_heads,
...@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available( ...@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
kv_max_seqlen, kv_max_seqlen,
head_dim_qk, head_dim_qk,
head_dim_v, head_dim_v,
(-1, -1) if window_size is None else window_size, window_size_tuple,
) )
return make_helper(attn_mask_type).is_fused_attn_kernel_available() return make_helper(attn_mask_type).is_fused_attn_kernel_available()
...@@ -497,6 +530,11 @@ def _segment_ids_pos_to_seqlens_offsets( ...@@ -497,6 +530,11 @@ def _segment_ids_pos_to_seqlens_offsets(
# #
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to # This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements. # examine only O(Q+KV) elements.
# For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if (attn_mask_type.is_causal() and window_size is None) or ( if (attn_mask_type.is_causal() and window_size is None) or (
window_size == (-1, -1) and not attn_mask_type.is_bottom_right() window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
...@@ -558,21 +596,6 @@ def _segment_ids_pos_to_seqlens_offsets( ...@@ -558,21 +596,6 @@ def _segment_ids_pos_to_seqlens_offsets(
) )
attn_mask = jnp.logical_and(segment_mask, causal_mask) attn_mask = jnp.logical_and(segment_mask, causal_mask)
# TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets
swa_mask = (
make_swa_mask(
segment_pos_q,
segment_pos_kv,
window_size,
dtype=jnp.bool,
segment_ids_q=segment_ids_q,
segment_ids_kv=segment_ids_kv,
)
if attn_mask_type.is_bottom_right()
else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
)
attn_mask = jnp.logical_and(attn_mask, swa_mask)
attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(
attn_mask_with_id, max_segments_per_seq attn_mask_with_id, max_segments_per_seq
...@@ -786,6 +809,7 @@ def _legacy_fused_attn( ...@@ -786,6 +809,7 @@ def _legacy_fused_attn(
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -793,6 +817,7 @@ def _legacy_fused_attn( ...@@ -793,6 +817,7 @@ def _legacy_fused_attn(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
softmax_offset: Optional[jnp.ndarray] = None,
): ):
""" """
Perform non-THD (non-packed) cuDNN fused attention. Perform non-THD (non-packed) cuDNN fused attention.
...@@ -815,6 +840,7 @@ def _legacy_fused_attn( ...@@ -815,6 +840,7 @@ def _legacy_fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout. seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
...@@ -863,10 +889,12 @@ def _legacy_fused_attn( ...@@ -863,10 +889,12 @@ def _legacy_fused_attn(
output = _fused_attn( output = _fused_attn(
qkv, qkv,
bias, bias,
softmax_offset,
SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)), SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)),
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
...@@ -900,6 +928,7 @@ def fused_attn_thd( ...@@ -900,6 +928,7 @@ def fused_attn_thd(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
softmax_offset: Optional[jnp.ndarray] = None,
): ):
""" """
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
...@@ -937,6 +966,7 @@ def fused_attn_thd( ...@@ -937,6 +966,7 @@ def fused_attn_thd(
output = _fused_attn( output = _fused_attn(
qkv, qkv,
bias, bias,
softmax_offset,
SequenceDescriptor.from_seqlens_and_offsets( SequenceDescriptor.from_seqlens_and_offsets(
(q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets) (q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets)
), ),
...@@ -945,6 +975,7 @@ def fused_attn_thd( ...@@ -945,6 +975,7 @@ def fused_attn_thd(
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
...@@ -957,15 +988,17 @@ def fused_attn_thd( ...@@ -957,15 +988,17 @@ def fused_attn_thd(
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)) @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fused_attn( def _fused_attn(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
softmax_offset: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor, sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray], seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -979,11 +1012,13 @@ def _fused_attn( ...@@ -979,11 +1012,13 @@ def _fused_attn(
output, _ = _fused_attn_fwd_rule( output, _ = _fused_attn_fwd_rule(
qkv, qkv,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
qkv_layout, qkv_layout,
softmax_type,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
...@@ -1000,11 +1035,13 @@ def _fused_attn( ...@@ -1000,11 +1035,13 @@ def _fused_attn(
def _fused_attn_fwd_rule( def _fused_attn_fwd_rule(
qkv, qkv,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
qkv_layout, qkv_layout,
softmax_type,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
...@@ -1018,10 +1055,12 @@ def _fused_attn_fwd_rule( ...@@ -1018,10 +1055,12 @@ def _fused_attn_fwd_rule(
output, softmax_aux, rng_state = tex.fused_attn_fwd( output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv, qkv,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
...@@ -1041,6 +1080,7 @@ def _fused_attn_fwd_rule( ...@@ -1041,6 +1080,7 @@ def _fused_attn_fwd_rule(
sequence_descriptor, sequence_descriptor,
softmax_aux, softmax_aux,
rng_state, rng_state,
softmax_offset,
output, output,
) )
...@@ -1049,6 +1089,7 @@ def _fused_attn_bwd_rule( ...@@ -1049,6 +1089,7 @@ def _fused_attn_bwd_rule(
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
qkv_layout, qkv_layout,
softmax_type,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
...@@ -1068,11 +1109,13 @@ def _fused_attn_bwd_rule( ...@@ -1068,11 +1109,13 @@ def _fused_attn_bwd_rule(
sequence_descriptor, sequence_descriptor,
softmax_aux, softmax_aux,
rng_state, rng_state,
softmax_offset,
output, output,
) = ctx ) = ctx
grad_qkv, grad_bias = tex.fused_attn_bwd( grad_qkv, grad_bias, grad_softmax_offset = tex.fused_attn_bwd(
qkv, qkv,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1080,6 +1123,7 @@ def _fused_attn_bwd_rule( ...@@ -1080,6 +1123,7 @@ def _fused_attn_bwd_rule(
sequence_descriptor, sequence_descriptor,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
...@@ -1092,9 +1136,12 @@ def _fused_attn_bwd_rule( ...@@ -1092,9 +1136,12 @@ def _fused_attn_bwd_rule(
) )
if attn_bias_type == AttnBiasType.NO_BIAS: if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None grad_bias = None
if softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX:
grad_softmax_offset = None
return ( return (
grad_qkv, grad_qkv,
grad_bias, grad_bias,
grad_softmax_offset,
None, None,
None, None,
) )
...@@ -1111,6 +1158,7 @@ def fused_attn( ...@@ -1111,6 +1158,7 @@ def fused_attn(
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -1120,6 +1168,7 @@ def fused_attn( ...@@ -1120,6 +1168,7 @@ def fused_attn(
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
context_checkpoint_name: str = "context", context_checkpoint_name: str = "context",
softmax_offset: Optional[jnp.ndarray] = None,
): ):
""" """
Perform cuDNN fused attention. Perform cuDNN fused attention.
...@@ -1139,6 +1188,7 @@ def fused_attn( ...@@ -1139,6 +1188,7 @@ def fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout. seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
...@@ -1153,6 +1203,9 @@ def fused_attn( ...@@ -1153,6 +1203,9 @@ def fused_attn(
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass. context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
Returns: Returns:
(jnp.ndarray): The output tensor from the fused attention. (jnp.ndarray): The output tensor from the fused attention.
...@@ -1200,6 +1253,7 @@ def fused_attn( ...@@ -1200,6 +1253,7 @@ def fused_attn(
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
...@@ -1208,15 +1262,18 @@ def fused_attn( ...@@ -1208,15 +1262,18 @@ def fused_attn(
context_parallel_strategy=context_parallel_strategy, context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
softmax_offset=softmax_offset,
) )
output = _fused_attn( output = _fused_attn(
qkv, qkv,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
softmax_type=softmax_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
......
...@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a ...@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeLayout,
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
QuantizeLayout,
) )
......
...@@ -20,11 +20,13 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend ...@@ -20,11 +20,13 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
AttnSoftmaxType,
QKVLayout, QKVLayout,
QKVFormat, QKVFormat,
CPStrategy, CPStrategy,
SequenceDescriptor, SequenceDescriptor,
) )
from ..sharding import with_sharding_constraint_by_logical_axes, HEAD_AXES
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .misc import ( from .misc import (
...@@ -61,6 +63,7 @@ __all__ = [ ...@@ -61,6 +63,7 @@ __all__ = [
meta_fields=[ meta_fields=[
"attn_bias_type", "attn_bias_type",
"attn_mask_type", "attn_mask_type",
"softmax_type",
"qkv_layout", "qkv_layout",
"scaling_factor", "scaling_factor",
"dropout_probability", "dropout_probability",
...@@ -80,6 +83,7 @@ class _FusedAttnConfig: ...@@ -80,6 +83,7 @@ class _FusedAttnConfig:
attn_bias_type: AttnBiasType attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType attn_mask_type: AttnMaskType
softmax_type: AttnSoftmaxType
qkv_layout: QKVLayout qkv_layout: QKVLayout
scaling_factor: float scaling_factor: float
dropout_probability: float dropout_probability: float
...@@ -103,6 +107,7 @@ class FusedAttnHelper: ...@@ -103,6 +107,7 @@ class FusedAttnHelper:
qkv_layout: QKVLayout qkv_layout: QKVLayout
attn_bias_type: AttnBiasType attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType attn_mask_type: AttnMaskType
softmax_type: AttnSoftmaxType
dropout_probability: float dropout_probability: float
q_num_heads: int q_num_heads: int
kv_num_heads: int kv_num_heads: int
...@@ -125,6 +130,7 @@ class FusedAttnHelper: ...@@ -125,6 +130,7 @@ class FusedAttnHelper:
self.qkv_layout.value, self.qkv_layout.value,
self.attn_bias_type.value, self.attn_bias_type.value,
self.attn_mask_type.value, self.attn_mask_type.value,
self.softmax_type.value,
self.dropout_probability, self.dropout_probability,
self.q_num_heads, self.q_num_heads,
self.kv_num_heads, self.kv_num_heads,
...@@ -254,7 +260,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -254,7 +260,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
name = "te_fused_attn_forward_ffi" name = "te_fused_attn_forward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (13,) impl_static_args = (14,)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -264,6 +270,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -264,6 +270,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k_aval, k_aval,
v_aval, v_aval,
bias_aval, bias_aval,
softmax_offset_aval,
seed_aval, seed_aval,
q_seqlen_or_cu_seqlen_aval, q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval, kv_seqlen_or_cu_seqlen_aval,
...@@ -312,6 +319,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -312,6 +319,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config.qkv_layout, config.qkv_layout,
config.attn_bias_type, config.attn_bias_type,
config.attn_mask_type, config.attn_mask_type,
config.softmax_type,
config.dropout_probability, config.dropout_probability,
attn_heads, attn_heads,
num_gqa_groups, num_gqa_groups,
...@@ -375,6 +383,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -375,6 +383,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config.dropout_probability, config.dropout_probability,
config.attn_bias_type.value, config.attn_bias_type.value,
config.attn_mask_type.value, config.attn_mask_type.value,
config.softmax_type.value,
config.qkv_layout.value, config.qkv_layout.value,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training, config.is_training,
...@@ -386,6 +395,12 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -386,6 +395,12 @@ class FusedAttnFwdPrimitive(BasePrimitive):
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
) )
assert softmax_offset_aval.dtype == jnp.float32
if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
assert softmax_offset_aval.shape == (1, attn_heads, 1, 1)
else:
assert softmax_offset_aval.shape == (0,)
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod @staticmethod
...@@ -405,6 +420,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -405,6 +420,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
...@@ -453,6 +469,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -453,6 +469,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
...@@ -481,6 +498,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -481,6 +498,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left, window_size_left=window_size_left,
window_size_right=window_size_right, window_size_right=window_size_right,
softmax_type=int(config.softmax_type.value),
) )
@staticmethod @staticmethod
...@@ -489,6 +507,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -489,6 +507,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -579,6 +598,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -579,6 +598,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
...@@ -596,7 +616,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -596,7 +616,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
def batcher(batched_args, batch_dims, *, config): def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, _, _, _, seed_bdim, *_ = batch_dims q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims
out_bdims = q_bdim, q_bdim, seed_bdim out_bdims = q_bdim, q_bdim, seed_bdim
return ( return (
...@@ -662,7 +682,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -662,7 +682,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
mesh, PartitionSpec(get_all_mesh_axes(), None) mesh, PartitionSpec(get_all_mesh_axes(), None)
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[5] = seed_sharding
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
...@@ -710,7 +730,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -710,7 +730,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
name = "te_fused_attn_backward_ffi" name = "te_fused_attn_backward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (16,) impl_static_args = (17,)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -720,6 +740,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -720,6 +740,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_aval, k_aval,
v_aval, v_aval,
bias_aval, bias_aval,
softmax_offset_aval,
softmax_aux_aval, softmax_aux_aval,
rng_state_aval, rng_state_aval,
output_aval, output_aval,
...@@ -781,6 +802,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -781,6 +802,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config.dropout_probability, config.dropout_probability,
config.attn_bias_type.value, config.attn_bias_type.value,
config.attn_mask_type.value, config.attn_mask_type.value,
config.softmax_type.value,
config.qkv_layout.value, config.qkv_layout.value,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training, config.is_training,
...@@ -798,15 +820,39 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -798,15 +820,39 @@ class FusedAttnBwdPrimitive(BasePrimitive):
shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype) shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype)
) )
return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval # Validate incoming softmax_offset shape and dtype
assert (
softmax_offset_aval.dtype == jnp.float32
), f"Incorrect softmax_offset dtype: {softmax_offset_aval.dtype}, expected: {jnp.float32}"
if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
assert softmax_offset_aval.shape == (1, attn_heads, 1, 1), (
f"Incorrect softmax_offset shape for {config.softmax_type}:"
f" {softmax_offset_aval.shape}, expected: (1, {attn_heads}, 1, 1)"
)
else:
assert softmax_offset_aval.shape == (0,), (
f"Incorrect softmax_offset shape for {config.softmax_type}:"
f" {softmax_offset_aval.shape}, expected: (0,)"
)
if config.softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
dsoftmax_offset_aval = q_aval.update(
shape=softmax_offset_aval.shape, dtype=softmax_offset_aval.dtype
)
else:
dsoftmax_offset_aval = q_aval.update(shape=(1, attn_heads, 1, 1), dtype=jnp.float32)
return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, wkspace_aval
@staticmethod @staticmethod
def outer_abstract(*args, **kwargs): def outer_abstract(*args, **kwargs):
""" """
Fused attention fwd outer primitive abstract Fused attention fwd outer primitive abstract
""" """
dq_aval, dk_aval, dv_aval, dbias_aval, _ = FusedAttnBwdPrimitive.abstract(*args, **kwargs) dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, _ = (
return dq_aval, dk_aval, dv_aval, dbias_aval FusedAttnBwdPrimitive.abstract(*args, **kwargs)
)
return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval
@staticmethod @staticmethod
def lowering( def lowering(
...@@ -815,6 +861,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -815,6 +861,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -866,6 +913,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -866,6 +913,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -897,6 +945,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -897,6 +945,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left, window_size_left=window_size_left,
window_size_right=window_size_right, window_size_right=window_size_right,
softmax_type=int(config.softmax_type.value),
) )
@staticmethod @staticmethod
...@@ -905,6 +954,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -905,6 +954,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -993,11 +1043,12 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -993,11 +1043,12 @@ class FusedAttnBwdPrimitive(BasePrimitive):
q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind( dq, dk, dv, dbias, dsoftmax_offset, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
q, q,
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1012,15 +1063,15 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1012,15 +1063,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_kv_segment_pos, _kv_segment_pos,
config=config, config=config,
) )
return dq, dk, dv, dbias return dq, dk, dv, dbias, dsoftmax_offset
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, config): def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None assert FusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, k_bdim, v_bdim, *_ = batch_dims q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims
out_bdims = q_bdim, k_bdim, v_bdim, q_bdim out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim
return ( return (
FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),
out_bdims, out_bdims,
...@@ -1033,11 +1084,13 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1033,11 +1084,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2]) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3]) bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding, dsoftmax_offset_sharding)
@staticmethod @staticmethod
def partition(config, mesh, arg_infos, result_infos): def partition(config, mesh, arg_infos, result_infos):
...@@ -1046,21 +1099,30 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1046,21 +1099,30 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2]) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3]) bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) out_shardings = (
dq_sharding,
dk_sharding,
dv_sharding,
dbias_sharding,
dsoftmax_offset_sharding,
)
def sharded_impl( def sharded_impl(
q, q,
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1074,36 +1136,43 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1074,36 +1136,43 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_q_segment_pos, _q_segment_pos,
_kv_segment_pos, _kv_segment_pos,
): ):
local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( local_dq, local_dk, local_dv, local_dbias, local_dsoftmax_offset = (
q, FusedAttnBwdPrimitive.impl(
k, q,
v, k,
bias, v,
softmax_aux, bias,
rng_state, softmax_offset,
output, softmax_aux,
doutput, rng_state,
q_cu_seqlen, output,
kv_cu_seqlen, doutput,
q_seq_offsets, q_cu_seqlen,
k_seq_offsets, kv_cu_seqlen,
_q_segment_ids, q_seq_offsets,
_kv_segment_ids, k_seq_offsets,
_q_segment_pos, _q_segment_ids,
_kv_segment_pos, _kv_segment_ids,
config=config, _q_segment_pos,
_kv_segment_pos,
config=config,
)
) )
global_dbias = local_dbias global_dbias = local_dbias
if config.attn_bias_type is not AttnBiasType.NO_BIAS: if config.attn_bias_type is not AttnBiasType.NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
return local_dq, local_dk, local_dv, global_dbias
global_dsoftmax_offset = local_dsoftmax_offset
if config.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
global_dsoftmax_offset = all_reduce_sum_along_dp_fsdp(local_dsoftmax_offset, mesh)
return local_dq, local_dk, local_dv, global_dbias, global_dsoftmax_offset
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod @staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types): def shardy_sharding_rule(config, mesh, value_types, result_types):
del config, mesh del config, mesh
# We only care about the four first arguments.
# Keep in sync with `infer_sharding_from_operands`. # Keep in sync with `infer_sharding_from_operands`.
input_spec = tuple((f"…{x}",) for x in range(len(value_types))) input_spec = tuple((f"…{x}",) for x in range(len(value_types)))
output_spec = tuple((f"…{x}",) for x in range(len(result_types))) output_spec = tuple((f"…{x}",) for x in range(len(result_types)))
...@@ -1229,6 +1298,11 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1229,6 +1298,11 @@ class _FusedAttnCPWithAllGatherHelper:
if self.config.dropout_probability != 0.0: if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout") raise ValueError(f"{header} does not support dropout")
if self.config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
raise ValueError(
f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}"
)
def get_adjusted_mask(self): def get_adjusted_mask(self):
"""Converts the mask for context parallelism.""" """Converts the mask for context parallelism."""
if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK: if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
...@@ -1240,6 +1314,7 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1240,6 +1314,7 @@ class _FusedAttnCPWithAllGatherHelper:
return _FusedAttnConfig( return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type, attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(), attn_mask_type=self.get_adjusted_mask(),
softmax_type=self.config.softmax_type,
qkv_layout=self.config.qkv_layout, qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor, scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability, dropout_probability=self.config.dropout_probability,
...@@ -1376,7 +1451,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1376,7 +1451,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
mesh, PartitionSpec(get_all_mesh_axes(), None) mesh, PartitionSpec(get_all_mesh_axes(), None)
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[5] = seed_sharding
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
...@@ -1385,6 +1460,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1385,6 +1460,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -1404,7 +1480,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1404,7 +1480,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
# meeting the expectation of the SPMD model. # meeting the expectation of the SPMD model.
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# mask/sequence length tensor to avoid this unrolled loop. # mask/sequence length tensor to avoid this unrolled loop.
def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed):
kv_max_seqlen = k.shape[1] kv_max_seqlen = k.shape[1]
kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"
...@@ -1431,6 +1507,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1431,6 +1507,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k_unmasked, k_unmasked,
v_unmasked, v_unmasked,
bias, bias,
softmax_offset,
seed, seed,
q_seqlen_for_step, q_seqlen_for_step,
kv_seqlen_for_step, kv_seqlen_for_step,
...@@ -1453,7 +1530,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1453,7 +1530,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k_ag, v_ag = helper.all_gather_kv(k, v) k_ag, v_ag = helper.all_gather_kv(k, v)
functions = [ functions = [
partial(_cross_attn, idx, q, k_ag, v_ag, bias, q_seqlen, kv_seqlen, seed) partial(
_cross_attn, idx, q, k_ag, v_ag, bias, softmax_offset, q_seqlen, kv_seqlen, seed
)
for idx in range(cp_size) for idx in range(cp_size)
] ]
...@@ -1492,18 +1571,27 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1492,18 +1571,27 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2]) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3]) bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) out_shardings = (
dq_sharding,
dk_sharding,
dv_sharding,
dbias_sharding,
dsoftmax_offset_sharding,
)
def impl( def impl(
q, q,
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1527,6 +1615,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1527,6 +1615,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1562,11 +1651,12 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1562,11 +1651,12 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks
dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl( dq_local, dk_local, dv_local, dbias_local, _ = FusedAttnBwdPrimitive.impl(
q_split[sub_idx], q_split[sub_idx],
k_unmasked, k_unmasked,
v_unmasked, v_unmasked,
bias, bias,
softmax_offset,
softmax_aux_split[sub_idx], softmax_aux_split[sub_idx],
rng_state, rng_state,
output_split[sub_idx], output_split[sub_idx],
...@@ -1604,6 +1694,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1604,6 +1694,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k_ag, k_ag,
v_ag, v_ag,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1621,7 +1712,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1621,7 +1712,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions) dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local) dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)
return dq, dk, dv, dbias # Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it)
dummy_dsoftmax_offset = jnp.empty_like(softmax_offset)
return dq, dk, dv, dbias, dummy_dsoftmax_offset
return mesh, impl, out_shardings, arg_shardings return mesh, impl, out_shardings, arg_shardings
...@@ -1679,6 +1772,11 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1679,6 +1772,11 @@ class _FusedAttnCPWithP2PHelper:
if self.config.dropout_probability != 0.0: if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout") raise ValueError(f"{header} does not support dropout")
if self.config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
raise ValueError(
f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}"
)
# We want to encourage use of scan loop to minimize unrolling and ensure more # We want to encourage use of scan loop to minimize unrolling and ensure more
# predictable scheduling from XLA. The unrolled flavor will be supported but # predictable scheduling from XLA. The unrolled flavor will be supported but
# not the prefered implementation. # not the prefered implementation.
...@@ -1703,6 +1801,7 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1703,6 +1801,7 @@ class _FusedAttnCPWithP2PHelper:
return _FusedAttnConfig( return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type, attn_bias_type=self.config.attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=self.config.softmax_type,
qkv_layout=QKVLayout.BSHD_BS2HD, qkv_layout=QKVLayout.BSHD_BS2HD,
scaling_factor=self.config.scaling_factor, scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability, dropout_probability=self.config.dropout_probability,
...@@ -1783,7 +1882,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1783,7 +1882,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
mesh, PartitionSpec(get_all_mesh_axes(), None) mesh, PartitionSpec(get_all_mesh_axes(), None)
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[5] = seed_sharding
# Ensure segment_pos gets same sharding as ID. # Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
...@@ -1795,6 +1894,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1795,6 +1894,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
k, k,
v, v,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -1840,6 +1940,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1840,6 +1940,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen_per_step, q_seqlen_per_step,
kv_seqlen_per_step, kv_seqlen_per_step,
...@@ -1865,6 +1966,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1865,6 +1966,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv_part, kv_part,
_not_used, _not_used,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen_per_step, q_seqlen_per_step,
kv_seqlen_per_step, kv_seqlen_per_step,
...@@ -1887,6 +1989,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1887,6 +1989,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen_per_step, q_seqlen_per_step,
kv_seqlen_per_step, kv_seqlen_per_step,
...@@ -1990,18 +2093,24 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1990,18 +2093,24 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2]) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3]) bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
# Ring attention doesn't use dsoftmax_offset, but we need to return it for arity matching
dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) dq_sharding,
dk_sharding,
dv_sharding,
dbias_sharding,
dsoftmax_offset_sharding,
)
helper = _FusedAttnCPWithP2PHelper(mesh, config) helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported() helper.check_supported()
...@@ -2011,6 +2120,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2011,6 +2120,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
k, k,
v, v,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2054,11 +2164,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2054,11 +2164,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
def mask_compute(attn_mask_type): def mask_compute(attn_mask_type):
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
q, q,
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2082,11 +2193,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2082,11 +2193,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2 kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1) kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1)
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
q, q,
kv_part, kv_part,
_not_used, _not_used,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2120,11 +2232,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2120,11 +2232,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2 softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2
) )
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
q_part, q_part,
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
softmax_aux_part, softmax_aux_part,
rng_state, rng_state,
output_part, output_part,
...@@ -2184,7 +2297,9 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2184,7 +2297,9 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dk_dv) dk, dv = helper.unstack_kv(dk_dv)
return dq, dk, dv, global_dbias # Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset = jnp.empty_like(_softmax_offset)
return dq, dk, dv, global_dbias, dummy_dsoftmax_offset
return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings
...@@ -2273,7 +2388,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2273,7 +2388,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
mesh, PartitionSpec(get_all_mesh_axes(), None) mesh, PartitionSpec(get_all_mesh_axes(), None)
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[5] = seed_sharding
# Ensure segment_pos gets same sharding as ID. # Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
...@@ -2285,6 +2400,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2285,6 +2400,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
k, k,
v, v,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -2336,6 +2452,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2336,6 +2452,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -2345,7 +2462,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2345,7 +2462,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv_segment_ids, kv_segment_ids,
q_segment_pos, q_segment_pos,
kv_segment_pos, kv_segment_pos,
config, config=config,
) )
if config.window_size != (-1, -1): if config.window_size != (-1, -1):
...@@ -2420,8 +2537,8 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2420,8 +2537,8 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
# dq, dk, dv, dbias sharding = q, k, v, bias sharding # dq, dk, dv, dbias, dsoftmax_offset sharding = q, k, v, bias, softmax_offset sharding
out_shardings = tuple(arg.sharding for arg in arg_infos[:4]) out_shardings = tuple(arg.sharding for arg in arg_infos[:5])
helper = _FusedAttnCPWithP2PHelper(mesh, config) helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported() helper.check_supported()
...@@ -2431,6 +2548,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2431,6 +2548,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
k, k,
v, v,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2478,11 +2596,12 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2478,11 +2596,12 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm) kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
def compute(config): def compute(config):
dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dkv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
q, q,
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2536,7 +2655,9 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2536,7 +2655,9 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dkv) dk, dv = helper.unstack_kv(dkv)
return dq, dk, dv, global_dbias # Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset = jnp.empty_like(_softmax_offset)
return dq, dk, dv, global_dbias, dummy_dsoftmax_offset
return mesh, bwd_impl, out_shardings, arg_shardings return mesh, bwd_impl, out_shardings, arg_shardings
...@@ -2557,10 +2678,12 @@ def _maybe_context_parallel_axis(cp_axis: str): ...@@ -2557,10 +2678,12 @@ def _maybe_context_parallel_axis(cp_axis: str):
def fused_attn_fwd( def fused_attn_fwd(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
softmax_offset: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor, sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray], seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
softmax_type: AttnSoftmaxType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
...@@ -2585,6 +2708,7 @@ def fused_attn_fwd( ...@@ -2585,6 +2708,7 @@ def fused_attn_fwd(
query has a different shape (e.g., cross-attention). query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors. - `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,]. q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,]. kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
q_seq_offsets (jnp.ndarray): q_seq_offsets (jnp.ndarray):
...@@ -2594,6 +2718,7 @@ def fused_attn_fwd( ...@@ -2594,6 +2718,7 @@ def fused_attn_fwd(
seed (Optional[jnp.ndarray]): Optional random seed for dropout. seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
...@@ -2633,10 +2758,36 @@ def fused_attn_fwd( ...@@ -2633,10 +2758,36 @@ def fused_attn_fwd(
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
if softmax_offset is None:
assert (
softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX
), f"Softmax type {softmax_type} is not supported when softmax_offset is None"
if softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
num_heads = qkv[0].shape[-2]
# Create tensor [1, h, 1, 1] filled with zeros (logit value = 0)
# This adds exp(0 - x_max) = exp(-x_max) to the denominator,
# which contributes exactly 1 after normalization, giving: exp(x_i) / (sum(exp(x_j)) + 1)
softmax_offset = jnp.zeros((1, num_heads, 1, 1), dtype=jnp.float32)
# Shard by heads dimension
softmax_offset = with_sharding_constraint_by_logical_axes(
softmax_offset, (None, HEAD_AXES, None, None)
)
else:
assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX
softmax_offset = jnp.zeros(0, dtype=jnp.float32)
else:
assert softmax_offset.dtype == jnp.float32
# Shard by heads dimension if not VANILLA_SOFTMAX
if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_offset = with_sharding_constraint_by_logical_axes(
softmax_offset, (None, HEAD_AXES, None, None)
)
fused_config = _FusedAttnConfig( fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
softmax_type=softmax_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
...@@ -2662,6 +2813,7 @@ def fused_attn_fwd( ...@@ -2662,6 +2813,7 @@ def fused_attn_fwd(
output, softmax_aux, rng_state = primitive.bind( output, softmax_aux, rng_state = primitive.bind(
*qkv_for_primitive, *qkv_for_primitive,
bias, bias,
softmax_offset,
seed, seed,
*seq_desc_flatten, *seq_desc_flatten,
config=fused_config, config=fused_config,
...@@ -2673,6 +2825,7 @@ def fused_attn_fwd( ...@@ -2673,6 +2825,7 @@ def fused_attn_fwd(
def fused_attn_bwd( def fused_attn_bwd(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
softmax_offset: Optional[jnp.ndarray],
softmax_aux: jnp.ndarray, softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray, rng_state: jnp.ndarray,
output: jnp.ndarray, output: jnp.ndarray,
...@@ -2681,6 +2834,7 @@ def fused_attn_bwd( ...@@ -2681,6 +2834,7 @@ def fused_attn_bwd(
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -2702,6 +2856,7 @@ def fused_attn_bwd( ...@@ -2702,6 +2856,7 @@ def fused_attn_bwd(
query has a different shape (e.g., cross-attention). query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors. - `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass. softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass. rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
output (jnp.ndarray): The output tensor from the forward pass. output (jnp.ndarray): The output tensor from the forward pass.
...@@ -2714,6 +2869,7 @@ def fused_attn_bwd( ...@@ -2714,6 +2869,7 @@ def fused_attn_bwd(
The offsets in the sequence dim for the query, with shape [batch + 1,]. The offsets in the sequence dim for the query, with shape [batch + 1,].
attn_bias_type (AttnBiasType): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
...@@ -2755,6 +2911,28 @@ def fused_attn_bwd( ...@@ -2755,6 +2911,28 @@ def fused_attn_bwd(
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
if softmax_offset is None:
assert softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX, f"Unknown {softmax_type=}"
if softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
num_heads = qkv[0].shape[-2]
# Create tensor [1, h, 1, 1] filled with zeros
softmax_offset = jnp.zeros((1, num_heads, 1, 1), dtype=jnp.float32)
# Shard by heads dimension
softmax_offset = with_sharding_constraint_by_logical_axes(
softmax_offset, (None, HEAD_AXES, None, None)
)
elif softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_offset = jnp.zeros(0, dtype=jnp.float32)
else:
raise NotImplementedError(f"Unknown {softmax_type=}")
else:
softmax_offset = softmax_offset.astype(jnp.float32)
# Shard by heads dimension if not VANILLA_SOFTMAX
if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_offset = with_sharding_constraint_by_logical_axes(
softmax_offset, (None, HEAD_AXES, None, None)
)
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+ # sm100+
compute_capabilities = get_all_device_compute_capability() compute_capabilities = get_all_device_compute_capability()
...@@ -2767,6 +2945,7 @@ def fused_attn_bwd( ...@@ -2767,6 +2945,7 @@ def fused_attn_bwd(
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
softmax_type=softmax_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
...@@ -2788,9 +2967,10 @@ def fused_attn_bwd( ...@@ -2788,9 +2967,10 @@ def fused_attn_bwd(
primitive = FusedRingAttnBwdPrimitive.outer_primitive primitive = FusedRingAttnBwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
*qkv_grads, bias_grad = primitive.bind( *qkv_grads, bias_grad, softmax_offset_grad = primitive.bind(
*qkv_for_primitive, *qkv_for_primitive,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2798,4 +2978,4 @@ def fused_attn_bwd( ...@@ -2798,4 +2978,4 @@ def fused_attn_bwd(
*seq_desc_flatten, *seq_desc_flatten,
config=fused_config, config=fused_config,
) )
return tuple(qkv_grads[: len(qkv)]), bias_grad return tuple(qkv_grads[: len(qkv)]), bias_grad, softmax_offset_grad
...@@ -39,12 +39,12 @@ from ..quantize import ( ...@@ -39,12 +39,12 @@ from ..quantize import (
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
QuantizerSet, QuantizerSet,
QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
get_quantize_config_with_recipe, get_quantize_config_with_recipe,
get_global_quantize_recipe, get_global_quantize_recipe,
QuantizeLayout,
) )
from .misc import get_padded_spec, is_all_reduce_in_float32 from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import ( from ..sharding import (
......
...@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1): ...@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
transpose. Note, transpose_axis should be greater than static_axis_boundary transpose. Note, transpose_axis should be greater than static_axis_boundary
examples: examples:
X in shape (dim0, dim1, dim2, dim3, dim4) X of shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis == 2 static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1) Xt = (dim2, dim3, dim4, dim0, dim1)
......
...@@ -35,9 +35,9 @@ from ..sharding import ( ...@@ -35,9 +35,9 @@ from ..sharding import (
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeLayout,
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
QuantizeLayout,
) )
......
...@@ -40,11 +40,11 @@ from ..quantize import ( ...@@ -40,11 +40,11 @@ from ..quantize import (
GroupedScaledTensor1x, GroupedScaledTensor1x,
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
QuantizeLayout,
ScalingMode, ScalingMode,
compute_scale_from_amax, compute_scale_from_amax,
NoScaleTensor, NoScaleTensor,
get_rht_matrix, get_rht_matrix,
QuantizeLayout,
) )
...@@ -497,6 +497,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -497,6 +497,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
amax_spec = get_padded_spec(arg_infos[2]) amax_spec = get_padded_spec(arg_infos[2])
sr_rng_state_spec = get_padded_spec(arg_infos[3])
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*x_spec), PartitionSpec(*x_spec),
...@@ -551,11 +552,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -551,11 +552,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
) )
arg_shardings = list(arg_i.sharding for arg_i in arg_infos) arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings[3] = NamedSharding( if len(sr_rng_state_spec) > 1:
mesh, # sr_rng_state shape [n_devices, state_per_device]
PartitionSpec(tuple(x for x in x_spec if x is not None), None), sr_rng_state_spec = (*tuple(x for x in x_spec if x is not None), None)
desc="BaseDBiasQuantizePrimitive.sr_rng_state", arg_shardings[3] = NamedSharding(
) mesh,
PartitionSpec(*sr_rng_state_spec),
desc="BaseDBiasQuantizePrimitive.sr_rng_state",
)
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
...@@ -654,10 +658,12 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -654,10 +658,12 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",) dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",)
amax = (BATCHING + prefix + "_amax",) amax = (BATCHING + prefix + "_amax",)
scale = (BATCHING + prefix + "_scale",) scale = (BATCHING + prefix + "_scale",)
sr_rng_state = ( sr_rng_state = (BATCHING + prefix + "_sr_rng_state",)
BATCHING + prefix + "_sr_rng_state_partition_axis", if value_types[3].shape != [0]:
BATCHING + prefix + "sr_rng_state_data_axis", sr_rng_state = (
) BATCHING + prefix + "_sr_rng_state_devices",
prefix + "sr_rng_state_data",
)
post_rht_amax = (BATCHING + prefix + "_post_rht_amax",) post_rht_amax = (BATCHING + prefix + "_post_rht_amax",)
rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2") rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2")
...@@ -849,7 +855,7 @@ def _quantize_dbias_impl( ...@@ -849,7 +855,7 @@ def _quantize_dbias_impl(
if force_1x_quantization: if force_1x_quantization:
q_layout = QuantizeLayout.ROWWISE q_layout = QuantizeLayout.ROWWISE
sr_rng_state = None sr_rng_state = jnp.empty((0,), jnp.uint32)
if quantizer.scaling_mode.is_nvfp4_scaling: if quantizer.scaling_mode.is_nvfp4_scaling:
# Only NVFP4 scaling modes support stochastic rounding # Only NVFP4 scaling modes support stochastic rounding
if quantizer.stochastic_rounding_rng_state is not None: if quantizer.stochastic_rounding_rng_state is not None:
...@@ -866,11 +872,7 @@ def _quantize_dbias_impl( ...@@ -866,11 +872,7 @@ def _quantize_dbias_impl(
x.data, x.data,
scale, scale,
amax, amax,
( sr_rng_state,
sr_rng_state
if sr_rng_state is not None
else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32)
),
post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32), post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
rht_matrix, rht_matrix,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
...@@ -880,7 +882,7 @@ def _quantize_dbias_impl( ...@@ -880,7 +882,7 @@ def _quantize_dbias_impl(
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False, is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
is_outer=True, is_outer=True,
stochastic_rounding=sr_rng_state is not None, stochastic_rounding=sr_rng_state.size != 0,
use_rht=use_rht, use_rht=use_rht,
) )
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
......
...@@ -11,10 +11,11 @@ import jax ...@@ -11,10 +11,11 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, ffi from jax import dtypes, ffi
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from .attention import AttnSoftmaxType
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .misc import get_padded_spec, check_valid_batch_dims from .misc import get_padded_spec, check_valid_batch_dims
from ..softmax import SoftmaxType from ..softmax import SoftmaxFusionType
__all__ = [ __all__ = [
...@@ -32,7 +33,8 @@ __all__ = [ ...@@ -32,7 +33,8 @@ __all__ = [
def is_softmax_kernel_available( def is_softmax_kernel_available(
softmax_type: SoftmaxType, softmax_fusion_type: SoftmaxFusionType,
softmax_type: AttnSoftmaxType,
batch: int, batch: int,
heads: int, heads: int,
q_seqlen: int, q_seqlen: int,
...@@ -40,15 +42,18 @@ def is_softmax_kernel_available( ...@@ -40,15 +42,18 @@ def is_softmax_kernel_available(
dtype: jnp.dtype, dtype: jnp.dtype,
): ):
"""check softmax available""" """check softmax available"""
if softmax_type is SoftmaxType.SCALED: if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
return False
if softmax_fusion_type is SoftmaxFusionType.SCALED:
return ScaledSoftmaxFwdPrimitive.is_kernel_available( return ScaledSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype batch, heads, q_seqlen, k_seqlen, dtype
) )
if softmax_type is SoftmaxType.SCALED_MASKED: if softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED:
return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available( return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype batch, heads, q_seqlen, k_seqlen, dtype
) )
if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: if softmax_fusion_type is SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available( return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype batch, heads, q_seqlen, k_seqlen, dtype
) )
...@@ -792,26 +797,77 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -792,26 +797,77 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
def jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float): def jax_scaled_softmax(
logits: jnp.ndarray, scale_factor: float, softmax_offset: jnp.ndarray | float | None = None
):
""" """
JAX based implementation of scaled softmax JAX based implementation of scaled softmax
""" """
if softmax_offset is not None:
return jax_general_softmax(scale_factor * logits, offset=softmax_offset)
return jax.nn.softmax(scale_factor * logits) return jax.nn.softmax(scale_factor * logits)
def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float): def jax_scaled_masked_softmax(
logits: jnp.ndarray,
mask: jnp.ndarray,
scale_factor: float,
softmax_offset: jnp.ndarray | float | None = None,
):
""" """
JAX based implementation of scaled and masked softmax JAX based implementation of scaled and masked softmax
""" """
if softmax_offset is not None:
return jax_general_softmax(logits * scale_factor, offset=softmax_offset, where=mask != 1)
return jax.nn.softmax(logits * scale_factor, where=mask != 1) return jax.nn.softmax(logits * scale_factor, where=mask != 1)
def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): def jax_scaled_upper_triang_masked_softmax(
logits: jnp.ndarray, scale_factor: float, softmax_offset: jnp.ndarray | float | None = None
):
""" """
JAX based implementation of scaled and upper triangle masked softmax JAX based implementation of scaled and upper triangle masked softmax
""" """
mask = 1 - jnp.tril(jnp.ones_like(logits)) mask = 1 - jnp.tril(jnp.ones_like(logits))
return jax_scaled_masked_softmax(logits, mask, scale_factor) return jax_scaled_masked_softmax(logits, mask, scale_factor, softmax_offset)
def jax_general_softmax(
x: jnp.ndarray,
axis: int = -1,
where: jnp.ndarray | None = None,
initial: jnp.ndarray = -jnp.inf,
offset: jnp.ndarray | float | None = None,
) -> jnp.ndarray:
"""
JAX based implementation of general softmax with optional masking and offset.
"""
# Compute max of x
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
if offset is not None:
# Cast offset to x.dtype to prevent type promotion
if isinstance(offset, (int, float)):
offset = jnp.array(offset, dtype=x.dtype)
else:
offset = offset.astype(x.dtype)
# Include offset in max: x_max = max(x_max, offset)
# This is equivalent to computing max over [x..., offset]
x_max = jnp.maximum(x_max, offset)
x_safe = x if where is None else jnp.where(where, x, initial)
unnormalized = jnp.exp(x_safe - x_max)
denominator = jnp.sum(unnormalized, axis, where=where, keepdims=True)
if offset is not None:
# Add exp(offset - x_max) to denominator
denominator = denominator + jnp.exp(offset - x_max)
result = unnormalized / denominator
if where is not None:
result = jnp.where(where, result, 0)
return result
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
......
...@@ -108,28 +108,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); ...@@ -108,28 +108,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
size_t q_num_heads, size_t kv_num_heads, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left,
size_t qk_head_dim, size_t v_head_dim, int64_t window_size_right);
int64_t window_size_left, int64_t window_size_right);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes( pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
int64_t window_size_right); int64_t window_size_left, int64_t window_size_right);
// GEMM // GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
......
...@@ -11,14 +11,12 @@ ...@@ -11,14 +11,12 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
size_t q_attn_heads, size_t kv_attn_heads, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left,
size_t qk_head_dim, size_t v_head_dim, int64_t window_size_right) {
int64_t window_size_left, int64_t window_size_right) {
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
...@@ -39,7 +37,8 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t ...@@ -39,7 +37,8 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
const size_t kv_max_seqlen, DType dtype, const size_t kv_max_seqlen, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend, NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
void *softmax_buf, void *rng_state_buf = nullptr, void *softmax_buf, void *rng_state_buf = nullptr,
void *bias_buf = nullptr) { void *bias_buf = nullptr,
void *softmax_offset_buf = nullptr) {
// all backends need softmax but expect different shapes/dtypes // all backends need softmax but expect different shapes/dtypes
// start with the max512 sequence length softmax shape/dtype and correct later // start with the max512 sequence length softmax shape/dtype and correct later
tensor_pack->size = 1; tensor_pack->size = 1;
...@@ -67,10 +66,12 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t ...@@ -67,10 +66,12 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
softmax_aux_data.shape.data[3] = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1} softmax_aux_data.shape.data[3] = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1}
softmax_aux_data.dtype = static_cast<NVTEDType>(DType::kFloat32); softmax_aux_data.dtype = static_cast<NVTEDType>(DType::kFloat32);
int size = 2; // Start at 2 (we have softmax and rng_state at indices 0, 1)
// include bias if enabled // include bias if enabled
if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) { if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
tensor_pack->size = 3; NVTETensor &bias_aux = tensor_pack->tensors[size];
NVTETensor &bias_aux = tensor_pack->tensors[2]; size++;
NVTEBasicTensor bias_aux_data; NVTEBasicTensor bias_aux_data;
bias_aux_data.data_ptr = bias_buf; bias_aux_data.data_ptr = bias_buf;
bias_aux_data.shape.ndim = 4; bias_aux_data.shape.ndim = 4;
...@@ -81,6 +82,24 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t ...@@ -81,6 +82,24 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
bias_aux_data.dtype = static_cast<NVTEDType>(dtype); bias_aux_data.dtype = static_cast<NVTEDType>(dtype);
nvte_set_tensor_param(&bias_aux, kNVTERowwiseData, &bias_aux_data); nvte_set_tensor_param(&bias_aux, kNVTERowwiseData, &bias_aux_data);
} }
// include softmax_offset if provided
if (softmax_offset_buf != nullptr) {
NVTETensor &softmax_offset_aux = tensor_pack->tensors[size];
size++;
NVTEBasicTensor softmax_offset_aux_data;
softmax_offset_aux_data.data_ptr = softmax_offset_buf;
softmax_offset_aux_data.shape.ndim = 4;
softmax_offset_aux_data.shape.data[0] = 1;
softmax_offset_aux_data.shape.data[1] = attn_heads;
softmax_offset_aux_data.shape.data[2] = 1;
softmax_offset_aux_data.shape.data[3] = 1;
softmax_offset_aux_data.dtype = static_cast<NVTEDType>(DType::kFloat32);
nvte_set_tensor_param(&softmax_offset_aux, kNVTERowwiseData, &softmax_offset_aux_data);
}
// Set final size
tensor_pack->size = size;
} }
nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data); nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data);
} }
...@@ -98,14 +117,16 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_ ...@@ -98,14 +117,16 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_
const size_t bias_heads, const size_t q_max_seqlen, const size_t bias_heads, const size_t q_max_seqlen,
const size_t kv_max_seqlen, DType dtype, const size_t kv_max_seqlen, DType dtype,
NVTE_Fused_Attn_Backend backend, void *softmax_buf, NVTE_Fused_Attn_Backend backend, void *softmax_buf,
void *rng_state_buf, void *bias_buf) { void *rng_state_buf, void *bias_buf,
void *softmax_offset_buf = nullptr) {
// Backward calls put everything into the tensor pack for every backend // Backward calls put everything into the tensor pack for every backend
// so we set dummy bias_type and backend choices here to follow the correct code path // so we set dummy bias_type and backend choices here to follow the correct code path
auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
PrepareFusedAttnForwardAuxTensors(tensor_pack, input_batch, bias_batch, attn_heads, bias_heads, PrepareFusedAttnForwardAuxTensors(tensor_pack, input_batch, bias_batch, attn_heads, bias_heads,
q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type, q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type,
dummy_backend, softmax_buf, rng_state_buf, bias_buf); dummy_backend, softmax_buf, rng_state_buf, bias_buf,
softmax_offset_buf);
// correct softmax shape for max512 sequence length kernel // correct softmax shape for max512 sequence length kernel
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
...@@ -121,8 +142,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -121,8 +142,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
...@@ -141,7 +163,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -141,7 +163,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
auto dummy_softmax_offset_tensor = auto dummy_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32); TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
...@@ -208,18 +229,21 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -208,18 +229,21 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto layout_group = nvte_get_qkv_layout_group(qkv_layout); auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
static void FusedAttnForwardImpl( static void FusedAttnForwardImpl(
cudaStream_t stream, void *q, void *k, void *v, void *bias, void *seed, void *q_cu_seqlens, cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_offset, void *seed,
void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output,
void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch,
size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups,
size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, size_t bias_heads, size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { DType dtype, DType wkspace_dtype, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) {
FUSED_ATTN_IMPL_COMMON_BLOCK; FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */ /* Input tensors */
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto softmax_offset_tensor =
TensorWrapper(softmax_offset, std::vector<size_t>{1, attn_heads, 1, 1}, DType::kFloat32);
if (is_ragged) { if (is_ragged) {
auto output_size = input_batch * q_max_seqlen * attn_heads * v_head_dim; auto output_size = input_batch * q_max_seqlen * attn_heads * v_head_dim;
...@@ -238,10 +262,6 @@ static void FusedAttnForwardImpl( ...@@ -238,10 +262,6 @@ static void FusedAttnForwardImpl(
/* Prepare RNG state */ /* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto dummy_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
...@@ -254,7 +274,7 @@ static void FusedAttnForwardImpl( ...@@ -254,7 +274,7 @@ static void FusedAttnForwardImpl(
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads, PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
backend, softmax_aux); backend, softmax_aux, softmax_offset);
/* Call the underlying NVTE API */ /* Call the underlying NVTE API */
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
...@@ -303,7 +323,7 @@ static void FusedAttnForwardImpl( ...@@ -303,7 +323,7 @@ static void FusedAttnForwardImpl(
nvte_fused_attn_fwd( nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
...@@ -332,6 +352,8 @@ static void FusedAttnForwardImpl( ...@@ -332,6 +352,8 @@ static void FusedAttnForwardImpl(
static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type")); \ static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type")); \
NVTE_Mask_Type mask_type = \ NVTE_Mask_Type mask_type = \
static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type")); \ static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type")); \
NVTE_Softmax_Type softmax_type = \
static_cast<NVTE_Softmax_Type>(get_attr_value<int64_t>(attrs, "softmax_type")); \
NVTE_QKV_Layout qkv_layout = \ NVTE_QKV_Layout qkv_layout = \
static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout")); \ static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout")); \
bool is_training = get_attr_value<bool>(attrs, "is_training"); \ bool is_training = get_attr_value<bool>(attrs, "is_training"); \
...@@ -342,7 +364,8 @@ static void FusedAttnForwardImpl( ...@@ -342,7 +364,8 @@ static void FusedAttnForwardImpl(
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type seed_buf, Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type softmax_offset_buf, Buffer_Type seed_buf,
Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
Variadic_Buffer_Type _unused_args, Result_Type output_buf, Variadic_Buffer_Type _unused_args, Result_Type output_buf,
...@@ -352,15 +375,15 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty ...@@ -352,15 +375,15 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
FusedAttnForwardImpl( FusedAttnForwardImpl(
stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
bias_buf.untyped_data(), seed_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), bias_buf.untyped_data(), softmax_offset_buf.untyped_data(), seed_buf.untyped_data(),
kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(),
is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(), is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(),
softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(),
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor, qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor,
dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, is_training, dropout_probability, bias_type, mask_type, softmax_type, qkv_layout, dtype, wkspace_dtype,
deterministic, window_size_left, window_size_right); is_training, deterministic, window_size_left, window_size_right);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -371,6 +394,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI, ...@@ -371,6 +394,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
.Arg<Buffer_Type>() // k .Arg<Buffer_Type>() // k
.Arg<Buffer_Type>() // v .Arg<Buffer_Type>() // v
.Arg<Buffer_Type>() // bias .Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // softmax_offset
.Arg<Buffer_Type>() // seed_buf .Arg<Buffer_Type>() // seed_buf
.Arg<Buffer_Type>() // q_cu_seqlens .Arg<Buffer_Type>() // q_cu_seqlens
.Arg<Buffer_Type>() // kv_cu_seqlens .Arg<Buffer_Type>() // kv_cu_seqlens
...@@ -388,9 +412,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -388,9 +412,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
int64_t window_size_right) { int64_t window_size_left, int64_t window_size_right) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
...@@ -425,9 +449,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -425,9 +449,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
// For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0 // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
min_num_segments = input_batch * max_segments_per_seq; min_num_segments = input_batch * max_segments_per_seq;
} }
auto dummy_d_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32); TensorWrapper dummy_d_softmax_offset_tensor;
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; if (softmax_type == NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX ||
softmax_type == NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX) {
dummy_d_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1, attn_heads, 1, 1}, DType::kFloat32);
}
for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) { for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
// the last one is the largest which will be the returned workspace size // the last one is the largest which will be the returned workspace size
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
...@@ -457,15 +486,16 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -457,15 +486,16 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
} }
static void FusedAttnBackwardImpl( static void FusedAttnBackwardImpl(
cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_aux, void *rng_state, cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_offset,
void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, void *softmax_aux, void *rng_state, void *output, void *doutput, void *q_cu_seqlens,
void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace, void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *dq, void *dk, void *dv,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, void *dbias, void *dsoftmax_offset, void *workspace, size_t input_batch, size_t bias_batch,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups,
size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, size_t bias_heads, size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
bool deterministic, int64_t window_size_left, int64_t window_size_right) { DType dtype, DType wkspace_dtype, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) {
FUSED_ATTN_IMPL_COMMON_BLOCK; FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */ /* Input tensors */
...@@ -476,9 +506,13 @@ static void FusedAttnBackwardImpl( ...@@ -476,9 +506,13 @@ static void FusedAttnBackwardImpl(
/* Output tensors */ /* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16 auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto dummy_d_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32); TensorWrapper dsoftmax_offset_tensor;
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; if (softmax_type == NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX ||
softmax_type == NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX) {
dsoftmax_offset_tensor =
TensorWrapper(dsoftmax_offset, std::vector<size_t>{1, attn_heads, 1, 1}, DType::kFloat32);
}
/* Auxiliary tensors (propagated from the forward pass) */ /* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
...@@ -490,7 +524,7 @@ static void FusedAttnBackwardImpl( ...@@ -490,7 +524,7 @@ static void FusedAttnBackwardImpl(
false, false); false, false);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias); softmax_aux, rng_state, bias, softmax_offset);
/* Call the underly NVTE API */ /* Call the underly NVTE API */
// Prepare Q, K, V pointers and shapes based on layout // Prepare Q, K, V pointers and shapes based on layout
...@@ -564,7 +598,7 @@ static void FusedAttnBackwardImpl( ...@@ -564,7 +598,7 @@ static void FusedAttnBackwardImpl(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dsoftmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream); window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream);
...@@ -574,26 +608,29 @@ static void FusedAttnBackwardImpl( ...@@ -574,26 +608,29 @@ static void FusedAttnBackwardImpl(
Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf, Buffer_Type softmax_offset_buf, Buffer_Type softmax_aux_buf,
Buffer_Type output_buf, Buffer_Type doutput_buf, Buffer_Type rng_state_buf, Buffer_Type output_buf,
Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type doutput_buf, Buffer_Type q_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_seq_offsets_buf,
Variadic_Buffer_Type _unused_args, Result_Type dq_buf, Buffer_Type k_seq_offsets_buf, Variadic_Buffer_Type _unused_args,
Result_Type dk_buf, Result_Type dv_buf, Result_Type dbias_buf, Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf,
Result_Type dbias_buf, Result_Type dsoftmax_offset_buf,
Result_Type workspace_buf, Dictionary attrs) { Result_Type workspace_buf, Dictionary attrs) {
FUSED_ATTN_FFI_GET_ATTRS; FUSED_ATTN_FFI_GET_ATTRS;
FusedAttnBackwardImpl( FusedAttnBackwardImpl(
stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
bias_buf.untyped_data(), softmax_aux_buf.untyped_data(), rng_state_buf.untyped_data(), bias_buf.untyped_data(), softmax_offset_buf.untyped_data(), softmax_aux_buf.untyped_data(),
output_buf.untyped_data(), doutput_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), rng_state_buf.untyped_data(), output_buf.untyped_data(), doutput_buf.untyped_data(),
kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(),
is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(), is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(),
dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(), dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(),
workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, dsoftmax_offset_buf->untyped_data(), workspace_buf->untyped_data(), input_batch, bias_batch,
attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim,
wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type,
wkspace_dtype, is_training, deterministic, window_size_left, window_size_right); softmax_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left,
window_size_right);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -605,6 +642,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, ...@@ -605,6 +642,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
.Arg<Buffer_Type>() // k .Arg<Buffer_Type>() // k
.Arg<Buffer_Type>() // v .Arg<Buffer_Type>() // v
.Arg<Buffer_Type>() // bias .Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // softmax_offset
.Arg<Buffer_Type>() // softmax_aux .Arg<Buffer_Type>() // softmax_aux
.Arg<Buffer_Type>() // rng_state .Arg<Buffer_Type>() // rng_state
.Arg<Buffer_Type>() // output .Arg<Buffer_Type>() // output
...@@ -618,6 +656,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, ...@@ -618,6 +656,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
.Ret<Buffer_Type>() // dk .Ret<Buffer_Type>() // dk
.Ret<Buffer_Type>() // dv .Ret<Buffer_Type>() // dv
.Ret<Buffer_Type>() // dbias .Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // dsoftmax_offset
.Ret<Buffer_Type>() // workspace .Ret<Buffer_Type>() // workspace
.Attrs(), .Attrs(),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment