Unverified Commit b7e2f800 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Update flashinfer to 0.0.5 (#554)

parent 09593e9b
......@@ -12,7 +12,8 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class RadixAttention(nn.Module):
def __init__(
self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1
self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int,
layer_id: int, logit_cap: int = -1
):
super().__init__()
self.tp_q_head_num = num_heads
......@@ -20,7 +21,6 @@ class RadixAttention(nn.Module):
self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim
self.layer_id = layer_id
self.logit_cap = logit_cap
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
......@@ -30,10 +30,17 @@ class RadixAttention(nn.Module):
self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
# flashinfer only accepts a boolean logit_cap argument
if logit_cap > 0:
assert logit_cap == 30
self.logit_cap = True
else:
self.logit_cap = False
else:
self.prefill_forward = self.prefill_forward_triton
self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton
self.logit_cap = logit_cap
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
......@@ -100,9 +107,10 @@ class RadixAttention(nn.Module):
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata)
o = input_metadata.prefill_wrapper.forward(
o = input_metadata.flashinfer_prefill_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
logits_cap=self.logit_cap,
)
return o.view(-1, self.tp_q_head_num * self.head_dim)
......@@ -110,9 +118,10 @@ class RadixAttention(nn.Module):
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata)
o = input_metadata.decode_wrapper.forward(
o = input_metadata.flashinfer_decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
logits_cap=self.logit_cap,
)
return o.view(-1, self.tp_q_head_num * self.head_dim)
......
......@@ -6,7 +6,7 @@ import logging
import pkgutil
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional, Type
from typing import List, Optional, Type, Any
import numpy as np
import torch
......@@ -34,7 +34,6 @@ global_server_args_dict = {}
@dataclass
class InputMetadata:
model_runner: "ModelRunner"
forward_mode: ForwardMode
batch_size: int
total_num_tokens: int
......@@ -65,15 +64,10 @@ class InputMetadata:
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
kv_last_page_len: torch.Tensor = None
prefill_wrapper = None
decode_wrapper = None
def init_flashinfer_args(self, tp_size):
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
flashinfer_prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
def init_flashinfer_args(self, num_attention_heads, num_key_value_heads, head_dim):
self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
......@@ -93,9 +87,6 @@ class InputMetadata:
dim=0,
).contiguous()
workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
)
if (
self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND
......@@ -104,34 +95,30 @@ class InputMetadata:
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
args = [
self.flashinfer_prefill_wrapper.end_forward()
self.flashinfer_prefill_wrapper.begin_forward(
self.qo_indptr,
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
self.model_runner.model_config.head_dim,
]
self.prefill_wrapper.begin_forward(*args)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
num_attention_heads,
num_key_value_heads,
head_dim,
1
)
self.decode_wrapper.begin_forward(
else:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
self.model_runner.model_config.head_dim,
num_attention_heads,
num_key_value_heads,
head_dim,
1,
"NONE",
"float16",
pos_encoding_mode="NONE",
data_type="float16",
)
def init_extend_args(self):
......@@ -155,6 +142,8 @@ class InputMetadata:
out_cache_cont_end=None,
top_logprobs_nums=None,
return_logprob=False,
flashinfer_prefill_wrapper=None,
flashinfer_decode_wrapper=None,
):
batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
......@@ -187,7 +176,6 @@ class InputMetadata:
other_kv_index = None
ret = cls(
model_runner=model_runner,
forward_mode=forward_mode,
batch_size=batch_size,
total_num_tokens=total_num_tokens,
......@@ -205,13 +193,19 @@ class InputMetadata:
other_kv_index=other_kv_index,
return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums,
flashinfer_prefill_wrapper=flashinfer_prefill_wrapper,
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
)
if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()
if global_server_args_dict.get("enable_flashinfer", False):
ret.init_flashinfer_args(tp_size)
ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size,
model_runner.model_config.num_key_value_heads // tp_size,
model_runner.model_config.head_dim
)
return ret
......@@ -234,12 +228,7 @@ class ModelRunner:
self.tp_size = tp_size
self.nccl_port = nccl_port
self.server_args = server_args
global global_server_args_dict
global_server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
self.is_multimodal_model = is_multimodal_model(self.model_config)
# Init torch distributed
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
......@@ -269,9 +258,17 @@ class ModelRunner:
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
# Set some global args
global global_server_args_dict
global_server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
# Load the model and create memory pool
self.load_model()
self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config)
self.init_flash_infer()
def load_model(self):
logger.info(
......@@ -347,6 +344,22 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
def init_flash_infer(self):
if global_server_args_dict.get("enable_flashinfer", False):
from flashinfer import (
BatchPrefillWithPagedKVCacheWrapper,
BatchDecodeWithPagedKVCacheWrapper,
)
workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
)
self.flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
@torch.inference_mode()
def forward_prefill(self, batch: Batch):
input_metadata = InputMetadata.create(
......@@ -360,6 +373,8 @@ class ModelRunner:
out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
......@@ -378,6 +393,8 @@ class ModelRunner:
out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
......@@ -398,6 +415,8 @@ class ModelRunner:
out_cache_cont_end=batch.out_cache_cont_end,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
......@@ -416,6 +435,8 @@ class ModelRunner:
out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
)
return self.model.forward(
batch.input_ids,
......
......@@ -150,7 +150,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if server_args.disable_disk_cache:
disable_cache()
if server_args.enable_flashinfer:
assert_pkg_version("flashinfer", "0.0.4")
assert_pkg_version("flashinfer", "0.0.5")
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment