Commit 081057de authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-ori

parents 7cf5d5c4 ba41cc90
......@@ -2,8 +2,10 @@
"""Attention backend utils"""
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
import numpy as np
import torch
......@@ -11,6 +13,7 @@ import torch
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.attention.backends.abstract import AttentionType
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
......@@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens(
return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens)
@dataclass
class MLADims:
q_lora_rank: Optional[int]
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)
......@@ -10,6 +10,9 @@ import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
......@@ -87,6 +90,7 @@ class Attention(nn.Module):
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer)
......@@ -329,17 +333,54 @@ class MultiHeadAttention(nn.Module):
return out.reshape(bsz, q_len, -1)
def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
def unified_attention_fake(
......@@ -367,6 +408,7 @@ def unified_attention_with_output(
output: torch.Tensor,
layer_name: str,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
......@@ -379,6 +421,8 @@ def unified_attention_with_output(
attn_metadata,
output=output)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_attention_with_output_fake(
query: torch.Tensor,
......
......@@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata:
block_usage: Optional[torch.Tensor]
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
......
......@@ -16,831 +16,778 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8
# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)
if triton.__version__ >= "2.1.0":
@triton.jit
def _fwd_kernel(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
k_scale,
v_scale,
B_Start_Loc,
B_Seqlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
IN_PRECISION: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
SKIP_DECODE: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
cur_batch_query_len = (cur_batch_in_all_stop_index -
cur_batch_in_all_start_index)
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
if SKIP_DECODE and cur_batch_query_len == 1:
return
# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc = BLOCK_M * start_m
# initialize offsets
# [N]; starts at 0
offs_n = tl.arange(0, BLOCK_N)
# [D]; starts at 0
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
# [M]; starts at current position in query
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# [M,D]
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
dim_mask = tl.where(
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
0).to(tl.int1) # [D]
q = tl.load(Q + off_q,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len),
other=0.0) # [M,D]
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M]
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M]
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
dtype=tl.float32) # [M,D]
# compute query against context (no causal mask here)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0) # [N]
# [D,N]
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
# [N,D]
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
-10000)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1) # [M]
p = tl.exp(qk - m_ij[:, None]) # [M,N]
l_ij = tl.sum(p, 1) # [M]
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij) # [M]
alpha = tl.exp(m_i - m_i_new) # [M]
beta = tl.exp(m_ij - m_i_new) # [M]
l_i_new = alpha * l_i + beta * l_ij # [M]
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v_load = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# # update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_query_len),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk *= sm_scale
# apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
if SLIDING_WINDOW > 0:
qk = tl.where(
offs_m[:, None] - (start_n + offs_n[None, :])
< SLIDING_WINDOW, qk, -10000)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_query_len),
other=0.0)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len))
# Here's an example autotuner config for this kernel. This config does provide
# a performance improvement, but dramatically increases first call latency in
# triton 3.2. Because of this tradeoff, it's currently commented out.
# @triton.autotune(
# configs=[
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
# "num_unroll_cache": 4, \
# "num_unroll_request": 1 } | \
# ({"kpack": 2, "waves_per_eu": 2} \
# if current_platform.is_rocm() else {}), \
# num_warps=4, \
# num_stages=1)
# ],
# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
# )
@triton.jit
def _fwd_kernel(Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
k_scale,
v_scale,
B_Start_Loc,
B_Seqlen,
x: tl.constexpr,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl: tl.constexpr,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: tl.constexpr,
IN_PRECISION: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DMODEL_PADDED: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
num_unroll_cache: tl.constexpr,
num_unroll_request: tl.constexpr,
SKIP_DECODE: tl.constexpr,
MAX_Q_LEN: tl.constexpr = 0,
MAX_CTX_LEN: tl.constexpr = 0):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
cur_batch_query_len = (cur_batch_in_all_stop_index -
cur_batch_in_all_start_index)
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
if SKIP_DECODE and cur_batch_query_len == 1:
return
@triton.jit
def _fwd_kernel_flash_attn_v2(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(Q + off_q,
mask=offs_m[:, None]
< cur_batch_seq_len - cur_batch_ctx_len,
# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc = BLOCK_M * start_m
# initialize offsets
# [BLOCK_SIZE]; starts at 0
offs_bs_n = tl.arange(0, BLOCK_SIZE)
# [N]; starts at 0
offs_n = tl.arange(0, BLOCK_N)
# [D]; starts at 0
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
# [M]; starts at current position in query
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# [M,D]
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
dim_mask = tl.where(
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
0).to(tl.int1) # [D]
q = tl.load(Q + off_q,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len),
other=0.0) # [M,D]
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
# compute query against context (no causal mask here)
for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \
loop_unroll_factor=num_unroll_cache):
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
(start_n // BLOCK_SIZE) * stride_b_loc_s)
# [D,BLOCK_SIZE]
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
# [BLOCK_SIZE,D]
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
offs_bs_n[:, None] * stride_v_cache_bl)
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
k_load = tl.load(
K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
else:
k_load = tl.load(K_cache + off_k)
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_bs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
(start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk,
-10000)
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha[:, None]
# update acc
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
v_load = tl.load(
V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
else:
v_load = tl.load(V_cache + off_v)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# # update m_i and l_i
l_i = l_i * alpha + l_ij
m_i = m_ij
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask)
for start_n in tl.range(0, \
block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \
loop_unroll_factor=num_unroll_request):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_query_len),
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :])
< cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None])
< cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# acc /= l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@triton.jit
def _fwd_kernel_alibi(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
k_scale,
v_scale,
B_Start_Loc,
B_Seqlen,
Alibi_slopes,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
IN_PRECISION: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
SKIP_DECODE: tl.constexpr,
):
# attn_bias[]
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
cur_batch_query_len = (cur_batch_in_all_stop_index -
cur_batch_in_all_start_index)
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
if SKIP_DECODE and cur_batch_query_len == 1:
return
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
dim_mask = tl.where(
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
q = tl.load(Q + off_q,
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk *= sm_scale
# apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
if SLIDING_WINDOW > 0:
qk = tl.where(
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
qk, -10000)
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
((start_n + offs_n[:, None]) < cur_batch_query_len),
other=0.0)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# update m_i and l_i
l_i = l_i * alpha + l_ij
m_i = m_ij
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
return
@triton.jit
def _fwd_kernel_flash_attn_v2(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :])
< cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = 0
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v_load = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
# init alibi
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = cur_batch_ctx_len
# # init debugger
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :])
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None])
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None])
< cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# acc /= l_i[:, None]
# initialize pointers to output
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@triton.jit
def _fwd_kernel_alibi(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
k_scale,
v_scale,
B_Start_Loc,
B_Seqlen,
Alibi_slopes,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
IN_PRECISION: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
SKIP_DECODE: tl.constexpr,
):
# attn_bias[]
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
cur_batch_query_len = (cur_batch_in_all_stop_index -
cur_batch_in_all_start_index)
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
if SKIP_DECODE and cur_batch_query_len == 1:
return
@torch.inference_mode()
def context_attention_fwd(q,
k,
v,
o,
kv_cache_dtype: str,
k_cache,
v_cache,
b_loc,
b_start_loc,
b_seq_len,
max_seq_len,
max_input_len,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
skip_decode=False):
q_dtype_is_f32 = q.dtype is torch.float32
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
dim_mask = tl.where(
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
q = tl.load(Q + off_q,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = 0
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v_load = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
# init alibi
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = cur_batch_ctx_len
# # init debugger
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] & ((start_n + offs_n[None, :])
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] & ((start_n + offs_n[:, None])
< cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
return
@torch.inference_mode()
def context_attention_fwd(q,
k,
v,
o,
kv_cache_dtype: str,
k_cache,
v_cache,
b_loc,
b_start_loc,
b_seq_len,
max_seq_len,
max_input_len,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
skip_decode=False):
q_dtype_is_f32 = q.dtype is torch.float32
# Turing does have tensor core for float32 multiplication
# use ieee as fallback for triton kernels work. There is also
# warning on vllm/config.py to inform users this fallback
# implementation
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert (k_cache.dtype == torch.uint8)
assert (v_cache.dtype == torch.uint8)
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
k_cache = k_cache.view(target_dtype)
v_cache = v_cache.view(target_dtype)
if (k_cache.dtype == torch.uint8
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
raise ValueError("kv_cache_dtype='auto' unsupported for\
FP8 KV Cache prefill kernel")
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)
if sm_scale is None:
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]
assert batch + 1 == len(b_start_loc)
# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
if alibi_slopes is not None:
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
# Turing does have tensor core for float32 multiplication
# use ieee as fallback for triton kernels work. There is also
# warning on vllm/config.py to inform users this fallback
# implementation
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert (k_cache.dtype == torch.uint8)
assert (v_cache.dtype == torch.uint8)
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
k_cache = k_cache.view(target_dtype)
v_cache = v_cache.view(target_dtype)
if (k_cache.dtype == torch.uint8
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
raise ValueError("kv_cache_dtype='auto' unsupported for\
FP8 KV Cache prefill kernel")
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)
if sm_scale is None:
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]
assert batch + 1 == len(b_start_loc)
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
if alibi_slopes is not None:
_fwd_kernel_alibi[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
k_scale,
v_scale,
b_start_loc,
b_seq_len,
alibi_slopes,
v_cache.shape[3],
k_cache.shape[4],
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SKIP_DECODE=skip_decode,
num_warps=NUM_WARPS,
num_stages=1,
)
return
_fwd_kernel[grid](
# batch, head,
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
_fwd_kernel_alibi[grid](
q,
k,
v,
......@@ -852,6 +799,7 @@ if triton.__version__ >= "2.1.0":
v_scale,
b_start_loc,
b_seq_len,
alibi_slopes,
v_cache.shape[3],
k_cache.shape[4],
o,
......@@ -886,9 +834,69 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode,
num_warps=NUM_WARPS,
num_stages=1,
)
return
max_seq_len = 0 if max_seq_len is None else max_seq_len
extra_kargs = {}
if current_platform.is_rocm():
extra_kargs = {"kpack": 2, "waves_per_eu": 2}
grid = lambda META: (batch, head,
triton.cdiv(max_input_len, META["BLOCK_M"]))
_fwd_kernel[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
k_scale,
v_scale,
b_start_loc,
b_seq_len,
k_cache.shape[4],
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size]
BLOCK_SIZE=v_cache.shape[3],
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode,
BLOCK_M=128,
BLOCK_N=64,
num_unroll_cache=4,
num_unroll_request=1,
num_warps=4,
num_stages=1,
**extra_kargs)
return
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
max_block_per_batch: int,
device: torch.device) -> tuple[torch.Tensor, ...]:
paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch,
dtype=torch.int32,
device=device)
paged_kv_indptr = torch.zeros(max_batch_size + 1,
dtype=torch.int32,
device=device)
paged_kv_last_page_lens = torch.full((max_batch_size, ),
block_size,
dtype=torch.int32)
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens
def aiter_mla_decode_fwd(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
sm_scale: float,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
logit_cap: float = 0.0,
):
from aiter.mla import mla_decode_fwd
mla_decode_fwd(q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
kv_indptr,
kv_indices,
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap)
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import aiter as rocm_aiter
import torch
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.platforms import current_platform
from vllm.utils import cdiv
FP8_DTYPE = current_platform.fp8_dtype()
class AITERPagedAttention(PagedAttention):
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, k_scale,
v_scale)
else:
kv_cache_torch_dtype = (FP8_DTYPE
if "fp8" in kv_cache_dtype else torch.int8)
key_cache = key_cache.view(kv_cache_torch_dtype)
value_cache = value_cache.view(kv_cache_torch_dtype)
rocm_aiter.reshape_and_cache_with_pertoken_quant(
key, value, key_cache, value_cache, k_scale, v_scale,
slot_mapping.flatten(), True)
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
return PagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
kv_cache_dtype=kv_cache_dtype,
num_kv_heads=num_kv_heads,
scale=scale,
alibi_slopes=alibi_slopes,
k_scale=k_scale,
v_scale=v_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
blocksparse_block_size=blocksparse_block_size,
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
if "fp8" in kv_cache_dtype:
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (blocksparse_block_size > 0 and
blocksparse_block_size % block_size == 0), \
(f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables.")
output = torch.empty_like(query)
block_size = value_cache.shape[3]
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
seq_lens, max_num_blocks_per_seq, k_scale,
v_scale, output)
return output
......@@ -39,11 +39,12 @@ is_hip_ = current_platform.is_rocm()
logger = logging.getLogger(__name__)
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance
# and accuracy.
logger.warning(
"The following error message 'operation scheduled before its operands' "
"can be ignored.")
# Only print the following warnings when triton version < 3.2.0.
# The issue won't affect performance or accuracy.
if triton.__version__ < '3.2.0':
logger.warning(
"The following error message 'operation scheduled before its operands' "
"can be ignored.")
@triton.jit
......
#!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
(https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
This is a Triton implementation of the Flash Attention v2 algorithm
See https://tridao.me/publications/flash2/flash2.pdf
Features supported:
Credits:
AMD Triton kernels team
OpenAI kernel team
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
Currently only the forward kernel is supported, and contains these features:
1) Non power of two head dims
1) Fwd with causal masking
2) Arbitrary Q and KV sequence lengths
3) Arbitrary head sizes
4) Multi and grouped query attention
5) Variable sequence lengths
6) ALiBi and matrix bias
7) FP8 support
"""
from typing import Optional
import torch
import triton
import triton.language as tl
torch_dtype: tl.constexpr = torch.float16
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
default_eight_bit_dtype_triton = tl.float8e4b8
default_eight_bit_dtype_torch = current_platform.fp8_dtype()
default_float8_info = torch.finfo(default_eight_bit_dtype_torch)
FP8_MIN = triton.language.constexpr(default_float8_info.min)
# According to https://github.com/vllm-project/vllm/blob/main
# /csrc/quantization/utils.cuh#L31,
# need to make the max for the uz datatype be 224.0 for accuracy reasons.
FP8_MAX = triton.language.constexpr(
default_float8_info.max if default_eight_bit_dtype_torch !=
torch.float8_e4m3fnuz else 224.0)
class MetaData:
cu_seqlens_q = None
cu_seqlens_k = None
max_seqlens_q = 0
max_seqlens_k = 0
bias = None
alibi_slopes = None
causal = False
num_contexts = 0
varlen = False
eight_bit = False
layout = None
return_encoded_softmax = False
eight_bit_dtype_triton = default_eight_bit_dtype_triton
eight_bit_dtype_torch = default_eight_bit_dtype_torch
output_dtype = None
# Note about layouts:
#
# thd - [num_tokens, num_heads, head_size]
# bshd - [batch_size, seq_len, num_heads, head_size]
# bhsd - [batch_size, num_heads, seq_len, head_size]
#
# This is for each tensor, all tensors must have same layout.
# Q can have num_heads and seq_len differ from from K and V,
# however K and V must agree on this.
#
# Notes about varlen and bias:
# Only one or the other is implemented, meaning can't combine
# both varlen and bias right now.
#
# Note about quantization:
# Only 8-bit quantization supported (for now) and specifically fp8.
# Scales must be tensors.
# o_scale: This is 'output scaling', but comes from parameter called
# 'input_scale', this is applied to the output from the kernel.
# o_scale should be None if none of the other quantization parameters
# are used.
#
# NOTE: Object is in a tentatively good state after initialized, however,
# to verify, call check_args(q,k,v,o) where o is the output tensor.
def __init__(
self,
sm_scale=1.0,
layout=None, # layout can be 'bshd', 'bhsd', or 'thd'
output_dtype=None,
max_seqlens_q=0,
max_seqlens_k=0,
# varlen params
cu_seqlens_q=None, # only 'thd' layout supported for varlen
cu_seqlens_k=None,
# quant params
q_descale=None,
k_descale=None,
v_descale=None,
p_scale=None,
o_scale=None,
# bias params
bias=None, # varlen not implemented for bias
seqlen_q=None,
seqlen_k=None,
# alibi params
alibi_slopes=None,
alibi_batch=None,
alibi_nheads=None,
# causal
causal=None,
):
self.sm_scale = sm_scale
self.output_dtype = output_dtype
self.max_seqlens_q = max_seqlens_q
self.max_seqlens_k = max_seqlens_k
self.layout = layout
if cu_seqlens_q is not None or cu_seqlens_k is not None:
assert cu_seqlens_q is not None and cu_seqlens_k is not None
assert layout is None or layout not in [
'bshd', 'bhsd'
], "Varlen only implemented for thd layout"
self.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
quant_params = [q_descale, k_descale, v_descale, p_scale, o_scale]
if any(x is not None for x in quant_params):
p_descale = 1.0 / p_scale if p_scale is not None else None
self.set_eight_bit_params(q_descale, k_descale, v_descale, p_scale,
p_descale, o_scale)
if bias is not None:
self.need_bias(bias, seqlen_q, seqlen_k)
if alibi_slopes is not None:
self.need_alibi(alibi_slopes, alibi_batch, alibi_nheads)
if causal is not None and causal:
self.need_causal()
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
self.varlen = True
self.layout = 'thd'
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_k = cu_seqlens_k
# Without "varlen", there should still be one sequence.
assert len(cu_seqlens_q) >= 2
assert len(cu_seqlens_q) == len(cu_seqlens_k)
self.num_contexts = len(cu_seqlens_q) - 1
for i in range(0, self.num_contexts):
self.max_seqlens_q = max(
cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(),
self.max_seqlens_q)
self.max_seqlens_k = max(
cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(),
self.max_seqlens_k)
def set_eight_bit_params(self, q_descale, k_descale, v_descale, p_scale,
p_descale, o_scale):
self.eight_bit = True
self.q_descale = q_descale
self.k_descale = k_descale
self.v_descale = v_descale
self.p_scale = p_scale
self.p_descale = p_descale
self.o_scale = o_scale
self.use_p_scale = (p_scale is not None) and (
p_descale is not None) and (v_descale is not None)
self.eight_bit_kv = ((q_descale is None) and (k_descale is not None)
and (v_descale is not None))
self.eight_bit_dtype_torch = default_eight_bit_dtype_torch
def need_bias(self, bias, seqlen_q, seqlen_k):
assert bias is not None
assert bias.is_cuda
assert bias.dim() == 4
assert bias.shape[0] == 1
assert bias.shape[2:] == (seqlen_q, seqlen_k)
self.bias = bias
def need_alibi(self, alibi_slopes, batch, nheads):
assert alibi_slopes.is_cuda
assert alibi_slopes.dim() == 2
assert alibi_slopes.shape[0] == batch
assert alibi_slopes.shape[1] == nheads
self.alibi_slopes = alibi_slopes
def need_causal(self):
self.causal = True
def check_args(self, q, k, v, o):
assert q.dim() == k.dim() and q.dim() == v.dim()
batch, nheads_q, nheads_k, head_size = get_shape_from_layout(
q, k, self)
if self.varlen:
assert q.dim() == 3
assert self.cu_seqlens_q is not None
assert self.cu_seqlens_k is not None
assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
# TODO: Remove once bias is supported with varlen
assert self.bias is None
assert not self.return_encoded_softmax
else:
assert q.dim() == 4
assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
if self.eight_bit:
if self.eight_bit_kv:
assert (v.dtype == k.dtype
and k.dtype == self.eight_bit_dtype_torch)
assert q.dtype != k.dtype
assert (self.v_descale is not None) and (self.k_descale
is not None)
else:
assert (q.dtype == k.dtype and q.dtype == v.dtype
and q.dtype == self.eight_bit_dtype_torch)
assert (self.q_descale
is not None) and (self.k_descale
is not None) and (self.v_descale
is not None)
if self.use_p_scale:
assert (self.p_scale is not None) and (self.p_descale
is not None)
else:
assert (q.dtype == k.dtype) and (q.dtype == v.dtype)
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
assert self.layout is not None
assert self.layout == 'thd' or not self.varlen
@triton.jit
......@@ -38,40 +244,85 @@ def max_fn(x, y):
return tl.math.max(x, y)
# Convenience function to load with optional boundary checks.
# "First" is the major dim, "second" is the minor dim.
@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]
def masked_load(ptrs, offset_first, offset_second, boundary_first,
boundary_second):
if offset_first is not None and offset_second is not None:
mask = (offset_first[:, None] < boundary_first) & \
(offset_second[None, :] < boundary_second)
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_first is not None:
mask = offset_first[:, None] < boundary_first
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_second is not None:
mask = offset_second[None, :] < boundary_second
tensor = tl.load(ptrs, mask=mask, other=0.0)
else:
tensor = tl.load(ptrs)
return tensor
@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
stride).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)
def compute_alibi_block(alibi_slope,
seqlen_q,
seqlen_k,
offs_m,
offs_n,
transpose=False):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to
# the bottom right of the attention matrix
# for casual mask we want something like this where (1 is kept and 0 is
# masked)
# seqlen_q = 2 and seqlen_k = 5
# 1 1 1 1 0
# 1 1 1 1 1
# seqlen_q = 5 and seqlen_k = 2
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# for alibi the diagonal is 0 indicating no penalty for attending to that
# spot and increasing penalty for attending further from the diagonal
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5,
# offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
# 1. offs_m[:,None] = [[0],
# [1],
# 2. offs_m[:,None] + seqlen_k = [[5],
# [6],
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
# [4],
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] =
# [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], [4], [ 4, 3, 2, 1, 0]]
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
# [ -4, -3, -2, -1, 0]],
relative_pos_block = (offs_m[:, None] + seqlen_k - seqlen_q -
offs_n[None, :])
alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
if transpose:
return alibi_block.T
else:
return alibi_block
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
stride)
rng_keep = rng_output > dropout_p
return rng_keep
def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k):
q_idx = torch.arange(seqlen_q, dtype=torch.int32,
device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
k_idx = torch.arange(seqlen_k, dtype=torch.int32,
device="cuda").unsqueeze(0) # (1, N_CTX_K)
relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q -
k_idx) # (N_CTX_Q, N_CTX_K)
return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(
-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)
@triton.jit
def load_fn(block_ptr, first, second, pad):
if first and second:
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
elif first:
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
def quant_fp8(x, scale):
x *= scale
x = tl.clamp(x, FP8_MIN, FP8_MAX)
return x
@triton.jit
......@@ -80,58 +331,68 @@ def _attn_fwd_inner(
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
k_ptrs,
v_ptrs,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m,
actual_seqlen_k,
dropout_p,
actual_seqlen_q,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
encoded_sm_ptrs,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
alibi_slope,
q_descale,
k_descale,
v_descale,
p_scale,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
PADDED_HEAD: tl.constexpr,
SHOULD_PRE_LOAD_V: tl.constexpr,
SHOULD_MASK_STEPS: tl.constexpr,
SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr,
USE_PADDED_HEAD: tl.constexpr,
IS_ACTUAL_BLOCK_DMODEL: tl.constexpr,
QK_SCALE: tl.constexpr,
IS_EIGHT_BIT_GEMM: tl.constexpr,
USE_P_SCALE: tl.constexpr,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
):
# loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k = load_fn(
K_block_ptr,
PADDED_HEAD,
MASK_STEPS and (n_extra_tokens != 0),
"zero",
)
if PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
k_offs_n = start_n + tl.arange(0,
BLOCK_N) if SHOULD_MASK_STEPS else None
k_offs_k = None if not USE_PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)
k = masked_load(k_ptrs, k_offs_k, k_offs_n, IS_ACTUAL_BLOCK_DMODEL,
actual_seqlen_k)
if SHOULD_PRE_LOAD_V:
# We can use the same offsets as k, just with dims transposed.
v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k,
IS_ACTUAL_BLOCK_DMODEL)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS: # noqa: SIM102
if SHOULD_MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
# if not is_modulo_mn. last step might get wasted but that is okay.
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not
# is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M],
......@@ -144,167 +405,276 @@ def _attn_fwd_inner(
causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ----
qk += tl.dot(q, k)
if bias_ptr is not None:
bias = load_fn(bias_ptr, False, MASK_STEPS
and (n_extra_tokens != 0), "zero")
# While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk += bias * 1.44269504089
if IS_EIGHT_BIT_GEMM:
qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) *
QK_SCALE)
else:
if IS_EIGHT_BIT_KV:
k = (k * k_descale).to(q.type.element_ty)
qk += (tl.dot(q, k) * QK_SCALE)
if bias_ptrs is not None:
bias_offs_n = start_n + tl.arange(
0, BLOCK_N) if SHOULD_MASK_STEPS else None
bias = masked_load(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q,
actual_seqlen_k)
# While bias is added after multiplying qk with sm_scale,
# our optimization to use 2^x instead of e^x results in an
# additional scale factor of log2(e) which we must also multiply
# the bias with.
qk += (bias * 1.44269504089)
if alibi_slope is not None:
# Compute the global position of each token within the sequence
global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
global_n_positions = start_n + tl.arange(0, BLOCK_N)
alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q,
actual_seqlen_k,
global_m_positions,
global_n_positions)
qk += (alibi_block * 1.44269504089) # scale factor of log2(e)
# softmax
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
philox_offset = (batch_philox_offset +
start_m * BLOCK_M * actual_seqlen_k + start_n -
BLOCK_N)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p,
-p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
if SHOULD_RETURN_ENCODED_SOFTMAX:
tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty))
# -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
if not PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
if not SHOULD_PRE_LOAD_V:
v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k,
IS_ACTUAL_BLOCK_DMODEL)
# -- update m_i and l_i
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, BLOCK_N))
if IS_EIGHT_BIT_GEMM:
if USE_P_SCALE:
p = quant_fp8(p, p_scale).to(QUANT_DTYPE)
acc += tl.dot(p, v)
else:
# v is in eight_bit but p is not, we want the gemm in p's type
acc += tl.dot(p, v.to(p.type.element_ty))
else:
if IS_EIGHT_BIT_KV:
v = (v * v_descale).to(p.type.element_ty)
acc += tl.dot(p.to(v.type.element_ty), v)
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
if bias_ptrs is not None:
bias_ptrs += BLOCK_N * stride_bn
if SHOULD_RETURN_ENCODED_SOFTMAX:
encoded_sm_ptrs += BLOCK_N
return acc, l_i, m_i
@triton.autotune(
configs=[
def get_cdna_autotune_configs():
return [
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 128,
'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 3,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 1,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=8,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 32,
'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4,
),
num_warps=4),
], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def get_rdna_autotune_configs():
return [
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=8,
),
num_warps=2),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4,
),
num_warps=2),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": True,
'BLOCK_M': 32,
'BLOCK_N': 16,
'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4,
),
num_warps=2),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": False,
'BLOCK_M': 32,
'BLOCK_N': 16,
'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4,
),
num_warps=2),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=8,
),
num_warps=2),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=8,
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
num_warps=2),
# Fall-back config.
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 1,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4,
),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
num_warps=2),
], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def get_general_autotune_configs():
return [
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 128,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 64,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 32,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def has_cdna_target():
ROCM_CDNA_TARGETS = ["gfx940", "gfx941", "gfx942", "gfx90a", "gfx908"]
return triton.runtime.driver.active.get_current_target(
).arch in ROCM_CDNA_TARGETS
def is_rocm_cdna():
return current_platform.is_rocm() and has_cdna_target()
def get_autotune_configs():
if is_rocm_cdna():
return get_cdna_autotune_configs()
elif current_platform.is_rocm():
return get_rdna_autotune_configs()
else:
return get_general_autotune_configs()
autotune_configs, autotune_keys = get_autotune_configs()
@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(
......@@ -312,38 +682,53 @@ def attn_fwd(
K,
V,
bias,
sm_scale,
SM_SCALE: tl.constexpr,
L,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
stride_bz,
stride_bh,
stride_bm,
stride_bn,
stride_qz: tl.int64,
stride_qh: tl.int64,
stride_qm: tl.int64,
stride_qk: tl.int64,
stride_kz: tl.int64,
stride_kh: tl.int64,
stride_kn: tl.int64,
stride_kk: tl.int64,
stride_vz: tl.int64,
stride_vh: tl.int64,
stride_vk: tl.int64,
stride_vn: tl.int64,
stride_oz: tl.int64,
stride_oh: tl.int64,
stride_om: tl.int64,
stride_on: tl.int64,
stride_bz: tl.int64,
stride_bh: tl.int64,
stride_bm: tl.int64,
stride_bn: tl.int64,
stride_az: tl.int64,
stride_ah: tl.int64,
q_descale_ptr,
k_descale_ptr,
p_scale_ptr,
p_descale_ptr,
o_descale_ptr,
v_descale_ptr,
q_descale_has_singleton: tl.constexpr,
k_descale_has_singleton: tl.constexpr,
p_descale_has_singleton: tl.constexpr,
v_descale_has_singleton: tl.constexpr,
cu_seqlens_q,
cu_seqlens_k,
dropout_p,
philox_seed,
NUM_CU: tl.constexpr,
GRID_CU_MULTIP: tl.constexpr,
B: tl.constexpr,
philox_offset_base,
encoded_softmax,
alibi_slopes,
HQ: tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
IS_ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr,
......@@ -351,24 +736,39 @@ def attn_fwd(
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
SHOULD_PRE_LOAD_V: tl.constexpr,
USE_BIAS: tl.constexpr,
SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr,
USE_ALIBI: tl.constexpr,
IS_EIGHT_BIT: tl.constexpr,
USE_P_SCALE: tl.constexpr,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
if o_descale_ptr is not None:
o_descale = tl.load(o_descale_ptr)
start_m: tl.int64 = tl.program_id(0)
off_h_q: tl.int64 = tl.program_id(1)
off_z: tl.int64 = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64)
offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64)
# as we can't have return statements inside while loop in Triton
continue_condition = True
if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
# We have a one-size-fits-all grid in id(0). Some seqlens might be
# too small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
return
continue_condition = False
# return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
......@@ -378,444 +778,598 @@ def attn_fwd(
seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if IS_CAUSAL:
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn
# matrix
n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is
# part of the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
off_h_q * stride_oh)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0
# for these masked blocks.
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
# tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel.
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
cu_seqlens_q_start * stride_qm)
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
cu_seqlens_k_start * stride_kn)
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
cu_seqlens_k_start * stride_vk)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
if BIAS_TYPE != 0:
bias_ptr = tl.make_block_ptr(
base=bias + off_h_q * stride_bh,
shape=(seqlen_q, seqlen_k),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base \
+ (off_z * HQ + off_h_q) \
* seqlen_q * seqlen_k
else:
batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout.
# In this case, we return an invalid pointer so indicate the mask is not i
# valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.make_block_ptr(
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
shape=(seqlen_q, seqlen_k),
strides=(seqlen_k, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
encoded_softmax_block_ptr = 0
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale = sm_scale * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q = load_fn(Q_block_ptr, True, padded_head, "zero")
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
# block. In this case we might exceed n_blocks so pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min,
block_max,
0,
0,
0,
bias_ptr,
# IS_CAUSAL, ....
False,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
False,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
padded_head,
)
block_min = block_max
block_max = n_blocks * BLOCK_N
tl.debug_barrier()
# Remaining blocks, if any, are full / not masked.
if masked_blocks > 0:
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, n_full_blocks))
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
True,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
padded_head,
)
# epilogue
acc = acc / l_i[:, None]
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
causal_start_idx,
dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None]
>= out_mask_boundary[None, :])
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
# overflow_size = end_m_idx - seqlen_q
# if overflow_size > 0:
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# # This is a > check because mask being 0 blocks the store.
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
# else:
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
off_h_q * stride_oh)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
def check_args(
q,
k,
v,
o,
varlen=True,
max_seqlens=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
):
assert q.dim() == k.dim() and q.dim() == v.dim()
if varlen:
assert q.dim() == 3
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
assert cu_seqlens_q is not None
assert cu_seqlens_k is not None
assert len(cu_seqlens_q) == len(cu_seqlens_k)
else:
assert q.dim() == 4
batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
assert max_seqlens > 0
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
if continue_condition:
# Now we compute whether we need to exit early due to causal
# masking. This is because for seqlen_q > seqlen_k, M rows of the
# attn scores are completely masked, resulting in 0s written to the
# output, and inf written to LSE. We don't need to do any GEMMs in
# this case. This block of code determines what N is, and if this
# WG is operating on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if (IS_CAUSAL):
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which
# means the causal mask boundary is bottom right aligned, and
# ends at either the top edge (seqlen_q < seqlen_k) or left
# edge. This captures the decrease in n_blocks if we have a
# rectangular attn matrix
n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all
# n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this
# WG is part of the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh +
cu_seqlens_q_start * stride_om)
o_ptrs = (o_offset + offs_m[:, None] * stride_om +
offs_d[None, :] * stride_on)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to(
[BLOCK_M, BLOCK_DMODEL])
# We still need to write 0s to the result
tl.store(o_ptrs, acc, mask=o_ptrs_mask)
# The tensor allocated for L is based on MAX_SEQLENS_Q as
# that is statically known.
l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q +
off_h_q * MAX_SEQLENS_Q + offs_m)
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this from qk which makes it -inf, such that
# exp(qk - inf) = 0 for these masked blocks.
l_value = tl.full([BLOCK_M],
value=float("inf"),
dtype=tl.float32)
l_ptrs_mask = offs_m < MAX_SEQLENS_Q
tl.store(l_ptrs, l_value, mask=l_ptrs_mask)
# TODO: Should dropout and return encoded softmax be
# handled here too?
continue_condition = False
# return
if continue_condition:
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
USE_PADDED_HEAD: tl.constexpr = (IS_ACTUAL_BLOCK_DMODEL
!= BLOCK_DMODEL)
# Compute pointers for all the tensors used in this kernel.
q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh +
cu_seqlens_q_start * stride_qm)
q_ptrs = (q_offset + offs_m[:, None] * stride_qm +
offs_d[None, :] * stride_qk)
k_offset = (K + off_z * stride_kz + off_h_k * stride_kh +
cu_seqlens_k_start * stride_kn)
k_ptrs = (k_offset + offs_d[:, None] * stride_kk +
offs_n[None, :] * stride_kn)
v_offset = (V + off_z * stride_vz + off_h_k * stride_vh +
cu_seqlens_k_start * stride_vk)
v_ptrs = (v_offset + offs_n[:, None] * stride_vk +
offs_d[None, :] * stride_vn)
# Compute pointers for all scale tensors used in this kernel.
IS_EIGHT_BIT_GEMM: tl.constexpr = IS_EIGHT_BIT & (
not IS_EIGHT_BIT_KV)
if IS_EIGHT_BIT:
if k_descale_has_singleton:
k_descale_ptrs = k_descale_ptr
else:
k_descale_ptrs = k_descale_ptr + off_h_k
if v_descale_has_singleton:
v_descale_ptrs = v_descale_ptr
else:
v_descale_ptrs = v_descale_ptr + off_h_k
if not IS_EIGHT_BIT_KV:
if q_descale_has_singleton:
q_descale_ptrs = q_descale_ptr
else:
q_descale_ptrs = q_descale_ptr + off_h_q
if USE_P_SCALE:
if p_descale_has_singleton:
p_scale_ptrs = p_scale_ptr
p_descale_ptrs = p_descale_ptr
else:
p_scale_ptrs = p_scale_ptr + off_h_q
p_descale_ptrs = p_descale_ptr + off_h_q
if USE_BIAS:
bias_offset = off_h_q * stride_bh
bias_ptrs = (bias + bias_offset + offs_m[:, None] * stride_bm +
offs_n[None, :] * stride_bn)
else:
bias_ptrs = None
if USE_ALIBI:
a_offset = off_z * stride_az + off_h_q * stride_ah
alibi_slope = tl.load(alibi_slopes + a_offset)
else:
alibi_slope = None
batch_philox_offset = 0
# We can ask to return the dropout mask without doing any
# dropout. In this case, we return an invalid pointer so
# indicate the mask is not valid.
if SHOULD_RETURN_ENCODED_SOFTMAX:
encoded_sm_base = (encoded_softmax +
off_h_q * seqlen_q * seqlen_k)
encoded_sm_ptrs = (encoded_sm_base +
offs_m[:, None] * seqlen_k +
offs_n[None, :])
else:
encoded_sm_ptrs = None
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do
# not have native e^x support in HW.
QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q_ptrs_mask = offs_m[:, None] < seqlen_q
if USE_PADDED_HEAD:
q_ptrs_mask = q_ptrs_mask & (offs_d[None, :]
< IS_ACTUAL_BLOCK_DMODEL)
q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)
if IS_EIGHT_BIT:
k_descale = tl.load(k_descale_ptrs)
v_descale = tl.load(v_descale_ptrs)
q_descale = None if IS_EIGHT_BIT_KV else tl.load(
q_descale_ptrs)
if USE_P_SCALE:
p_scale = tl.load(p_scale_ptrs)
p_descale = tl.load(p_descale_ptrs)
else:
p_scale = None
p_descale = None
else:
q_descale = None
k_descale = None
v_descale = None
p_scale = None
p_descale = None
# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked
# blocks. Additionally there might be one more due to
# dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an
# additional block. In this case we might exceed n_blocks so
# pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false
# regardless of its actual value because there is no masking.
# Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
k_ptrs,
v_ptrs,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m,
seqlen_k,
seqlen_q,
philox_seed,
batch_philox_offset,
encoded_sm_ptrs,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min,
block_max,
0,
0,
0,
alibi_slope,
q_descale,
k_descale,
v_descale,
p_scale,
# IS_CAUSAL, ....
False,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, SHOULD_MASK_STEPS, ...
SHOULD_PRE_LOAD_V,
False,
SHOULD_RETURN_ENCODED_SOFTMAX,
USE_PADDED_HEAD,
IS_ACTUAL_BLOCK_DMODEL,
QK_SCALE,
IS_EIGHT_BIT_GEMM,
USE_P_SCALE,
IS_EIGHT_BIT_KV,
QUANT_DTYPE)
block_min = block_max
block_max = n_blocks * BLOCK_N
tl.debug_barrier()
# Remaining blocks, if any, are full / not masked.
if (masked_blocks > 0):
if IS_CAUSAL:
offs_n_causal = offs_n + (seqlen_q - seqlen_k)
else:
offs_n_causal = 0
k_ptrs += n_full_blocks * BLOCK_N * stride_kn
v_ptrs += n_full_blocks * BLOCK_N * stride_vk
if USE_BIAS:
bias_ptrs += n_full_blocks * BLOCK_N * stride_bn
if SHOULD_RETURN_ENCODED_SOFTMAX:
encoded_sm_ptrs += n_full_blocks * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
k_ptrs,
v_ptrs,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m,
seqlen_k,
seqlen_q,
philox_seed,
batch_philox_offset,
encoded_sm_ptrs,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
alibi_slope,
q_descale,
k_descale,
v_descale,
p_scale,
IS_CAUSAL,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, SHOULD_MASK_STEPS, ...
SHOULD_PRE_LOAD_V,
True,
SHOULD_RETURN_ENCODED_SOFTMAX,
USE_PADDED_HEAD,
IS_ACTUAL_BLOCK_DMODEL,
QK_SCALE,
IS_EIGHT_BIT_GEMM,
USE_P_SCALE,
IS_EIGHT_BIT_KV,
QUANT_DTYPE)
if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV:
if USE_P_SCALE:
acc *= p_descale
acc *= v_descale
# epilogue
# This helps the compiler do Newton Raphson on l_i vs on acc
# which is much larger.
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
# If seqlen_q > seqlen_k but the delta is not a multiple of
# BLOCK_M, then we have one block with a row of all NaNs which
# come from computing softmax over a row of all
# -infs (-inf - inf = NaN). We check for that here and store 0s
# where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: # noqa: SIM102
if o_descale_ptr is not None:
acc = quant_fp8(acc, o_descale)
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if (causal_start_idx > start_m_idx
and causal_start_idx < end_m_idx):
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
causal_start_idx,
dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None]
>= out_mask_boundary[None, :])
z = tl.zeros((1, ), tl.float32)
acc = tl.where(out_ptrs_mask, acc,
z.to(acc.type.element_ty))
# write back LSE
l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q +
off_h_q * MAX_SEQLENS_Q + offs_m)
# If seqlen_q not multiple of BLOCK_M, we need to mask out the
# last few rows. This is only true for the last M block.
# For others, overflow_size will be -ve
overflow_size = end_m_idx - seqlen_q
if overflow_size > 0:
boundary = tl.full((BLOCK_M, ),
BLOCK_M - overflow_size,
dtype=tl.int32)
l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary
tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
else:
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh +
cu_seqlens_q_start * stride_om)
o_ptrs = (o_offset + offs_m[:, None] * stride_om +
offs_d[None, :] * stride_on)
o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1)
if overflow_size > 0:
o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q)
if USE_PADDED_HEAD:
o_ptrs_mask = o_ptrs_mask & (offs_d[None, :]
< IS_ACTUAL_BLOCK_DMODEL)
tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)
def get_shape_from_layout(q, k, metadata):
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
if metadata.layout == 'thd':
nheads_q, nheads_k = q.shape[1], k.shape[1]
head_size = q.shape[-1]
batch = metadata.num_contexts
elif metadata.layout == 'bhsd':
batch, nheads_q, _, head_size = q.shape
nheads_k = k.shape[1]
elif metadata.layout == 'bshd':
batch, _, nheads_q, head_size = q.shape
nheads_k = k.shape[2]
return batch, nheads_q, nheads_k, head_size
def get_strides_from_layout(q, k, v, o, metadata):
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
STRIDE_PERMUTATIONS = {
'thd': (None, 1, 0, 2),
'bhsd': (0, 1, 2, 3),
'bshd': (0, 2, 1, 3),
}
perm = STRIDE_PERMUTATIONS[metadata.layout]
stride = lambda x, p: (0 if p is None else x.stride(p))
strides = lambda x: (stride(x, p) for p in perm)
return tuple(strides(x) for x in [q, k, v, o])
class _attention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
):
def forward(ctx, q, k, v, o, metadata: MetaData):
# NOTE: a large bias tensor leads to overflow during pointer arithmetic
if (metadata.bias is not None):
assert (metadata.bias.numel() < 2**31)
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
if metadata.eight_bit:
o = torch.empty_like(
q,
dtype=metadata.output_dtype if metadata.output_dtype
is not None else metadata.eight_bit_dtype_torch)
else:
o = torch.empty_like(q, dtype=q.dtype)
check_args(
q,
k,
v,
o,
varlen=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
metadata.check_args(q, k, v, o)
batch, nheads_q, nheads_k, head_size = get_shape_from_layout(
q, k, metadata)
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(
q, k, v, o, metadata)
# Get closest power of 2 over or equal to 32.
unpadded_head_dims = {32, 64, 128, 256}
if head_size not in unpadded_head_dims:
padded_d_model = None
for i in unpadded_head_dims:
if i > head_size:
padded_d_model = i
break
assert padded_d_model is not None
else:
padded_d_model = head_size
padded_d_model = 1 << (head_size - 1).bit_length()
# Smallest head_dim supported is 16. If smaller, the tile in the
# kernel is padded - there is no padding in memory for any dims.
padded_d_model = max(padded_d_model, 16)
grid = lambda META: (
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
nheads_q,
batch,
)
# encoded_softmax is used to validate dropout behavior vs the
# PyTorch SDPA math backend reference. We zero this out to give a
# consistent starting point and then populate it with the output of
# softmax with the sign bit set according to the dropout mask.
# The resulting return allows this mask to be fed into the reference
# implementation for testing only. This return holds no useful output
# aside from debugging.
if metadata.return_encoded_softmax:
encoded_softmax = torch.zeros(
(q.shape[0], q.shape[1], q.shape[2], k.shape[2]),
device=q.device,
dtype=torch.float32)
else:
encoded_softmax = None
encoded_softmax = None
M = torch.empty((batch, nheads_q, metadata.max_seqlens_q),
device=q.device,
dtype=torch.float32)
# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
if bias is not None:
bias_strides = (
bias.stride(0),
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
if metadata.bias is not None:
bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1),
metadata.bias.stride(2), metadata.bias.stride(3))
else:
bias_strides = (0, 0, 0, 0)
if metadata.alibi_slopes is not None:
alibi_strides = (metadata.alibi_slopes.stride(0),
metadata.alibi_slopes.stride(1))
else:
alibi_strides = (0, 0)
if metadata.eight_bit:
q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = (
metadata.q_descale, metadata.k_descale, metadata.p_scale,
metadata.p_descale, metadata.v_descale, metadata.o_scale)
o_descale = 1.0 / o_scale if o_scale is not None else None
else:
q_descale = k_descale = p_scale = None
p_descale = v_descale = o_descale = None
# number of compute units available
NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count
grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META[
'BLOCK_M']), nheads_q, batch)
attn_fwd[grid](
q,
k,
v,
bias,
sm_scale,
None,
metadata.bias,
metadata.sm_scale,
M,
o,
*q_strides,
*k_strides,
*v_strides,
*o_strides,
*bias_strides,
cu_seqlens_q,
cu_seqlens_k,
dropout_p=0.0,
*alibi_strides,
q_descale,
k_descale,
p_scale,
p_descale,
o_descale,
v_descale,
q_descale.numel() == 1 if q_descale is not None else False,
k_descale.numel() == 1 if k_descale is not None else False,
p_descale.numel() == 1 if p_descale is not None else False,
v_descale.numel() == 1 if v_descale is not None else False,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
alibi_slopes=metadata.alibi_slopes,
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
IS_CAUSAL=causal,
VARLEN=True,
IS_ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=metadata.max_seqlens_q,
MAX_SEQLENS_K=metadata.max_seqlens_k,
IS_CAUSAL=metadata.causal,
VARLEN=metadata.varlen,
BLOCK_DMODEL=padded_d_model,
BIAS_TYPE=0 if bias is None else 1,
ENABLE_DROPOUT=False,
RETURN_ENCODED_SOFTMAX=False,
)
USE_BIAS=metadata.bias is not None,
USE_ALIBI=metadata.alibi_slopes is not None,
SHOULD_RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax,
IS_EIGHT_BIT=metadata.eight_bit,
USE_P_SCALE=metadata.eight_bit and metadata.use_p_scale,
IS_EIGHT_BIT_KV=metadata.eight_bit and metadata.eight_bit_kv,
NUM_CU=NUM_CU,
B=batch,
QUANT_DTYPE=metadata.eight_bit_dtype_triton)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.sm_scale = metadata.sm_scale
ctx.BLOCK_DMODEL = head_size
ctx.causal = causal
ctx.dropout_p = 0.0
ctx.causal = metadata.causal
ctx.alibi_slopes = metadata.alibi_slopes
ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = False
ctx.return_encoded_softmax = metadata.return_encoded_softmax
return o, encoded_softmax
triton_attention = _attention.apply
triton_attention_rocm = _attention.apply
def scale_fp8(t, scale=None):
t_scaled, scale_out = ops.scaled_fp8_quant(t.reshape(-1, t.shape[-1]),
scale)
return t_scaled.reshape(t.shape), scale_out
def maybe_quantize_fp8(t, scale):
eight_bit_dtype = current_platform.fp8_dtype()
if t.dtype != eight_bit_dtype:
t, _ = scale_fp8(t, scale)
return t
def check_and_maybe_quantize_qkv(q, k, v, fp8_scales):
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
q = maybe_quantize_fp8(q, q_scale)
k = maybe_quantize_fp8(k, k_scale)
v = maybe_quantize_fp8(v, v_scale)
return q, k, v
# query - [num_tokens, num_heads, head_size]
# key - [num_tokens, num_kv_heads, head_size]
# value - [num_tokens, num_kv_heads, head_size
# output - [num_tokens, num_heads, head_size]
def triton_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlens_q: int,
max_seqlens_k: int,
causal: bool = False,
sm_scale: float = 1.0,
bias: Optional[torch.Tensor] = None,
fp8_scales: Optional[tuple[float, ...]] = None,
input_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if fp8_scales is not None:
q_descale, k_descale, v_descale, p_scale = fp8_scales
else:
q_descale = k_descale = v_descale = p_scale = None
attn_metadata = MetaData(sm_scale=sm_scale,
max_seqlens_q=max_seqlens_q,
max_seqlens_k=max_seqlens_k,
causal=causal,
bias=bias,
output_dtype=q.dtype,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
p_scale=p_scale,
o_scale=input_scale)
if fp8_scales is not None:
q, k, v = check_and_maybe_quantize_qkv(q, k, v, fp8_scales)
return triton_attention_rocm(q, k, v, o, attn_metadata)
......@@ -66,7 +66,10 @@ def merge_attn_states_kernel(
max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
out_se = (tl.exp(p_lse) + tl.exp(s_lse))
# Will reuse precomputed Exp values for scale factor computation.
p_se = tl.exp(p_lse)
s_se = tl.exp(s_lse)
out_se = (p_se + s_se)
if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse
......@@ -84,8 +87,8 @@ def merge_attn_states_kernel(
# NOTE(woosuk): Be careful with the numerical stability.
# We should compute the scale first, and then multiply it with the output.
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
p_scale = tl.exp(p_lse) / out_se
s_scale = tl.exp(s_lse) / out_se
p_scale = p_se / out_se
s_scale = s_se / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
......
......@@ -38,9 +38,18 @@ class BeamSearchOutput:
class BeamSearchInstance:
def __init__(self, prompt_tokens: list[int]):
def __init__(
self,
prompt_tokens: list[int],
logprobs: Optional[list[dict[int, Logprob]]] = None,
**kwargs,
):
self.beams: list[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
BeamSearchSequence(
tokens=prompt_tokens,
logprobs=[] if logprobs is None else list(logprobs),
**kwargs,
)
]
self.completed: list[BeamSearchSequence] = []
......
# SPDX-License-Identifier: Apache-2.0
"""
This module defines a framework for sampling benchmark requests from various
datasets. Each dataset subclass of BenchmarkDataset must implement sample
generation. Supported dataset types include:
- ShareGPT
- Random (synthetic)
- Sonnet
- BurstGPT
- HuggingFace
- VisionArena
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
SampleRequest instances, similar to the approach used in ShareGPT.
"""
import base64
import io
import json
import logging
import random
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from functools import cache
from io import BytesIO
from typing import Any, Callable, Optional, Union
import numpy as np
from PIL import Image
from transformers import PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Data Classes
# -----------------------------------------------------------------------------
@dataclass
class SampleRequest:
"""
Represents a single inference request for benchmarking.
"""
prompt: Union[str, Any]
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
lora_request: Optional[LoRARequest] = None
# -----------------------------------------------------------------------------
# Benchmark Dataset Base Class
# -----------------------------------------------------------------------------
class BenchmarkDataset(ABC):
DEFAULT_SEED = 0
def __init__(
self,
dataset_path: Optional[str] = None,
random_seed: int = DEFAULT_SEED,
) -> None:
"""
Initialize the BenchmarkDataset with an optional dataset path and random
seed.
Args:
dataset_path (Optional[str]): Path to the dataset. If None, it
indicates that a default or random dataset might be used.
random_seed (int): Seed value for reproducible shuffling or
sampling. Defaults to DEFAULT_SEED.
"""
self.dataset_path = dataset_path
# Set the random seed, ensuring that a None value is replaced with the
# default seed.
self.random_seed = (random_seed
if random_seed is not None else self.DEFAULT_SEED)
self.data = None
def apply_multimodal_chat_transformation(
self,
prompt: str,
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
"""
Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation
format.
"""
content = [{"text": prompt, "type": "text"}]
if mm_content is not None:
content.append(mm_content)
return [{"role": "user", "content": content}]
def load_data(self) -> None:
"""
Load data from the dataset path into self.data.
This method must be overridden by subclasses since the method to load
data will vary depending on the dataset format and source.
Raises:
NotImplementedError: If a subclass does not implement this method.
"""
# TODO (jenniferzhao): add support for downloading data
raise NotImplementedError(
"load_data must be implemented in subclasses.")
def get_random_lora_request(
self,
tokenizer: PreTrainedTokenizerBase,
max_loras: Optional[int] = None,
lora_path: Optional[str] = None,
) -> tuple[Optional[LoRARequest], AnyTokenizer]:
"""
Optionally select a random LoRA request and return its associated
tokenizer.
This method is used when LoRA parameters are provided. It randomly
selects a LoRA based on max_loras and retrieves a cached tokenizer for
that LoRA if available. Otherwise, it returns the base tokenizer.
Args:
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
LoRA is selected. max_loras (Optional[int]): The maximum number of
LoRAs available. If None, LoRA is not used. lora_path
(Optional[str]): Path to the LoRA parameters on disk. If None, LoRA
is not used.
Returns:
tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first
element is a LoRARequest (or None if not applicable) and the second
element is the tokenizer associated with the LoRA request (or the
base tokenizer).
"""
if max_loras is None or lora_path is None:
return None, tokenizer
# Generate a random LoRA ID in the range [1, max_loras].
lora_id = random.randint(1, max_loras)
lora_request = LoRARequest(
lora_name=str(lora_id),
lora_int_id=lora_id,
lora_path=lora_path_on_disk(lora_path),
)
if lora_id not in lora_tokenizer_cache:
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
# Return lora_request and the cached tokenizer if available; otherwise,
# return the base tokenizer
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
@abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase,
num_requests: int) -> list[SampleRequest]:
"""
Abstract method to generate sample requests from the dataset.
Subclasses must override this method to implement dataset-specific logic
for generating a list of SampleRequest objects.
Args:
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
Returns:
list[SampleRequest]: A list of sample requests generated from the
dataset.
"""
raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest],
num_requests: int) -> None:
"""
Oversamples the list of requests if its size is less than the desired
number.
Args:
requests (List[SampleRequest]): The current list of sampled
requests. num_requests (int): The target number of requests.
"""
if len(requests) < num_requests:
random.seed(self.random_seed)
additional = random.choices(requests,
k=num_requests - len(requests))
requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.",
num_requests)
# -----------------------------------------------------------------------------
# Utility Functions and Global Caches
# -----------------------------------------------------------------------------
def is_valid_sequence(
prompt_len: int,
output_len: int,
min_len: int = 4,
max_prompt_len: int = 1024,
max_total_len: int = 2048,
skip_min_output_len_check: bool = False,
) -> bool:
"""
Validate a sequence based on prompt and output lengths.
Default pruning criteria are copied from the original `sample_hf_requests`
and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as
from `sample_requests` in benchmark_throughput.py.
"""
# Check for invalid conditions
prompt_too_short = prompt_len < min_len
output_too_short = (not skip_min_output_len_check) and (output_len
< min_len)
prompt_too_long = prompt_len > max_prompt_len
combined_too_long = (prompt_len + output_len) > max_total_len
# Return True if none of the invalid conditions are met
return not (prompt_too_short or output_too_short or prompt_too_long
or combined_too_long)
@cache
def lora_path_on_disk(lora_path: str) -> str:
return get_adapter_absolute_path(lora_path)
# Global cache for LoRA tokenizers.
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
def process_image(image: Any) -> Mapping[str, Any]:
"""
Process a single image input and return a multimedia content dictionary.
Supports three input types:
1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
containing raw image data. - Loads the bytes as a PIL.Image.Image.
2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as
a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns
a dictionary with the image as a base64 data URL.
3. String input: - Treats the string as a URL or local file path. -
Prepends "file://" if the string doesn't start with "http://" or
"file://". - Returns a dictionary with the image URL.
Raises:
ValueError: If the input is not a supported type.
"""
if isinstance(image, dict) and 'bytes' in image:
image = Image.open(BytesIO(image['bytes']))
if isinstance(image, Image.Image):
image = image.convert("RGB")
with io.BytesIO() as image_data:
image.save(image_data, format="JPEG")
image_base64 = base64.b64encode(
image_data.getvalue()).decode("utf-8")
return {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
}
if isinstance(image, str):
image_url = (image if image.startswith(
("http://", "file://")) else f"file://{image}")
return {"type": "image_url", "image_url": {"url": image_url}}
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
" or str or dictionary with raw image bytes.")
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------
class RandomDataset(BenchmarkDataset):
# Default values copied from benchmark_serving.py for the random dataset.
DEFAULT_PREFIX_LEN = 0
DEFAULT_RANGE_RATIO = 0.0
DEFAULT_INPUT_LEN = 1024
DEFAULT_OUTPUT_LEN = 128
def __init__(
self,
**kwargs,
) -> None:
super().__init__(**kwargs)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
prefix_len: int = DEFAULT_PREFIX_LEN,
range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
**kwargs,
) -> list[SampleRequest]:
# Enforce range_ratio < 1
assert range_ratio < 1.0, (
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
)
vocab_size = tokenizer.vocab_size
prefix_token_ids = (np.random.randint(
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
# New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(input_len * (1 - range_ratio))
input_high = int(input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio))
output_high = int(output_len * (1 + range_ratio))
# Add logging for debugging
logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
logger.info("Sampling output_len from [%s, %s]", output_low,
output_high)
input_lens = np.random.randint(input_low,
input_high + 1,
size=num_requests)
output_lens = np.random.randint(output_low,
output_high + 1,
size=num_requests)
offsets = np.random.randint(0, vocab_size, size=num_requests)
requests = []
for i in range(num_requests):
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence)
total_input_len = prefix_len + int(input_lens[i])
requests.append(
SampleRequest(
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
))
return requests
# -----------------------------------------------------------------------------
# ShareGPT Dataset Implementation
# -----------------------------------------------------------------------------
class ShareGPTDataset(BenchmarkDataset):
"""
Implements the ShareGPT dataset. Loads data from a JSON file and generates
sample requests based on conversation turns.
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.load_data()
def load_data(self) -> None:
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")
with open(self.dataset_path, encoding="utf-8") as f:
self.data = json.load(f)
# Filter entries with at least two conversation turns.
self.data = [
entry for entry in self.data
if "conversations" in entry and len(entry["conversations"]) >= 2
]
random.seed(self.random_seed)
random.shuffle(self.data)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
lora_path: Optional[str] = None,
max_loras: Optional[int] = None,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
samples: list = []
for entry in self.data:
if len(samples) >= num_requests:
break
prompt, completion = (
entry["conversations"][0]["value"],
entry["conversations"][1]["value"],
)
lora_request, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
new_output_len = (len(completion_ids)
if output_len is None else output_len)
if not is_valid_sequence(prompt_len,
new_output_len,
skip_min_output_len_check=output_len
is not None):
continue
if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation(
prompt, None)
samples.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=new_output_len,
lora_request=lora_request,
))
self.maybe_oversample_requests(samples, num_requests)
return samples
# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------
class SonnetDataset(BenchmarkDataset):
"""
Simplified implementation of the Sonnet dataset. Loads poem lines from a
text file and generates sample requests. Default values here copied from
`benchmark_serving.py` for the sonnet dataset.
"""
DEFAULT_PREFIX_LEN = 200
DEFAULT_INPUT_LEN = 550
DEFAULT_OUTPUT_LEN = 150
def __init__(
self,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.load_data()
def load_data(self) -> None:
if not self.dataset_path:
raise ValueError("dataset_path must be provided.")
with open(self.dataset_path, encoding="utf-8") as f:
self.data = f.readlines()
def sample(
self,
tokenizer,
num_requests: int,
prefix_len: int = DEFAULT_PREFIX_LEN,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
return_prompt_formatted: bool = False,
**kwargs,
) -> list:
# Calculate average token length for a poem line.
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
avg_len = sum(len(tokens)
for tokens in tokenized_lines) / len(tokenized_lines)
# Build the base prompt.
base_prompt = "Pick as many lines as you can from these poem lines:\n"
base_msg = [{"role": "user", "content": base_prompt}]
base_fmt = tokenizer.apply_chat_template(base_msg,
add_generation_prompt=True,
tokenize=False)
base_offset = len(tokenizer(base_fmt).input_ids)
if input_len <= base_offset:
raise ValueError(
f"'input_len' must be higher than the base prompt length "
f"({base_offset}).")
# Determine how many poem lines to use.
num_input_lines = round((input_len - base_offset) / avg_len)
num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0)
prefix_lines = self.data[:num_prefix_lines]
samples = []
while len(samples) < num_requests:
extra_lines = random.choices(self.data,
k=num_input_lines - num_prefix_lines)
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
msg = [{"role": "user", "content": prompt}]
prompt_formatted = tokenizer.apply_chat_template(
msg, add_generation_prompt=True, tokenize=False)
prompt_len = len(tokenizer(prompt_formatted).input_ids)
if prompt_len <= input_len:
samples.append(
SampleRequest(
prompt=prompt_formatted
if return_prompt_formatted else prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
return samples
# -----------------------------------------------------------------------------
# BurstGPT Dataset Implementation
# -----------------------------------------------------------------------------
class BurstGPTDataset(BenchmarkDataset):
"""
Implements the BurstGPT dataset. Loads data from a CSV file and generates
sample requests based on synthetic prompt generation. Only rows with Model
"GPT-4" and positive response tokens are used.
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.load_data()
def load_data(self, ):
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")
try:
import pandas as pd
except ImportError as e:
raise ImportError(
"Pandas is required for BurstGPTDataset. Please install it "
"using `pip install pandas`.") from e
df = pd.read_csv(self.dataset_path)
# Filter to keep only GPT-4 rows.
gpt4_df = df[df["Model"] == "GPT-4"]
# Remove failed requests (where Response tokens is 0 or less).
gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
# Sample the desired number of rows.
self.data = gpt4_df
def _sample_loaded_data(self, num_requests: int) -> list:
if num_requests <= len(self.data):
data = self.data.sample(n=num_requests,
random_state=self.random_seed)
else:
data = self.data.sample(
n=num_requests,
random_state=self.random_seed,
replace=True,
)
# Convert the dataframe to a list of lists.
return data.values.tolist()
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
max_loras: Optional[int] = None,
lora_path: Optional[str] = None,
**kwargs,
) -> list[SampleRequest]:
samples = []
data = self._sample_loaded_data(num_requests=num_requests)
for i in range(num_requests):
input_len = int(data[i][2])
output_len = int(data[i][3])
lora_req, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
vocab_size = tokenizer.vocab_size
# Generate a synthetic prompt: a list of token IDs computed as (i +
# j) modulo vocab_size.
token_ids = [(i + j) % vocab_size for j in range(input_len)]
prompt = tokenizer.decode(token_ids)
samples.append(
SampleRequest(
prompt=prompt,
prompt_len=input_len,
expected_output_len=output_len,
lora_request=lora_req,
))
return samples
# -----------------------------------------------------------------------------
# HuggingFace Dataset Base Implementation
# -----------------------------------------------------------------------------
class HuggingFaceDataset(BenchmarkDataset):
"""Base class for datasets hosted on HuggingFace."""
SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set()
def __init__(
self,
dataset_path: str,
dataset_split: str,
dataset_subset: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(dataset_path=dataset_path, **kwargs)
self.dataset_split = dataset_split
self.dataset_subset = dataset_subset
self.load_data()
def load_data(self) -> None:
"""Load data from HuggingFace datasets."""
try:
from datasets import load_dataset
except ImportError as e:
raise ImportError(
"Hugging Face datasets library is required for this dataset. "
"Please install it using `pip install datasets`.") from e
self.data = load_dataset(
self.dataset_path,
name=self.dataset_subset,
split=self.dataset_split,
streaming=True,
)
self.data = self.data.shuffle(seed=self.random_seed)
# -----------------------------------------------------------------------------
# Conversation Dataset Implementation
# -----------------------------------------------------------------------------
class ConversationDataset(HuggingFaceDataset):
"""Dataset for conversation data with multimodal support."""
SUPPORTED_DATASET_PATHS = {
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
}
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
# Filter examples with at least 2 conversations
filtered_data = self.data.filter(
lambda x: len(x["conversations"]) >= 2)
sampled_requests = []
dynamic_output = output_len is None
for item in filtered_data:
if len(sampled_requests) >= num_requests:
break
conv = item["conversations"]
prompt, completion = conv[0]["value"], conv[1]["value"]
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0
if dynamic_output and not is_valid_sequence(
prompt_len, completion_len):
continue
mm_content = process_image(
item["image"]) if "image" in item else None
if enable_multimodal_chat:
# Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the
# actual prompt len and output len
prompt = self.apply_multimodal_chat_transformation(
prompt, mm_content)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# -----------------------------------------------------------------------------
# Vision Arena Dataset Implementation
# -----------------------------------------------------------------------------
class VisionArenaDataset(HuggingFaceDataset):
"""
Vision Arena Dataset.
"""
DEFAULT_OUTPUT_LEN = 128
SUPPORTED_DATASET_PATHS = {
"lmarena-ai/VisionArena-Chat":
lambda x: x["conversation"][0][0]["content"],
"lmarena-ai/vision-arena-bench-v0.1":
lambda x: x["turns"][0][0]["content"]
}
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
if parser_fn is None:
raise ValueError(
f"Unsupported dataset path: {self.dataset_path}")
prompt = parser_fn(item)
mm_content = process_image(item["images"][0])
prompt_len = len(tokenizer(prompt).input_ids)
if enable_multimodal_chat:
# Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the
# actual prompt len
prompt = self.apply_multimodal_chat_transformation(
prompt, mm_content)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# -----------------------------------------------------------------------------
# Instruct Coder Dataset Implementation
# -----------------------------------------------------------------------------
class InstructCoderDataset(HuggingFaceDataset):
"""
InstructCoder Dataset.
https://huggingface.co/datasets/likaixin/InstructCoder
InstructCoder is the dataset designed for general code editing. It consists
of 114,239 instruction-input-output triplets, and covers multiple distinct
code editing scenario.
"""
DEFAULT_OUTPUT_LEN = 200 # this is the average default output length
SUPPORTED_DATASET_PATHS = {
"likaixin/InstructCoder",
}
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt = f"{item['instruction']}:\n{item['input']}"
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# -----------------------------------------------------------------------------
# AIMO Dataset Implementation
# -----------------------------------------------------------------------------
class AIMODataset(HuggingFaceDataset):
"""
Dataset class for processing a AIMO dataset with reasoning questions.
"""
SUPPORTED_DATASET_PATHS = {
"AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5",
"AI-MO/NuminaMath-CoT"
}
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
**kwargs) -> list:
sampled_requests = []
dynamic_output = output_len is None
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt, completion = item['problem'], item["solution"]
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0
if dynamic_output and not is_valid_sequence(prompt_len,
completion_len,
max_prompt_len=2048,
max_total_len=32000):
continue
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=None,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# SPDX-License-Identifier: Apache-2.0
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import dataclasses
import json
import os
import time
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format,
write_to_json)
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.sampling_params import BeamSearchParams
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: dict[str, Any]) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={"latency": results["latencies"]},
extra_info={k: results[k]
for k in ["avg_latency", "percentiles"]})
if pt_records:
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--input-len", type=int, default=32)
parser.add_argument("--output-len", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument(
"--n",
type=int,
default=1,
help="Number of generated sequences per prompt.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
"--num-iters-warmup",
type=int,
default=10,
help="Number of iterations to run for warmup.",
)
parser.add_argument("--num-iters",
type=int,
default=30,
help="Number of iterations to run.")
parser.add_argument(
"--profile",
action="store_true",
help="profile the generation process of a single batch",
)
parser.add_argument(
"--profile-result-dir",
type=str,
default=None,
help=("path to save the pytorch profiler output. Can be visualized "
"with ui.perfetto.dev or Tensorboard."),
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the latency results in JSON format.",
)
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=("Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"),
)
parser = EngineArgs.add_cli_args(parser)
def main(args: argparse.Namespace):
print(args)
engine_args = EngineArgs.from_cli_args(args)
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args))
assert llm.llm_engine.model_config.max_model_len >= (
args.input_len +
args.output_len), ("Please ensure that max_model_len is greater than"
" the sum of input_len and output_len.")
sampling_params = SamplingParams(
n=args.n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=args.output_len,
detokenize=not args.disable_detokenize,
)
print(sampling_params)
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompts: list[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
def llm_generate():
if not args.use_beam_search:
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
else:
llm.beam_search(
dummy_prompts,
BeamSearchParams(
beam_width=args.n,
max_tokens=args.output_len,
ignore_eos=True,
),
)
def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir)),
) as p:
llm_generate()
print(p.key_averages().table(sort_by="self_cuda_time_total"))
else:
start_time = time.perf_counter()
llm_generate()
end_time = time.perf_counter()
latency = end_time - start_time
return latency
print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion(profile_dir=None)
if args.profile:
profile_dir = args.profile_result_dir
if not profile_dir:
profile_dir = (Path(".") / "vllm_benchmark_result" /
f"latency_result_{time.time()}")
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return
# Benchmark.
latencies = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile_dir=None))
latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90, 99]
percentiles = np.percentile(latencies, percentages)
print(f"Avg latency: {np.mean(latencies)} seconds")
for percentage, percentile in zip(percentages, percentiles):
print(f"{percentage}% percentile latency: {percentile} seconds")
# Output JSON results if specified
if args.output_json:
results = {
"avg_latency": np.mean(latencies),
"latencies": latencies.tolist(),
"percentiles": dict(zip(percentages, percentiles.tolist())),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)
# SPDX-License-Identifier: Apache-2.0
"""Benchmark offline inference throughput."""
import argparse
import dataclasses
import json
import os
import random
import time
import warnings
from typing import Any, Optional, Union
import torch
import uvloop
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset,
ConversationDataset,
InstructCoderDataset, RandomDataset,
SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format,
write_to_json)
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import merge_async_iterators
def run_vllm(
requests: list[SampleRequest],
n: int,
engine_args: EngineArgs,
disable_detokenize: bool = False,
) -> tuple[float, Optional[list[RequestOutput]]]:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len >= (
request.prompt_len + request.expected_output_len)
for request in requests), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.")
# Add the requests to the engine.
prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = []
for request in requests:
prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
))
lora_requests: Optional[list[LoRARequest]] = None
if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests]
use_beam_search = False
outputs = None
if not use_beam_search:
start = time.perf_counter()
outputs = llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for request in requests:
assert request.expected_output_len == output_len
start = time.perf_counter()
llm.beam_search(
prompts,
BeamSearchParams(
beam_width=n,
max_tokens=output_len,
ignore_eos=True,
))
end = time.perf_counter()
return end - start, outputs
def run_vllm_chat(
requests: list[SampleRequest],
n: int,
engine_args: EngineArgs,
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
"""
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
multimodal models as it properly handles multimodal inputs and chat
formatting. For non-multimodal models, use run_vllm() instead.
"""
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len >= (
request.prompt_len + request.expected_output_len)
for request in requests), (
"Please ensure that max_model_len is greater than the sum of "
"prompt_len and expected_output_len for all requests.")
prompts = []
sampling_params: list[SamplingParams] = []
for request in requests:
prompts.append(request.prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
))
start = time.perf_counter()
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
return end - start, outputs
async def run_vllm_async(
requests: list[SampleRequest],
n: int,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
disable_detokenize: bool = False,
) -> float:
from vllm import SamplingParams
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
assert all(
llm.model_config.max_model_len >= (request.prompt_len +
request.expected_output_len)
for request in requests), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.")
# Add the requests to the engine.
prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = []
lora_requests: list[Optional[LoRARequest]] = []
for request in requests:
prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
))
lora_requests.append(request.lora_request)
generators = []
start = time.perf_counter()
for i, (prompt, sp,
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
generator = llm.generate(prompt,
sp,
lora_request=lr,
request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
end = time.perf_counter()
return end - start
def run_hf(
requests: list[SampleRequest],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
max_batch_size: int,
trust_remote_code: bool,
disable_detokenize: bool = False,
) -> float:
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda()
pbar = tqdm(total=len(requests))
start = time.perf_counter()
batch: list[str] = []
max_prompt_len = 0
max_output_len = 0
for i in range(len(requests)):
prompt = requests[i].prompt
prompt_len = requests[i].prompt_len
output_len = requests[i].expected_output_len
# Add the prompt to the batch.
batch.append(prompt)
max_prompt_len = max(max_prompt_len, prompt_len)
max_output_len = max(max_output_len, output_len)
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
next_prompt_len = requests[i + 1].prompt_len
next_output_len = requests[i + 1].expected_output_len
if (max(max_prompt_len, next_prompt_len) +
max(max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=True,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_output_len,
)
if not disable_detokenize:
# Include the decoding time.
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
pbar.update(len(batch))
# Clear the batch.
batch = []
max_prompt_len = 0
max_output_len = 0
end = time.perf_counter()
return end - start
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: dict[str, Any]) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"requests_per_second": [results["requests_per_second"]],
"tokens_per_second": [results["tokens_per_second"]],
},
extra_info={
k: results[k]
for k in ["elapsed_time", "num_requests", "total_num_tokens"]
})
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
def get_requests(args, tokenizer):
# Common parameters for all dataset types.
common_kwargs = {
"dataset_path": args.dataset_path,
"random_seed": args.seed,
}
sample_kwargs = {
"tokenizer": tokenizer,
"lora_path": args.lora_path,
"max_loras": args.max_loras,
"num_requests": args.num_prompts,
"input_len": args.input_len,
"output_len": args.output_len,
}
if args.dataset_path is None or args.dataset_name == "random":
sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len
dataset_cls = RandomDataset
elif args.dataset_name == "sharegpt":
dataset_cls = ShareGPTDataset
if args.backend == "vllm-chat":
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset.")
dataset_cls = SonnetDataset
sample_kwargs["prefix_len"] = args.prefix_len
sample_kwargs["return_prompt_formatted"] = True
elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset
elif args.dataset_name == "hf":
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = InstructCoderDataset
common_kwargs['dataset_split'] = "train"
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = ConversationDataset
common_kwargs['dataset_subset'] = args.hf_subset
common_kwargs['dataset_split'] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_cls = AIMODataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
def validate_args(args):
"""
Validate command-line arguments.
"""
# === Deprecation and Defaulting ===
if args.dataset is not None:
warnings.warn(
"The '--dataset' argument will be deprecated in the next release. "
"Please use '--dataset-name' and '--dataset-path' instead.",
stacklevel=2)
args.dataset_path = args.dataset
if not getattr(args, "tokenizer", None):
args.tokenizer = args.model
# === Backend Validation ===
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
if args.backend not in valid_backends:
raise ValueError(f"Unsupported backend: {args.backend}")
# === Dataset Configuration ===
if not args.dataset and not args.dataset_path:
print(
"When dataset path is not set, it will default to random dataset")
args.dataset_name = 'random'
if args.input_len is None:
raise ValueError("input_len must be provided for a random dataset")
# === Dataset Name Specific Checks ===
# --hf-subset and --hf-split: only used
# when dataset_name is 'hf'
if args.dataset_name != "hf" and (
getattr(args, "hf_subset", None) is not None
or getattr(args, "hf_split", None) is not None):
warnings.warn("--hf-subset and --hf-split will be ignored \
since --dataset-name is not 'hf'.",
stacklevel=2)
elif args.dataset_name == "hf":
if args.dataset_path in (
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
| ConversationDataset.SUPPORTED_DATASET_PATHS):
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS):
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
else:
raise ValueError(
f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random'
if args.dataset_name != 'random' and args.random_range_ratio is not None:
warnings.warn("--random-range-ratio will be ignored since \
--dataset-name is not 'random'.",
stacklevel=2)
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
# set.
if args.dataset_name not in {"random", "sonnet", None
} and args.prefix_len is not None:
warnings.warn("--prefix-len will be ignored since --dataset-name\
is not 'random', 'sonnet', or not set.",
stacklevel=2)
# === LoRA Settings ===
if getattr(args, "enable_lora", False) and args.backend != "vllm":
raise ValueError(
"LoRA benchmarking is only supported for vLLM backend")
if getattr(args, "enable_lora", False) and args.lora_path is None:
raise ValueError("LoRA path must be provided when enable_lora is True")
# === Backend-specific Validations ===
if args.backend == "hf" and args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend")
if args.backend != "hf" and args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
if args.backend in {"hf", "mii"} and getattr(args, "quantization",
None) is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.backend == "mii" and args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
if args.backend == "mii" and args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.backend == "mii" and args.tokenizer != args.model:
raise ValueError(
"Tokenizer must be the same as the model for MII backend.")
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii", "vllm-chat"],
default="vllm")
parser.add_argument(
"--dataset-name",
type=str,
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
help="Name of the dataset to benchmark on.",
default="sharegpt")
parser.add_argument(
"--dataset",
type=str,
default=None,
help="Path to the ShareGPT dataset, will be deprecated in\
the next release. The dataset is expected to "
"be a json in form of list[dict[..., conversations: "
"list[dict[..., value: <prompt_or_response>]]]]")
parser.add_argument("--dataset-path",
type=str,
default=None,
help="Path to the dataset")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=("Do not detokenize the response (i.e. do not include "
"detokenization time in the measurement)"))
# LoRA
parser.add_argument(
"--lora-path",
type=str,
default=None,
help="Path to the lora adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.")
parser.add_argument(
"--prefix-len",
type=int,
default=0,
help="Number of fixed prefix tokens before the random "
"context in a request (default: 0).",
)
# random dataset
parser.add_argument(
"--random-range-ratio",
type=float,
default=0.0,
help="Range ratio for sampling input/output length, "
"used only for RandomDataset. Must be in the range [0, 1) to define "
"a symmetric sampling range "
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
)
# hf dtaset
parser.add_argument("--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.")
parser.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
parser = AsyncEngineArgs.add_cli_args(parser)
def main(args: argparse.Namespace):
if args.tokenizer is None:
args.tokenizer = args.model
validate_args(args)
if args.seed is None:
args.seed = 0
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None
for request in requests)
request_outputs: Optional[list[RequestOutput]] = None
if args.backend == "vllm":
if args.async_engine:
elapsed_time = uvloop.run(
run_vllm_async(
requests,
args.n,
AsyncEngineArgs.from_cli_args(args),
args.disable_frontend_multiprocessing,
args.disable_detokenize,
))
else:
elapsed_time, request_outputs = run_vllm(
requests, args.n, EngineArgs.from_cli_args(args),
args.disable_detokenize)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.hf_max_batch_size, args.trust_remote_code,
args.disable_detokenize)
elif args.backend == "vllm-chat":
elapsed_time, request_outputs = run_vllm_chat(
requests, args.n, EngineArgs.from_cli_args(args),
args.disable_detokenize)
else:
raise ValueError(f"Unknown backend: {args.backend}")
if request_outputs:
# Note: with the vllm and vllm-chat backends,
# we have request_outputs, which we use to count tokens.
total_prompt_tokens = 0
total_output_tokens = 0
for ro in request_outputs:
if not isinstance(ro, RequestOutput):
continue
total_prompt_tokens += len(
ro.prompt_token_ids) if ro.prompt_token_ids else 0
total_output_tokens += sum(
len(o.token_ids) for o in ro.outputs if o)
total_num_tokens = total_prompt_tokens + total_output_tokens
else:
total_num_tokens = sum(r.prompt_len + r.expected_output_len
for r in requests)
total_output_tokens = sum(r.expected_output_len for r in requests)
total_prompt_tokens = total_num_tokens - total_output_tokens
if is_multi_modal and args.backend != "vllm-chat":
print("\033[91mWARNING\033[0m: Multi-modal request with "
f"{args.backend} backend detected. The "
"following metrics are not accurate because image tokens are not"
" counted. See vllm-project/vllm/issues/9778 for details.")
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
# vllm-chat backend counts the image tokens now
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
print(f"Total num prompt tokens: {total_prompt_tokens}")
print(f"Total num output tokens: {total_output_tokens}")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)
......@@ -282,13 +282,21 @@ def get_vllm_version():
if __version__ == "dev":
return "N/A (dev)"
if len(__version_tuple__) == 4: # dev build
git_sha = __version_tuple__[-1][1:] # type: ignore
return f"{__version__} (git sha: {git_sha}"
version_str = __version_tuple__[-1]
if isinstance(version_str, str) and version_str.startswith('g'):
# it's a dev build
if '.' in version_str:
# it's a dev build containing local changes
git_sha = version_str.split('.')[0][1:]
date = version_str.split('.')[-1][1:]
return f"{__version__} (git sha: {git_sha}, date: {date})"
else:
# it's a dev build without local changes
git_sha = version_str[1:] # type: ignore
return f"{__version__} (git sha: {git_sha})"
return __version__
def summarize_vllm_build_flags():
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format(
......@@ -502,7 +510,9 @@ def get_pip_packages(run_lambda, patterns=None):
print("uv is set")
cmd = ["uv", "pip", "list", "--format=freeze"]
else:
raise RuntimeError("Could not collect pip list output (pip or uv module not available)")
raise RuntimeError(
"Could not collect pip list output (pip or uv module not available)"
)
out = run_and_read_all(run_lambda, cmd)
return "\n".join(line for line in out.splitlines()
......@@ -535,13 +545,12 @@ def is_xnnpack_available():
else:
return "N/A"
def get_env_vars():
env_vars = ''
secret_terms=('secret', 'token', 'api', 'access', 'password')
report_prefix = ("TORCH", "NCCL", "PYTORCH",
"CUDA", "CUBLAS", "CUDNN",
"OMP_", "MKL_",
"NVIDIA")
secret_terms = ('secret', 'token', 'api', 'access', 'password')
report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN",
"OMP_", "MKL_", "NVIDIA")
for k, v in os.environ.items():
if any(term in k.lower() for term in secret_terms):
continue
......@@ -552,6 +561,7 @@ def get_env_vars():
return env_vars
def get_env_info():
run_lambda = run
pip_version, pip_list_output = get_pip_packages(run_lambda)
......
......@@ -110,10 +110,14 @@ class CompilerManager:
compiled_graph = self.load(graph, example_inputs, graph_index,
runtime_shape)
if compiled_graph is not None:
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Directly load the compiled graph for shape %s "
"from the cache", str(runtime_shape)) # noqa
if graph_index == num_graphs - 1:
# after loading the last graph for this shape, record the time.
# there can be multiple graphs due to piecewise compilation.
now = time.time()
elapsed = now - compilation_start_time
logger.info(
"Directly load the compiled graph(s) for shape %s "
"from the cache, took %.3f s", str(runtime_shape), elapsed)
return compiled_graph
# no compiler cached the graph, or the cache is disabled,
......@@ -335,7 +339,7 @@ class VllmBackend:
def configure_post_pass(self):
config = self.compilation_config
self.post_grad_pass_manager.configure(config.pass_config)
self.post_grad_pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.
......
......@@ -11,9 +11,12 @@ import torch
import torch._inductor.compile_fx
import torch.fx as fx
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.utils import is_torch_equal_or_newer
from .inductor_pass import pass_context
class CompilerInterface:
"""
......@@ -167,8 +170,7 @@ class InductorAdaptor(CompilerInterface):
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
) -> Tuple[Optional[Callable], Optional[Any]]:
from torch._inductor import config
current_config = config.get_config_copy()
current_config = {}
from torch._inductor.compile_fx import compile_fx
# disable remote cache
......@@ -196,7 +198,6 @@ class InductorAdaptor(CompilerInterface):
hash_str, file_path = None, None
from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash)
if torch.__version__.startswith("2.5"):
original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
......@@ -281,6 +282,16 @@ class InductorAdaptor(CompilerInterface):
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env))
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache)
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
_get_shape_env))
# for forcing the graph to be cached
stack.enter_context(
patch(
......@@ -290,16 +301,34 @@ class InductorAdaptor(CompilerInterface):
# Dynamo metrics context, see method for more details.
stack.enter_context(self.metrics_context())
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config)
assert hash_str is not None, (
"failed to get the hash of the compiled graph")
assert file_path is not None, (
"failed to get the file path of the compiled graph")
# Disable remote caching. When these are on, on remote cache-hit,
# the monkey-patched functions never actually get called.
# vLLM today assumes and requires the monkey-patched functions to
# get hit.
# TODO(zou3519): we're going to replace this all with
# standalone_compile sometime.
if is_torch_equal_or_newer("2.6"):
stack.enter_context(
torch._inductor.config.patch(fx_graph_remote_cache=False))
stack.enter_context(
torch._functorch.config.patch(
enable_remote_autograd_cache=False))
with pass_context(runtime_shape):
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config)
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
# compilation cache. So turn off the checks if we disable the
# compilation cache.
if not envs.VLLM_DISABLE_COMPILE_CACHE:
assert hash_str is not None, (
"failed to get the hash of the compiled graph")
assert file_path is not None, (
"failed to get the file path of the compiled graph")
return compiled_graph, (hash_str, file_path)
def load(self,
......@@ -313,11 +342,19 @@ class InductorAdaptor(CompilerInterface):
assert isinstance(handle[1], str)
hash_str = handle[0]
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache)
from torch._inductor.codecache import FxGraphCache
with ExitStack() as exit_stack:
exit_stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
exit_stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
# Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context())
......
......@@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
from vllm.config import CompilationConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -531,7 +531,7 @@ class FusionPass(VllmInductorPass):
_instance: 'Optional[FusionPass]' = None
@classmethod
def instance(cls, config: CompilationConfig.PassConfig):
def instance(cls, config: VllmConfig):
"""
Get the singleton instance of the FusionPass.
If the instance exists, the config is updated but
......@@ -540,10 +540,10 @@ class FusionPass(VllmInductorPass):
if cls._instance is None:
cls._instance = FusionPass(config)
else:
cls._instance.config = config
cls._instance.pass_config = config.compilation_config.pass_config
return cls._instance
def __init__(self, config: CompilationConfig.PassConfig):
def __init__(self, config: VllmConfig):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"
super().__init__(config)
......
......@@ -12,6 +12,22 @@ def is_func(node: fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target
# Returns the first specified node with the given op (if it exists)
def find_specified_fn_maybe(nodes: Iterable[fx.Node],
op: OpOverload) -> Optional[fx.Node]:
for node in nodes:
if node.target == op:
return node
return None
# Returns the first specified node with the given op
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
node = find_specified_fn_maybe(nodes, op)
assert node is not None, f"Could not find {op} in nodes {nodes}"
return node
# Returns the first auto_functionalized node with the given op (if it exists)
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
op: OpOverload) -> Optional[fx.Node]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment