"...git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "ec034c15023ca0412a91aeddd8aad164e155b695"
Unverified Commit 08fab2b0 authored by eigen's avatar eigen Committed by GitHub
Browse files

minor: global workspace buffer for trtllm-gen mha from flashinfer (#8952)

parent 0d1e27a0
...@@ -25,6 +25,9 @@ if TYPE_CHECKING: ...@@ -25,6 +25,9 @@ if TYPE_CHECKING:
# Constants # Constants
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
# Reuse this workspace buffer across all TRTLLM MHA wrappers
global_workspace_buffer = None
@dataclass @dataclass
class TRTLLMMHAMetadata: class TRTLLMMHAMetadata:
...@@ -69,9 +72,15 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -69,9 +72,15 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
# Workspace allocation # Workspace allocation
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024 self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
self.workspace_buffer = torch.empty( # Allocate buffers
self.workspace_size, dtype=torch.int8, device=self.device global global_workspace_buffer
) if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
self.workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
# CUDA graph state # CUDA graph state
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment