Unverified Commit 08fab2b0 authored by eigen's avatar eigen Committed by GitHub
Browse files

minor: global workspace buffer for trtllm-gen mha from flashinfer (#8952)

parent 0d1e27a0
......@@ -25,6 +25,9 @@ if TYPE_CHECKING:
# Constants
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
# Reuse this workspace buffer across all TRTLLM MHA wrappers
global_workspace_buffer = None
@dataclass
class TRTLLMMHAMetadata:
......@@ -69,9 +72,15 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
# Workspace allocation
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
self.workspace_buffer = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
# Allocate buffers
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
self.workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
# CUDA graph state
self.decode_cuda_graph_metadata = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment