Commit 4099aa8e authored by yuguo's avatar yuguo
Browse files
parents c520cba3 96f9c6de
......@@ -177,16 +177,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef USE_ROCM
m.def("te_general_batched_gemm", &te_general_batched_gemm, "Batched GEMM"); /// rocblas
#endif
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"),
py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
py::call_guard<py::gil_scoped_release>());
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
......@@ -194,6 +186,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding",
py::call_guard<py::gil_scoped_release>());
// attention kernels
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("copy_to_kv_cache", &copy_to_kv_cache, "Copy new KV tokens to KV cache");
m.def("convert_thd_to_bshd", &convert_thd_to_bshd, "Convert a tensor from THD to BSHD");
m.def("convert_bshd_to_thd", &convert_bshd_to_thd, "Convert a tesnor from BSHD to THD");
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD",
py::call_guard<py::gil_scoped_release>());
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace transformer_engine {
namespace fused_attn {
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
template <typename scalar_t>
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
template <typename scalar_t>
__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
int d_v, int b, int max_seq_len) {
// k_cache, v_cache: bshd
// batch_indices: [b]; cu_new_lens, cu_cached_lens: [b + 1]
int actual_b = b;
for (int i = 0; i < b - 1; i++) {
if (batch_indices[i + 1] < batch_indices[i]) {
actual_b = i + 1;
}
}
for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) {
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; token_idx += gridDim.x) {
int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k;
int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k;
int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v;
int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v;
for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) {
*(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i);
}
for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) {
*(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i);
}
}
}
}
template <typename scalar_t>
__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache,
scalar_t *v_cache, int *page_table, int *cu_new_lens,
int *cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
// new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1]
// page_table: [b, max_pages_per_seq]
int page_size = max_seq_len / max_pages_per_seq;
if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int new_token_offset = batch_idx * max_ctx_len;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) =
*(new_k + (new_token_offset + i) * h_kv * d_k + j);
}
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) =
*(new_v + (new_token_offset + i) * h_kv * d_v + j);
}
}
}
} else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j);
}
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v + j);
}
}
}
} else if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) =
*(new_k + (cu_new_lens[batch_idx] + i) * h_kv * d_k + j);
}
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) =
*(new_v + (cu_new_lens[batch_idx] + i) * h_kv * d_v + j);
}
}
}
}
}
} // namespace fused_attn
} // namespace transformer_engine
#endif
......@@ -20,7 +20,7 @@ from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_module
from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager
from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.quantized_tensor import QuantizedTensor, Quantizer
......@@ -328,11 +328,14 @@ class _CheckpointFunction(torch.autograd.Function):
tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
ctx.save_for_backward(*tensor_inputs)
fp8 = FP8GlobalStateManager.is_fp8_enabled()
ctx.get_rng_state_tracker = get_rng_state_tracker
ctx.tp_group = tp_group
ctx.recompute_ctx = recompute_ctx
ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx
ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.kwargs = kwargs
return outputs
......@@ -375,6 +378,8 @@ class _CheckpointFunction(torch.autograd.Function):
detached_inputs = detach_variable(inputs)
with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
activation_recompute=True, recompute_phase=True
), fp8_autocast(
enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe
):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
......@@ -398,6 +403,9 @@ class _CheckpointFunction(torch.autograd.Function):
"none of output has requires_grad=True, this checkpoint() is not necessary"
)
# backward does not require entering autocast context because
# backward implementations already retrieve fp8 recipe and
# enablement from stored ctx.
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs
......@@ -694,10 +702,15 @@ def checkpoint(
# Preserve the torch autocast contexts from the forward pass during recompute phase.
torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()
fp8 = FP8GlobalStateManager.is_fp8_enabled()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
def recompute_fn(*args, **kwargs):
with torch.autograd.enable_grad(), (
te_recompute_ctx
), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx:
), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, fp8_autocast(
enabled=fp8, fp8_recipe=fp8_recipe
):
function(*args, **kwargs)
# Initialize a new checkpoint frame for each new forward pass.
......
......@@ -2,52 +2,797 @@
#
# See LICENSE for license information.
"""
Inference classes for attention
"""
"""Inference"""
import logging
from collections import OrderedDict, defaultdict
from typing import Optional, List
from einops import rearrange
import torch
class InferenceParams: # pylint: disable=too-few-public-methods
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat
__all__ = ["InferenceParams", "KVCacheManager", "NonPagedKVCacheManager", "PagedKVCacheManager"]
class KVCacheManager:
"""Base KV cache manager"""
def __init__(self):
"""Initialize cache manager"""
self.cache = {}
self.sequences = OrderedDict()
def reset(self):
"""Reset cache manager state"""
self.sequences = OrderedDict()
def allocate_memory(self, layer_number: int):
"""Allocate memory for the cache"""
self.cache[layer_number] = (None, None)
def pre_step(
self,
step_dict: OrderedDict, # pylint: disable=unused-argument
):
"""Update tracked sequences and prepare for step()"""
return self.sequences
def step(
self,
layer_number: int,
new_k: torch.Tensor, # pylint: disable=unused-argument
new_v: torch.Tensor, # pylint: disable=unused-argument
cu_new_seqlens: torch.Tensor, # pylint: disable=unused-argument
cu_cached_seqlens: torch.Tensor, # pylint: disable=unused-argument
qkv_format: str, # pylint: disable=unused-argument
):
"""Copy the new tokens to KV cache"""
return self.cache[layer_number]
class InferenceParams:
"""
Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference.
KV caching for inference. The memory allocation of the caches and the copying of new tokens
to the cache take place at the following locations.::
class TransformerLayer:
class MultiHeadAttention:
if self.layer_number not in inference_params.cache_manager.cache:
inference_params.allocate_memory(self.layer_number)
class DotProductAttention:
if inference_params is not None:
k_cache, v_cache, new_qkv_format = inference_params.step(
new_k, new_v, qkv_format)
output = attention(new_q, k_cache, v_cache, new_qkv_format)
allocate_memory() can be called outside the model, independently. step() can take three formats,
qkv_format = {'bshd', 'sbhd', 'thd'}. It converts new_k and new_v to 'bshd' in both
NonPagedKVCacheManager and PagedKVCacheManager. The format of new_q may change depending on the
backend. If it is unchanged, we would have new_qkv_format = {'bshd', 'sbhd_2bshd', 'thd_2bshd'}.
A standard KV caching workflow for inference is as follows.::
model = [TransformerLayer() for _ in range(num_layers)]
# initialize InferenceParams, e.g. with PagedKVCacheManager
inference_params = InferenceParams(..., is_paged=True)
# inference loop
for i in range(num_iters):
# get info for iteration i, e.g. seq_ids = [0, 2, 3], step_lens = [10, 1, 1]
step_dict = OrderedDict(zip(seq_ids, step_lens))
# update inference_params' state
inference_params.pre_step(step_dict)
# run iteration
output = model(
...,
attn_mask_type="padding_causal",
cu_seqlens_q=cu_seqlens_new_q,
cu_seqlens_kv=cu_seqlens_new_kv,
inference_params=inference_params,
)
# get output tokens based on qkv_format
# 'bshd': output = output[:,step_dict.values()-1]
# 'sbhd': output = output[step_dict.values()-1,:]
# 'thd' : output = output[cu_seqlens_new_q[j+1]-1], j=0,...b-1
Parameters
----------
max_batch_size : int
maximum batch size during inference.
max_sequence_length : int
maximum sequence length during inference.
max_batch_size: int
Maximum batch size in inference
max_seqlen_kv: int
Maximum sequence length in inference
num_heads_kv: int
Number of attention heads in keys and values
head_dim_k: int
Head size for keys
dtype: torch.dtype
Data type of the KV cache
head_dim_v: int, default = None
Head size for values. If None, initialized as head_dim_k.
is_paged: bool, default = False
Whether the KV cache is paged (True) or non-paged (False)
total_num_pages: int, default = None
Total number of pages in the KV cache. Required for is_paged = True.
page_size: int, default = None
Page size of the KV cache. Required for is_paged = True.
max_ctx_len: int, default = None
Maximum context length in inference. 1 <= max_ctx_len <= max_seqlen_kv.
qkv_format: str, default = "bshd"
Format of the incoming query/key/value tensors in current iteration
custom_cache_manager: KVCacheManager, default = None
Custom cache manager, with KVCacheManager as the base class.
"""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
def __init__(
self,
max_batch_size: int,
max_seqlen_kv: int,
num_heads_kv: int = 16,
head_dim_k: int = 64,
dtype: torch.dtype = torch.bfloat16,
head_dim_v: int = None,
is_paged: bool = False,
total_num_pages: int = None,
page_size: int = None,
max_ctx_len: int = None,
qkv_format: str = "bshd",
custom_cache_manager: KVCacheManager = None,
):
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
self.max_seqlen_kv = max_seqlen_kv
self.num_heads_kv = num_heads_kv
self.head_dim_k = head_dim_k
self.dtype = dtype
self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k
self.is_paged = is_paged
if not self.is_paged:
cache_manager = (
custom_cache_manager if custom_cache_manager is not None else NonPagedKVCacheManager
)
self.cache_manager = cache_manager(
max_batch_size=self.max_batch_size,
max_seqlen=self.max_seqlen_kv,
num_heads=self.num_heads_kv,
head_dim_k=self.head_dim_k,
dtype=self.dtype,
head_dim_v=self.head_dim_v,
)
else:
assert page_size is not None, "Paged KV cache requires page_size is not None."
self.page_size = page_size
assert (
max_seqlen_kv % page_size == 0
), "Paged KV cache requires max_seqlen_kv % page_size = 0."
max_pages_per_seq = max_seqlen_kv // page_size
assert (
total_num_pages == self.max_batch_size * max_pages_per_seq
), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq."
self.total_num_pages = total_num_pages
cache_manager = (
custom_cache_manager if custom_cache_manager is not None else PagedKVCacheManager
)
self.cache_manager = cache_manager(
total_num_pages=self.total_num_pages,
page_size=self.page_size,
num_heads=self.num_heads_kv,
head_dim_k=self.head_dim_k,
dtype=self.dtype,
max_batch_size=self.max_batch_size,
max_seqlen=self.max_seqlen_kv,
head_dim_v=self.head_dim_v,
)
if qkv_format == "thd":
assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!"
self.max_ctx_len = max_ctx_len
self.cache_qkv_format = "bshd"
self.input_qkv_format = qkv_format
if self.input_qkv_format == self.cache_qkv_format:
self.output_qkv_format = self.cache_qkv_format
else:
self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format
self.sequences_pre_step = OrderedDict()
self.sequences = OrderedDict()
self.batch_size = 0
self.cu_seqlens_q = torch.zeros(
self.max_batch_size + 1,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
self.cu_seqlens_kv = torch.zeros(
self.max_batch_size + 1,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
def reset(self):
"""Reset InferenceParams state"""
self.sequences = OrderedDict()
self.cache_manager.reset()
def swap_key_value_dict(self, batch_indices):
def __repr__(self) -> str:
if self.is_paged:
return (
f"dtype={self.dtype}, "
f"is_paged={self.is_paged}, "
f"total_pages={self.total_num_pages}, "
f"page_size={self.page_size}, "
f"num_heads={self.num_heads_kv}, "
f"head_dim_k={self.head_dim_k}, "
f"head_dim_v={self.head_dim_v}"
)
return (
f"dtype={self.dtype}, "
f"is_paged={self.is_paged}, "
f"max_batch_size={self.max_batch_size}, "
f"max_seqlen={self.max_seqlen_kv}, "
f"num_heads={self.num_heads_kv}, "
f"head_dim_k={self.head_dim_k}, "
f"head_dim_v={self.head_dim_v}"
)
def allocate_memory(self, layer_number: int):
"""
Allocate memory for the cache. For layer layer_number,
- NonPagedKVCacheManager:
- K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k]
- V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v]
- PagedKVCacheManager:
- K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k]
- V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v]
"""
self.cache_manager.allocate_memory(layer_number)
def pre_step(
self,
step_dict: OrderedDict,
):
"""Update tracked sequences and prepare for step()"""
self.batch_size = len(step_dict)
self.sequences = self.cache_manager.pre_step(step_dict)
# track the pre-step seqlens for the next layer in the model
self.sequences_pre_step = OrderedDict()
for k, v in self.sequences.items():
self.sequences_pre_step[k] = v - step_dict[k]
seqlens_q = list(step_dict.values())
cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)]
cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size)
self.cu_seqlens_q.copy_(torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu"))
seqlens_kv = list(self.sequences.values())
cu_seqlens_kv = [0] + [sum(seqlens_kv[:i]) for i in range(1, self.batch_size + 1)]
cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * (
self.max_batch_size - self.batch_size
)
self.cu_seqlens_kv.copy_(torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu"))
def get_seqlens_pre_step(self):
"""Get cached sequence lengths before the stepping"""
return torch.Tensor(list(self.sequences_pre_step.values())).to(
dtype=torch.int32, device="cpu"
)
def convert_paged_to_nonpaged(self, layer_number: int):
"""
Reorders the KV cache using the specified batch indices.
Convert k_cache and v_cache from paged to non-paged format.
Parameters
----------
batch_indices : List[int]
Sequence of indices to reorder along the batch dimensions of
the KV cache. Must have a length equal to the batch size.
layer_number: int
Layer number of attention in the model
Returns
-------
k_cache: torch.Tensor
Non-paged key cache tensor
v_cache: torch.Tensor
Non-paged value cache tensor
"""
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
k_cache, v_cache = self.cache_manager.cache[layer_number]
page_table = self.cache_manager.page_table
batch_size = page_table.shape[0]
new_k_cache = rearrange(
k_cache[page_table.flatten()],
"(b npages) page_size ... -> b (npages page_size) ...",
b=batch_size,
)
new_v_cache = rearrange(
v_cache[page_table.flatten()],
"(b npages) page_size ... -> b (npages page_size) ...",
b=batch_size,
)
for layer_number, inference_memory in self.key_value_memory_dict.items():
inference_key_memory, inference_value_memory = inference_memory
assert (
len(batch_indices) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_indices]
new_inference_value_memory = inference_value_memory[:, batch_indices]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
new_k_cache = new_k_cache[: self.batch_size].contiguous()
new_v_cache = new_v_cache[: self.batch_size].contiguous()
return new_k_cache, new_v_cache
def step(
self,
layer_number: int,
new_k: torch.Tensor,
new_v: torch.Tensor,
qkv_format: str,
):
"""
Copy new KV tokens to the cache.
Parameters
----------
layer_number: int
Layer number of attention in the model
new_k: torch.Tensor
New key tokens for layer_number in current inference iteration
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
qkv_format: str
Format of new_q, new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
Returns
-------
k_cache: torch.Tensor
Full key tensor containing both previous and current key tokens
v_cache: torch.Tensor
Full value tensor containing both previous and current value tokens
cu_seqlens_q: torch.Tensor
Updated cumulative sequence lengths for query, [batch_size + 1]
cu_seqlens_kv: torch.Tensor
Updated cumulative sequence lengths for key and value, [batch_size + 1]
max_seqlen_q: int
Update maximum sequence length for query
max_seqlen_kv: int
Update maximum sequence length for key and value
qkv_format: str
Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step()
"""
self.input_qkv_format = qkv_format
if self.input_qkv_format == self.cache_qkv_format:
self.output_qkv_format = self.cache_qkv_format
else:
self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format
k_cache, v_cache = self.cache_manager.step(
layer_number,
new_k,
new_v,
self.cu_seqlens_q,
self.cu_seqlens_kv,
qkv_format,
)
return (
k_cache,
v_cache,
self.cu_seqlens_q,
self.cu_seqlens_kv,
self.max_seqlen_kv,
self.output_qkv_format,
)
class NonPagedKVCacheManager(KVCacheManager):
"""Non-paged KV cache manager"""
def __init__(
self,
max_batch_size: int,
max_seqlen: int,
num_heads: int,
head_dim_k: int,
dtype: torch.dtype,
head_dim_v: Optional[int] = None,
):
super().__init__()
"""Initialize cache manager"""
self.max_batch_size = max_batch_size
self.max_seqlen = max_seqlen
self.num_heads = num_heads
self.head_dim_k = head_dim_k
self.dtype = dtype
self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k
# track sequences in the cache, {seq_id: seq_len}
self.sequences = OrderedDict()
# cache tensors, cache[layer_number] = (k_cache, v_cache)
self.cache = {}
# track sequence indices in the batch in order to re-index k_cache and v_cache
self.batch_indices = torch.zeros(
self.max_batch_size,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
# after re-indexing, batch indices are always [0, ..., b-1]
self.batch_indices_post_step = torch.range(
0,
self.max_batch_size - 1,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
def allocate_memory(self, layer_number):
"""Allocate memory for the cache"""
k_cache = torch.zeros(
self.max_batch_size,
self.max_seqlen,
self.num_heads,
self.head_dim_k,
dtype=self.dtype,
device=torch.cuda.current_device(),
)
v_cache = torch.zeros(
self.max_batch_size,
self.max_seqlen,
self.num_heads,
self.head_dim_v,
dtype=self.dtype,
device=torch.cuda.current_device(),
)
self.cache[layer_number] = (k_cache, v_cache)
def pre_step(
self,
step_dict: OrderedDict,
):
"""Update tracked sequences and prepare for step()"""
# Track unfinished sequences' indices in the batch, e.g.
# at t-1, seq_ids = [0, 1, 2, 3]; at t, seq_ids = [0, 2, 3] since seq_id 1 is finished
# step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that
# they are contiguous and match the indexing in q
prev_batch_size = len(self.sequences)
unfinished_seqs = self.sequences.keys() & step_dict.keys()
finished_seqs = self.sequences.keys() - unfinished_seqs
unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs]
finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs]
self.batch_indices.copy_(
torch.Tensor(
(
unfinished_indices
+ finished_indices
+ list(range(prev_batch_size, self.max_batch_size))
)
).to(dtype=torch.int32, device="cpu")
)
# Advance unfinished sequences
for i in unfinished_seqs:
self.sequences[i] += 1
# Remove finished sequences
for i in finished_seqs:
self.sequences.pop(i)
# Add new sequences
new_seqs = step_dict.keys() - self.sequences.keys()
for i in new_seqs:
self.sequences[i] = step_dict[i]
return self.sequences
def step(
self,
layer_number,
new_k: torch.Tensor,
new_v: torch.Tensor,
cu_new_seqlens,
cu_cached_seqlens,
qkv_format: str,
):
"""
Copy the new tokens to the non-paged KV cache.
Parameters
----------
layer_number: int
Layer number of attention in the model
new_k: torch.Tensor
New key tokens for layer_number in current inference iteration
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1]
qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
Returns
-------
k_cache: torch.Tensor
Full key tensor containing both previous and current key tokens
v_cache: torch.Tensor
Full value tensor containing both previous and current value tokens
"""
k_cache, v_cache = self.cache[layer_number]
batch_size = self.max_batch_size
ctx_len = 1
if qkv_format == "bshd":
batch_size = new_k.shape[0]
ctx_len = new_k.shape[1]
if qkv_format == "sbhd":
batch_size = new_k.shape[1]
ctx_len = new_k.shape[0]
tex.copy_to_kv_cache(
new_k,
new_v,
k_cache,
v_cache,
self.batch_indices,
cu_new_seqlens,
cu_cached_seqlens,
QKVFormat[qkv_format],
batch_size,
ctx_len,
self.max_seqlen,
1,
True,
)
k_cache = k_cache[:batch_size]
v_cache = v_cache[:batch_size]
return k_cache, v_cache
class Page:
"""A single page"""
def __init__(self, page_id: int):
"""Initialize a page"""
self.page_id = page_id
self.allocated = 0
def allocate_page(self):
"""Allocate a page"""
self.allocated = True
def deallocate_page(self):
"""Deallocate a page"""
self.allocated = False
class PagedKVCacheManager(KVCacheManager):
"""Paged KV cache manager"""
def __init__(
self,
total_num_pages: int,
page_size: int,
num_heads: int,
head_dim_k: int,
dtype: torch.dtype,
max_batch_size: int,
max_seqlen: int,
head_dim_v: Optional[int] = None,
):
super().__init__()
"""Initialize cache manager"""
self.total_num_pages = total_num_pages
self.page_size = page_size
self.num_heads = num_heads
self.head_dim_k = head_dim_k
self.dtype = dtype
self.max_batch_size = max_batch_size
self.max_seqlen = max_seqlen
self.max_pages_per_seq = max_seqlen // self.page_size
self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k
# track sequences in the cache, {seq_id: seq_len}
self.sequences = OrderedDict()
# cache tensors, cache[layer_number] = (k_cache, v_cache)
self.cache = {}
# available pages, [Page(),...]
self.free_pages = []
for i in range(self.total_num_pages):
self.free_pages.append(Page(i))
# allocated pages, {seq_id: [page_id,...]}
self.allocated_pages = defaultdict(list)
# page table, [batch_size, max_pages_per_seq]
self.page_table = torch.zeros(
self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda"
)
def reset(self):
"""Reset cache manager state"""
self.sequences = OrderedDict()
self.free_pages = []
for i in range(self.total_num_pages):
self.free_pages.append(Page(i))
self.allocated_pages = defaultdict(list)
self.page_table.fill_(0)
def allocate_memory(self, layer_number):
"""Allocate memory for the cache"""
k_cache = torch.zeros(
self.total_num_pages,
self.page_size,
self.num_heads,
self.head_dim_k,
dtype=self.dtype,
device=torch.cuda.current_device(),
)
v_cache = torch.zeros(
self.total_num_pages,
self.page_size,
self.num_heads,
self.head_dim_v,
dtype=self.dtype,
device=torch.cuda.current_device(),
)
self.cache[layer_number] = (k_cache, v_cache)
def print_cache(self):
"""Print KV cache status"""
used_pages = [self.get_page_count(seq) for seq in self.sequences]
logger = logging.getLogger("PagedKVCacheManager")
logger.debug("Cache status:")
logger.debug(
" total pages: %s (used %s, free %s)",
self.total_num_pages,
sum(used_pages),
len(self.free_pages),
)
logger.debug(" total sequences: %s", self.get_sequence_count())
for i, seq in enumerate(self.sequences):
logger.debug(
" >> batch index %s: seq_id %s, num_tokens %s, num_pages %s, page_list %s",
i,
seq,
self.get_sequence_lengths()[i],
self.get_page_count(seq),
self.get_page_list(seq),
)
def get_sequence_count(self):
"""Get the total number of sequences in the KV cache"""
return len(self.sequences)
def get_sequence_lengths(self):
"""Get the list of sequence lengths in the KV cache"""
return list(self.sequences.values())
def has_free_page(self) -> bool:
"""Whether the page pool has any free pages left"""
return len(self.free_pages) > 0
def get_page_count(self, seq: int):
"""Get the number of pages allocated to a sequence"""
return len(self.allocated_pages[seq])
def get_page_list(self, seq: int):
"""Get the list of pages allocated to a sequence"""
return [x.page_id for x in self.allocated_pages[seq]]
def get_page_table(self, sequences: List[int]):
"""Get the page table, in shape [batch_size, max_pages_per_seq]"""
page_table = torch.Tensor(
[
self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq))
for seq in sequences
]
).to(dtype=torch.int32, device="cpu")
self.page_table[: self.get_sequence_count()].copy_(page_table)
return self.page_table
def allocate_page(self, seq: int):
"""Allocate a new page to a sequence"""
if not self.has_free_page():
raise RuntimeError("KV cache is full!")
page = self.free_pages.pop(0)
page.allocate_page()
self.allocated_pages[seq].append(page)
def allocate_sequence(self, seq: int, context_len: int):
"""Add a new sequence to the cache"""
num_pages = context_len // self.page_size
if context_len % self.page_size > 0:
num_pages = num_pages + 1
for _ in range(num_pages):
self.allocate_page(seq)
def deallocate_sequence(self, seq: int):
"""Deallocate all the pages for a sequence"""
for page in self.allocated_pages[seq]:
page.deallocate_page()
if not page.allocated:
self.free_pages.append(page)
self.allocated_pages.pop(seq)
def pre_step(
self,
step_dict: OrderedDict,
):
"""Update tracked sequences and prepare for step()"""
# Remove finished sequences and advance unfinished sequences
unfinished_seqs = self.sequences.keys() & step_dict.keys()
finished_seqs = self.sequences.keys() - unfinished_seqs
for seq in finished_seqs:
self.sequences.pop(seq)
self.deallocate_sequence(seq)
for seq in unfinished_seqs:
if self.sequences[seq] % self.page_size == 0 and self.sequences[seq] < self.max_seqlen:
self.allocate_page(seq)
self.sequences[seq] += 1
# Add new sequences
new_seqs = step_dict.keys() - self.sequences.keys()
for seq in new_seqs:
self.sequences[seq] = step_dict[seq]
self.allocate_sequence(seq, step_dict[seq])
# Get page table
self.page_table = self.get_page_table(list(self.sequences.keys()))
return self.sequences
def step(
self,
layer_number: int,
new_k: torch.Tensor,
new_v: torch.Tensor,
cu_new_seqlens,
cu_cached_seqlens,
qkv_format: str,
):
"""
Copy the new tokens to the paged KV cache.
Parameters
----------
layer_number: int
Layer number of attention in the model
new_k: torch.Tensor
New key tokens for layer_number in current inference iteration
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1]
qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
Returns
-------
k_cache: torch.Tensor
Full key tensor containing both previous and current key tokens
v_cache: torch.Tensor
Full value tensor containing both previous and current value tokens
"""
k_cache, v_cache = self.cache[layer_number]
batch_size = self.max_batch_size
ctx_len = 1
if qkv_format == "bshd":
batch_size = new_k.shape[0]
ctx_len = new_k.shape[1]
if qkv_format == "sbhd":
batch_size = new_k.shape[1]
ctx_len = new_k.shape[0]
tex.copy_to_kv_cache(
new_k,
new_v,
k_cache,
v_cache,
self.page_table,
cu_new_seqlens,
cu_cached_seqlens,
QKVFormat[qkv_format],
batch_size,
ctx_len,
self.max_seqlen,
self.max_pages_per_seq,
False,
)
return k_cache, v_cache
......@@ -34,6 +34,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_O_CP,
META_DQKV_CP,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.constants import TE_DType
......@@ -91,7 +92,6 @@ class FlashAttentionUtils:
Manage Flash Attention versioning information
"""
# Detect flash-attn v2 in the environment
is_installed = False
version = PkgVersion("0")
version_required = PkgVersion("2.1.1")
......@@ -102,21 +102,25 @@ class FlashAttentionUtils:
v2_3_plus = False
v2_4_plus = False
v2_4_1_plus = False
v2_5_plus = False
v2_5_7_plus = False
v2_6_0_plus = False
v2_7_0_plus = False
warning_printed = False
v3_is_installed = False
fa3_version = PkgVersion("0")
v3_0_0_beta = False
use_v3 = False
# TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved
# https://github.com/Dao-AILab/flash-attention/issues/1452
# FA3 from FA 2.7.3+/hopper has different APIs than FA3 from 2.7.2/hopper
# Please follow these instructions to install FA3
v3_installation_steps = """\
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py"""
(1) git clone https://github.com/Dao-AILab/flash-attention.git
(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(4) mkdir -p $python_path/flash_attn_3
(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py"""
v3_warning_printed = False
@staticmethod
def set_flash_attention_version():
......@@ -129,13 +133,11 @@ class FlashAttentionUtils:
FlashAttentionUtils.v2_3_plus = FlashAttentionUtils.version >= PkgVersion("2.3")
FlashAttentionUtils.v2_4_plus = FlashAttentionUtils.version >= PkgVersion("2.4")
FlashAttentionUtils.v2_4_1_plus = FlashAttentionUtils.version >= PkgVersion("2.4.1")
FlashAttentionUtils.v2_5_plus = FlashAttentionUtils.version >= PkgVersion("2.5.0")
FlashAttentionUtils.v2_5_7_plus = FlashAttentionUtils.version >= PkgVersion("2.5.7")
FlashAttentionUtils.v2_6_0_plus = FlashAttentionUtils.version >= PkgVersion("2.6.0")
FlashAttentionUtils.v2_7_0_plus = FlashAttentionUtils.version >= PkgVersion("2.7.0")
# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
@staticmethod
def set_flash_attention_3_params():
"""
......@@ -145,7 +147,6 @@ class FlashAttentionUtils:
FlashAttentionUtils.v3_0_0_beta = (
PkgVersion("3.0.0b") < FlashAttentionUtils.fa3_version < PkgVersion("3.0.0")
)
FlashAttentionUtils.use_v3 = True
@dataclass(eq=True)
......@@ -203,6 +204,8 @@ class AttentionParams:
Whether `DotProductAttention` is in an `fp8_autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None`
The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None`
Inference-related parameters. See InferenceParams for details.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
......@@ -228,6 +231,7 @@ class AttentionParams:
is_training: bool = True
fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None
inference_params: Optional[InferenceParams] = None
def __eq__(self, other):
"""
......@@ -298,6 +302,7 @@ def get_attention_backend(
is_training = attention_params.is_training
fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta
inference_params = attention_params.inference_params
# Run config
logger = logging.getLogger("DotProductAttention")
......@@ -334,13 +339,19 @@ def get_attention_backend(
# regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
# necessary for performance/functionality, a warning will be issued to prompt users to
# install an appropriate FA version.
qkv_format, q_format, _ = get_qkv_format(qkv_layout, inference_params)
# Filter: Environment variables
use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
use_flash_attention_2 = use_flash_attention
use_flash_attention_3 = use_flash_attention
flash_attention_backend = None
use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
if not use_flash_attention and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if not use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 due to NVTE_FLASH_ATTN=0")
if not use_flash_attention_3 and FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 due to NVTE_FLASH_ATTN=0")
if not use_fused_attention:
logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
if not use_unfused_attention:
......@@ -348,70 +359,134 @@ def get_attention_backend(
# Filter: Compute capability
if not IS_HIP_EXTENSION and device_compute_capability < (8, 0):
if use_flash_attention and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
use_flash_attention = False
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 for compute capability < sm80")
use_flash_attention_2 = False
if use_fused_attention:
logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
logger.debug("Disabling FusedAttention for compute capability < sm80")
use_fused_attention = False
if device_compute_capability < (9, 0):
if use_flash_attention and FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
FlashAttentionUtils.use_v3 = False
if device_compute_capability != (9, 0):
if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for compute capability != sm90")
use_flash_attention_3 = False
# Filter: Data type
if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
if qkv_dtype not in [torch.bfloat16, torch.float16]:
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention 2 for unsupported qkv_dtype = %s. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. ",
qkv_dtype,
)
use_flash_attention_2 = False
if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in [
torch.Tensor,
Float8Tensor,
]:
if use_flash_attention and FlashAttentionUtils.is_installed:
if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed:
logger.debug(
"Disabling FlashAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
"Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, "
"qkv_type = {torch.Tensor, Float8Tensor}. ",
qkv_dtype,
qkv_type,
)
use_flash_attention = False
use_flash_attention_3 = False
if use_fused_attention:
logger.debug(
"Disabling FusedAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
"Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, "
"qkv_type = {torch.Tensor, Float8Tensor}. ",
qkv_dtype,
qkv_type,
)
use_fused_attention = False
# Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention and not FlashAttentionUtils.use_v3:
if FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
use_flash_attention = False
if use_flash_attention and FlashAttentionUtils.use_v3 and is_training:
logger.debug(
"Disabling FlashAttention as FlashAttention 3 does not support FP8 training"
)
use_flash_attention = False
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 for FP8 attention")
use_flash_attention_2 = False
if use_flash_attention_3 and is_training:
if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for FP8 training")
use_flash_attention_3 = False
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention")
use_unfused_attention = False
# TODO: rocm fused attention backends does not support fp8 yet
if IS_HIP_EXTENSION and use_fused_attention:
logger.debug("Disabling ROCm FusedAttention as it does not support FP8")
use_fused_attention = False
# Filter: KV cache
# backend | precision | KV cache | architecture | qkv_format | page_size
# ---------------------------------------------------------------------------------------
# Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1
# Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256
# Flash v3 | FP16/BF16 | non-paged/paged | sm90 | bshd,sbhd,thd | >= 1
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if inference_params is not None:
if context_parallel:
logger.debug("Disabling all backends for KV caching with context parallelism")
use_flash_attention = False
use_fused_attention = False
use_unfused_attention = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
if fp8_meta["recipe"].fp8_mha:
logger.debug("Disabling all backends for KV caching with FP8 MHA")
use_flash_attention = False
use_fused_attention = False
use_unfused_attention = False
if use_flash_attention_3 and q_format != "thd":
if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for FP8 KV caching and non-THD")
use_flash_attention_3 = False
if use_fused_attention:
logger.debug("Disabling FusedAttention for FP8 KV caching")
use_fused_attention = False
else:
if q_format == "thd" and pad_between_seqs:
logger.debug("Disabling all backends for pad_between_seqs = True and KV caching")
use_flash_attention = False
use_fused_attention = False
use_unfused_attention = False
if inference_params.is_paged:
if use_flash_attention_2 and inference_params.page_size < 256:
if FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 for page size < 256")
use_flash_attention_2 = False
if use_flash_attention_2:
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.5")
elif not FlashAttentionUtils.v2_5_plus:
logger.debug(
"Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+"
)
use_flash_attention_2 = False
# Filter: Head dimension
if not IS_HIP_EXTENSION:
if use_flash_attention and head_dim_qk != head_dim_v:
if FlashAttentionUtils.is_installed:
if head_dim_qk != head_dim_v:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed
):
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and qkv_layout_group != "hd_hd_hd":
logger.debug(
"Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
qkv_layout,
)
use_fused_attention = False
else:
if use_fused_attention and head_dim_qk != head_dim_v:
logger.debug("Disabling FusedAttention as it does not support MLA in rocm backend.")
use_fused_attention = False
if use_flash_attention and (
if use_flash_attention_2 and (
head_dim_qk > 256
or head_dim_qk % 8 != 0
or (
......@@ -421,7 +496,7 @@ def get_attention_backend(
):
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
"Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
......@@ -429,23 +504,21 @@ def get_attention_backend(
head_dim_v,
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
logger.debug(
"Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
qkv_layout,
)
use_fused_attention = False
use_flash_attention_2 = False
if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128):
if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for head_dim > 128")
use_flash_attention_3 = False
# Filter: QKV layout
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if use_flash_attention and pad_between_seqs:
if FlashAttentionUtils.is_installed:
if pad_between_seqs:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed
):
logger.debug(
"Disabling FlashAttention for qkv_format = thd when there is "
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
......@@ -459,9 +532,9 @@ def get_attention_backend(
use_fused_attention = False
# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention and FlashAttentionUtils.use_v3:
if attention_dropout != 0.0 and use_flash_attention_3:
logger.debug("Disabling FlashAttention 3 for dropout")
FlashAttentionUtils.use_v3 = False
use_flash_attention_3 = False
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
......@@ -480,42 +553,38 @@ def get_attention_backend(
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
)
use_unfused_attention = False
if context_parallel and use_flash_attention:
if fp8 and fp8_meta["recipe"].fp8_dpa:
if FlashAttentionUtils.is_installed:
if context_parallel and (use_flash_attention_2 or use_flash_attention_3):
if FlashAttentionUtils.is_installed or FlashAttentionUtils.v3_is_installed:
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with FP8"
)
use_flash_attention = False
if "bottom_right" in attn_mask_type:
if FlashAttentionUtils.is_installed:
use_flash_attention = False
if "bottom_right" in attn_mask_type:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal_bottom_right masking"
)
use_flash_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
if FlashAttentionUtils.is_installed:
use_flash_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal masking for cross-attention"
)
use_flash_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
if FlashAttentionUtils.is_installed:
use_flash_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with bias"
" type of %s",
core_attention_bias_type,
)
use_flash_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
if FlashAttentionUtils.is_installed:
use_flash_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" attention bias for THD format"
)
use_flash_attention = False
use_flash_attention = False
if context_parallel and use_fused_attention:
if "bottom_right" in attn_mask_type:
......@@ -568,61 +637,25 @@ def get_attention_backend(
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] |
if attn_mask_type == "arbitrary":
if use_flash_attention and FlashAttentionUtils.is_installed:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed
):
logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
if (
use_flash_attention
and FlashAttentionUtils.use_v3
(use_flash_attention_2 or use_flash_attention_3)
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
logger.warning(
"Disabling FlashAttention 3 as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"Disabling FlashAttention as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1 (our minimum supported version). See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
FlashAttentionUtils.use_v3 = False
if (
use_flash_attention
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
if FlashAttentionUtils.v2_1_plus:
logger.warning(
"Disabling FlashAttention as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.max_version = PkgVersion("2.1")
if (
use_flash_attention
and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
and max_seqlen_q != max_seqlen_kv
):
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.1")
elif not FlashAttentionUtils.v2_1_plus and not FlashAttentionUtils.use_v3:
logger.warning(
"Disabling FlashAttention as it only supports top-left-diagonal "
"causal mask before flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if (
use_flash_attention
and FlashAttentionUtils.use_v3
and fp8
and fp8_meta["recipe"].fp8_dpa
and "padding" in attn_mask_type
):
logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
FlashAttentionUtils.use_v3 = False
use_flash_attention = False
# Filter: Sliding window attention
# backend | window_size | diagonal alignment
......@@ -653,19 +686,14 @@ def get_attention_backend(
"with s_q > s_kv for cross-attention"
)
use_fused_attention = False
if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if FlashAttentionUtils.use_v3:
logger.debug(
"Disabling FlashAttention 3 as it does not support sliding window attention"
)
FlashAttentionUtils.use_v3 = False
if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.3")
elif not FlashAttentionUtils.v2_3_plus:
logger.debug(
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
use_flash_attention = False
use_flash_attention_2 = False
# Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment
......@@ -676,21 +704,25 @@ def get_attention_backend(
# | | bottom_right (converts to a 'post_scale_bias' bias)
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if use_flash_attention and core_attention_bias_type == "alibi":
if FlashAttentionUtils.use_v3:
logger.debug("Disabling FlashAttention 3 for ALiBi")
FlashAttentionUtils.use_v3 = False
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.4")
elif not FlashAttentionUtils.v2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention = False
if core_attention_bias_type == "alibi":
if use_flash_attention_3:
if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for ALiBi")
use_flash_attention_3 = False
if use_flash_attention_2:
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.4")
elif not FlashAttentionUtils.v2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention_2 = False
if use_flash_attention and (
if (
core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias_shape is not None
):
if FlashAttentionUtils.is_installed:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed
):
logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
......@@ -795,16 +827,16 @@ def get_attention_backend(
# | otherwise: no
# sub-backend 2 | no
# UnfusedDotProductAttention | yes
if use_flash_attention and deterministic:
if use_flash_attention_2 and deterministic:
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.4.1")
elif not FlashAttentionUtils.v2_4_1_plus and not FlashAttentionUtils.use_v3:
elif not FlashAttentionUtils.v2_4_1_plus:
logger.warning(
"Disabling FlashAttention as version <2.4.1 does not support deterministic "
"execution. To use FlashAttention with deterministic behavior, "
"please install flash-attn >= 2.4.1."
)
use_flash_attention = False
use_flash_attention_2 = False
if use_fused_attention and deterministic:
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons")
......@@ -821,29 +853,58 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
# All available backends
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
# use_flash_attention may have been set above
use_flash_attention_2 = use_flash_attention and use_flash_attention_2
use_flash_attention_3 = use_flash_attention and use_flash_attention_3
# `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
# When `FusedAttention` does not support the provided attention params, and `FlashAttention`
# does, we recommend users to install flash-attn if not installed already.
if not use_fused_attention and use_flash_attention and not FlashAttentionUtils.is_installed:
logger.warning(
"flash-attn may provide important feature support or performance improvement."
" Please install flash-attn %s.",
_get_supported_versions(
FlashAttentionUtils.version_required,
FlashAttentionUtils.max_version,
),
)
if use_flash_attention and not FlashAttentionUtils.is_installed:
use_flash_attention = False
available_backends[0] = False
if not use_fused_attention and _NVTE_FLASH_ATTN:
if (
use_flash_attention_3
and not FlashAttentionUtils.v3_is_installed
and not FlashAttentionUtils.v3_warning_printed
and torch.cuda.current_device() == 0
):
logger.warning(
"flash-attn v3 may provide important feature support or performance improvement."
" Please install flash-attn v3 by \n%s",
FlashAttentionUtils.v3_installation_steps,
)
FlashAttentionUtils.v3_warning_printed = True
elif (
use_flash_attention_2
and not FlashAttentionUtils.is_installed
and not FlashAttentionUtils.warning_printed
and torch.cuda.current_device() == 0
):
logger.warning(
"flash-attn may provide important feature support or performance improvement."
" Please install flash-attn %s by pip3 install flash-attn==<version>.",
_get_supported_versions(
FlashAttentionUtils.version_required,
FlashAttentionUtils.max_version,
),
)
FlashAttentionUtils.warning_printed = True
# All available backends
if use_flash_attention_2 and not FlashAttentionUtils.is_installed:
use_flash_attention_2 = False
if use_flash_attention_3 and not FlashAttentionUtils.v3_is_installed:
use_flash_attention_3 = False
use_flash_attention = use_flash_attention_2 or use_flash_attention_3
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
if use_flash_attention_2:
flash_attention_backend = FlashAttentionUtils.version
if use_flash_attention_3:
flash_attention_backend = FlashAttentionUtils.fa3_version
logger.debug(
"Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
"Available backends = {FlashAttention=%s%s, FusedAttention=%s%s,"
" UnfusedDotProductAttention=%s}",
bool(available_backends[0]),
(f" ({str(flash_attention_backend)})" if flash_attention_backend is not None else ""),
bool(available_backends[1]),
(
f" (sub-backend {int(fused_attention_backend)})"
......@@ -854,27 +915,10 @@ def get_attention_backend(
)
# Select FusedAttention for performance
if (
use_flash_attention and (not IS_HIP_EXTENSION)
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
):
if device_compute_capability >= (9, 0):
logger.debug(
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
"for performance reasons"
)
use_flash_attention = False
if (
use_flash_attention
and use_fused_attention
and not IS_HIP_EXTENSION
and fused_attention_backend == FusedAttnBackend["FP8"]
and FlashAttentionUtils.use_v3
):
if use_flash_attention and use_fused_attention and (not IS_HIP_EXTENSION) and device_compute_capability >= (9, 0):
logger.debug(
"Disabling FlashAttention 3 to give FusedAttention preference for performance reasons "
"in FP8 execution"
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
"for performance reasons"
)
use_flash_attention = False
......@@ -886,22 +930,16 @@ def get_attention_backend(
use_unfused_attention = False
selected_backend = "NoBackend"
if use_flash_attention:
selected_backend = "FlashAttention"
selected_backend = f"FlashAttention ({str(flash_attention_backend)})"
elif use_fused_attention:
selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
elif use_unfused_attention:
selected_backend = "UnfusedDotProductAttention"
logger.debug("Selected backend = %s", selected_backend)
"""global _attention_backends
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False"""
return (
use_flash_attention,
flash_attention_backend,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
......@@ -909,6 +947,49 @@ def get_attention_backend(
)
@torch.no_grad()
def get_padding_mask(
batch_size: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_q: int,
max_seqlen_kv: int,
):
"""Convert cu_seqlens to attention_mask"""
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(batch_size):
attention_mask_q = torch.cat(
[
attention_mask_q,
torch.Tensor([False] * seqlens_q[i] + [True] * (max_seqlen_q - seqlens_q[i]))
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask_kv = torch.cat(
[
attention_mask_kv,
torch.Tensor([False] * seqlens_kv[i] + [True] * (max_seqlen_kv - seqlens_kv[i]))
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = (
attention_mask_q.to(device="cuda"),
attention_mask_kv.to(device="cuda"),
)
return attention_mask
@torch.no_grad()
def get_full_mask(
max_seqlen_q: int,
......@@ -1417,11 +1498,46 @@ class UnpackTensor(torch.autograd.Function):
return None, None, _pack_tensor(indices, grad_output)
def get_qkv_format(
qkv_layout: str = "bshd_bshd_bshd",
inference_params: InferenceParams = None,
) -> str:
"""Get qkv format.
Parameters
----------
qkv_layout: str
Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details.
inference_params: InferenceParams, default = `None`
InferenceParams related to KV caching.
Returns
----------
qkv_format: str, default = `sbhd`
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}.
q_format: str
Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str
Format of the `k` and `v` tensors, {`bshd`, `sbhd`, `thd`}.
"""
splited = qkv_layout.replace("paged_kv_", "").split("_")
if inference_params is not None:
q_format = "".join([i for i in splited[0] if i.isalpha()])
kv_format = "".join([i for i in splited[1] if i.isalpha()])
qkv_format = q_format + "_2" + kv_format if q_format != kv_format else q_format
else:
qkv_format = "".join([i for i in splited[0] if i.isalpha()])
q_format = qkv_format
kv_format = qkv_format
return qkv_format, q_format, kv_format
def get_qkv_layout(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qkv_format: str = "sbhd",
inference_params: InferenceParams = None,
) -> str:
"""Get qkv layout.
......@@ -1438,20 +1554,33 @@ def get_qkv_layout(
the sequence length dimension, `b` batch size, `h` the number of attention heads,
`d` head size, and `t` the total number of tokens in a batch, i.e.
`t = sum(s_i) for i = 0...b-1`.
inference_params: InferenceParams, default = `None`
InferenceParams related to KV caching.
Returns
----------
qkv_layout: str
Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
`q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
`v = kv[:,:,:,1,:]`.
Memory layout of `q`, `k` and `v`. Each `qkv_layout` maps to a pair of `q_format` and
`kv_format` in {`bshd`, `sbhd`, `thd`}. The `paged_kv_` prefix is used to indicate that
paged KV caching is in play. A few examples of the layouts are as follows.
(1) `sb3hd` means `q`, `k`, `v` are created as one chunk of memory and that they are
interleaved in the `2`nd dimension. (2) `sbhd_sbh2d` means `q` and `kv` are created in
two chunks and that `q` itself is contiguous and `k`, `v` are interleaved with each other
in the `3`rd dimension, `k = kv[:,:,:,0,:]` and `v = kv[:,:,:,1,:]`. `q_format` and
`kv_format` in this case are still both `sbhd`. (3) `paged_kv_thd_bshd_bshd` means `q` is
created in `thd` and `k`, `v` are in `sbhd`. This is likely due to the cache format in
paged KV caching.
Mapping:
`sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
`bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
`sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`, `paged_kv_sbhd_sbhd_sbhd`}
`bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`, `paged_kv_bshd_bshd_bshd`}
`thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
`sbhd_2bshd`: {`sbhd_bshd_bshd`, `paged_kv_sbhd_bshd_bshd`}
`bshd_2sbhd`: {`bshd_sbhd_sbhd`, `paged_kv_bshd_sbhd_sbhd`}
`thd_2bshd`: {`thd_bshd_bshd`, `paged_kv_thd_bshd_bshd`}
`thd_2sbhd`: {`thd_sbhd_sbhd`, `paged_kv_thd_sbhd_sbhd`}
q: torch.Tensor
Query tensor. It may be different from input `q` as we try to fit tensors to
a supported layout.
......@@ -1461,10 +1590,21 @@ def get_qkv_layout(
v: torch.Tensor
Value tensor. It may be different from input `v` as we try to fit tensors to
a supported layout.
q_format: str
Format of the query tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str
Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}.
"""
check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
if "_2" in qkv_format:
q_format, kv_format = qkv_format.split("_2")
is_same_q_kv_format = False
else:
q_format = qkv_format
kv_format = qkv_format
is_same_q_kv_format = True
def run_iteratively(q, k, v):
# check data pointers
......@@ -1551,7 +1691,10 @@ def get_qkv_layout(
# three chunks of memory, q, k and v, which may be disjoint or consecutive, and
# when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
# check_ptrs_qk=True or check_ptrs_kv=True
qkv_layout = "_".join(list([qkv_format]) * 3)
if is_same_q_kv_format:
qkv_layout = "_".join(list([qkv_format]) * 3)
else:
qkv_layout = q_format + "_" + kv_format + "_" + kv_format
else:
qkv_layout = "not_supported"
......@@ -1565,7 +1708,10 @@ def get_qkv_layout(
if qkv_layout == "not_supported":
raise RuntimeError("The provided qkv memory layout is not supported!")
return qkv_layout, q, k, v
if inference_params is not None and inference_params.is_paged:
qkv_layout = "paged_kv_" + qkv_layout
return qkv_layout, q, k, v, q_format, kv_format
def check_set_window_size(
......
......@@ -91,6 +91,14 @@ def _make_graphed_callables(
sample_args = (sample_args,)
sample_kwargs = (sample_kwargs,)
# Check training/inference
is_training = all(c.training for c in callables)
if not is_training and any(c.training for c in callables):
assert False, (
"make_graphed_callables only supports when modules are all in training or all in"
" inference mode."
)
# Check sizes of args
if _order is None:
assert len(sample_args) == len(callables)
......@@ -255,13 +263,16 @@ def _make_graphed_callables(
outputs, _ = _tree_flatten(func(*args, **kwargs))
for hook in hooks:
hook.remove()
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad),
only_inputs=True,
allow_unused=allow_unused_input,
)
if is_training:
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad),
only_inputs=True,
allow_unused=allow_unused_input,
)
else:
grad_inputs = None
del outputs, grad_inputs
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
......@@ -314,22 +325,23 @@ def _make_graphed_callables(
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward,
)
if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
static_grad_inputs = []
grad_idx = 0
for arg in static_input_surface:
if arg.requires_grad:
if is_training and isinstance(arg, torch.Tensor) and arg.requires_grad:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
......@@ -366,22 +378,23 @@ def _make_graphed_callables(
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward,
)
if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that
# don't require grad. I couldn't think of a slick one-liner for this pattern.
static_grad_inputs = []
grad_idx = 0
for arg in static_input_surface:
if arg.requires_grad:
if is_training and isinstance(arg, torch.Tensor) and arg.requires_grad:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
......@@ -422,7 +435,10 @@ def _make_graphed_callables(
# Copy values from new tensors into static tensors
for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
if (
isinstance(static_input_surface[i], torch.Tensor)
and static_input_surface[i].data_ptr() != inputs[i].data_ptr()
):
static_input_surface[i].copy_(inputs[i])
# Replay forward graph
......
......@@ -79,7 +79,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor,
bias: torch.Tensor,
use_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
......@@ -383,6 +382,17 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_object = weight
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
weightmat,
......@@ -411,7 +421,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.use_bias = bias is not None
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp_shape
......@@ -526,8 +536,11 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad
if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
origin_weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad
ctx.ub_obj_gradout = None
ub_obj_dgrad = None
......@@ -742,10 +755,6 @@ class _LayerNormLinear(torch.autograd.Function):
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total)
# Don't return grad bias if not needed
if not ctx.use_bias:
grad_bias = None
# Synchronize tensor parallel communication
if ln_out_total_work is not None:
ln_out_total_work.wait()
......@@ -827,7 +836,6 @@ class _LayerNormLinear(torch.autograd.Function):
dbeta,
wgrad,
grad_bias,
None, # use_bias
None, # eps
None, # is_first_microbatch
None, # fp8
......@@ -1330,8 +1338,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.eps,
is_first_microbatch,
self.fp8,
......
......@@ -140,10 +140,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_bias: torch.Tensor,
fc1_weight: torch.Tensor,
fc1_bias: torch.Tensor,
use_fc1_bias: bool,
fc2_weight: torch.Tensor,
fc2_bias: torch.Tensor,
use_fc2_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
......@@ -368,7 +366,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 GEMM
# There are 2 fussions possible:
# There are 2 fusions possible:
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
# - bias_gelu_fusion - only for full precision.
# If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer
......@@ -453,8 +451,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
if is_grad_enabled:
else:
if cpu_offloading:
if fp8 and fc1_weight_final is not None:
set_offloading_param(fc1_weight_final, "weight_offloading", True)
......@@ -537,9 +534,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_fc1_bias = use_fc1_bias
ctx.use_fc2_bias = use_fc2_bias
ctx.use_bias = ctx.use_fc1_bias
ctx.use_bias = fc2_bias is not None
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp_shape
......@@ -774,14 +769,13 @@ class _LayerNormMLP(torch.autograd.Function):
quantization_params=None, # wgrad in high precision
layout="NT",
grad=True,
bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
bias=fc2_bias if fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
if fc2_bias_grad is None:
fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_
clear_tensor_data(act_out)
# bias computation
......@@ -1046,11 +1040,9 @@ class _LayerNormMLP(torch.autograd.Function):
dgamma,
dbeta,
fc1_wgrad,
fc1_bias_grad if ctx.use_fc1_bias else None,
None, # use_fc1_bias
fc1_bias_grad if fc1_bias is not None else None,
fc2_wgrad, # pylint: disable=possibly-used-before-assignment
fc2_bias_grad if ctx.use_fc2_bias else None,
None, # use_fc2_bias
fc2_bias_grad,
None, # eps
None, # is_first_microbatch
None, # fp8
......@@ -1471,10 +1463,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_bias,
fc1_weight,
fc1_bias,
self.use_bias,
fc2_weight,
fc2_bias,
self.apply_bias and not self.gemm_bias_unfused_add,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.eps,
is_first_microbatch,
self.fp8,
......
......@@ -291,6 +291,17 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_object = weight
# TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat,
......@@ -392,9 +403,11 @@ class _Linear(torch.autograd.Function):
else None
)
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, weight.requires_grad)
weight.main_grad = main_grad
if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad
# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment