Unverified Commit 4dbf4360 authored by eigen's avatar eigen Committed by GitHub
Browse files

fix: zero_init buffer (#9065)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 3d6be1fb
...@@ -63,7 +63,7 @@ srt = [ ...@@ -63,7 +63,7 @@ srt = [
"torchaudio==2.8.0", "torchaudio==2.8.0",
"torchvision", "torchvision",
"cuda-python", "cuda-python",
"flashinfer_python==0.2.11.post1", "flashinfer_python==0.2.11.post3",
] ]
blackwell = [ blackwell = [
...@@ -73,7 +73,7 @@ blackwell = [ ...@@ -73,7 +73,7 @@ blackwell = [
"torchaudio==2.8.0", "torchaudio==2.8.0",
"torchvision", "torchvision",
"cuda-python", "cuda-python",
"flashinfer_python==0.2.11.post1", "flashinfer_python==0.2.11.post3",
] ]
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD
......
...@@ -647,7 +647,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -647,7 +647,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if server_args.attention_backend == "flashinfer": if server_args.attention_backend == "flashinfer":
assert_pkg_version( assert_pkg_version(
"flashinfer_python", "flashinfer_python",
"0.2.11.post1", "0.2.11.post3",
"Please uninstall the old version and " "Please uninstall the old version and "
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
......
...@@ -122,6 +122,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -122,6 +122,7 @@ class FlashInferAttnBackend(AttentionBackend):
# Allocate buffers # Allocate buffers
global global_workspace_buffer global global_workspace_buffer
if global_workspace_buffer is None: if global_workspace_buffer is None:
# different from flashinfer zero_init_global_workspace_buffer
global_workspace_buffer = torch.empty( global_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size, global_config.flashinfer_workspace_size,
dtype=torch.uint8, dtype=torch.uint8,
......
...@@ -81,6 +81,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -81,6 +81,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
# Allocate buffers # Allocate buffers
global global_workspace_buffer global global_workspace_buffer
if global_workspace_buffer is None: if global_workspace_buffer is None:
# different from flashinfer zero_init_global_workspace_buffer
global_workspace_buffer = torch.empty( global_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size, global_config.flashinfer_workspace_size,
dtype=torch.uint8, dtype=torch.uint8,
......
...@@ -23,10 +23,12 @@ if TYPE_CHECKING: ...@@ -23,10 +23,12 @@ if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo from sglang.srt.speculative.spec_info import SpecInfo
# Constants # Constants
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB DEFAULT_WORKSPACE_SIZE_MB = (
512 # Memory workspace size in MB, todo(Yingyi): read from config
)
# Reuse this workspace buffer across all TRTLLM MHA wrappers # Reuse this workspace buffer across all TRTLLM MHA wrappers
global_workspace_buffer = None global_zero_init_workspace_buffer = None
@dataclass @dataclass
...@@ -73,14 +75,14 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -73,14 +75,14 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
# Workspace allocation # Workspace allocation
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024 self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
# Allocate buffers # Allocate buffers
global global_workspace_buffer global global_zero_init_workspace_buffer
if global_workspace_buffer is None: if global_zero_init_workspace_buffer is None:
global_workspace_buffer = torch.empty( global_zero_init_workspace_buffer = torch.zeros(
self.workspace_size, self.workspace_size,
dtype=torch.uint8, dtype=torch.uint8,
device=model_runner.device, device=model_runner.device,
) )
self.workspace_buffer = global_workspace_buffer self.workspace_buffer = global_zero_init_workspace_buffer
# CUDA graph state # CUDA graph state
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
......
...@@ -39,6 +39,8 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB ...@@ -39,6 +39,8 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
# compute the LCM with other padding constraints. # compute the LCM with other padding constraints.
TRTLLM_BLOCK_CONSTRAINT = 128 TRTLLM_BLOCK_CONSTRAINT = 128
global_zero_init_workspace_buffer = None
@dataclass @dataclass
class TRTLLMMLADecodeMetadata: class TRTLLMMLADecodeMetadata:
...@@ -83,9 +85,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -83,9 +85,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# Workspace allocation # Workspace allocation
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024 self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
self.workspace_buffer = torch.empty( global global_zero_init_workspace_buffer
self.workspace_size, dtype=torch.int8, device=self.device if global_zero_init_workspace_buffer is None:
) global_zero_init_workspace_buffer = torch.zeros(
self.workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_zero_init_workspace_buffer
# CUDA graph state # CUDA graph state
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
......
...@@ -143,4 +143,4 @@ ...@@ -143,4 +143,4 @@
"num_warps": 4, "num_warps": 4,
"num_stages": 3 "num_stages": 3
} }
} }
\ No newline at end of file
...@@ -143,4 +143,4 @@ ...@@ -143,4 +143,4 @@
"num_warps": 4, "num_warps": 4,
"num_stages": 4 "num_stages": 4
} }
} }
\ No newline at end of file
...@@ -143,4 +143,4 @@ ...@@ -143,4 +143,4 @@
"num_warps": 4, "num_warps": 4,
"num_stages": 3 "num_stages": 3
} }
} }
\ No newline at end of file
...@@ -143,4 +143,4 @@ ...@@ -143,4 +143,4 @@
"num_warps": 4, "num_warps": 4,
"num_stages": 3 "num_stages": 3
} }
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment