Unverified Commit b91cb67e authored by Binyao Jiang's avatar Binyao Jiang Committed by GitHub
Browse files

[Performance] Qwen3-Next: replace arange to cached query_start_loc_li… (#10553)

parent e7bc6003
...@@ -61,18 +61,15 @@ class MambaAttnBackend(AttentionBackend): ...@@ -61,18 +61,15 @@ class MambaAttnBackend(AttentionBackend):
self.forward_metadata: ForwardMetadata = None self.forward_metadata: ForwardMetadata = None
self.state_indices_list = [] self.state_indices_list = []
self.query_start_loc_list = [] self.query_start_loc_list = []
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
@classmethod self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
@lru_cache(maxsize=128)
def _get_cached_arange(cls, bs: int, device_str: str) -> torch.Tensor:
"""Cache torch.arange tensors for common batch sizes to avoid repeated allocation."""
device = torch.device(device_str)
return torch.arange(0, bs + 1, dtype=torch.int32, device=device)
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
query_start_loc = self._get_cached_arange(bs, str(self.device)) query_start_loc = torch.arange(
0, bs + 1, dtype=torch.int32, device=self.device
)
elif forward_batch.forward_mode.is_extend(): elif forward_batch.forward_mode.is_extend():
if forward_batch.forward_mode.is_target_verify(): if forward_batch.forward_mode.is_target_verify():
query_start_loc = torch.arange( query_start_loc = torch.arange(
...@@ -102,6 +99,10 @@ class MambaAttnBackend(AttentionBackend): ...@@ -102,6 +99,10 @@ class MambaAttnBackend(AttentionBackend):
) )
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
assert (
max_num_tokens % max_bs == 0
), f"max_num_tokens={max_num_tokens} must be divisible by max_bs={max_bs}"
verify_step = max_num_tokens / max_bs
for i in range(max_bs): for i in range(max_bs):
self.state_indices_list.append( self.state_indices_list.append(
torch.full( torch.full(
...@@ -111,6 +112,16 @@ class MambaAttnBackend(AttentionBackend): ...@@ -111,6 +112,16 @@ class MambaAttnBackend(AttentionBackend):
self.query_start_loc_list.append( self.query_start_loc_list.append(
torch.empty((i + 2,), dtype=torch.int32, device=self.device) torch.empty((i + 2,), dtype=torch.int32, device=self.device)
) )
self.cached_cuda_graph_decode_query_start_loc = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=self.device
)
self.cached_cuda_graph_verify_query_start_loc = torch.arange(
0,
max_bs * verify_step + 1,
step=verify_step,
dtype=torch.int32,
device=self.device,
)
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
...@@ -123,16 +134,12 @@ class MambaAttnBackend(AttentionBackend): ...@@ -123,16 +134,12 @@ class MambaAttnBackend(AttentionBackend):
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda")) self.query_start_loc_list[bs - 1].copy_(
self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
)
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
self.query_start_loc_list[bs - 1].copy_( self.query_start_loc_list[bs - 1].copy_(
torch.arange( self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
0,
bs * spec_info.draft_token_num + 1,
step=spec_info.draft_token_num,
dtype=torch.int32,
device=self.device,
)
) )
else: else:
raise ValueError(f"Invalid forward mode: {forward_mode=}") raise ValueError(f"Invalid forward mode: {forward_mode=}")
...@@ -163,23 +170,29 @@ class MambaAttnBackend(AttentionBackend): ...@@ -163,23 +170,29 @@ class MambaAttnBackend(AttentionBackend):
mamba_indices[bs - num_padding :] = -1 mamba_indices[bs - num_padding :] = -1
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda")) if num_padding == 0:
if num_padding > 0: self.query_start_loc_list[bs - 1].copy_(
self.query_start_loc_list[bs - 1][bs - num_padding :] = bs - num_padding self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
elif forward_mode.is_target_verify():
self.query_start_loc_list[bs - 1].copy_(
torch.arange(
0,
bs * spec_info.draft_token_num + 1,
step=spec_info.draft_token_num,
dtype=torch.int32,
device=self.device,
) )
) else:
if num_padding > 0: self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
self.query_start_loc_list[bs - 1][bs - num_padding :] = ( self.cached_cuda_graph_decode_query_start_loc[: bs - num_padding]
)
self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
bs - num_padding bs - num_padding
) * spec_info.draft_token_num )
elif forward_mode.is_target_verify():
if num_padding == 0:
self.query_start_loc_list[bs - 1].copy_(
self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
)
else:
self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
self.cached_cuda_graph_verify_query_start_loc[: bs - num_padding]
)
self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
(bs - num_padding) * spec_info.draft_token_num
)
else: else:
raise ValueError(f"Invalid forward mode: {forward_mode=}") raise ValueError(f"Invalid forward mode: {forward_mode=}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment