Unverified Commit 396a6924 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Cleanup attention backend: flashinfer and triton (#611)

parent af4e7910
"""Radix attention.""" """Radix attention."""
import numpy as np
import torch import torch
from flashinfer.cascade import merge_state from flashinfer.cascade import merge_state
from torch import nn from torch import nn
...@@ -51,13 +50,13 @@ class RadixAttention(nn.Module): ...@@ -51,13 +50,13 @@ class RadixAttention(nn.Module):
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
input_metadata.req_to_token_pool.req_to_token, input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices, input_metadata.req_pool_indices,
input_metadata.start_loc, input_metadata.triton_start_loc,
input_metadata.seq_lens, input_metadata.seq_lens,
input_metadata.prefix_lens, input_metadata.triton_prefix_lens,
input_metadata.extend_start_loc, input_metadata.extend_start_loc,
input_metadata.extend_seq_lens, input_metadata.extend_seq_lens,
input_metadata.max_seq_len, input_metadata.triton_max_seq_len,
input_metadata.max_extend_len, input_metadata.triton_max_extend_len,
sm_scale=self.scaling, sm_scale=self.scaling,
logit_cap=self.logit_cap, logit_cap=self.logit_cap,
) )
...@@ -75,9 +74,9 @@ class RadixAttention(nn.Module): ...@@ -75,9 +74,9 @@ class RadixAttention(nn.Module):
o.view(-1, self.tp_q_head_num, self.head_dim), o.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.req_to_token_pool.req_to_token, input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices, input_metadata.req_pool_indices,
input_metadata.start_loc, input_metadata.triton_start_loc,
input_metadata.seq_lens, input_metadata.seq_lens,
input_metadata.max_seq_len, input_metadata.triton_max_seq_len,
input_metadata.total_num_tokens, input_metadata.total_num_tokens,
sm_scale=self.scaling, sm_scale=self.scaling,
logit_cap=self.logit_cap, logit_cap=self.logit_cap,
...@@ -95,7 +94,7 @@ class RadixAttention(nn.Module): ...@@ -95,7 +94,7 @@ class RadixAttention(nn.Module):
logits_soft_cap=self.logit_cap, logits_soft_cap=self.logit_cap,
) )
if input_metadata.no_prefix: if input_metadata.extend_no_prefix:
o = o1 o = o1
else: else:
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse( o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
......
...@@ -312,7 +312,7 @@ def token_attention_fwd( ...@@ -312,7 +312,7 @@ def token_attention_fwd(
b_seq_len, b_seq_len,
max_len_in_batch, max_len_in_batch,
total_num_tokens, total_num_tokens,
sm_scale=None, sm_scale,
logit_cap=-1, logit_cap=-1,
att_m=None, att_m=None,
): ):
...@@ -320,7 +320,6 @@ def token_attention_fwd( ...@@ -320,7 +320,6 @@ def token_attention_fwd(
att_m = torch.empty( att_m = torch.empty(
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda" (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
) )
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
_token_att_m_fwd( _token_att_m_fwd(
q, q,
......
...@@ -75,6 +75,7 @@ class Req: ...@@ -75,6 +75,7 @@ class Req:
"""Store all inforamtion of a request.""" """Store all inforamtion of a request."""
def __init__(self, rid, origin_input_text, origin_input_ids): def __init__(self, rid, origin_input_text, origin_input_ids):
# Input and output info
self.rid = rid self.rid = rid
self.origin_input_text = origin_input_text self.origin_input_text = origin_input_text
self.origin_input_ids_unpadded = origin_input_ids # Before image padding self.origin_input_ids_unpadded = origin_input_ids # Before image padding
...@@ -97,6 +98,11 @@ class Req: ...@@ -97,6 +98,11 @@ class Req:
self.image_offset = 0 self.image_offset = 0
self.pad_value = None self.pad_value = None
# Prefix info
self.extend_input_len = 0
self.prefix_indices = []
self.last_node = None
# Sampling parameters # Sampling parameters
self.sampling_params = None self.sampling_params = None
self.stream = False self.stream = False
...@@ -105,11 +111,6 @@ class Req: ...@@ -105,11 +111,6 @@ class Req:
self.tokenizer = None self.tokenizer = None
self.finished_reason = None self.finished_reason = None
# Prefix info
self.extend_input_len = 0
self.prefix_indices = []
self.last_node = None
# Logprobs # Logprobs
self.return_logprob = False self.return_logprob = False
self.logprob_start_len = 0 self.logprob_start_len = 0
...@@ -261,35 +262,36 @@ class Req: ...@@ -261,35 +262,36 @@ class Req:
class Batch: class Batch:
"""Store all inforamtion of a batch.""" """Store all inforamtion of a batch."""
# Request, memory pool, and cache
reqs: List[Req] reqs: List[Req]
req_to_token_pool: ReqToTokenPool req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool token_to_kv_pool: TokenToKVPool
tree_cache: RadixCache tree_cache: RadixCache
# batched arguments to model runner # Batched arguments to model runner
input_ids: torch.Tensor = None input_ids: torch.Tensor = None
req_pool_indices: torch.Tensor = None req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None seq_lens: torch.Tensor = None
prefix_lens: torch.Tensor = None prefix_lens: torch.Tensor = None
position_ids_offsets: torch.Tensor = None position_ids_offsets: torch.Tensor = None
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None out_cache_cont_start: int = None
out_cache_cont_end: torch.Tensor = None out_cache_cont_end: int = None
# for processing logprobs # For processing logprobs
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None top_logprobs_nums: List[int] = None
# for multimodal # For multimodal
pixel_values: List[torch.Tensor] = None pixel_values: List[torch.Tensor] = None
image_sizes: List[List[int]] = None image_sizes: List[List[int]] = None
image_offsets: List[int] = None image_offsets: List[int] = None
# other arguments for control # Other arguments for control
output_ids: torch.Tensor = None output_ids: torch.Tensor = None
extend_num_tokens: int = None extend_num_tokens: int = None
# batched sampling params # Batched sampling params
temperatures: torch.Tensor = None temperatures: torch.Tensor = None
top_ps: torch.Tensor = None top_ps: torch.Tensor = None
top_ks: torch.Tensor = None top_ks: torch.Tensor = None
...@@ -312,8 +314,8 @@ class Batch: ...@@ -312,8 +314,8 @@ class Batch:
def is_empty(self): def is_empty(self):
return len(self.reqs) == 0 return len(self.reqs) == 0
# whether batch has at least 1 streaming request
def has_stream(self) -> bool: def has_stream(self) -> bool:
# Return whether batch has at least 1 streaming request
return any(r.stream for r in self.reqs) return any(r.stream for r in self.reqs)
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
...@@ -347,7 +349,7 @@ class Batch: ...@@ -347,7 +349,7 @@ class Batch:
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device) position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
# Alloc mem # Allocate memory
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens) seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
extend_num_tokens = seq_lens.sum() - prefix_lens.sum() extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
...@@ -703,7 +705,6 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor ...@@ -703,7 +705,6 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
return probs_sort, probs_idx return probs_sort, probs_idx
@dataclass @dataclass
class InputMetadata: class InputMetadata:
"""Store all inforamtion of a forward pass.""" """Store all inforamtion of a forward pass."""
...@@ -711,110 +712,37 @@ class InputMetadata: ...@@ -711,110 +712,37 @@ class InputMetadata:
forward_mode: ForwardMode forward_mode: ForwardMode
batch_size: int batch_size: int
total_num_tokens: int total_num_tokens: int
max_seq_len: int
req_pool_indices: torch.Tensor req_pool_indices: torch.Tensor
start_loc: torch.Tensor
seq_lens: torch.Tensor seq_lens: torch.Tensor
prefix_lens: torch.Tensor
positions: torch.Tensor positions: torch.Tensor
req_to_token_pool: ReqToTokenPool req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool token_to_kv_pool: TokenToKVPool
# for extend # For extend
extend_seq_lens: torch.Tensor = None extend_seq_lens: torch.Tensor
extend_start_loc: torch.Tensor = None extend_start_loc: torch.Tensor
max_extend_len: int = 0 extend_no_prefix: bool
# Output location of the KV cache
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None out_cache_cont_start: int = None
out_cache_cont_end: torch.Tensor = None out_cache_cont_end: int = None
# Output options
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None top_logprobs_nums: List[int] = None
# for flashinfer # Trition attention backend
qo_indptr: torch.Tensor = None triton_max_seq_len: int = 0
kv_indptr: torch.Tensor = None triton_max_extend_len: int = 0
kv_indices: torch.Tensor = None triton_start_loc: torch.Tensor = None
kv_last_page_len: torch.Tensor = None triton_prefix_lens: torch.Tensor = None
# FlashInfer attention backend
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
if self.forward_mode == ForwardMode.DECODE:
paged_kernel_lens = self.seq_lens
else:
paged_kernel_lens = self.prefix_lens
self.no_prefix = torch.all(self.prefix_lens == 0)
kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
kv_indices = torch.cat(
[
self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
]
for i in range(self.batch_size)
],
dim=0,
).contiguous()
kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
if self.forward_mode == ForwardMode.DECODE:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype,
)
else:
# extend part
qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.flashinfer_prefill_wrapper_ragged.end_forward()
self.flashinfer_prefill_wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
)
# cached part
self.flashinfer_prefill_wrapper_paged.end_forward()
self.flashinfer_prefill_wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.max_extend_len = int(torch.max(self.extend_seq_lens))
@classmethod @classmethod
def create( def create(
cls, cls,
...@@ -830,14 +758,20 @@ class InputMetadata: ...@@ -830,14 +758,20 @@ class InputMetadata:
top_logprobs_nums=None, top_logprobs_nums=None,
return_logprob=False, return_logprob=False,
): ):
if not model_runner.server_args.disable_flashinfer:
init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens)
batch_size = len(req_pool_indices) batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
total_num_tokens = int(torch.sum(seq_lens))
max_seq_len = int(torch.max(seq_lens))
if forward_mode == ForwardMode.DECODE: if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
extend_seq_lens = extend_start_loc = extend_no_prefix = None
if not model_runner.server_args.disable_flashinfer:
# This variable is not needed in this case,
# we do not compute it to make it compatbile with cuda graph.
total_num_tokens = None
else:
total_num_tokens = int(torch.sum(seq_lens))
else: else:
seq_lens_cpu = seq_lens.cpu().numpy() seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_cpu = prefix_lens.cpu().numpy() prefix_lens_cpu = prefix_lens.cpu().numpy()
...@@ -855,22 +789,27 @@ class InputMetadata: ...@@ -855,22 +789,27 @@ class InputMetadata:
), ),
device="cuda", device="cuda",
) )
extend_seq_lens = seq_lens - prefix_lens
extend_start_loc = torch.zeros_like(seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
extend_no_prefix = torch.all(prefix_lens == 0)
total_num_tokens = int(torch.sum(seq_lens))
ret = cls( ret = cls(
forward_mode=forward_mode, forward_mode=forward_mode,
batch_size=batch_size, batch_size=batch_size,
total_num_tokens=total_num_tokens, total_num_tokens=total_num_tokens,
max_seq_len=max_seq_len,
req_pool_indices=req_pool_indices, req_pool_indices=req_pool_indices,
start_loc=start_loc,
seq_lens=seq_lens, seq_lens=seq_lens,
prefix_lens=prefix_lens,
positions=positions, positions=positions,
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool, token_to_kv_pool=model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start, out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end, out_cache_cont_end=out_cache_cont_end,
extend_seq_lens=extend_seq_lens,
extend_start_loc=extend_start_loc,
extend_no_prefix=extend_no_prefix,
return_logprob=return_logprob, return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums, top_logprobs_nums=top_logprobs_nums,
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged, flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
...@@ -878,14 +817,96 @@ class InputMetadata: ...@@ -878,14 +817,96 @@ class InputMetadata:
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper, flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
) )
if forward_mode == ForwardMode.EXTEND: if model_runner.server_args.disable_flashinfer:
ret.init_extend_args() (ret.triton_max_seq_len,
ret.triton_max_extend_len,
if not global_server_args_dict.get("disable_flashinfer", False): ret.triton_start_loc,
ret.init_flashinfer_args( ret.triton_prefix_lens) = init_triton_args(forward_mode, seq_lens, prefix_lens)
model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
model_runner.model_config.head_dim,
)
return ret return ret
def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens):
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
head_dim = model_runner.model_config.head_dim
batch_size = len(req_pool_indices)
if forward_mode == ForwardMode.DECODE:
paged_kernel_lens = seq_lens
else:
paged_kernel_lens = prefix_lens
kv_indptr = torch.zeros(
(batch_size + 1,), dtype=torch.int32, device="cuda"
)
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
kv_indices = torch.cat(
[
model_runner.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
]
for i in range(batch_size)
],
dim=0,
).contiguous()
kv_last_page_len = torch.ones(
(batch_size,), dtype=torch.int32, device="cuda"
)
if forward_mode == ForwardMode.DECODE:
model_runner.flashinfer_decode_wrapper.end_forward()
model_runner.flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
else:
# extend part
qo_indptr = torch.zeros(
(batch_size + 1,), dtype=torch.int32, device="cuda"
)
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
)
# cached part
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
def init_triton_args(forward_mode, seq_lens, prefix_lens):
batch_size = len(seq_lens)
max_seq_len = int(torch.max(seq_lens))
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
if forward_mode == ForwardMode.DECODE:
max_extend_len = None
else:
extend_seq_lens = seq_lens - prefix_lens
max_extend_len = int(torch.max(extend_seq_lens))
return max_seq_len, max_extend_len, start_loc, prefix_lens
...@@ -182,39 +182,39 @@ class ModelRunner: ...@@ -182,39 +182,39 @@ class ModelRunner:
return c return c
def init_flash_infer(self): def init_flash_infer(self):
if not global_server_args_dict.get("disable_flashinfer", False): if self.server_args.disable_flashinfer:
from flashinfer import ( self.flashinfer_prefill_wrapper_ragged = None
BatchDecodeWithPagedKVCacheWrapper, self.flashinfer_prefill_wrapper_paged = None
BatchPrefillWithPagedKVCacheWrapper, self.flashinfer_decode_wrapper = None
BatchPrefillWithRaggedKVCacheWrapper, return
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels( from flashinfer import (
self.model_config.num_attention_heads // self.tp_size, BatchDecodeWithPagedKVCacheWrapper,
self.model_config.get_num_kv_heads(self.tp_size), BatchPrefillWithPagedKVCacheWrapper,
): BatchPrefillWithRaggedKVCacheWrapper,
use_tensor_cores = True )
else: from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
use_tensor_cores = False
workspace_buffers = torch.empty( if not _grouped_size_compiled_for_decode_kernels(
2, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda" self.model_config.num_attention_heads // self.tp_size,
) self.model_config.get_num_kv_heads(self.tp_size),
self.flashinfer_prefill_wrapper_ragged = ( ):
BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD") use_tensor_cores = True
)
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffers[1], "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
)
else: else:
self.flashinfer_prefill_wrapper_ragged = ( use_tensor_cores = False
self.flashinfer_prefill_wrapper_paged
) = None workspace_buffers = torch.empty(
self.flashinfer_decode_wrapper = None 3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
)
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffers[0], "NHD"
)
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffers[1], "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
)
@torch.inference_mode() @torch.inference_mode()
def forward_extend(self, batch: Batch): def forward_extend(self, batch: Batch):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment