Commit 081057de authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-ori

parents 7cf5d5c4 ba41cc90
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
"""Attention backend utils""" """Attention backend utils"""
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
import numpy as np import numpy as np
import torch import torch
...@@ -11,6 +13,7 @@ import torch ...@@ -11,6 +13,7 @@ import torch
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState) AttentionState)
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
...@@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens( ...@@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens(
return (num_prefill_query_tokens, num_prefill_kv_tokens, return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) num_decode_query_tokens)
@dataclass
class MLADims:
q_lora_rank: Optional[int]
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)
...@@ -10,6 +10,9 @@ import vllm.envs as envs ...@@ -10,6 +10,9 @@ import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -87,6 +90,7 @@ class Attention(nn.Module): ...@@ -87,6 +90,7 @@ class Attention(nn.Module):
# FlashAttn doesn't support quantizing the kv-cache only # FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well. # but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32) self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention # We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer) # backends that don't support tensors (Flashinfer)
...@@ -329,17 +333,54 @@ class MultiHeadAttention(nn.Module): ...@@ -329,17 +333,54 @@ class MultiHeadAttention(nn.Module):
return out.reshape(bsz, q_len, -1) return out.reshape(bsz, q_len, -1)
def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
def unified_attention( def unified_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
layer_name: str, layer_name: str,
) -> torch.Tensor: ) -> torch.Tensor:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine] kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
def unified_attention_fake( def unified_attention_fake(
...@@ -367,6 +408,7 @@ def unified_attention_with_output( ...@@ -367,6 +408,7 @@ def unified_attention_with_output(
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
) -> None: ) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
...@@ -379,6 +421,8 @@ def unified_attention_with_output( ...@@ -379,6 +421,8 @@ def unified_attention_with_output(
attn_metadata, attn_metadata,
output=output) output=output)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_attention_with_output_fake( def unified_attention_with_output_fake(
query: torch.Tensor, query: torch.Tensor,
......
...@@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata: ...@@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata:
block_usage: Optional[torch.Tensor] block_usage: Optional[torch.Tensor]
block_indices: Optional[torch.Tensor] block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor] block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor] block_groups: Optional[torch.Tensor]
......
...@@ -16,831 +16,778 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8 ...@@ -16,831 +16,778 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8
# To check compatibility # To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5) IS_TURING = current_platform.get_device_capability() == (7, 5)
if triton.__version__ >= "2.1.0":
# Here's an example autotuner config for this kernel. This config does provide
@triton.jit # a performance improvement, but dramatically increases first call latency in
def _fwd_kernel( # triton 3.2. Because of this tradeoff, it's currently commented out.
Q, # @triton.autotune(
K, # configs=[
V, # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
K_cache, # "num_unroll_cache": 4, \
V_cache, # "num_unroll_request": 1 } | \
B_Loc, # ({"kpack": 2, "waves_per_eu": 2} \
sm_scale, # if current_platform.is_rocm() else {}), \
k_scale, # num_warps=4, \
v_scale, # num_stages=1)
B_Start_Loc, # ],
B_Seqlen, # key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
block_size, # )
x, @triton.jit
Out, def _fwd_kernel(Q,
stride_b_loc_b, K,
stride_b_loc_s, V,
stride_qbs, K_cache,
stride_qh, V_cache,
stride_qd, B_Loc,
stride_kbs, sm_scale,
stride_kh, k_scale,
stride_kd, v_scale,
stride_vbs, B_Start_Loc,
stride_vh, B_Seqlen,
stride_vd, x: tl.constexpr,
stride_obs, Out,
stride_oh, stride_b_loc_b,
stride_od, stride_b_loc_s,
stride_k_cache_bs, stride_qbs,
stride_k_cache_h, stride_qh,
stride_k_cache_d, stride_qd,
stride_k_cache_bl, stride_kbs,
stride_k_cache_x, stride_kh,
stride_v_cache_bs, stride_kd,
stride_v_cache_h, stride_vbs,
stride_v_cache_d, stride_vh,
stride_v_cache_bl, stride_vd,
num_queries_per_kv: int, stride_obs,
IN_PRECISION: tl.constexpr, stride_oh,
BLOCK_M: tl.constexpr, stride_od,
BLOCK_DMODEL: tl.constexpr, # head size stride_k_cache_bs,
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 stride_k_cache_h,
BLOCK_N: tl.constexpr, stride_k_cache_d,
SLIDING_WINDOW: tl.constexpr, stride_k_cache_bl: tl.constexpr,
SKIP_DECODE: tl.constexpr, stride_k_cache_x,
): stride_v_cache_bs,
stride_v_cache_h,
cur_batch = tl.program_id(0) stride_v_cache_d,
cur_head = tl.program_id(1) stride_v_cache_bl,
start_m = tl.program_id(2) num_queries_per_kv: tl.constexpr,
IN_PRECISION: tl.constexpr,
cur_kv_head = cur_head // num_queries_per_kv BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) BLOCK_DMODEL_PADDED: tl.constexpr,
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) BLOCK_SIZE: tl.constexpr,
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) BLOCK_N: tl.constexpr,
cur_batch_query_len = (cur_batch_in_all_stop_index - SLIDING_WINDOW: tl.constexpr,
cur_batch_in_all_start_index) num_unroll_cache: tl.constexpr,
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len num_unroll_request: tl.constexpr,
SKIP_DECODE: tl.constexpr,
if SKIP_DECODE and cur_batch_query_len == 1: MAX_Q_LEN: tl.constexpr = 0,
return MAX_CTX_LEN: tl.constexpr = 0):
# start position inside of the query cur_batch = tl.program_id(0)
# generally, N goes over kv, while M goes over query_len cur_head = tl.program_id(1)
block_start_loc = BLOCK_M * start_m start_m = tl.program_id(2)
# initialize offsets cur_kv_head = cur_head // num_queries_per_kv
# [N]; starts at 0
offs_n = tl.arange(0, BLOCK_N) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
# [D]; starts at 0 cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
# [M]; starts at current position in query cur_batch_query_len = (cur_batch_in_all_stop_index -
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) cur_batch_in_all_start_index)
# [M,D] cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + if SKIP_DECODE and cur_batch_query_len == 1:
cur_head * stride_qh + offs_d[None, :] * stride_qd)
dim_mask = tl.where(
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
0).to(tl.int1) # [D]
q = tl.load(Q + off_q,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len),
other=0.0) # [M,D]
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M]
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M]
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
dtype=tl.float32) # [M,D]
# compute query against context (no causal mask here)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0) # [N]
# [D,N]
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
# [N,D]
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
-10000)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1) # [M]
p = tl.exp(qk - m_ij[:, None]) # [M,N]
l_ij = tl.sum(p, 1) # [M]
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij) # [M]
alpha = tl.exp(m_i - m_i_new) # [M]
beta = tl.exp(m_ij - m_i_new) # [M]
l_i_new = alpha * l_i + beta * l_ij # [M]
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v_load = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# # update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_query_len),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk *= sm_scale
# apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
if SLIDING_WINDOW > 0:
qk = tl.where(
offs_m[:, None] - (start_n + offs_n[None, :])
< SLIDING_WINDOW, qk, -10000)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_query_len),
other=0.0)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len))
return return
@triton.jit # start position inside of the query
def _fwd_kernel_flash_attn_v2( # generally, N goes over kv, while M goes over query_len
Q, block_start_loc = BLOCK_M * start_m
K,
V, # initialize offsets
K_cache, # [BLOCK_SIZE]; starts at 0
V_cache, offs_bs_n = tl.arange(0, BLOCK_SIZE)
B_Loc, # [N]; starts at 0
sm_scale, offs_n = tl.arange(0, BLOCK_N)
B_Start_Loc, # [D]; starts at 0
B_Seqlen, offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
B_Ctxlen, # [M]; starts at current position in query
block_size, offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
x, # [M,D]
Out, off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
stride_b_loc_b, cur_head * stride_qh + offs_d[None, :] * stride_qd)
stride_b_loc_s,
stride_qbs, dim_mask = tl.where(
stride_qh, tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
stride_qd, 0).to(tl.int1) # [D]
stride_kbs,
stride_kh, q = tl.load(Q + off_q,
stride_kd, mask=dim_mask[None, :] &
stride_vbs, (offs_m[:, None] < cur_batch_query_len),
stride_vh, other=0.0) # [M,D]
stride_vd,
stride_obs, # initialize pointer to m and l
stride_oh, m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
stride_od, l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
stride_k_cache_bs, acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
stride_k_cache_h,
stride_k_cache_d, # compute query against context (no causal mask here)
stride_k_cache_bl, for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \
stride_k_cache_x, loop_unroll_factor=num_unroll_cache):
stride_v_cache_bs, start_n = tl.multiple_of(start_n, BLOCK_SIZE)
stride_v_cache_h, # -- compute qk ----
stride_v_cache_d, bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
stride_v_cache_bl, (start_n // BLOCK_SIZE) * stride_b_loc_s)
num_queries_per_kv: int, # [D,BLOCK_SIZE]
BLOCK_M: tl.constexpr, off_k = (
BLOCK_DMODEL: tl.constexpr, bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
BLOCK_N: tl.constexpr, (offs_d[:, None] // x) * stride_k_cache_d +
): ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl +
cur_batch = tl.program_id(0) (offs_d[:, None] % x) * stride_k_cache_x)
cur_head = tl.program_id(1)
start_m = tl.program_id(2) # [BLOCK_SIZE,D]
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head = cur_head // num_queries_per_kv cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) offs_bs_n[:, None] * stride_v_cache_bl)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
block_start_loc = BLOCK_M * start_m k_load = tl.load(
K_cache + off_k,
# initialize offsets mask=dim_mask[:, None] &
offs_n = tl.arange(0, BLOCK_N) ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
offs_d = tl.arange(0, BLOCK_DMODEL) other=0.0) # [D,N]
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) else:
off_q = ( k_load = tl.load(K_cache + off_k)
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd) if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
q = tl.load(Q + off_q, else:
mask=offs_m[:, None] k = k_load
< cur_batch_seq_len - cur_batch_ctx_len,
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_bs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
(start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk,
-10000)
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha[:, None]
# update acc
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
v_load = tl.load(
V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
else:
v_load = tl.load(V_cache + off_v)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# # update m_i and l_i
l_i = l_i * alpha + l_ij
m_i = m_ij
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask)
for start_n in tl.range(0, \
block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \
loop_unroll_factor=num_unroll_request):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_query_len),
other=0.0) other=0.0)
# # initialize pointer to m and l qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) qk *= sm_scale
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
for start_n in range(0, cur_batch_ctx_len, BLOCK_N): float("-inf"))
start_n = tl.multiple_of(start_n, BLOCK_N) if SLIDING_WINDOW > 0:
# -- compute qk ---- qk = tl.where(
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
((start_n + offs_n) // block_size) * stride_b_loc_s, qk, -10000)
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0) # compute running maximum
off_k = (bn[None, :] * stride_k_cache_bs + m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
cur_kv_head * stride_k_cache_h + p = tl.exp(qk - m_ij[:, None])
(offs_d[:, None] // x) * stride_k_cache_d + l_ij = tl.sum(p, axis=1)
((start_n + offs_n[None, :]) % block_size) * alpha = tl.exp(m_i - m_ij)
stride_k_cache_bl + acc = acc * alpha[:, None]
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = ( # update acc
bn[:, None] * stride_v_cache_bs + v = tl.load(v_ptrs +
cur_kv_head * stride_v_cache_h + (cur_batch_in_all_start_index + start_n) * stride_vbs,
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :])
< cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None])
< cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# acc /= l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@triton.jit
def _fwd_kernel_alibi(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
k_scale,
v_scale,
B_Start_Loc,
B_Seqlen,
Alibi_slopes,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
IN_PRECISION: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
SKIP_DECODE: tl.constexpr,
):
# attn_bias[]
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
cur_batch_query_len = (cur_batch_in_all_stop_index -
cur_batch_in_all_start_index)
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
if SKIP_DECODE and cur_batch_query_len == 1:
return
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
dim_mask = tl.where(
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
q = tl.load(Q + off_q,
mask=dim_mask[None, :] & mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), ((start_n + offs_n[:, None]) < cur_batch_query_len),
other=0.0)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# update m_i and l_i
l_i = l_i * alpha + l_ij
m_i = m_ij
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
return
@triton.jit
def _fwd_kernel_flash_attn_v2(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :])
< cur_batch_seq_len - cur_batch_ctx_len,
other=0.0) other=0.0)
# # initialize pointer to m and l qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") qk += tl.dot(q, k)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) qk *= sm_scale
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange( # -- compute m_ij, p, l_ij
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len m_ij = tl.max(qk, 1)
alibi_start_k = 0 m_i_new = tl.maximum(m_i, m_ij)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N): p = tl.math.exp(qk - m_i_new[:, None])
start_n = tl.multiple_of(start_n, BLOCK_N) l_ij = tl.sum(p, 1)
# -- compute qk ---- # -- update m_i and l_i
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s, alpha = tl.math.exp(m_i - m_i_new)
mask=(start_n + offs_n) < cur_batch_ctx_len, l_i_new = alpha * l_i + l_ij
other=0) # -- update output accumulator --
off_k = (bn[None, :] * stride_k_cache_bs + # scale p
cur_kv_head * stride_k_cache_h + # scale acc
(offs_d[:, None] // x) * stride_k_cache_d + acc_scale = alpha
((start_n + offs_n[None, :]) % block_size) * # acc_scale = l_i / l_i_new * alpha
stride_k_cache_bl + acc = acc * acc_scale[:, None]
(offs_d[:, None] % x) * stride_k_cache_x) # update acc
off_v = ( v = tl.load(v_ptrs +
bn[:, None] * stride_v_cache_bs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
cur_kv_head * stride_v_cache_h + mask=(start_n + offs_n[:, None])
offs_d[None, :] * stride_v_cache_d + < cur_batch_seq_len - cur_batch_ctx_len,
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) other=0.0)
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] & p = p.to(v.dtype)
((start_n + offs_n[None, :]) < cur_batch_ctx_len), acc += tl.dot(p, v)
other=0.0) # [D,N] # update m_i and l_i
l_i = l_i_new
if k_load.dtype.is_fp8(): m_i = m_i_new
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else: # acc /= l_i[:, None]
k = k_load # initialize pointers to output
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) cur_head * stride_oh + offs_d[None, :] * stride_od)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) out_ptrs = Out + off_o
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, tl.store(out_ptrs,
float("-inf")) acc,
qk *= sm_scale mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope @triton.jit
alibi = tl.where( def _fwd_kernel_alibi(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), Q,
alibi, float("-inf")) K,
qk += alibi V,
alibi_start_k += BLOCK_N K_cache,
V_cache,
# -- compute m_ij, p, l_ij B_Loc,
m_ij = tl.max(qk, 1) sm_scale,
m_i_new = tl.maximum(m_i, m_ij) k_scale,
p = tl.math.exp(qk - m_i_new[:, None]) v_scale,
l_ij = tl.sum(p, 1) B_Start_Loc,
# -- update m_i and l_i B_Seqlen,
Alibi_slopes,
alpha = tl.math.exp(m_i - m_i_new) block_size,
l_i_new = alpha * l_i + l_ij x,
# -- update output accumulator -- Out,
# scale p stride_b_loc_b,
# scale acc stride_b_loc_s,
acc_scale = alpha stride_qbs,
# acc_scale = l_i / l_i_new * alpha stride_qh,
acc = acc * acc_scale[:, None] stride_qd,
# update acc stride_kbs,
v_load = tl.load(V_cache + off_v, stride_kh,
mask=dim_mask[None, :] & stride_kd,
((start_n + offs_n[:, None]) < cur_batch_ctx_len), stride_vbs,
other=0.0) stride_vh,
if v_load.dtype.is_fp8(): stride_vd,
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) stride_obs,
else: stride_oh,
v = v_load stride_od,
p = p.to(v.dtype) stride_k_cache_bs,
stride_k_cache_h,
acc = tl.dot(p, v, acc=acc, input_precision='ieee') stride_k_cache_d,
# update m_i and l_i stride_k_cache_bl,
l_i = l_i_new stride_k_cache_x,
m_i = m_i_new stride_v_cache_bs,
stride_v_cache_h,
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + stride_v_cache_d,
offs_d[:, None] * stride_kd) stride_v_cache_bl,
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + num_queries_per_kv: int,
offs_d[None, :] * stride_vd) IN_PRECISION: tl.constexpr,
k_ptrs = K + off_k BLOCK_M: tl.constexpr,
v_ptrs = V + off_v BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
block_mask = tl.where( BLOCK_N: tl.constexpr,
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) SKIP_DECODE: tl.constexpr,
):
# init alibi # attn_bias[]
alibi_slope = tl.load(Alibi_slopes + cur_head) cur_batch = tl.program_id(0)
alibi_start_q = tl.arange( cur_head = tl.program_id(1)
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len start_m = tl.program_id(2)
alibi_start_k = cur_batch_ctx_len
# # init debugger cur_kv_head = cur_head // num_queries_per_kv
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N) # cur_batch_seq_len: the length of prompts
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] # cur_batch_ctx_len: the length of prefix
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): # cur_batch_in_all_start_index: the start id of the dim=0
start_n = tl.multiple_of(start_n, BLOCK_N) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
# -- compute qk ---- cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
k = tl.load(k_ptrs + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
(cur_batch_in_all_start_index + start_n) * stride_kbs, cur_batch_query_len = (cur_batch_in_all_stop_index -
mask=dim_mask[:, None] & cur_batch_in_all_start_index)
((start_n + offs_n[None, :]) cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0) if SKIP_DECODE and cur_batch_query_len == 1:
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None])
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
return return
@torch.inference_mode() block_start_loc = BLOCK_M * start_m
def context_attention_fwd(q,
k, # initialize offsets
v, offs_n = tl.arange(0, BLOCK_N)
o, offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
kv_cache_dtype: str, offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
k_cache, off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
v_cache, cur_head * stride_qh + offs_d[None, :] * stride_qd)
b_loc,
b_start_loc, dim_mask = tl.where(
b_seq_len, tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
max_seq_len,
max_input_len, q = tl.load(Q + off_q,
k_scale: torch.Tensor, mask=dim_mask[None, :] &
v_scale: torch.Tensor, (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
alibi_slopes=None, other=0.0)
sliding_window=None,
sm_scale=None, # # initialize pointer to m and l
skip_decode=False): m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
q_dtype_is_f32 = q.dtype is torch.float32 acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = 0
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v_load = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
# init alibi
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = cur_batch_ctx_len
# # init debugger
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] & ((start_n + offs_n[None, :])
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] & ((start_n + offs_n[:, None])
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
return
@torch.inference_mode()
def context_attention_fwd(q,
k,
v,
o,
kv_cache_dtype: str,
k_cache,
v_cache,
b_loc,
b_start_loc,
b_seq_len,
max_seq_len,
max_input_len,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
skip_decode=False):
q_dtype_is_f32 = q.dtype is torch.float32
# Turing does have tensor core for float32 multiplication
# use ieee as fallback for triton kernels work. There is also
# warning on vllm/config.py to inform users this fallback
# implementation
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert (k_cache.dtype == torch.uint8)
assert (v_cache.dtype == torch.uint8)
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
k_cache = k_cache.view(target_dtype)
v_cache = v_cache.view(target_dtype)
if (k_cache.dtype == torch.uint8
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
raise ValueError("kv_cache_dtype='auto' unsupported for\
FP8 KV Cache prefill kernel")
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)
if sm_scale is None:
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]
assert batch + 1 == len(b_start_loc)
# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
if alibi_slopes is not None:
# need to reduce num. blocks when using fp32 # need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory # due to increased use of GPU shared memory
# if q.dtype is torch.float32: # if q.dtype is torch.float32:
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
# batch, head,
# Turing does have tensor core for float32 multiplication grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
# use ieee as fallback for triton kernels work. There is also _fwd_kernel_alibi[grid](
# warning on vllm/config.py to inform users this fallback
# implementation
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert (k_cache.dtype == torch.uint8)
assert (v_cache.dtype == torch.uint8)
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
k_cache = k_cache.view(target_dtype)
v_cache = v_cache.view(target_dtype)
if (k_cache.dtype == torch.uint8
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
raise ValueError("kv_cache_dtype='auto' unsupported for\
FP8 KV Cache prefill kernel")
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)
if sm_scale is None:
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]
assert batch + 1 == len(b_start_loc)
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
if alibi_slopes is not None:
_fwd_kernel_alibi[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
k_scale,
v_scale,
b_start_loc,
b_seq_len,
alibi_slopes,
v_cache.shape[3],
k_cache.shape[4],
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SKIP_DECODE=skip_decode,
num_warps=NUM_WARPS,
num_stages=1,
)
return
_fwd_kernel[grid](
q, q,
k, k,
v, v,
...@@ -852,6 +799,7 @@ if triton.__version__ >= "2.1.0": ...@@ -852,6 +799,7 @@ if triton.__version__ >= "2.1.0":
v_scale, v_scale,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
alibi_slopes,
v_cache.shape[3], v_cache.shape[3],
k_cache.shape[4], k_cache.shape[4],
o, o,
...@@ -886,9 +834,69 @@ if triton.__version__ >= "2.1.0": ...@@ -886,9 +834,69 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL=Lk, BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode, SKIP_DECODE=skip_decode,
num_warps=NUM_WARPS, num_warps=NUM_WARPS,
num_stages=1, num_stages=1,
) )
return return
max_seq_len = 0 if max_seq_len is None else max_seq_len
extra_kargs = {}
if current_platform.is_rocm():
extra_kargs = {"kpack": 2, "waves_per_eu": 2}
grid = lambda META: (batch, head,
triton.cdiv(max_input_len, META["BLOCK_M"]))
_fwd_kernel[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
k_scale,
v_scale,
b_start_loc,
b_seq_len,
k_cache.shape[4],
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size]
BLOCK_SIZE=v_cache.shape[3],
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode,
BLOCK_M=128,
BLOCK_N=64,
num_unroll_cache=4,
num_unroll_request=1,
num_warps=4,
num_stages=1,
**extra_kargs)
return
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
max_block_per_batch: int,
device: torch.device) -> tuple[torch.Tensor, ...]:
paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch,
dtype=torch.int32,
device=device)
paged_kv_indptr = torch.zeros(max_batch_size + 1,
dtype=torch.int32,
device=device)
paged_kv_last_page_lens = torch.full((max_batch_size, ),
block_size,
dtype=torch.int32)
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens
def aiter_mla_decode_fwd(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
sm_scale: float,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
logit_cap: float = 0.0,
):
from aiter.mla import mla_decode_fwd
mla_decode_fwd(q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
kv_indptr,
kv_indices,
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap)
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import aiter as rocm_aiter
import torch
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.platforms import current_platform
from vllm.utils import cdiv
FP8_DTYPE = current_platform.fp8_dtype()
class AITERPagedAttention(PagedAttention):
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, k_scale,
v_scale)
else:
kv_cache_torch_dtype = (FP8_DTYPE
if "fp8" in kv_cache_dtype else torch.int8)
key_cache = key_cache.view(kv_cache_torch_dtype)
value_cache = value_cache.view(kv_cache_torch_dtype)
rocm_aiter.reshape_and_cache_with_pertoken_quant(
key, value, key_cache, value_cache, k_scale, v_scale,
slot_mapping.flatten(), True)
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
return PagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
kv_cache_dtype=kv_cache_dtype,
num_kv_heads=num_kv_heads,
scale=scale,
alibi_slopes=alibi_slopes,
k_scale=k_scale,
v_scale=v_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
blocksparse_block_size=blocksparse_block_size,
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
if "fp8" in kv_cache_dtype:
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (blocksparse_block_size > 0 and
blocksparse_block_size % block_size == 0), \
(f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables.")
output = torch.empty_like(query)
block_size = value_cache.shape[3]
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
seq_lens, max_num_blocks_per_seq, k_scale,
v_scale, output)
return output
...@@ -39,11 +39,12 @@ is_hip_ = current_platform.is_rocm() ...@@ -39,11 +39,12 @@ is_hip_ = current_platform.is_rocm()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance # Only print the following warnings when triton version < 3.2.0.
# and accuracy. # The issue won't affect performance or accuracy.
logger.warning( if triton.__version__ < '3.2.0':
"The following error message 'operation scheduled before its operands' " logger.warning(
"can be ignored.") "The following error message 'operation scheduled before its operands' "
"can be ignored.")
@triton.jit @triton.jit
......
#!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Fused Attention Fused Attention
=============== ===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao This is a Triton implementation of the Flash Attention v2 algorithm
(https://tridao.me/publications/flash2/flash2.pdf) See https://tridao.me/publications/flash2/flash2.pdf
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported: Credits:
AMD Triton kernels team
OpenAI kernel team
1) Fwd with causal masking Currently only the forward kernel is supported, and contains these features:
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims 1) Fwd with causal masking
2) Arbitrary Q and KV sequence lengths
3) Arbitrary head sizes
4) Multi and grouped query attention
5) Variable sequence lengths
6) ALiBi and matrix bias
7) FP8 support
""" """
from typing import Optional
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
torch_dtype: tl.constexpr = torch.float16 from vllm import _custom_ops as ops
from vllm.platforms import current_platform
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
default_eight_bit_dtype_triton = tl.float8e4b8
default_eight_bit_dtype_torch = current_platform.fp8_dtype()
default_float8_info = torch.finfo(default_eight_bit_dtype_torch)
FP8_MIN = triton.language.constexpr(default_float8_info.min)
# According to https://github.com/vllm-project/vllm/blob/main
# /csrc/quantization/utils.cuh#L31,
# need to make the max for the uz datatype be 224.0 for accuracy reasons.
FP8_MAX = triton.language.constexpr(
default_float8_info.max if default_eight_bit_dtype_torch !=
torch.float8_e4m3fnuz else 224.0)
class MetaData:
cu_seqlens_q = None
cu_seqlens_k = None
max_seqlens_q = 0
max_seqlens_k = 0
bias = None
alibi_slopes = None
causal = False
num_contexts = 0
varlen = False
eight_bit = False
layout = None
return_encoded_softmax = False
eight_bit_dtype_triton = default_eight_bit_dtype_triton
eight_bit_dtype_torch = default_eight_bit_dtype_torch
output_dtype = None
# Note about layouts:
#
# thd - [num_tokens, num_heads, head_size]
# bshd - [batch_size, seq_len, num_heads, head_size]
# bhsd - [batch_size, num_heads, seq_len, head_size]
#
# This is for each tensor, all tensors must have same layout.
# Q can have num_heads and seq_len differ from from K and V,
# however K and V must agree on this.
#
# Notes about varlen and bias:
# Only one or the other is implemented, meaning can't combine
# both varlen and bias right now.
#
# Note about quantization:
# Only 8-bit quantization supported (for now) and specifically fp8.
# Scales must be tensors.
# o_scale: This is 'output scaling', but comes from parameter called
# 'input_scale', this is applied to the output from the kernel.
# o_scale should be None if none of the other quantization parameters
# are used.
#
# NOTE: Object is in a tentatively good state after initialized, however,
# to verify, call check_args(q,k,v,o) where o is the output tensor.
def __init__(
self,
sm_scale=1.0,
layout=None, # layout can be 'bshd', 'bhsd', or 'thd'
output_dtype=None,
max_seqlens_q=0,
max_seqlens_k=0,
# varlen params
cu_seqlens_q=None, # only 'thd' layout supported for varlen
cu_seqlens_k=None,
# quant params
q_descale=None,
k_descale=None,
v_descale=None,
p_scale=None,
o_scale=None,
# bias params
bias=None, # varlen not implemented for bias
seqlen_q=None,
seqlen_k=None,
# alibi params
alibi_slopes=None,
alibi_batch=None,
alibi_nheads=None,
# causal
causal=None,
):
self.sm_scale = sm_scale
self.output_dtype = output_dtype
self.max_seqlens_q = max_seqlens_q
self.max_seqlens_k = max_seqlens_k
self.layout = layout
if cu_seqlens_q is not None or cu_seqlens_k is not None:
assert cu_seqlens_q is not None and cu_seqlens_k is not None
assert layout is None or layout not in [
'bshd', 'bhsd'
], "Varlen only implemented for thd layout"
self.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
quant_params = [q_descale, k_descale, v_descale, p_scale, o_scale]
if any(x is not None for x in quant_params):
p_descale = 1.0 / p_scale if p_scale is not None else None
self.set_eight_bit_params(q_descale, k_descale, v_descale, p_scale,
p_descale, o_scale)
if bias is not None:
self.need_bias(bias, seqlen_q, seqlen_k)
if alibi_slopes is not None:
self.need_alibi(alibi_slopes, alibi_batch, alibi_nheads)
if causal is not None and causal:
self.need_causal()
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
self.varlen = True
self.layout = 'thd'
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_k = cu_seqlens_k
# Without "varlen", there should still be one sequence.
assert len(cu_seqlens_q) >= 2
assert len(cu_seqlens_q) == len(cu_seqlens_k)
self.num_contexts = len(cu_seqlens_q) - 1
for i in range(0, self.num_contexts):
self.max_seqlens_q = max(
cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(),
self.max_seqlens_q)
self.max_seqlens_k = max(
cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(),
self.max_seqlens_k)
def set_eight_bit_params(self, q_descale, k_descale, v_descale, p_scale,
p_descale, o_scale):
self.eight_bit = True
self.q_descale = q_descale
self.k_descale = k_descale
self.v_descale = v_descale
self.p_scale = p_scale
self.p_descale = p_descale
self.o_scale = o_scale
self.use_p_scale = (p_scale is not None) and (
p_descale is not None) and (v_descale is not None)
self.eight_bit_kv = ((q_descale is None) and (k_descale is not None)
and (v_descale is not None))
self.eight_bit_dtype_torch = default_eight_bit_dtype_torch
def need_bias(self, bias, seqlen_q, seqlen_k):
assert bias is not None
assert bias.is_cuda
assert bias.dim() == 4
assert bias.shape[0] == 1
assert bias.shape[2:] == (seqlen_q, seqlen_k)
self.bias = bias
def need_alibi(self, alibi_slopes, batch, nheads):
assert alibi_slopes.is_cuda
assert alibi_slopes.dim() == 2
assert alibi_slopes.shape[0] == batch
assert alibi_slopes.shape[1] == nheads
self.alibi_slopes = alibi_slopes
def need_causal(self):
self.causal = True
def check_args(self, q, k, v, o):
assert q.dim() == k.dim() and q.dim() == v.dim()
batch, nheads_q, nheads_k, head_size = get_shape_from_layout(
q, k, self)
if self.varlen:
assert q.dim() == 3
assert self.cu_seqlens_q is not None
assert self.cu_seqlens_k is not None
assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
# TODO: Remove once bias is supported with varlen
assert self.bias is None
assert not self.return_encoded_softmax
else:
assert q.dim() == 4
assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
if self.eight_bit:
if self.eight_bit_kv:
assert (v.dtype == k.dtype
and k.dtype == self.eight_bit_dtype_torch)
assert q.dtype != k.dtype
assert (self.v_descale is not None) and (self.k_descale
is not None)
else:
assert (q.dtype == k.dtype and q.dtype == v.dtype
and q.dtype == self.eight_bit_dtype_torch)
assert (self.q_descale
is not None) and (self.k_descale
is not None) and (self.v_descale
is not None)
if self.use_p_scale:
assert (self.p_scale is not None) and (self.p_descale
is not None)
else:
assert (q.dtype == k.dtype) and (q.dtype == v.dtype)
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
assert self.layout is not None
assert self.layout == 'thd' or not self.varlen
@triton.jit @triton.jit
...@@ -38,40 +244,85 @@ def max_fn(x, y): ...@@ -38,40 +244,85 @@ def max_fn(x, y):
return tl.math.max(x, y) return tl.math.max(x, y)
# Convenience function to load with optional boundary checks.
# "First" is the major dim, "second" is the minor dim.
@triton.jit @triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): def masked_load(ptrs, offset_first, offset_second, boundary_first,
ms = tl.arange(0, m) boundary_second):
ns = tl.arange(0, n) if offset_first is not None and offset_second is not None:
return philox_offset + ms[:, None] * stride + ns[None, :] mask = (offset_first[:, None] < boundary_first) & \
(offset_second[None, :] < boundary_second)
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_first is not None:
mask = offset_first[:, None] < boundary_first
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_second is not None:
mask = offset_second[None, :] < boundary_second
tensor = tl.load(ptrs, mask=mask, other=0.0)
else:
tensor = tl.load(ptrs)
return tensor
@triton.jit @triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): def compute_alibi_block(alibi_slope,
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, seqlen_q,
stride).to(tl.uint32) seqlen_k,
# TODO: use tl.randint for better performance offs_m,
return tl.rand(philox_seed, rng_offsets) offs_n,
transpose=False):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to
# the bottom right of the attention matrix
# for casual mask we want something like this where (1 is kept and 0 is
# masked)
# seqlen_q = 2 and seqlen_k = 5
# 1 1 1 1 0
# 1 1 1 1 1
# seqlen_q = 5 and seqlen_k = 2
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# for alibi the diagonal is 0 indicating no penalty for attending to that
# spot and increasing penalty for attending further from the diagonal
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5,
# offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
# 1. offs_m[:,None] = [[0],
# [1],
# 2. offs_m[:,None] + seqlen_k = [[5],
# [6],
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
# [4],
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] =
# [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], [4], [ 4, 3, 2, 1, 0]]
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
# [ -4, -3, -2, -1, 0]],
relative_pos_block = (offs_m[:, None] + seqlen_k - seqlen_q -
offs_n[None, :])
alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
if transpose:
return alibi_block.T
else:
return alibi_block
@triton.jit def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k):
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): q_idx = torch.arange(seqlen_q, dtype=torch.int32,
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
stride) k_idx = torch.arange(seqlen_k, dtype=torch.int32,
rng_keep = rng_output > dropout_p device="cuda").unsqueeze(0) # (1, N_CTX_K)
return rng_keep relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q -
k_idx) # (N_CTX_Q, N_CTX_K)
return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(
-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)
@triton.jit @triton.jit
def load_fn(block_ptr, first, second, pad): def quant_fp8(x, scale):
if first and second: x *= scale
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) x = tl.clamp(x, FP8_MIN, FP8_MAX)
elif first: return x
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit @triton.jit
...@@ -80,58 +331,68 @@ def _attn_fwd_inner( ...@@ -80,58 +331,68 @@ def _attn_fwd_inner(
l_i, l_i,
m_i, m_i,
q, q,
K_block_ptr, k_ptrs,
V_block_ptr, v_ptrs,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m, start_m,
actual_seqlen_k, actual_seqlen_k,
dropout_p, actual_seqlen_q,
philox_seed, philox_seed,
batch_philox_offset, batch_philox_offset,
encoded_softmax_block_ptr, encoded_sm_ptrs,
block_min, block_min,
block_max, block_max,
offs_n_causal, offs_n_causal,
masked_blocks, masked_blocks,
n_extra_tokens, n_extra_tokens,
bias_ptr, alibi_slope,
q_descale,
k_descale,
v_descale,
p_scale,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr, OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr, SHOULD_PRE_LOAD_V: tl.constexpr,
MASK_STEPS: tl.constexpr, SHOULD_MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_PADDED_HEAD: tl.constexpr,
PADDED_HEAD: tl.constexpr, IS_ACTUAL_BLOCK_DMODEL: tl.constexpr,
QK_SCALE: tl.constexpr,
IS_EIGHT_BIT_GEMM: tl.constexpr,
USE_P_SCALE: tl.constexpr,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
): ):
# loop over k, v, and update accumulator # loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N): for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if # For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range. # we load all BLOCK_N. For others, the blocks are all within range.
k = load_fn( k_offs_n = start_n + tl.arange(0,
K_block_ptr, BLOCK_N) if SHOULD_MASK_STEPS else None
PADDED_HEAD, k_offs_k = None if not USE_PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)
MASK_STEPS and (n_extra_tokens != 0), k = masked_load(k_ptrs, k_offs_k, k_offs_n, IS_ACTUAL_BLOCK_DMODEL,
"zero", actual_seqlen_k)
) if SHOULD_PRE_LOAD_V:
if PRE_LOAD_V: # We can use the same offsets as k, just with dims transposed.
v = load_fn( v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k,
V_block_ptr, IS_ACTUAL_BLOCK_DMODEL)
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need # We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n # to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block. # TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS: # noqa: SIM102 if SHOULD_MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to # If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size # mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not
# if not is_modulo_mn. last step might get wasted but that is okay. # is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case. # check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M], boundary_m = tl.full([BLOCK_M],
...@@ -144,167 +405,276 @@ def _attn_fwd_inner( ...@@ -144,167 +405,276 @@ def _attn_fwd_inner(
causal_boundary = start_n + offs_n_causal causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf")) qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ---- # -- compute qk ----
qk += tl.dot(q, k) if IS_EIGHT_BIT_GEMM:
if bias_ptr is not None: qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) *
bias = load_fn(bias_ptr, False, MASK_STEPS QK_SCALE)
and (n_extra_tokens != 0), "zero") else:
# While bias is added after multiplying qk with sm_scale, our if IS_EIGHT_BIT_KV:
# optimization to use 2^x instead of e^x results in an additional k = (k * k_descale).to(q.type.element_ty)
# scale factor of log2(e) which we must also multiply the bias with. qk += (tl.dot(q, k) * QK_SCALE)
qk += bias * 1.44269504089
if bias_ptrs is not None:
bias_offs_n = start_n + tl.arange(
0, BLOCK_N) if SHOULD_MASK_STEPS else None
bias = masked_load(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q,
actual_seqlen_k)
# While bias is added after multiplying qk with sm_scale,
# our optimization to use 2^x instead of e^x results in an
# additional scale factor of log2(e) which we must also multiply
# the bias with.
qk += (bias * 1.44269504089)
if alibi_slope is not None:
# Compute the global position of each token within the sequence
global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
global_n_positions = start_n + tl.arange(0, BLOCK_N)
alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q,
actual_seqlen_k,
global_m_positions,
global_n_positions)
qk += (alibi_block * 1.44269504089) # scale factor of log2(e)
# softmax
m_ij = tl.maximum(m_i, tl.max(qk, 1)) m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None] qk = qk - m_ij[:, None]
p = tl.math.exp2(qk) p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout # CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1) l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT: if SHOULD_RETURN_ENCODED_SOFTMAX:
philox_offset = (batch_philox_offset + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty))
start_m * BLOCK_M * actual_seqlen_k + start_n -
BLOCK_N)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p,
-p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
# -- update output accumulator -- # -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij) alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None] acc = acc * alpha[:, None]
if not PRE_LOAD_V: if not SHOULD_PRE_LOAD_V:
v = load_fn( v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k,
V_block_ptr, IS_ACTUAL_BLOCK_DMODEL)
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
# -- update m_i and l_i # -- update m_i and l_i
l_i = l_i * alpha + l_ij l_i = l_i * alpha + l_ij
# update m_i and l_i # update m_i and l_i
m_i = m_ij m_i = m_ij
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) if IS_EIGHT_BIT_GEMM:
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) if USE_P_SCALE:
if bias_ptr is not None: p = quant_fp8(p, p_scale).to(QUANT_DTYPE)
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) acc += tl.dot(p, v)
if RETURN_ENCODED_SOFTMAX: else:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, # v is in eight_bit but p is not, we want the gemm in p's type
(0, BLOCK_N)) acc += tl.dot(p, v.to(p.type.element_ty))
else:
if IS_EIGHT_BIT_KV:
v = (v * v_descale).to(p.type.element_ty)
acc += tl.dot(p.to(v.type.element_ty), v)
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
if bias_ptrs is not None:
bias_ptrs += BLOCK_N * stride_bn
if SHOULD_RETURN_ENCODED_SOFTMAX:
encoded_sm_ptrs += BLOCK_N
return acc, l_i, m_i return acc, l_i, m_i
@triton.autotune( def get_cdna_autotune_configs():
configs=[ return [
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 128,
'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 3,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config( triton.Config(
{ {
"BLOCK_M": 256, 'BLOCK_M': 128,
"BLOCK_N": 64, 'BLOCK_N': 64,
"waves_per_eu": 2, 'waves_per_eu': 1,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=8, num_warps=4),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 128, 'BLOCK_M': 128,
"BLOCK_N": 128, 'BLOCK_N': 32,
"waves_per_eu": 2, 'waves_per_eu': 2,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=4),
), ], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def get_rdna_autotune_configs():
return [
triton.Config( triton.Config(
{ {
"BLOCK_M": 256, 'BLOCK_M': 32,
"BLOCK_N": 128, 'BLOCK_N': 32,
"waves_per_eu": 2, 'waves_per_eu': 4,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=8, num_warps=2),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 128, 'BLOCK_M': 32,
"BLOCK_N": 64, 'BLOCK_N': 32,
"waves_per_eu": 1, 'waves_per_eu': 2,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=2),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 128, 'BLOCK_M': 32,
"BLOCK_N": 64, 'BLOCK_N': 16,
"waves_per_eu": 3, 'waves_per_eu': 4,
"PRE_LOAD_V": True, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=2),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 128, 'BLOCK_M': 32,
"BLOCK_N": 64, 'BLOCK_N': 16,
"waves_per_eu": 3, 'waves_per_eu': 2,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=2),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 64, 'BLOCK_M': 16,
"BLOCK_N": 64, 'BLOCK_N': 16,
"waves_per_eu": 4, 'waves_per_eu': 4,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=8, num_warps=2),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 32, 'BLOCK_M': 16,
"BLOCK_N": 32, 'BLOCK_N': 16,
"waves_per_eu": 4, 'waves_per_eu': 2,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=8, num_warps=2),
), # Fall-back config.
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config( triton.Config(
{ {
"BLOCK_M": 16, 'BLOCK_M': 16,
"BLOCK_N": 16, 'BLOCK_N': 16,
"waves_per_eu": 1, 'waves_per_eu': 1,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=2),
), ], [
], 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def get_general_autotune_configs():
return [
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 128,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 64,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 32,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def has_cdna_target():
ROCM_CDNA_TARGETS = ["gfx940", "gfx941", "gfx942", "gfx90a", "gfx908"]
return triton.runtime.driver.active.get_current_target(
).arch in ROCM_CDNA_TARGETS
def is_rocm_cdna():
return current_platform.is_rocm() and has_cdna_target()
def get_autotune_configs():
if is_rocm_cdna():
return get_cdna_autotune_configs()
elif current_platform.is_rocm():
return get_rdna_autotune_configs()
else:
return get_general_autotune_configs()
autotune_configs, autotune_keys = get_autotune_configs()
@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
use_cuda_graph=True,
) )
@triton.jit @triton.jit
def attn_fwd( def attn_fwd(
...@@ -312,38 +682,53 @@ def attn_fwd( ...@@ -312,38 +682,53 @@ def attn_fwd(
K, K,
V, V,
bias, bias,
sm_scale, SM_SCALE: tl.constexpr,
L, L,
Out, Out,
stride_qz, stride_qz: tl.int64,
stride_qh, stride_qh: tl.int64,
stride_qm, stride_qm: tl.int64,
stride_qk, stride_qk: tl.int64,
stride_kz, stride_kz: tl.int64,
stride_kh, stride_kh: tl.int64,
stride_kn, stride_kn: tl.int64,
stride_kk, stride_kk: tl.int64,
stride_vz, stride_vz: tl.int64,
stride_vh, stride_vh: tl.int64,
stride_vk, stride_vk: tl.int64,
stride_vn, stride_vn: tl.int64,
stride_oz, stride_oz: tl.int64,
stride_oh, stride_oh: tl.int64,
stride_om, stride_om: tl.int64,
stride_on, stride_on: tl.int64,
stride_bz, stride_bz: tl.int64,
stride_bh, stride_bh: tl.int64,
stride_bm, stride_bm: tl.int64,
stride_bn, stride_bn: tl.int64,
stride_az: tl.int64,
stride_ah: tl.int64,
q_descale_ptr,
k_descale_ptr,
p_scale_ptr,
p_descale_ptr,
o_descale_ptr,
v_descale_ptr,
q_descale_has_singleton: tl.constexpr,
k_descale_has_singleton: tl.constexpr,
p_descale_has_singleton: tl.constexpr,
v_descale_has_singleton: tl.constexpr,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
dropout_p,
philox_seed, philox_seed,
NUM_CU: tl.constexpr,
GRID_CU_MULTIP: tl.constexpr,
B: tl.constexpr,
philox_offset_base, philox_offset_base,
encoded_softmax, encoded_softmax,
alibi_slopes,
HQ: tl.constexpr, HQ: tl.constexpr,
HK: tl.constexpr, HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr, IS_ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr, VARLEN: tl.constexpr,
...@@ -351,24 +736,39 @@ def attn_fwd( ...@@ -351,24 +736,39 @@ def attn_fwd(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr, SHOULD_PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr, USE_BIAS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr,
IS_EIGHT_BIT: tl.constexpr,
USE_P_SCALE: tl.constexpr,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
): ):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1) if o_descale_ptr is not None:
off_z = tl.program_id(2) o_descale = tl.load(o_descale_ptr)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N) start_m: tl.int64 = tl.program_id(0)
off_h_q: tl.int64 = tl.program_id(1)
off_z: tl.int64 = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64)
offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64)
# as we can't have return statements inside while loop in Triton
continue_condition = True
if VARLEN: if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too # We have a one-size-fits-all grid in id(0). Some seqlens might be
# small for all start_m so for those we return early. # too small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q: if start_m * BLOCK_M > seqlen_q:
return continue_condition = False
# return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
...@@ -378,444 +778,598 @@ def attn_fwd( ...@@ -378,444 +778,598 @@ def attn_fwd(
seqlen_q = MAX_SEQLENS_Q seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking. if continue_condition:
# This is because for seqlen_q > seqlen_k, M rows of the attn scores # Now we compute whether we need to exit early due to causal
# are completely masked, resulting in 0s written to the output, and # masking. This is because for seqlen_q > seqlen_k, M rows of the
# inf written to LSE. We don't need to do any GEMMs in this case. # attn scores are completely masked, resulting in 0s written to the
# This block of code determines what N is, and if this WG is operating # output, and inf written to LSE. We don't need to do any GEMMs in
# on those M rows. # this case. This block of code determines what N is, and if this
n_blocks = cdiv_fn(seqlen_k, BLOCK_N) # WG is operating on those M rows.
if IS_CAUSAL: n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
# If seqlen_q == seqlen_k, the attn scores are a square matrix. if (IS_CAUSAL):
# If seqlen_q != seqlen_k, attn scores are rectangular which means # If seqlen_q == seqlen_k, the attn scores are a square matrix.
# the causal mask boundary is bottom right aligned, and ends at either # If seqlen_q != seqlen_k, attn scores are rectangular which
# the top edge (seqlen_q < seqlen_k) or left edge. # means the causal mask boundary is bottom right aligned, and
# This captures the decrease in n_blocks if we have a rectangular attn # ends at either the top edge (seqlen_q < seqlen_k) or left
# matrix # edge. This captures the decrease in n_blocks if we have a
n_blocks_seqlen = cdiv_fn( # rectangular attn matrix
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) n_blocks_seqlen = cdiv_fn(
# This is what adjusts the block_max for the current WG, only (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks # This is what adjusts the block_max for the current WG, only
n_blocks = min(n_blocks, n_blocks_seqlen) # if IS_CAUSAL. Otherwise we want to always iterate through all
# If we have no blocks after adjusting for seqlen deltas, this WG is # n_blocks
# part of the blocks that are all 0. We exit early. n_blocks = min(n_blocks, n_blocks_seqlen)
if n_blocks <= 0: # If we have no blocks after adjusting for seqlen deltas, this
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + # WG is part of the blocks that are all 0. We exit early.
off_h_q * stride_oh) if n_blocks <= 0:
O_block_ptr = tl.make_block_ptr( o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh +
base=Out + o_offset, cu_seqlens_q_start * stride_om)
shape=(seqlen_q, BLOCK_DMODEL), o_ptrs = (o_offset + offs_m[:, None] * stride_om +
strides=(stride_om, stride_on), offs_d[None, :] * stride_on)
offsets=(start_m * BLOCK_M, 0), acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
block_shape=(BLOCK_M, BLOCK_DMODEL), o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to(
order=(1, 0), [BLOCK_M, BLOCK_DMODEL])
) # We still need to write 0s to the result
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) tl.store(o_ptrs, acc, mask=o_ptrs_mask)
# We still need to write 0s to the result # The tensor allocated for L is based on MAX_SEQLENS_Q as
# tl.store(O_block_ptr, # that is statically known.
# acc.to(Out.type.element_ty), boundary_check=(0,1)) l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q +
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q off_h_q * MAX_SEQLENS_Q + offs_m)
# + offs_m # We store inf to LSE, not -inf because in the bwd pass,
# We store inf to LSE, not -inf because in the bwd pass, # we subtract this from qk which makes it -inf, such that
# we subtract this # exp(qk - inf) = 0 for these masked blocks.
# from qk which makes it -inf, such that exp(qk - inf) = 0 l_value = tl.full([BLOCK_M],
# for these masked blocks. value=float("inf"),
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) dtype=tl.float32)
# tl.store(l_ptrs, l) l_ptrs_mask = offs_m < MAX_SEQLENS_Q
# TODO: Should dropout and return encoded softmax be handled here? tl.store(l_ptrs, l_value, mask=l_ptrs_mask)
return # TODO: Should dropout and return encoded softmax be
# handled here too?
# If MQA / GQA, set the K and V head offsets appropriately. continue_condition = False
GROUP_SIZE: tl.constexpr = HQ // HK # return
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
if continue_condition:
n_extra_tokens = 0 # If MQA / GQA, set the K and V head offsets appropriately.
if seqlen_k < BLOCK_N: GROUP_SIZE: tl.constexpr = HQ // HK
n_extra_tokens = BLOCK_N - seqlen_k off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
elif seqlen_k % BLOCK_N: n_extra_tokens = 0
n_extra_tokens = seqlen_k % BLOCK_N if seqlen_k < BLOCK_N:
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
# Compute pointers for all the tensors used in this kernel. n_extra_tokens = seqlen_k % BLOCK_N
q_offset = (off_z * stride_qz + off_h_q * stride_qh + USE_PADDED_HEAD: tl.constexpr = (IS_ACTUAL_BLOCK_DMODEL
cu_seqlens_q_start * stride_qm) != BLOCK_DMODEL)
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset, # Compute pointers for all the tensors used in this kernel.
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh +
strides=(stride_qm, stride_qk), cu_seqlens_q_start * stride_qm)
offsets=(start_m * BLOCK_M, 0), q_ptrs = (q_offset + offs_m[:, None] * stride_qm +
block_shape=(BLOCK_M, BLOCK_DMODEL), offs_d[None, :] * stride_qk)
order=(1, 0), k_offset = (K + off_z * stride_kz + off_h_k * stride_kh +
) cu_seqlens_k_start * stride_kn)
k_offset = (off_z * stride_kz + off_h_k * stride_kh + k_ptrs = (k_offset + offs_d[:, None] * stride_kk +
cu_seqlens_k_start * stride_kn) offs_n[None, :] * stride_kn)
K_block_ptr = tl.make_block_ptr( v_offset = (V + off_z * stride_vz + off_h_k * stride_vh +
base=K + k_offset, cu_seqlens_k_start * stride_vk)
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), v_ptrs = (v_offset + offs_n[:, None] * stride_vk +
strides=(stride_kk, stride_kn), offs_d[None, :] * stride_vn)
offsets=(0, 0), # Compute pointers for all scale tensors used in this kernel.
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1), IS_EIGHT_BIT_GEMM: tl.constexpr = IS_EIGHT_BIT & (
) not IS_EIGHT_BIT_KV)
v_offset = (off_z * stride_vz + off_h_k * stride_vh + if IS_EIGHT_BIT:
cu_seqlens_k_start * stride_vk) if k_descale_has_singleton:
V_block_ptr = tl.make_block_ptr( k_descale_ptrs = k_descale_ptr
base=V + v_offset, else:
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), k_descale_ptrs = k_descale_ptr + off_h_k
strides=(stride_vk, stride_vn),
offsets=(0, 0), if v_descale_has_singleton:
block_shape=(BLOCK_N, BLOCK_DMODEL), v_descale_ptrs = v_descale_ptr
order=(1, 0), else:
) v_descale_ptrs = v_descale_ptr + off_h_k
if BIAS_TYPE != 0:
bias_ptr = tl.make_block_ptr( if not IS_EIGHT_BIT_KV:
base=bias + off_h_q * stride_bh, if q_descale_has_singleton:
shape=(seqlen_q, seqlen_k), q_descale_ptrs = q_descale_ptr
strides=(stride_bm, stride_bn), else:
offsets=(start_m * BLOCK_M, 0), q_descale_ptrs = q_descale_ptr + off_h_q
block_shape=(BLOCK_M, BLOCK_N), if USE_P_SCALE:
order=(1, 0), if p_descale_has_singleton:
) p_scale_ptrs = p_scale_ptr
else: p_descale_ptrs = p_descale_ptr
bias_ptr = None else:
if ENABLE_DROPOUT: p_scale_ptrs = p_scale_ptr + off_h_q
batch_philox_offset = philox_offset_base \ p_descale_ptrs = p_descale_ptr + off_h_q
+ (off_z * HQ + off_h_q) \
* seqlen_q * seqlen_k if USE_BIAS:
else: bias_offset = off_h_q * stride_bh
batch_philox_offset = 0 bias_ptrs = (bias + bias_offset + offs_m[:, None] * stride_bm +
# We can ask to return the dropout mask without actually doing any dropout. offs_n[None, :] * stride_bn)
# In this case, we return an invalid pointer so indicate the mask is not i else:
# valid. bias_ptrs = None
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if RETURN_ENCODED_SOFTMAX: if USE_ALIBI:
encoded_softmax_block_ptr = tl.make_block_ptr( a_offset = off_z * stride_az + off_h_q * stride_ah
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, alibi_slope = tl.load(alibi_slopes + a_offset)
shape=(seqlen_q, seqlen_k), else:
strides=(seqlen_k, 1), alibi_slope = None
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N), batch_philox_offset = 0
order=(1, 0), # We can ask to return the dropout mask without doing any
) # dropout. In this case, we return an invalid pointer so
else: # indicate the mask is not valid.
encoded_softmax_block_ptr = 0 if SHOULD_RETURN_ENCODED_SOFTMAX:
# initialize pointer to m and l encoded_sm_base = (encoded_softmax +
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) off_h_q * seqlen_q * seqlen_k)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) encoded_sm_ptrs = (encoded_sm_base +
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) offs_m[:, None] * seqlen_k +
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not offs_n[None, :])
# have native e^x support in HW. else:
qk_scale = sm_scale * 1.44269504089 encoded_sm_ptrs = None
# Q is loaded once at the beginning and shared by all N blocks. # initialize pointer to m and l
q = load_fn(Q_block_ptr, True, padded_head, "zero") m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
q = (q * qk_scale).to(Q_block_ptr.type.element_ty) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# Here we compute how many full and masked blocks we have. # scale sm_scale by log_2(e) and use 2^x in the loop as we do
padded_block_k = n_extra_tokens != 0 # not have native e^x support in HW.
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089
if IS_CAUSAL: # Q is loaded once at the beginning and shared by all N blocks.
# There are always at least BLOCK_M // BLOCK_N masked blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q
# Additionally there might be one more due to dissimilar seqlens. if USE_PADDED_HEAD:
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) q_ptrs_mask = q_ptrs_mask & (offs_d[None, :]
else: < IS_ACTUAL_BLOCK_DMODEL)
# Padding on Q does not need to be masked in the FA loop. q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional if IS_EIGHT_BIT:
# block. In this case we might exceed n_blocks so pick the min. k_descale = tl.load(k_descale_ptrs)
masked_blocks = min(masked_blocks, n_blocks) v_descale = tl.load(v_descale_ptrs)
n_full_blocks = n_blocks - masked_blocks q_descale = None if IS_EIGHT_BIT_KV else tl.load(
block_min = 0 q_descale_ptrs)
block_max = n_blocks * BLOCK_N if USE_P_SCALE:
# Compute for full blocks. Here we set causal to false regardless of its p_scale = tl.load(p_scale_ptrs)
# value because there is no masking. Similarly we do not need padding. p_descale = tl.load(p_descale_ptrs)
if n_full_blocks > 0: else:
block_max = (n_blocks - masked_blocks) * BLOCK_N p_scale = None
acc, l_i, m_i = _attn_fwd_inner( p_descale = None
acc, else:
l_i, q_descale = None
m_i, k_descale = None
q, v_descale = None
K_block_ptr, p_scale = None
V_block_ptr, p_descale = None
start_m, # Here we compute how many full and masked blocks we have.
seqlen_k, padded_block_k = n_extra_tokens != 0
dropout_p, is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
philox_seed, if IS_CAUSAL:
batch_philox_offset, # There are always at least BLOCK_M // BLOCK_N masked
encoded_softmax_block_ptr, # blocks. Additionally there might be one more due to
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ # dissimilar seqlens.
block_min, masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
block_max, else:
0, # Padding on Q does not need to be masked in the FA loop.
0, masked_blocks = padded_block_k
0, # if IS_CAUSAL, not is_modulo_mn does not always result in an
bias_ptr, # additional block. In this case we might exceed n_blocks so
# IS_CAUSAL, .... # pick the min.
False, masked_blocks = min(masked_blocks, n_blocks)
BLOCK_M, n_full_blocks = n_blocks - masked_blocks
BLOCK_DMODEL, block_min = 0
BLOCK_N, block_max = n_blocks * BLOCK_N
offs_m, # Compute for full blocks. Here we set causal to false
offs_n, # regardless of its actual value because there is no masking.
# _, MASK_STEPS, ... # Similarly we do not need padding.
PRE_LOAD_V, if n_full_blocks > 0:
False, block_max = (n_blocks - masked_blocks) * BLOCK_N
ENABLE_DROPOUT, acc, l_i, m_i = _attn_fwd_inner(
RETURN_ENCODED_SOFTMAX, acc,
padded_head, l_i,
) m_i,
block_min = block_max q,
block_max = n_blocks * BLOCK_N k_ptrs,
v_ptrs,
tl.debug_barrier() bias_ptrs,
# Remaining blocks, if any, are full / not masked. stride_kn,
if masked_blocks > 0: stride_vk,
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 stride_bn,
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) start_m,
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) seqlen_k,
if bias_ptr is not None: seqlen_q,
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) philox_seed,
if RETURN_ENCODED_SOFTMAX: batch_philox_offset,
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, encoded_sm_ptrs,
(0, n_full_blocks)) # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
acc, l_i, m_i = _attn_fwd_inner( block_min,
acc, block_max,
l_i, 0,
m_i, 0,
q, 0,
K_block_ptr, alibi_slope,
V_block_ptr, q_descale,
start_m, k_descale,
seqlen_k, v_descale,
dropout_p, p_scale,
philox_seed, # IS_CAUSAL, ....
batch_philox_offset, False,
encoded_softmax_block_ptr, BLOCK_M,
block_min, BLOCK_DMODEL,
block_max, BLOCK_N,
offs_n_causal, offs_m,
masked_blocks, offs_n,
n_extra_tokens, # _, SHOULD_MASK_STEPS, ...
bias_ptr, SHOULD_PRE_LOAD_V,
IS_CAUSAL, False,
BLOCK_M, SHOULD_RETURN_ENCODED_SOFTMAX,
BLOCK_DMODEL, USE_PADDED_HEAD,
BLOCK_N, IS_ACTUAL_BLOCK_DMODEL,
offs_m, QK_SCALE,
offs_n, IS_EIGHT_BIT_GEMM,
# _, MASK_STEPS, ... USE_P_SCALE,
PRE_LOAD_V, IS_EIGHT_BIT_KV,
True, QUANT_DTYPE)
ENABLE_DROPOUT, block_min = block_max
RETURN_ENCODED_SOFTMAX, block_max = n_blocks * BLOCK_N
padded_head,
) tl.debug_barrier()
# epilogue # Remaining blocks, if any, are full / not masked.
acc = acc / l_i[:, None] if (masked_blocks > 0):
if ENABLE_DROPOUT: if IS_CAUSAL:
acc = acc / (1 - dropout_p) offs_n_causal = offs_n + (seqlen_q - seqlen_k)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, else:
# then we have one block with a row of all NaNs which come from computing offs_n_causal = 0
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here k_ptrs += n_full_blocks * BLOCK_N * stride_kn
# and store 0s where there are NaNs as these rows should've been zeroed out. v_ptrs += n_full_blocks * BLOCK_N * stride_vk
end_m_idx = (start_m + 1) * BLOCK_M if USE_BIAS:
start_m_idx = start_m * BLOCK_M bias_ptrs += n_full_blocks * BLOCK_N * stride_bn
causal_start_idx = seqlen_q - seqlen_k if SHOULD_RETURN_ENCODED_SOFTMAX:
acc = acc.to(Out.type.element_ty) encoded_sm_ptrs += n_full_blocks * BLOCK_N
if IS_CAUSAL: # noqa: SIM102 acc, l_i, m_i = _attn_fwd_inner(
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: acc,
out_mask_boundary = tl.full((BLOCK_DMODEL, ), l_i,
causal_start_idx, m_i,
dtype=tl.int32) q,
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) k_ptrs,
out_ptrs_mask = (mask_m_offsets[:, None] v_ptrs,
>= out_mask_boundary[None, :]) bias_ptrs,
z = 0.0 stride_kn,
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) stride_vk,
# write back LSE stride_bn,
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m start_m,
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last seqlen_k,
# few rows. This is only true for the last M block. For others, seqlen_q,
# overflow_size will be -ve philox_seed,
# overflow_size = end_m_idx - seqlen_q batch_philox_offset,
# if overflow_size > 0: encoded_sm_ptrs,
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) block_min,
# # This is a > check because mask being 0 blocks the store. block_max,
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) offs_n_causal,
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) masked_blocks,
# else: n_extra_tokens,
# tl.store(l_ptrs, m_i + tl.math.log2(l_i)) alibi_slope,
q_descale,
# write back O k_descale,
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + v_descale,
off_h_q * stride_oh) p_scale,
O_block_ptr = tl.make_block_ptr( IS_CAUSAL,
base=Out + o_offset, BLOCK_M,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), BLOCK_DMODEL,
strides=(stride_om, stride_on), BLOCK_N,
offsets=(start_m * BLOCK_M, 0), offs_m,
block_shape=(BLOCK_M, BLOCK_DMODEL), offs_n,
order=(1, 0), # _, SHOULD_MASK_STEPS, ...
) SHOULD_PRE_LOAD_V,
# Need boundary check on this to make sure the padding from the True,
# Q and KV tensors in both dims are not part of what we store back. SHOULD_RETURN_ENCODED_SOFTMAX,
# TODO: Do the boundary check optionally. USE_PADDED_HEAD,
tl.store(O_block_ptr, acc, boundary_check=(0, 1)) IS_ACTUAL_BLOCK_DMODEL,
QK_SCALE,
IS_EIGHT_BIT_GEMM,
def check_args( USE_P_SCALE,
q, IS_EIGHT_BIT_KV,
k, QUANT_DTYPE)
v,
o, if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV:
varlen=True, if USE_P_SCALE:
max_seqlens=None, acc *= p_descale
cu_seqlens_q=None, acc *= v_descale
cu_seqlens_k=None,
): # epilogue
assert q.dim() == k.dim() and q.dim() == v.dim() # This helps the compiler do Newton Raphson on l_i vs on acc
if varlen: # which is much larger.
assert q.dim() == 3 l_recip = 1 / l_i[:, None]
total_q, nheads_q, head_size = q.shape acc = acc * l_recip
total_k, nheads_k, _ = k.shape
assert cu_seqlens_q is not None # If seqlen_q > seqlen_k but the delta is not a multiple of
assert cu_seqlens_k is not None # BLOCK_M, then we have one block with a row of all NaNs which
assert len(cu_seqlens_q) == len(cu_seqlens_k) # come from computing softmax over a row of all
else: # -infs (-inf - inf = NaN). We check for that here and store 0s
assert q.dim() == 4 # where there are NaNs as these rows should've been zeroed out.
batch, nheads_q, seqlen_q, head_size = q.shape end_m_idx = (start_m + 1) * BLOCK_M
_, nheads_k, seqlen_k, _ = k.shape start_m_idx = start_m * BLOCK_M
assert max_seqlens > 0 causal_start_idx = seqlen_q - seqlen_k
assert k.shape == v.shape if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: # noqa: SIM102
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] if o_descale_ptr is not None:
# TODO: Change assert if we support qkl f8 and v f16 acc = quant_fp8(acc, o_descale)
assert q.dtype == k.dtype and q.dtype == v.dtype
assert head_size <= 256 acc = acc.to(Out.type.element_ty)
assert o.shape == q.shape if IS_CAUSAL: # noqa: SIM102
assert (nheads_q % nheads_k) == 0 if (causal_start_idx > start_m_idx
and causal_start_idx < end_m_idx):
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
causal_start_idx,
dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None]
>= out_mask_boundary[None, :])
z = tl.zeros((1, ), tl.float32)
acc = tl.where(out_ptrs_mask, acc,
z.to(acc.type.element_ty))
# write back LSE
l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q +
off_h_q * MAX_SEQLENS_Q + offs_m)
# If seqlen_q not multiple of BLOCK_M, we need to mask out the
# last few rows. This is only true for the last M block.
# For others, overflow_size will be -ve
overflow_size = end_m_idx - seqlen_q
if overflow_size > 0:
boundary = tl.full((BLOCK_M, ),
BLOCK_M - overflow_size,
dtype=tl.int32)
l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary
tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
else:
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh +
cu_seqlens_q_start * stride_om)
o_ptrs = (o_offset + offs_m[:, None] * stride_om +
offs_d[None, :] * stride_on)
o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1)
if overflow_size > 0:
o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q)
if USE_PADDED_HEAD:
o_ptrs_mask = o_ptrs_mask & (offs_d[None, :]
< IS_ACTUAL_BLOCK_DMODEL)
tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)
def get_shape_from_layout(q, k, metadata):
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
if metadata.layout == 'thd':
nheads_q, nheads_k = q.shape[1], k.shape[1]
head_size = q.shape[-1]
batch = metadata.num_contexts
elif metadata.layout == 'bhsd':
batch, nheads_q, _, head_size = q.shape
nheads_k = k.shape[1]
elif metadata.layout == 'bshd':
batch, _, nheads_q, head_size = q.shape
nheads_k = k.shape[2]
return batch, nheads_q, nheads_k, head_size
def get_strides_from_layout(q, k, v, o, metadata):
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
STRIDE_PERMUTATIONS = {
'thd': (None, 1, 0, 2),
'bhsd': (0, 1, 2, 3),
'bshd': (0, 2, 1, 3),
}
perm = STRIDE_PERMUTATIONS[metadata.layout]
stride = lambda x, p: (0 if p is None else x.stride(p))
strides = lambda x: (stride(x, p) for p in perm)
return tuple(strides(x) for x in [q, k, v, o])
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(ctx, q, k, v, o, metadata: MetaData):
ctx, # NOTE: a large bias tensor leads to overflow during pointer arithmetic
q, if (metadata.bias is not None):
k, assert (metadata.bias.numel() < 2**31)
v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
):
if o is None: if o is None:
o = torch.empty_like(q, dtype=v.dtype) if metadata.eight_bit:
o = torch.empty_like(
q,
dtype=metadata.output_dtype if metadata.output_dtype
is not None else metadata.eight_bit_dtype_torch)
else:
o = torch.empty_like(q, dtype=q.dtype)
check_args( metadata.check_args(q, k, v, o)
q,
k, batch, nheads_q, nheads_k, head_size = get_shape_from_layout(
v, q, k, metadata)
o, q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(
varlen=True, q, k, v, o, metadata)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32. # Get closest power of 2 over or equal to 32.
unpadded_head_dims = {32, 64, 128, 256} padded_d_model = 1 << (head_size - 1).bit_length()
if head_size not in unpadded_head_dims: # Smallest head_dim supported is 16. If smaller, the tile in the
padded_d_model = None # kernel is padded - there is no padding in memory for any dims.
for i in unpadded_head_dims: padded_d_model = max(padded_d_model, 16)
if i > head_size:
padded_d_model = i
break
assert padded_d_model is not None
else:
padded_d_model = head_size
grid = lambda META: ( # encoded_softmax is used to validate dropout behavior vs the
triton.cdiv(max_seqlens_q, META["BLOCK_M"]), # PyTorch SDPA math backend reference. We zero this out to give a
nheads_q, # consistent starting point and then populate it with the output of
batch, # softmax with the sign bit set according to the dropout mask.
) # The resulting return allows this mask to be fed into the reference
# implementation for testing only. This return holds no useful output
# aside from debugging.
if metadata.return_encoded_softmax:
encoded_softmax = torch.zeros(
(q.shape[0], q.shape[1], q.shape[2], k.shape[2]),
device=q.device,
dtype=torch.float32)
else:
encoded_softmax = None
encoded_softmax = None M = torch.empty((batch, nheads_q, metadata.max_seqlens_q),
device=q.device,
dtype=torch.float32)
# Seed the RNG so we get reproducible results for testing. # Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52 philox_seed = 0x1BF52
philox_offset = 0x1D4B42 philox_offset = 0x1D4B42
if bias is not None: if metadata.bias is not None:
bias_strides = ( bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1),
bias.stride(0), metadata.bias.stride(2), metadata.bias.stride(3))
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
else: else:
bias_strides = (0, 0, 0, 0) bias_strides = (0, 0, 0, 0)
if metadata.alibi_slopes is not None:
alibi_strides = (metadata.alibi_slopes.stride(0),
metadata.alibi_slopes.stride(1))
else:
alibi_strides = (0, 0)
if metadata.eight_bit:
q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = (
metadata.q_descale, metadata.k_descale, metadata.p_scale,
metadata.p_descale, metadata.v_descale, metadata.o_scale)
o_descale = 1.0 / o_scale if o_scale is not None else None
else:
q_descale = k_descale = p_scale = None
p_descale = v_descale = o_descale = None
# number of compute units available
NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count
grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META[
'BLOCK_M']), nheads_q, batch)
attn_fwd[grid]( attn_fwd[grid](
q, q,
k, k,
v, v,
bias, metadata.bias,
sm_scale, metadata.sm_scale,
None, M,
o, o,
*q_strides, *q_strides,
*k_strides, *k_strides,
*v_strides, *v_strides,
*o_strides, *o_strides,
*bias_strides, *bias_strides,
cu_seqlens_q, *alibi_strides,
cu_seqlens_k, q_descale,
dropout_p=0.0, k_descale,
p_scale,
p_descale,
o_descale,
v_descale,
q_descale.numel() == 1 if q_descale is not None else False,
k_descale.numel() == 1 if k_descale is not None else False,
p_descale.numel() == 1 if p_descale is not None else False,
v_descale.numel() == 1 if v_descale is not None else False,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
philox_seed=philox_seed, philox_seed=philox_seed,
philox_offset_base=philox_offset, philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax, encoded_softmax=encoded_softmax,
alibi_slopes=metadata.alibi_slopes,
HQ=nheads_q, HQ=nheads_q,
HK=nheads_k, HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size, IS_ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_Q=metadata.max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, MAX_SEQLENS_K=metadata.max_seqlens_k,
IS_CAUSAL=causal, IS_CAUSAL=metadata.causal,
VARLEN=True, VARLEN=metadata.varlen,
BLOCK_DMODEL=padded_d_model, BLOCK_DMODEL=padded_d_model,
BIAS_TYPE=0 if bias is None else 1, USE_BIAS=metadata.bias is not None,
ENABLE_DROPOUT=False, USE_ALIBI=metadata.alibi_slopes is not None,
RETURN_ENCODED_SOFTMAX=False, SHOULD_RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax,
) IS_EIGHT_BIT=metadata.eight_bit,
USE_P_SCALE=metadata.eight_bit and metadata.use_p_scale,
IS_EIGHT_BIT_KV=metadata.eight_bit and metadata.eight_bit_kv,
NUM_CU=NUM_CU,
B=batch,
QUANT_DTYPE=metadata.eight_bit_dtype_triton)
ctx.grid = grid ctx.grid = grid
ctx.sm_scale = sm_scale ctx.sm_scale = metadata.sm_scale
ctx.BLOCK_DMODEL = head_size ctx.BLOCK_DMODEL = head_size
ctx.causal = causal ctx.causal = metadata.causal
ctx.dropout_p = 0.0 ctx.alibi_slopes = metadata.alibi_slopes
ctx.philox_seed = philox_seed ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = False ctx.return_encoded_softmax = metadata.return_encoded_softmax
return o, encoded_softmax return o, encoded_softmax
triton_attention = _attention.apply triton_attention_rocm = _attention.apply
def scale_fp8(t, scale=None):
t_scaled, scale_out = ops.scaled_fp8_quant(t.reshape(-1, t.shape[-1]),
scale)
return t_scaled.reshape(t.shape), scale_out
def maybe_quantize_fp8(t, scale):
eight_bit_dtype = current_platform.fp8_dtype()
if t.dtype != eight_bit_dtype:
t, _ = scale_fp8(t, scale)
return t
def check_and_maybe_quantize_qkv(q, k, v, fp8_scales):
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
q = maybe_quantize_fp8(q, q_scale)
k = maybe_quantize_fp8(k, k_scale)
v = maybe_quantize_fp8(v, v_scale)
return q, k, v
# query - [num_tokens, num_heads, head_size]
# key - [num_tokens, num_kv_heads, head_size]
# value - [num_tokens, num_kv_heads, head_size
# output - [num_tokens, num_heads, head_size]
def triton_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlens_q: int,
max_seqlens_k: int,
causal: bool = False,
sm_scale: float = 1.0,
bias: Optional[torch.Tensor] = None,
fp8_scales: Optional[tuple[float, ...]] = None,
input_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if fp8_scales is not None:
q_descale, k_descale, v_descale, p_scale = fp8_scales
else:
q_descale = k_descale = v_descale = p_scale = None
attn_metadata = MetaData(sm_scale=sm_scale,
max_seqlens_q=max_seqlens_q,
max_seqlens_k=max_seqlens_k,
causal=causal,
bias=bias,
output_dtype=q.dtype,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
p_scale=p_scale,
o_scale=input_scale)
if fp8_scales is not None:
q, k, v = check_and_maybe_quantize_qkv(q, k, v, fp8_scales)
return triton_attention_rocm(q, k, v, o, attn_metadata)
...@@ -66,7 +66,10 @@ def merge_attn_states_kernel( ...@@ -66,7 +66,10 @@ def merge_attn_states_kernel(
max_lse = tl.maximum(p_lse, s_lse) max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse p_lse = p_lse - max_lse
s_lse = s_lse - max_lse s_lse = s_lse - max_lse
out_se = (tl.exp(p_lse) + tl.exp(s_lse)) # Will reuse precomputed Exp values for scale factor computation.
p_se = tl.exp(p_lse)
s_se = tl.exp(s_lse)
out_se = (p_se + s_se)
if OUTPUT_LSE: if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse out_lse = tl.log(out_se) + max_lse
...@@ -84,8 +87,8 @@ def merge_attn_states_kernel( ...@@ -84,8 +87,8 @@ def merge_attn_states_kernel(
# NOTE(woosuk): Be careful with the numerical stability. # NOTE(woosuk): Be careful with the numerical stability.
# We should compute the scale first, and then multiply it with the output. # We should compute the scale first, and then multiply it with the output.
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
p_scale = tl.exp(p_lse) / out_se p_scale = p_se / out_se
s_scale = tl.exp(s_lse) / out_se s_scale = s_se / out_se
out = p_out * p_scale + s_out * s_scale out = p_out * p_scale + s_out * s_scale
tl.store(output + token_idx * num_heads * HEAD_SIZE + tl.store(output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange, head_idx * HEAD_SIZE + head_arange,
......
...@@ -38,9 +38,18 @@ class BeamSearchOutput: ...@@ -38,9 +38,18 @@ class BeamSearchOutput:
class BeamSearchInstance: class BeamSearchInstance:
def __init__(self, prompt_tokens: list[int]): def __init__(
self,
prompt_tokens: list[int],
logprobs: Optional[list[dict[int, Logprob]]] = None,
**kwargs,
):
self.beams: list[BeamSearchSequence] = [ self.beams: list[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens, logprobs=[]) BeamSearchSequence(
tokens=prompt_tokens,
logprobs=[] if logprobs is None else list(logprobs),
**kwargs,
)
] ]
self.completed: list[BeamSearchSequence] = [] self.completed: list[BeamSearchSequence] = []
......
# SPDX-License-Identifier: Apache-2.0
"""
This module defines a framework for sampling benchmark requests from various
datasets. Each dataset subclass of BenchmarkDataset must implement sample
generation. Supported dataset types include:
- ShareGPT
- Random (synthetic)
- Sonnet
- BurstGPT
- HuggingFace
- VisionArena
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
SampleRequest instances, similar to the approach used in ShareGPT.
"""
import base64
import io
import json
import logging
import random
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from functools import cache
from io import BytesIO
from typing import Any, Callable, Optional, Union
import numpy as np
from PIL import Image
from transformers import PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Data Classes
# -----------------------------------------------------------------------------
@dataclass
class SampleRequest:
"""
Represents a single inference request for benchmarking.
"""
prompt: Union[str, Any]
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
lora_request: Optional[LoRARequest] = None
# -----------------------------------------------------------------------------
# Benchmark Dataset Base Class
# -----------------------------------------------------------------------------
class BenchmarkDataset(ABC):
DEFAULT_SEED = 0
def __init__(
self,
dataset_path: Optional[str] = None,
random_seed: int = DEFAULT_SEED,
) -> None:
"""
Initialize the BenchmarkDataset with an optional dataset path and random
seed.
Args:
dataset_path (Optional[str]): Path to the dataset. If None, it
indicates that a default or random dataset might be used.
random_seed (int): Seed value for reproducible shuffling or
sampling. Defaults to DEFAULT_SEED.
"""
self.dataset_path = dataset_path
# Set the random seed, ensuring that a None value is replaced with the
# default seed.
self.random_seed = (random_seed
if random_seed is not None else self.DEFAULT_SEED)
self.data = None
def apply_multimodal_chat_transformation(
self,
prompt: str,
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
"""
Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation
format.
"""
content = [{"text": prompt, "type": "text"}]
if mm_content is not None:
content.append(mm_content)
return [{"role": "user", "content": content}]
def load_data(self) -> None:
"""
Load data from the dataset path into self.data.
This method must be overridden by subclasses since the method to load
data will vary depending on the dataset format and source.
Raises:
NotImplementedError: If a subclass does not implement this method.
"""
# TODO (jenniferzhao): add support for downloading data
raise NotImplementedError(
"load_data must be implemented in subclasses.")
def get_random_lora_request(
self,
tokenizer: PreTrainedTokenizerBase,
max_loras: Optional[int] = None,
lora_path: Optional[str] = None,
) -> tuple[Optional[LoRARequest], AnyTokenizer]:
"""
Optionally select a random LoRA request and return its associated
tokenizer.
This method is used when LoRA parameters are provided. It randomly
selects a LoRA based on max_loras and retrieves a cached tokenizer for
that LoRA if available. Otherwise, it returns the base tokenizer.
Args:
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
LoRA is selected. max_loras (Optional[int]): The maximum number of
LoRAs available. If None, LoRA is not used. lora_path
(Optional[str]): Path to the LoRA parameters on disk. If None, LoRA
is not used.
Returns:
tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first
element is a LoRARequest (or None if not applicable) and the second
element is the tokenizer associated with the LoRA request (or the
base tokenizer).
"""
if max_loras is None or lora_path is None:
return None, tokenizer
# Generate a random LoRA ID in the range [1, max_loras].
lora_id = random.randint(1, max_loras)
lora_request = LoRARequest(
lora_name=str(lora_id),
lora_int_id=lora_id,
lora_path=lora_path_on_disk(lora_path),
)
if lora_id not in lora_tokenizer_cache:
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
# Return lora_request and the cached tokenizer if available; otherwise,
# return the base tokenizer
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
@abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase,
num_requests: int) -> list[SampleRequest]:
"""
Abstract method to generate sample requests from the dataset.
Subclasses must override this method to implement dataset-specific logic
for generating a list of SampleRequest objects.
Args:
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
Returns:
list[SampleRequest]: A list of sample requests generated from the
dataset.
"""
raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest],
num_requests: int) -> None:
"""
Oversamples the list of requests if its size is less than the desired
number.
Args:
requests (List[SampleRequest]): The current list of sampled
requests. num_requests (int): The target number of requests.
"""
if len(requests) < num_requests:
random.seed(self.random_seed)
additional = random.choices(requests,
k=num_requests - len(requests))
requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.",
num_requests)
# -----------------------------------------------------------------------------
# Utility Functions and Global Caches
# -----------------------------------------------------------------------------
def is_valid_sequence(
prompt_len: int,
output_len: int,
min_len: int = 4,
max_prompt_len: int = 1024,
max_total_len: int = 2048,
skip_min_output_len_check: bool = False,
) -> bool:
"""
Validate a sequence based on prompt and output lengths.
Default pruning criteria are copied from the original `sample_hf_requests`
and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as
from `sample_requests` in benchmark_throughput.py.
"""
# Check for invalid conditions
prompt_too_short = prompt_len < min_len
output_too_short = (not skip_min_output_len_check) and (output_len
< min_len)
prompt_too_long = prompt_len > max_prompt_len
combined_too_long = (prompt_len + output_len) > max_total_len
# Return True if none of the invalid conditions are met
return not (prompt_too_short or output_too_short or prompt_too_long
or combined_too_long)
@cache
def lora_path_on_disk(lora_path: str) -> str:
return get_adapter_absolute_path(lora_path)
# Global cache for LoRA tokenizers.
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
def process_image(image: Any) -> Mapping[str, Any]:
"""
Process a single image input and return a multimedia content dictionary.
Supports three input types:
1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
containing raw image data. - Loads the bytes as a PIL.Image.Image.
2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as
a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns
a dictionary with the image as a base64 data URL.
3. String input: - Treats the string as a URL or local file path. -
Prepends "file://" if the string doesn't start with "http://" or
"file://". - Returns a dictionary with the image URL.
Raises:
ValueError: If the input is not a supported type.
"""
if isinstance(image, dict) and 'bytes' in image:
image = Image.open(BytesIO(image['bytes']))
if isinstance(image, Image.Image):
image = image.convert("RGB")
with io.BytesIO() as image_data:
image.save(image_data, format="JPEG")
image_base64 = base64.b64encode(
image_data.getvalue()).decode("utf-8")
return {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
}
if isinstance(image, str):
image_url = (image if image.startswith(
("http://", "file://")) else f"file://{image}")
return {"type": "image_url", "image_url": {"url": image_url}}
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
" or str or dictionary with raw image bytes.")
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------
class RandomDataset(BenchmarkDataset):
# Default values copied from benchmark_serving.py for the random dataset.
DEFAULT_PREFIX_LEN = 0
DEFAULT_RANGE_RATIO = 0.0
DEFAULT_INPUT_LEN = 1024
DEFAULT_OUTPUT_LEN = 128
def __init__(
self,
**kwargs,
) -> None:
super().__init__(**kwargs)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
prefix_len: int = DEFAULT_PREFIX_LEN,
range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
**kwargs,
) -> list[SampleRequest]:
# Enforce range_ratio < 1
assert range_ratio < 1.0, (
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
)
vocab_size = tokenizer.vocab_size
prefix_token_ids = (np.random.randint(
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
# New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(input_len * (1 - range_ratio))
input_high = int(input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio))
output_high = int(output_len * (1 + range_ratio))
# Add logging for debugging
logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
logger.info("Sampling output_len from [%s, %s]", output_low,
output_high)
input_lens = np.random.randint(input_low,
input_high + 1,
size=num_requests)
output_lens = np.random.randint(output_low,
output_high + 1,
size=num_requests)
offsets = np.random.randint(0, vocab_size, size=num_requests)
requests = []
for i in range(num_requests):
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence)
total_input_len = prefix_len + int(input_lens[i])
requests.append(
SampleRequest(
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
))
return requests
# -----------------------------------------------------------------------------
# ShareGPT Dataset Implementation
# -----------------------------------------------------------------------------
class ShareGPTDataset(BenchmarkDataset):
"""
Implements the ShareGPT dataset. Loads data from a JSON file and generates
sample requests based on conversation turns.
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.load_data()
def load_data(self) -> None:
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")
with open(self.dataset_path, encoding="utf-8") as f:
self.data = json.load(f)
# Filter entries with at least two conversation turns.
self.data = [
entry for entry in self.data
if "conversations" in entry and len(entry["conversations"]) >= 2
]
random.seed(self.random_seed)
random.shuffle(self.data)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
lora_path: Optional[str] = None,
max_loras: Optional[int] = None,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
samples: list = []
for entry in self.data:
if len(samples) >= num_requests:
break
prompt, completion = (
entry["conversations"][0]["value"],
entry["conversations"][1]["value"],
)
lora_request, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
new_output_len = (len(completion_ids)
if output_len is None else output_len)
if not is_valid_sequence(prompt_len,
new_output_len,
skip_min_output_len_check=output_len
is not None):
continue
if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation(
prompt, None)
samples.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=new_output_len,
lora_request=lora_request,
))
self.maybe_oversample_requests(samples, num_requests)
return samples
# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------
class SonnetDataset(BenchmarkDataset):
"""
Simplified implementation of the Sonnet dataset. Loads poem lines from a
text file and generates sample requests. Default values here copied from
`benchmark_serving.py` for the sonnet dataset.
"""
DEFAULT_PREFIX_LEN = 200
DEFAULT_INPUT_LEN = 550
DEFAULT_OUTPUT_LEN = 150
def __init__(
self,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.load_data()
def load_data(self) -> None:
if not self.dataset_path:
raise ValueError("dataset_path must be provided.")
with open(self.dataset_path, encoding="utf-8") as f:
self.data = f.readlines()
def sample(
self,
tokenizer,
num_requests: int,
prefix_len: int = DEFAULT_PREFIX_LEN,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
return_prompt_formatted: bool = False,
**kwargs,
) -> list:
# Calculate average token length for a poem line.
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
avg_len = sum(len(tokens)
for tokens in tokenized_lines) / len(tokenized_lines)
# Build the base prompt.
base_prompt = "Pick as many lines as you can from these poem lines:\n"
base_msg = [{"role": "user", "content": base_prompt}]
base_fmt = tokenizer.apply_chat_template(base_msg,
add_generation_prompt=True,
tokenize=False)
base_offset = len(tokenizer(base_fmt).input_ids)
if input_len <= base_offset:
raise ValueError(
f"'input_len' must be higher than the base prompt length "
f"({base_offset}).")
# Determine how many poem lines to use.
num_input_lines = round((input_len - base_offset) / avg_len)
num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0)
prefix_lines = self.data[:num_prefix_lines]
samples = []
while len(samples) < num_requests:
extra_lines = random.choices(self.data,
k=num_input_lines - num_prefix_lines)
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
msg = [{"role": "user", "content": prompt}]
prompt_formatted = tokenizer.apply_chat_template(
msg, add_generation_prompt=True, tokenize=False)
prompt_len = len(tokenizer(prompt_formatted).input_ids)
if prompt_len <= input_len:
samples.append(
SampleRequest(
prompt=prompt_formatted
if return_prompt_formatted else prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
return samples
# -----------------------------------------------------------------------------
# BurstGPT Dataset Implementation
# -----------------------------------------------------------------------------
class BurstGPTDataset(BenchmarkDataset):
"""
Implements the BurstGPT dataset. Loads data from a CSV file and generates
sample requests based on synthetic prompt generation. Only rows with Model
"GPT-4" and positive response tokens are used.
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.load_data()
def load_data(self, ):
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")
try:
import pandas as pd
except ImportError as e:
raise ImportError(
"Pandas is required for BurstGPTDataset. Please install it "
"using `pip install pandas`.") from e
df = pd.read_csv(self.dataset_path)
# Filter to keep only GPT-4 rows.
gpt4_df = df[df["Model"] == "GPT-4"]
# Remove failed requests (where Response tokens is 0 or less).
gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
# Sample the desired number of rows.
self.data = gpt4_df
def _sample_loaded_data(self, num_requests: int) -> list:
if num_requests <= len(self.data):
data = self.data.sample(n=num_requests,
random_state=self.random_seed)
else:
data = self.data.sample(
n=num_requests,
random_state=self.random_seed,
replace=True,
)
# Convert the dataframe to a list of lists.
return data.values.tolist()
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
max_loras: Optional[int] = None,
lora_path: Optional[str] = None,
**kwargs,
) -> list[SampleRequest]:
samples = []
data = self._sample_loaded_data(num_requests=num_requests)
for i in range(num_requests):
input_len = int(data[i][2])
output_len = int(data[i][3])
lora_req, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
vocab_size = tokenizer.vocab_size
# Generate a synthetic prompt: a list of token IDs computed as (i +
# j) modulo vocab_size.
token_ids = [(i + j) % vocab_size for j in range(input_len)]
prompt = tokenizer.decode(token_ids)
samples.append(
SampleRequest(
prompt=prompt,
prompt_len=input_len,
expected_output_len=output_len,
lora_request=lora_req,
))
return samples
# -----------------------------------------------------------------------------
# HuggingFace Dataset Base Implementation
# -----------------------------------------------------------------------------
class HuggingFaceDataset(BenchmarkDataset):
"""Base class for datasets hosted on HuggingFace."""
SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set()
def __init__(
self,
dataset_path: str,
dataset_split: str,
dataset_subset: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(dataset_path=dataset_path, **kwargs)
self.dataset_split = dataset_split
self.dataset_subset = dataset_subset
self.load_data()
def load_data(self) -> None:
"""Load data from HuggingFace datasets."""
try:
from datasets import load_dataset
except ImportError as e:
raise ImportError(
"Hugging Face datasets library is required for this dataset. "
"Please install it using `pip install datasets`.") from e
self.data = load_dataset(
self.dataset_path,
name=self.dataset_subset,
split=self.dataset_split,
streaming=True,
)
self.data = self.data.shuffle(seed=self.random_seed)
# -----------------------------------------------------------------------------
# Conversation Dataset Implementation
# -----------------------------------------------------------------------------
class ConversationDataset(HuggingFaceDataset):
"""Dataset for conversation data with multimodal support."""
SUPPORTED_DATASET_PATHS = {
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
}
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
# Filter examples with at least 2 conversations
filtered_data = self.data.filter(
lambda x: len(x["conversations"]) >= 2)
sampled_requests = []
dynamic_output = output_len is None
for item in filtered_data:
if len(sampled_requests) >= num_requests:
break
conv = item["conversations"]
prompt, completion = conv[0]["value"], conv[1]["value"]
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0
if dynamic_output and not is_valid_sequence(
prompt_len, completion_len):
continue
mm_content = process_image(
item["image"]) if "image" in item else None
if enable_multimodal_chat:
# Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the
# actual prompt len and output len
prompt = self.apply_multimodal_chat_transformation(
prompt, mm_content)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# -----------------------------------------------------------------------------
# Vision Arena Dataset Implementation
# -----------------------------------------------------------------------------
class VisionArenaDataset(HuggingFaceDataset):
"""
Vision Arena Dataset.
"""
DEFAULT_OUTPUT_LEN = 128
SUPPORTED_DATASET_PATHS = {
"lmarena-ai/VisionArena-Chat":
lambda x: x["conversation"][0][0]["content"],
"lmarena-ai/vision-arena-bench-v0.1":
lambda x: x["turns"][0][0]["content"]
}
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
if parser_fn is None:
raise ValueError(
f"Unsupported dataset path: {self.dataset_path}")
prompt = parser_fn(item)
mm_content = process_image(item["images"][0])
prompt_len = len(tokenizer(prompt).input_ids)
if enable_multimodal_chat:
# Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the
# actual prompt len
prompt = self.apply_multimodal_chat_transformation(
prompt, mm_content)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# -----------------------------------------------------------------------------
# Instruct Coder Dataset Implementation
# -----------------------------------------------------------------------------
class InstructCoderDataset(HuggingFaceDataset):
"""
InstructCoder Dataset.
https://huggingface.co/datasets/likaixin/InstructCoder
InstructCoder is the dataset designed for general code editing. It consists
of 114,239 instruction-input-output triplets, and covers multiple distinct
code editing scenario.
"""
DEFAULT_OUTPUT_LEN = 200 # this is the average default output length
SUPPORTED_DATASET_PATHS = {
"likaixin/InstructCoder",
}
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt = f"{item['instruction']}:\n{item['input']}"
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# -----------------------------------------------------------------------------
# AIMO Dataset Implementation
# -----------------------------------------------------------------------------
class AIMODataset(HuggingFaceDataset):
"""
Dataset class for processing a AIMO dataset with reasoning questions.
"""
SUPPORTED_DATASET_PATHS = {
"AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5",
"AI-MO/NuminaMath-CoT"
}
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
**kwargs) -> list:
sampled_requests = []
dynamic_output = output_len is None
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt, completion = item['problem'], item["solution"]
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0
if dynamic_output and not is_valid_sequence(prompt_len,
completion_len,
max_prompt_len=2048,
max_total_len=32000):
continue
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=None,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# SPDX-License-Identifier: Apache-2.0
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import dataclasses
import json
import os
import time
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format,
write_to_json)
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.sampling_params import BeamSearchParams
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: dict[str, Any]) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={"latency": results["latencies"]},
extra_info={k: results[k]
for k in ["avg_latency", "percentiles"]})
if pt_records:
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--input-len", type=int, default=32)
parser.add_argument("--output-len", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument(
"--n",
type=int,
default=1,
help="Number of generated sequences per prompt.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
"--num-iters-warmup",
type=int,
default=10,
help="Number of iterations to run for warmup.",
)
parser.add_argument("--num-iters",
type=int,
default=30,
help="Number of iterations to run.")
parser.add_argument(
"--profile",
action="store_true",
help="profile the generation process of a single batch",
)
parser.add_argument(
"--profile-result-dir",
type=str,
default=None,
help=("path to save the pytorch profiler output. Can be visualized "
"with ui.perfetto.dev or Tensorboard."),
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the latency results in JSON format.",
)
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=("Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"),
)
parser = EngineArgs.add_cli_args(parser)
def main(args: argparse.Namespace):
print(args)
engine_args = EngineArgs.from_cli_args(args)
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args))
assert llm.llm_engine.model_config.max_model_len >= (
args.input_len +
args.output_len), ("Please ensure that max_model_len is greater than"
" the sum of input_len and output_len.")
sampling_params = SamplingParams(
n=args.n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=args.output_len,
detokenize=not args.disable_detokenize,
)
print(sampling_params)
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompts: list[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
def llm_generate():
if not args.use_beam_search:
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
else:
llm.beam_search(
dummy_prompts,
BeamSearchParams(
beam_width=args.n,
max_tokens=args.output_len,
ignore_eos=True,
),
)
def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir)),
) as p:
llm_generate()
print(p.key_averages().table(sort_by="self_cuda_time_total"))
else:
start_time = time.perf_counter()
llm_generate()
end_time = time.perf_counter()
latency = end_time - start_time
return latency
print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion(profile_dir=None)
if args.profile:
profile_dir = args.profile_result_dir
if not profile_dir:
profile_dir = (Path(".") / "vllm_benchmark_result" /
f"latency_result_{time.time()}")
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return
# Benchmark.
latencies = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile_dir=None))
latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90, 99]
percentiles = np.percentile(latencies, percentages)
print(f"Avg latency: {np.mean(latencies)} seconds")
for percentage, percentile in zip(percentages, percentiles):
print(f"{percentage}% percentile latency: {percentile} seconds")
# Output JSON results if specified
if args.output_json:
results = {
"avg_latency": np.mean(latencies),
"latencies": latencies.tolist(),
"percentiles": dict(zip(percentages, percentiles.tolist())),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)
# SPDX-License-Identifier: Apache-2.0
"""Benchmark offline inference throughput."""
import argparse
import dataclasses
import json
import os
import random
import time
import warnings
from typing import Any, Optional, Union
import torch
import uvloop
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset,
ConversationDataset,
InstructCoderDataset, RandomDataset,
SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format,
write_to_json)
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import merge_async_iterators
def run_vllm(
requests: list[SampleRequest],
n: int,
engine_args: EngineArgs,
disable_detokenize: bool = False,
) -> tuple[float, Optional[list[RequestOutput]]]:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len >= (
request.prompt_len + request.expected_output_len)
for request in requests), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.")
# Add the requests to the engine.
prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = []
for request in requests:
prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
))
lora_requests: Optional[list[LoRARequest]] = None
if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests]
use_beam_search = False
outputs = None
if not use_beam_search:
start = time.perf_counter()
outputs = llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for request in requests:
assert request.expected_output_len == output_len
start = time.perf_counter()
llm.beam_search(
prompts,
BeamSearchParams(
beam_width=n,
max_tokens=output_len,
ignore_eos=True,
))
end = time.perf_counter()
return end - start, outputs
def run_vllm_chat(
requests: list[SampleRequest],
n: int,
engine_args: EngineArgs,
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
"""
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
multimodal models as it properly handles multimodal inputs and chat
formatting. For non-multimodal models, use run_vllm() instead.
"""
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len >= (
request.prompt_len + request.expected_output_len)
for request in requests), (
"Please ensure that max_model_len is greater than the sum of "
"prompt_len and expected_output_len for all requests.")
prompts = []
sampling_params: list[SamplingParams] = []
for request in requests:
prompts.append(request.prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
))
start = time.perf_counter()
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
return end - start, outputs
async def run_vllm_async(
requests: list[SampleRequest],
n: int,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
disable_detokenize: bool = False,
) -> float:
from vllm import SamplingParams
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
assert all(
llm.model_config.max_model_len >= (request.prompt_len +
request.expected_output_len)
for request in requests), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.")
# Add the requests to the engine.
prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = []
lora_requests: list[Optional[LoRARequest]] = []
for request in requests:
prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
))
lora_requests.append(request.lora_request)
generators = []
start = time.perf_counter()
for i, (prompt, sp,
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
generator = llm.generate(prompt,
sp,
lora_request=lr,
request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
end = time.perf_counter()
return end - start
def run_hf(
requests: list[SampleRequest],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
max_batch_size: int,
trust_remote_code: bool,
disable_detokenize: bool = False,
) -> float:
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda()
pbar = tqdm(total=len(requests))
start = time.perf_counter()
batch: list[str] = []
max_prompt_len = 0
max_output_len = 0
for i in range(len(requests)):
prompt = requests[i].prompt
prompt_len = requests[i].prompt_len
output_len = requests[i].expected_output_len
# Add the prompt to the batch.
batch.append(prompt)
max_prompt_len = max(max_prompt_len, prompt_len)
max_output_len = max(max_output_len, output_len)
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
next_prompt_len = requests[i + 1].prompt_len
next_output_len = requests[i + 1].expected_output_len
if (max(max_prompt_len, next_prompt_len) +
max(max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=True,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_output_len,
)
if not disable_detokenize:
# Include the decoding time.
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
pbar.update(len(batch))
# Clear the batch.
batch = []
max_prompt_len = 0
max_output_len = 0
end = time.perf_counter()
return end - start
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: dict[str, Any]) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"requests_per_second": [results["requests_per_second"]],
"tokens_per_second": [results["tokens_per_second"]],
},
extra_info={
k: results[k]
for k in ["elapsed_time", "num_requests", "total_num_tokens"]
})
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
def get_requests(args, tokenizer):
# Common parameters for all dataset types.
common_kwargs = {
"dataset_path": args.dataset_path,
"random_seed": args.seed,
}
sample_kwargs = {
"tokenizer": tokenizer,
"lora_path": args.lora_path,
"max_loras": args.max_loras,
"num_requests": args.num_prompts,
"input_len": args.input_len,
"output_len": args.output_len,
}
if args.dataset_path is None or args.dataset_name == "random":
sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len
dataset_cls = RandomDataset
elif args.dataset_name == "sharegpt":
dataset_cls = ShareGPTDataset
if args.backend == "vllm-chat":
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset.")
dataset_cls = SonnetDataset
sample_kwargs["prefix_len"] = args.prefix_len
sample_kwargs["return_prompt_formatted"] = True
elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset
elif args.dataset_name == "hf":
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = InstructCoderDataset
common_kwargs['dataset_split'] = "train"
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = ConversationDataset
common_kwargs['dataset_subset'] = args.hf_subset
common_kwargs['dataset_split'] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_cls = AIMODataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
def validate_args(args):
"""
Validate command-line arguments.
"""
# === Deprecation and Defaulting ===
if args.dataset is not None:
warnings.warn(
"The '--dataset' argument will be deprecated in the next release. "
"Please use '--dataset-name' and '--dataset-path' instead.",
stacklevel=2)
args.dataset_path = args.dataset
if not getattr(args, "tokenizer", None):
args.tokenizer = args.model
# === Backend Validation ===
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
if args.backend not in valid_backends:
raise ValueError(f"Unsupported backend: {args.backend}")
# === Dataset Configuration ===
if not args.dataset and not args.dataset_path:
print(
"When dataset path is not set, it will default to random dataset")
args.dataset_name = 'random'
if args.input_len is None:
raise ValueError("input_len must be provided for a random dataset")
# === Dataset Name Specific Checks ===
# --hf-subset and --hf-split: only used
# when dataset_name is 'hf'
if args.dataset_name != "hf" and (
getattr(args, "hf_subset", None) is not None
or getattr(args, "hf_split", None) is not None):
warnings.warn("--hf-subset and --hf-split will be ignored \
since --dataset-name is not 'hf'.",
stacklevel=2)
elif args.dataset_name == "hf":
if args.dataset_path in (
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
| ConversationDataset.SUPPORTED_DATASET_PATHS):
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS):
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
else:
raise ValueError(
f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random'
if args.dataset_name != 'random' and args.random_range_ratio is not None:
warnings.warn("--random-range-ratio will be ignored since \
--dataset-name is not 'random'.",
stacklevel=2)
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
# set.
if args.dataset_name not in {"random", "sonnet", None
} and args.prefix_len is not None:
warnings.warn("--prefix-len will be ignored since --dataset-name\
is not 'random', 'sonnet', or not set.",
stacklevel=2)
# === LoRA Settings ===
if getattr(args, "enable_lora", False) and args.backend != "vllm":
raise ValueError(
"LoRA benchmarking is only supported for vLLM backend")
if getattr(args, "enable_lora", False) and args.lora_path is None:
raise ValueError("LoRA path must be provided when enable_lora is True")
# === Backend-specific Validations ===
if args.backend == "hf" and args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend")
if args.backend != "hf" and args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
if args.backend in {"hf", "mii"} and getattr(args, "quantization",
None) is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.backend == "mii" and args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
if args.backend == "mii" and args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.backend == "mii" and args.tokenizer != args.model:
raise ValueError(
"Tokenizer must be the same as the model for MII backend.")
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii", "vllm-chat"],
default="vllm")
parser.add_argument(
"--dataset-name",
type=str,
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
help="Name of the dataset to benchmark on.",
default="sharegpt")
parser.add_argument(
"--dataset",
type=str,
default=None,
help="Path to the ShareGPT dataset, will be deprecated in\
the next release. The dataset is expected to "
"be a json in form of list[dict[..., conversations: "
"list[dict[..., value: <prompt_or_response>]]]]")
parser.add_argument("--dataset-path",
type=str,
default=None,
help="Path to the dataset")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=("Do not detokenize the response (i.e. do not include "
"detokenization time in the measurement)"))
# LoRA
parser.add_argument(
"--lora-path",
type=str,
default=None,
help="Path to the lora adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.")
parser.add_argument(
"--prefix-len",
type=int,
default=0,
help="Number of fixed prefix tokens before the random "
"context in a request (default: 0).",
)
# random dataset
parser.add_argument(
"--random-range-ratio",
type=float,
default=0.0,
help="Range ratio for sampling input/output length, "
"used only for RandomDataset. Must be in the range [0, 1) to define "
"a symmetric sampling range "
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
)
# hf dtaset
parser.add_argument("--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.")
parser.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
parser = AsyncEngineArgs.add_cli_args(parser)
def main(args: argparse.Namespace):
if args.tokenizer is None:
args.tokenizer = args.model
validate_args(args)
if args.seed is None:
args.seed = 0
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None
for request in requests)
request_outputs: Optional[list[RequestOutput]] = None
if args.backend == "vllm":
if args.async_engine:
elapsed_time = uvloop.run(
run_vllm_async(
requests,
args.n,
AsyncEngineArgs.from_cli_args(args),
args.disable_frontend_multiprocessing,
args.disable_detokenize,
))
else:
elapsed_time, request_outputs = run_vllm(
requests, args.n, EngineArgs.from_cli_args(args),
args.disable_detokenize)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.hf_max_batch_size, args.trust_remote_code,
args.disable_detokenize)
elif args.backend == "vllm-chat":
elapsed_time, request_outputs = run_vllm_chat(
requests, args.n, EngineArgs.from_cli_args(args),
args.disable_detokenize)
else:
raise ValueError(f"Unknown backend: {args.backend}")
if request_outputs:
# Note: with the vllm and vllm-chat backends,
# we have request_outputs, which we use to count tokens.
total_prompt_tokens = 0
total_output_tokens = 0
for ro in request_outputs:
if not isinstance(ro, RequestOutput):
continue
total_prompt_tokens += len(
ro.prompt_token_ids) if ro.prompt_token_ids else 0
total_output_tokens += sum(
len(o.token_ids) for o in ro.outputs if o)
total_num_tokens = total_prompt_tokens + total_output_tokens
else:
total_num_tokens = sum(r.prompt_len + r.expected_output_len
for r in requests)
total_output_tokens = sum(r.expected_output_len for r in requests)
total_prompt_tokens = total_num_tokens - total_output_tokens
if is_multi_modal and args.backend != "vllm-chat":
print("\033[91mWARNING\033[0m: Multi-modal request with "
f"{args.backend} backend detected. The "
"following metrics are not accurate because image tokens are not"
" counted. See vllm-project/vllm/issues/9778 for details.")
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
# vllm-chat backend counts the image tokens now
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
print(f"Total num prompt tokens: {total_prompt_tokens}")
print(f"Total num output tokens: {total_output_tokens}")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)
...@@ -282,13 +282,21 @@ def get_vllm_version(): ...@@ -282,13 +282,21 @@ def get_vllm_version():
if __version__ == "dev": if __version__ == "dev":
return "N/A (dev)" return "N/A (dev)"
version_str = __version_tuple__[-1]
if len(__version_tuple__) == 4: # dev build if isinstance(version_str, str) and version_str.startswith('g'):
git_sha = __version_tuple__[-1][1:] # type: ignore # it's a dev build
return f"{__version__} (git sha: {git_sha}" if '.' in version_str:
# it's a dev build containing local changes
git_sha = version_str.split('.')[0][1:]
date = version_str.split('.')[-1][1:]
return f"{__version__} (git sha: {git_sha}, date: {date})"
else:
# it's a dev build without local changes
git_sha = version_str[1:] # type: ignore
return f"{__version__} (git sha: {git_sha})"
return __version__ return __version__
def summarize_vllm_build_flags(): def summarize_vllm_build_flags():
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format( return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format(
...@@ -502,7 +510,9 @@ def get_pip_packages(run_lambda, patterns=None): ...@@ -502,7 +510,9 @@ def get_pip_packages(run_lambda, patterns=None):
print("uv is set") print("uv is set")
cmd = ["uv", "pip", "list", "--format=freeze"] cmd = ["uv", "pip", "list", "--format=freeze"]
else: else:
raise RuntimeError("Could not collect pip list output (pip or uv module not available)") raise RuntimeError(
"Could not collect pip list output (pip or uv module not available)"
)
out = run_and_read_all(run_lambda, cmd) out = run_and_read_all(run_lambda, cmd)
return "\n".join(line for line in out.splitlines() return "\n".join(line for line in out.splitlines()
...@@ -535,13 +545,12 @@ def is_xnnpack_available(): ...@@ -535,13 +545,12 @@ def is_xnnpack_available():
else: else:
return "N/A" return "N/A"
def get_env_vars(): def get_env_vars():
env_vars = '' env_vars = ''
secret_terms=('secret', 'token', 'api', 'access', 'password') secret_terms = ('secret', 'token', 'api', 'access', 'password')
report_prefix = ("TORCH", "NCCL", "PYTORCH", report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN",
"CUDA", "CUBLAS", "CUDNN", "OMP_", "MKL_", "NVIDIA")
"OMP_", "MKL_",
"NVIDIA")
for k, v in os.environ.items(): for k, v in os.environ.items():
if any(term in k.lower() for term in secret_terms): if any(term in k.lower() for term in secret_terms):
continue continue
...@@ -552,6 +561,7 @@ def get_env_vars(): ...@@ -552,6 +561,7 @@ def get_env_vars():
return env_vars return env_vars
def get_env_info(): def get_env_info():
run_lambda = run run_lambda = run
pip_version, pip_list_output = get_pip_packages(run_lambda) pip_version, pip_list_output = get_pip_packages(run_lambda)
......
...@@ -110,10 +110,14 @@ class CompilerManager: ...@@ -110,10 +110,14 @@ class CompilerManager:
compiled_graph = self.load(graph, example_inputs, graph_index, compiled_graph = self.load(graph, example_inputs, graph_index,
runtime_shape) runtime_shape)
if compiled_graph is not None: if compiled_graph is not None:
if graph_index == 0: if graph_index == num_graphs - 1:
# adds some info logging for the first graph # after loading the last graph for this shape, record the time.
logger.info("Directly load the compiled graph for shape %s " # there can be multiple graphs due to piecewise compilation.
"from the cache", str(runtime_shape)) # noqa now = time.time()
elapsed = now - compilation_start_time
logger.info(
"Directly load the compiled graph(s) for shape %s "
"from the cache, took %.3f s", str(runtime_shape), elapsed)
return compiled_graph return compiled_graph
# no compiler cached the graph, or the cache is disabled, # no compiler cached the graph, or the cache is disabled,
...@@ -335,7 +339,7 @@ class VllmBackend: ...@@ -335,7 +339,7 @@ class VllmBackend:
def configure_post_pass(self): def configure_post_pass(self):
config = self.compilation_config config = self.compilation_config
self.post_grad_pass_manager.configure(config.pass_config) self.post_grad_pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass # Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager. # hook. If a pass for that hook exists, add it to the pass manager.
......
...@@ -11,9 +11,12 @@ import torch ...@@ -11,9 +11,12 @@ import torch
import torch._inductor.compile_fx import torch._inductor.compile_fx
import torch.fx as fx import torch.fx as fx
import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils import is_torch_equal_or_newer from vllm.utils import is_torch_equal_or_newer
from .inductor_pass import pass_context
class CompilerInterface: class CompilerInterface:
""" """
...@@ -167,8 +170,7 @@ class InductorAdaptor(CompilerInterface): ...@@ -167,8 +170,7 @@ class InductorAdaptor(CompilerInterface):
compiler_config: Dict[str, Any], compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None runtime_shape: Optional[int] = None
) -> Tuple[Optional[Callable], Optional[Any]]: ) -> Tuple[Optional[Callable], Optional[Any]]:
from torch._inductor import config current_config = {}
current_config = config.get_config_copy()
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx
# disable remote cache # disable remote cache
...@@ -196,7 +198,6 @@ class InductorAdaptor(CompilerInterface): ...@@ -196,7 +198,6 @@ class InductorAdaptor(CompilerInterface):
hash_str, file_path = None, None hash_str, file_path = None, None
from torch._inductor.codecache import (FxGraphCache, from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash) compiled_fx_graph_hash)
if torch.__version__.startswith("2.5"): if torch.__version__.startswith("2.5"):
original_load = FxGraphCache.load original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load" original_load_name = "torch._inductor.codecache.FxGraphCache.load"
...@@ -281,6 +282,16 @@ class InductorAdaptor(CompilerInterface): ...@@ -281,6 +282,16 @@ class InductorAdaptor(CompilerInterface):
patch("torch._inductor.codecache.FxGraphCache._get_shape_env", patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env)) _get_shape_env))
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache)
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
_get_shape_env))
# for forcing the graph to be cached # for forcing the graph to be cached
stack.enter_context( stack.enter_context(
patch( patch(
...@@ -290,16 +301,34 @@ class InductorAdaptor(CompilerInterface): ...@@ -290,16 +301,34 @@ class InductorAdaptor(CompilerInterface):
# Dynamo metrics context, see method for more details. # Dynamo metrics context, see method for more details.
stack.enter_context(self.metrics_context()) stack.enter_context(self.metrics_context())
compiled_graph = compile_fx( # Disable remote caching. When these are on, on remote cache-hit,
graph, # the monkey-patched functions never actually get called.
example_inputs, # vLLM today assumes and requires the monkey-patched functions to
inner_compile=hijacked_compile_fx_inner, # get hit.
config_patches=current_config) # TODO(zou3519): we're going to replace this all with
# standalone_compile sometime.
assert hash_str is not None, ( if is_torch_equal_or_newer("2.6"):
"failed to get the hash of the compiled graph") stack.enter_context(
assert file_path is not None, ( torch._inductor.config.patch(fx_graph_remote_cache=False))
"failed to get the file path of the compiled graph") stack.enter_context(
torch._functorch.config.patch(
enable_remote_autograd_cache=False))
with pass_context(runtime_shape):
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config)
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
# compilation cache. So turn off the checks if we disable the
# compilation cache.
if not envs.VLLM_DISABLE_COMPILE_CACHE:
assert hash_str is not None, (
"failed to get the hash of the compiled graph")
assert file_path is not None, (
"failed to get the file path of the compiled graph")
return compiled_graph, (hash_str, file_path) return compiled_graph, (hash_str, file_path)
def load(self, def load(self,
...@@ -313,11 +342,19 @@ class InductorAdaptor(CompilerInterface): ...@@ -313,11 +342,19 @@ class InductorAdaptor(CompilerInterface):
assert isinstance(handle[1], str) assert isinstance(handle[1], str)
hash_str = handle[0] hash_str = handle[0]
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache)
from torch._inductor.codecache import FxGraphCache from torch._inductor.codecache import FxGraphCache
with ExitStack() as exit_stack: with ExitStack() as exit_stack:
exit_stack.enter_context( exit_stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache._get_shape_env", patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv())) lambda *args, **kwargs: AlwaysHitShapeEnv()))
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
exit_stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
# Dynamo metrics context, see method for more details. # Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context()) exit_stack.enter_context(self.metrics_context())
......
...@@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized ...@@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload from torch._ops import OpOverload
from vllm.config import CompilationConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -531,7 +531,7 @@ class FusionPass(VllmInductorPass): ...@@ -531,7 +531,7 @@ class FusionPass(VllmInductorPass):
_instance: 'Optional[FusionPass]' = None _instance: 'Optional[FusionPass]' = None
@classmethod @classmethod
def instance(cls, config: CompilationConfig.PassConfig): def instance(cls, config: VllmConfig):
""" """
Get the singleton instance of the FusionPass. Get the singleton instance of the FusionPass.
If the instance exists, the config is updated but If the instance exists, the config is updated but
...@@ -540,10 +540,10 @@ class FusionPass(VllmInductorPass): ...@@ -540,10 +540,10 @@ class FusionPass(VllmInductorPass):
if cls._instance is None: if cls._instance is None:
cls._instance = FusionPass(config) cls._instance = FusionPass(config)
else: else:
cls._instance.config = config cls._instance.pass_config = config.compilation_config.pass_config
return cls._instance return cls._instance
def __init__(self, config: CompilationConfig.PassConfig): def __init__(self, config: VllmConfig):
assert self.__class__._instance is None, \ assert self.__class__._instance is None, \
"FusionPass singleton instance already exists" "FusionPass singleton instance already exists"
super().__init__(config) super().__init__(config)
......
...@@ -12,6 +12,22 @@ def is_func(node: fx.Node, target) -> bool: ...@@ -12,6 +12,22 @@ def is_func(node: fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target return node.op == "call_function" and node.target == target
# Returns the first specified node with the given op (if it exists)
def find_specified_fn_maybe(nodes: Iterable[fx.Node],
op: OpOverload) -> Optional[fx.Node]:
for node in nodes:
if node.target == op:
return node
return None
# Returns the first specified node with the given op
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
node = find_specified_fn_maybe(nodes, op)
assert node is not None, f"Could not find {op} in nodes {nodes}"
return node
# Returns the first auto_functionalized node with the given op (if it exists) # Returns the first auto_functionalized node with the given op (if it exists)
def find_auto_fn_maybe(nodes: Iterable[fx.Node], def find_auto_fn_maybe(nodes: Iterable[fx.Node],
op: OpOverload) -> Optional[fx.Node]: op: OpOverload) -> Optional[fx.Node]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment