Unverified Commit 31e9d3a5 authored by Binyao Jiang's avatar Binyao Jiang Committed by GitHub
Browse files

[Fix] Init mamba related memory pools with torch.zeros (#10400)

parent 6f4676ef
...@@ -127,7 +127,7 @@ class MambaPool: ...@@ -127,7 +127,7 @@ class MambaPool:
if speculative_num_draft_tokens is not None: if speculative_num_draft_tokens is not None:
# Cache intermediate SSM states per draft token during target verify # Cache intermediate SSM states per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V] # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
intermediate_ssm_state_cache = torch.empty( intermediate_ssm_state_cache = torch.zeros(
size=( size=(
num_mamba_layers, num_mamba_layers,
size + 1, size + 1,
...@@ -141,7 +141,7 @@ class MambaPool: ...@@ -141,7 +141,7 @@ class MambaPool:
) )
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1] # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = torch.empty( intermediate_conv_window_cache = torch.zeros(
size=( size=(
num_mamba_layers, num_mamba_layers,
size + 1, size + 1,
...@@ -240,7 +240,7 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -240,7 +240,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)} self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
self.device = device self.device = device
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.empty( self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
size, dtype=torch.int32, device=self.device size, dtype=torch.int32, device=self.device
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment