Unverified Commit b7e2f800 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Update flashinfer to 0.0.5 (#554)

parent 09593e9b
...@@ -12,7 +12,8 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada ...@@ -12,7 +12,8 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
def __init__( def __init__(
self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1 self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int,
layer_id: int, logit_cap: int = -1
): ):
super().__init__() super().__init__()
self.tp_q_head_num = num_heads self.tp_q_head_num = num_heads
...@@ -20,7 +21,6 @@ class RadixAttention(nn.Module): ...@@ -20,7 +21,6 @@ class RadixAttention(nn.Module):
self.tp_v_head_num = num_kv_heads self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim self.head_dim = head_dim
self.layer_id = layer_id self.layer_id = layer_id
self.logit_cap = logit_cap
assert np.allclose(scaling, 1.0 / (head_dim**0.5)) assert np.allclose(scaling, 1.0 / (head_dim**0.5))
...@@ -30,10 +30,17 @@ class RadixAttention(nn.Module): ...@@ -30,10 +30,17 @@ class RadixAttention(nn.Module):
self.prefill_forward = self.prefill_forward_flashinfer self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer
# flashinfer only accepts a boolean logit_cap argument
if logit_cap > 0:
assert logit_cap == 30
self.logit_cap = True
else:
self.logit_cap = False
else: else:
self.prefill_forward = self.prefill_forward_triton self.prefill_forward = self.prefill_forward_triton
self.extend_forward = self.extend_forward_triton self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton self.decode_forward = self.decode_forward_triton
self.logit_cap = logit_cap
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata): def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q) o = torch.empty_like(q)
...@@ -100,9 +107,10 @@ class RadixAttention(nn.Module): ...@@ -100,9 +107,10 @@ class RadixAttention(nn.Module):
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata) self.store_kv_cache(k, v, input_metadata)
o = input_metadata.prefill_wrapper.forward( o = input_metadata.flashinfer_prefill_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id], input_metadata.token_to_kv_pool.kv_data[self.layer_id],
logits_cap=self.logit_cap,
) )
return o.view(-1, self.tp_q_head_num * self.head_dim) return o.view(-1, self.tp_q_head_num * self.head_dim)
...@@ -110,9 +118,10 @@ class RadixAttention(nn.Module): ...@@ -110,9 +118,10 @@ class RadixAttention(nn.Module):
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata) self.store_kv_cache(k, v, input_metadata)
o = input_metadata.decode_wrapper.forward( o = input_metadata.flashinfer_decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id], input_metadata.token_to_kv_pool.kv_data[self.layer_id],
logits_cap=self.logit_cap,
) )
return o.view(-1, self.tp_q_head_num * self.head_dim) return o.view(-1, self.tp_q_head_num * self.head_dim)
......
...@@ -6,7 +6,7 @@ import logging ...@@ -6,7 +6,7 @@ import logging
import pkgutil import pkgutil
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import List, Optional, Type from typing import List, Optional, Type, Any
import numpy as np import numpy as np
import torch import torch
...@@ -34,7 +34,6 @@ global_server_args_dict = {} ...@@ -34,7 +34,6 @@ global_server_args_dict = {}
@dataclass @dataclass
class InputMetadata: class InputMetadata:
model_runner: "ModelRunner"
forward_mode: ForwardMode forward_mode: ForwardMode
batch_size: int batch_size: int
total_num_tokens: int total_num_tokens: int
...@@ -65,15 +64,10 @@ class InputMetadata: ...@@ -65,15 +64,10 @@ class InputMetadata:
kv_indptr: torch.Tensor = None kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None kv_indices: torch.Tensor = None
kv_last_page_len: torch.Tensor = None kv_last_page_len: torch.Tensor = None
prefill_wrapper = None flashinfer_prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper" = None
decode_wrapper = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
def init_flashinfer_args(self, tp_size):
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
def init_flashinfer_args(self, num_attention_heads, num_key_value_heads, head_dim):
self.kv_indptr = torch.zeros( self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
...@@ -93,9 +87,6 @@ class InputMetadata: ...@@ -93,9 +87,6 @@ class InputMetadata:
dim=0, dim=0,
).contiguous() ).contiguous()
workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
)
if ( if (
self.forward_mode == ForwardMode.PREFILL self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND or self.forward_mode == ForwardMode.EXTEND
...@@ -104,34 +95,30 @@ class InputMetadata: ...@@ -104,34 +95,30 @@ class InputMetadata:
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD" self.flashinfer_prefill_wrapper.end_forward()
) self.flashinfer_prefill_wrapper.begin_forward(
args = [
self.qo_indptr, self.qo_indptr,
self.kv_indptr, self.kv_indptr,
self.kv_indices, self.kv_indices,
self.kv_last_page_len, self.kv_last_page_len,
self.model_runner.model_config.num_attention_heads // tp_size, num_attention_heads,
self.model_runner.model_config.num_key_value_heads // tp_size, num_key_value_heads,
self.model_runner.model_config.head_dim, head_dim,
] 1
self.prefill_wrapper.begin_forward(*args)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
) )
self.decode_wrapper.begin_forward( else:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
self.kv_indptr, self.kv_indptr,
self.kv_indices, self.kv_indices,
self.kv_last_page_len, self.kv_last_page_len,
self.model_runner.model_config.num_attention_heads // tp_size, num_attention_heads,
self.model_runner.model_config.num_key_value_heads // tp_size, num_key_value_heads,
self.model_runner.model_config.head_dim, head_dim,
1, 1,
"NONE", pos_encoding_mode="NONE",
"float16", data_type="float16",
) )
def init_extend_args(self): def init_extend_args(self):
...@@ -155,6 +142,8 @@ class InputMetadata: ...@@ -155,6 +142,8 @@ class InputMetadata:
out_cache_cont_end=None, out_cache_cont_end=None,
top_logprobs_nums=None, top_logprobs_nums=None,
return_logprob=False, return_logprob=False,
flashinfer_prefill_wrapper=None,
flashinfer_decode_wrapper=None,
): ):
batch_size = len(req_pool_indices) batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
...@@ -187,7 +176,6 @@ class InputMetadata: ...@@ -187,7 +176,6 @@ class InputMetadata:
other_kv_index = None other_kv_index = None
ret = cls( ret = cls(
model_runner=model_runner,
forward_mode=forward_mode, forward_mode=forward_mode,
batch_size=batch_size, batch_size=batch_size,
total_num_tokens=total_num_tokens, total_num_tokens=total_num_tokens,
...@@ -205,13 +193,19 @@ class InputMetadata: ...@@ -205,13 +193,19 @@ class InputMetadata:
other_kv_index=other_kv_index, other_kv_index=other_kv_index,
return_logprob=return_logprob, return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums, top_logprobs_nums=top_logprobs_nums,
flashinfer_prefill_wrapper=flashinfer_prefill_wrapper,
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
) )
if forward_mode == ForwardMode.EXTEND: if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args() ret.init_extend_args()
if global_server_args_dict.get("enable_flashinfer", False): if global_server_args_dict.get("enable_flashinfer", False):
ret.init_flashinfer_args(tp_size) ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size,
model_runner.model_config.num_key_value_heads // tp_size,
model_runner.model_config.head_dim
)
return ret return ret
...@@ -234,12 +228,7 @@ class ModelRunner: ...@@ -234,12 +228,7 @@ class ModelRunner:
self.tp_size = tp_size self.tp_size = tp_size
self.nccl_port = nccl_port self.nccl_port = nccl_port
self.server_args = server_args self.server_args = server_args
self.is_multimodal_model = is_multimodal_model(self.model_config)
global global_server_args_dict
global_server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
# Init torch distributed # Init torch distributed
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.") logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
...@@ -269,9 +258,17 @@ class ModelRunner: ...@@ -269,9 +258,17 @@ class ModelRunner:
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes." "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
) )
# Set some global args
global global_server_args_dict
global_server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
# Load the model and create memory pool
self.load_model() self.load_model()
self.init_memory_pool(total_gpu_memory) self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config) self.init_flash_infer()
def load_model(self): def load_model(self):
logger.info( logger.info(
...@@ -347,6 +344,22 @@ class ModelRunner: ...@@ -347,6 +344,22 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
) )
def init_flash_infer(self):
if global_server_args_dict.get("enable_flashinfer", False):
from flashinfer import (
BatchPrefillWithPagedKVCacheWrapper,
BatchDecodeWithPagedKVCacheWrapper,
)
workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
)
self.flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
@torch.inference_mode() @torch.inference_mode()
def forward_prefill(self, batch: Batch): def forward_prefill(self, batch: Batch):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
...@@ -360,6 +373,8 @@ class ModelRunner: ...@@ -360,6 +373,8 @@ class ModelRunner:
out_cache_loc=batch.out_cache_loc, out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
...@@ -378,6 +393,8 @@ class ModelRunner: ...@@ -378,6 +393,8 @@ class ModelRunner:
out_cache_loc=batch.out_cache_loc, out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
...@@ -398,6 +415,8 @@ class ModelRunner: ...@@ -398,6 +415,8 @@ class ModelRunner:
out_cache_cont_end=batch.out_cache_cont_end, out_cache_cont_end=batch.out_cache_cont_end,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
...@@ -416,6 +435,8 @@ class ModelRunner: ...@@ -416,6 +435,8 @@ class ModelRunner:
out_cache_loc=batch.out_cache_loc, out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, batch.input_ids,
......
...@@ -150,7 +150,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -150,7 +150,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if server_args.disable_disk_cache: if server_args.disable_disk_cache:
disable_cache() disable_cache()
if server_args.enable_flashinfer: if server_args.enable_flashinfer:
assert_pkg_version("flashinfer", "0.0.4") assert_pkg_version("flashinfer", "0.0.5")
if server_args.chat_template: if server_args.chat_template:
# TODO: replace this with huggingface transformers template # TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template) load_chat_template_for_openai_api(server_args.chat_template)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment