Unverified Commit f440baa1 authored by ykcombat's avatar ykcombat Committed by GitHub
Browse files

[Feature] Reuse flashinfer workspace for PD-Multiplexing. (#11540)

parent 2bc3fcd4
......@@ -34,7 +34,9 @@ def create_flashinfer_backend(runner):
or not runner.plan_stream_for_flashinfer
):
runner.plan_stream_for_flashinfer = torch.cuda.Stream()
return FlashInferAttnBackend(runner)
return FlashInferAttnBackend(
runner, init_new_workspace=runner.init_new_workspace
)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
......
......@@ -118,6 +118,7 @@ class FlashInferAttnBackend(AttentionBackend):
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None,
init_new_workspace: bool = False,
):
super().__init__()
......@@ -192,7 +193,14 @@ class FlashInferAttnBackend(AttentionBackend):
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
if init_new_workspace:
self.workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
else:
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = [
......
......@@ -284,6 +284,7 @@ class ModelRunner:
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
self.attention_chunk_size = model_config.attention_chunk_size
self.forward_pass_id = 0
self.init_new_workspace = False
# Apply the rank zero filter to logger
if server_args.show_time_cost:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment