"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "49b959b5408b97274e2ee423059d9239445aea26"
Unverified Commit ac1f2928 authored by eigen's avatar eigen Committed by GitHub
Browse files

feat: add fast_decode_plan from flashinfer, flashinfer to 0.4.0rc3 (#10760)


Co-authored-by: default avatarZihao Ye <yezihhhao@gmail.com>
Co-authored-by: default avatarSleepcoo <Sleepcoo@gmail.com>
parent 195a59fe
...@@ -62,7 +62,7 @@ dependencies = [ ...@@ -62,7 +62,7 @@ dependencies = [
"torchaudio==2.8.0", "torchaudio==2.8.0",
"torchvision", "torchvision",
"cuda-python", "cuda-python",
"flashinfer_python==0.4.0rc1", "flashinfer_python==0.4.0rc3",
"openai==1.99.1", "openai==1.99.1",
"tiktoken", "tiktoken",
"anthropic>=0.20.0", "anthropic>=0.20.0",
......
...@@ -70,7 +70,7 @@ srt = [ ...@@ -70,7 +70,7 @@ srt = [
"torchaudio==2.8.0", "torchaudio==2.8.0",
"torchvision", "torchvision",
"cuda-python", "cuda-python",
"flashinfer_python==0.4.0rc1", "flashinfer_python==0.4.0rc3",
] ]
blackwell = [ blackwell = [
...@@ -80,8 +80,8 @@ blackwell = [ ...@@ -80,8 +80,8 @@ blackwell = [
"torchaudio==2.8.0", "torchaudio==2.8.0",
"torchvision", "torchvision",
"cuda-python", "cuda-python",
"flashinfer_python==0.4.0rc1", "flashinfer_python==0.4.0rc3",
"nvidia-cutlass-dsl==4.2.1", "nvidia-cutlass-dsl==4.2.0",
] ]
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD
......
...@@ -703,7 +703,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -703,7 +703,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if server_args.attention_backend == "flashinfer": if server_args.attention_backend == "flashinfer":
assert_pkg_version( assert_pkg_version(
"flashinfer_python", "flashinfer_python",
"0.4.0rc1", "0.4.0rc3",
"Please uninstall the old version and " "Please uninstall the old version and "
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
......
...@@ -47,6 +47,7 @@ if is_flashinfer_available(): ...@@ -47,6 +47,7 @@ if is_flashinfer_available():
BatchDecodeWithPagedKVCacheWrapper, BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
fast_decode_plan,
) )
from flashinfer.cascade import merge_state from flashinfer.cascade import merge_state
from flashinfer.decode import _get_range_buf, get_seq_lens from flashinfer.decode import _get_range_buf, get_seq_lens
...@@ -842,23 +843,51 @@ class FlashInferIndicesUpdaterDecode: ...@@ -842,23 +843,51 @@ class FlashInferIndicesUpdaterDecode:
global_override_indptr_cpu[0] = 0 global_override_indptr_cpu[0] = 0
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0) global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
wrapper.begin_forward( # Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan
kv_indptr, # by checking if it's a partial function with fast_decode_plan as the func
kv_indices, wrapper_uses_fast_decode_plan = (
self.kv_last_page_len[:bs], hasattr(wrapper.begin_forward, "func")
self.num_qo_heads, and wrapper.begin_forward.func == fast_decode_plan
self.num_kv_heads,
self.head_dim,
1,
data_type=self.data_type,
q_data_type=self.q_data_type,
non_blocking=True,
fixed_split_size=fixed_split_size,
disable_split_kv=(
disable_split_kv if disable_split_kv is not None else False
),
) )
if wrapper_uses_fast_decode_plan:
# When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu
wrapper.begin_forward(
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
data_type=self.data_type,
q_data_type=self.q_data_type,
non_blocking=True,
fixed_split_size=fixed_split_size,
disable_split_kv=(
disable_split_kv if disable_split_kv is not None else False
),
global_override_indptr_cpu=global_override_indptr_cpu,
)
else:
# When using original begin_forward, don't pass global_override_indptr_cpu
wrapper.begin_forward(
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
data_type=self.data_type,
q_data_type=self.q_data_type,
non_blocking=True,
fixed_split_size=fixed_split_size,
disable_split_kv=(
disable_split_kv if disable_split_kv is not None else False
),
)
if locally_override: if locally_override:
global_override_indptr_cpu = None global_override_indptr_cpu = None
...@@ -1328,174 +1357,3 @@ def should_use_tensor_core( ...@@ -1328,174 +1357,3 @@ def should_use_tensor_core(
return gqa_group_size >= 4 return gqa_group_size >= 4
else: else:
return False return False
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global_override_indptr_cpu = None
def fast_decode_plan(
self,
indptr: torch.Tensor,
indices: torch.Tensor,
last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
pos_encoding_mode: str = "NONE",
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
q_data_type: Optional[Union[str, torch.dtype]] = None,
kv_data_type: Optional[Union[str, torch.dtype]] = None,
data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
non_blocking: bool = True,
fixed_split_size: Optional[int] = None,
disable_split_kv: bool = False,
) -> None:
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
Modifications:
- Remove unnecessary device-to-device copy for the cuda graph buffers.
- Remove unnecessary host-to-device copy for the metadata buffers.
"""
batch_size = len(last_page_len)
if logits_soft_cap is None:
logits_soft_cap = 0.0
# Handle data types consistently
if data_type is not None:
if q_data_type is None:
q_data_type = data_type
if kv_data_type is None:
kv_data_type = data_type
elif q_data_type is None:
q_data_type = "float16"
if kv_data_type is None:
kv_data_type = q_data_type
if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
# Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function
if fixed_split_size is None:
fixed_split_size = -1
if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
" mismatches the batch size set during initialization {}".format(
batch_size, self._fixed_batch_size
)
)
if len(indices) > len(self._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
else:
self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len
if self.use_tensor_cores:
self._qo_indptr_buf = qo_indptr_host.to(
self.device, non_blocking=non_blocking
)
# Create empty tensors for dtype info if needed
empty_q_data = torch.empty(
0,
dtype=(
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
),
device=self.device,
)
empty_kv_cache = torch.empty(
0,
dtype=(
getattr(torch, kv_data_type)
if isinstance(kv_data_type, str)
else kv_data_type
),
device=self.device,
)
indptr_host = (
global_override_indptr_cpu
if global_override_indptr_cpu is not None
else indptr.cpu()
)
with torch.cuda.device(self.device):
if self.use_tensor_cores:
# ALSO convert last_page_len to CPU
if page_size == 1:
# When page size is 1, last_page_len is always 1.
# Directly construct the host tensor rather than executing a device-to-host copy.
last_page_len_host = torch.ones(
(batch_size,), dtype=torch.int32, device="cpu"
)
else:
last_page_len_host = last_page_len.cpu()
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
try:
# Make sure we pass exactly 15 arguments for tensor core version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
window_left,
fixed_split_size,
disable_split_kv,
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}")
else:
try:
# Make sure we pass exactly 15 arguments for standard version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
empty_q_data,
empty_kv_cache,
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}")
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment