Unverified Commit f440baa1 authored by ykcombat's avatar ykcombat Committed by GitHub
Browse files

[Feature] Reuse flashinfer workspace for PD-Multiplexing. (#11540)

parent 2bc3fcd4
...@@ -34,7 +34,9 @@ def create_flashinfer_backend(runner): ...@@ -34,7 +34,9 @@ def create_flashinfer_backend(runner):
or not runner.plan_stream_for_flashinfer or not runner.plan_stream_for_flashinfer
): ):
runner.plan_stream_for_flashinfer = torch.cuda.Stream() runner.plan_stream_for_flashinfer = torch.cuda.Stream()
return FlashInferAttnBackend(runner) return FlashInferAttnBackend(
runner, init_new_workspace=runner.init_new_workspace
)
else: else:
from sglang.srt.layers.attention.flashinfer_mla_backend import ( from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend, FlashInferMLAAttnBackend,
......
...@@ -118,6 +118,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -118,6 +118,7 @@ class FlashInferAttnBackend(AttentionBackend):
skip_prefill: bool = False, skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None, kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None, kv_last_page_len_buf: Optional[torch.Tensor] = None,
init_new_workspace: bool = False,
): ):
super().__init__() super().__init__()
...@@ -192,7 +193,14 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -192,7 +193,14 @@ class FlashInferAttnBackend(AttentionBackend):
dtype=torch.uint8, dtype=torch.uint8,
device=model_runner.device, device=model_runner.device,
) )
self.workspace_buffer = global_workspace_buffer if init_new_workspace:
self.workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
else:
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None: if kv_indptr_buf is None:
self.kv_indptr = [ self.kv_indptr = [
......
...@@ -284,6 +284,7 @@ class ModelRunner: ...@@ -284,6 +284,7 @@ class ModelRunner:
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
self.attention_chunk_size = model_config.attention_chunk_size self.attention_chunk_size = model_config.attention_chunk_size
self.forward_pass_id = 0 self.forward_pass_id = 0
self.init_new_workspace = False
# Apply the rank zero filter to logger # Apply the rank zero filter to logger
if server_args.show_time_cost: if server_args.show_time_cost:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment