Unverified Commit e0ce171d authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix triton backend eagle illegal memory access (#9344)

parent fe43e889
...@@ -172,7 +172,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -172,7 +172,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty( kv_indices = torch.empty(
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device forward_batch.seq_lens_sum, dtype=torch.int64, device=self.device
) )
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
...@@ -238,7 +238,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -238,7 +238,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty( kv_indices = torch.empty(
kv_indptr[-1], dtype=torch.int32, device=self.device kv_indptr[-1], dtype=torch.int64, device=self.device
) )
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
...@@ -289,6 +289,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -289,6 +289,7 @@ class TritonAttnBackend(AttentionBackend):
self.req_to_token, self.req_to_token,
) )
) )
kv_indices = kv_indices.to(torch.int64)
mask_indptr = None mask_indptr = None
# TODO(FIXME): This will trigger an invalid Eagle tree when using # TODO(FIXME): This will trigger an invalid Eagle tree when using
# `max(spec_info.accept_length_cpu)`. # `max(spec_info.accept_length_cpu)`.
...@@ -304,7 +305,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -304,7 +305,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty( kv_indices = torch.empty(
forward_batch.extend_prefix_lens.sum().item(), forward_batch.extend_prefix_lens.sum().item(),
dtype=torch.int32, dtype=torch.int64,
device=self.device, device=self.device,
) )
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
...@@ -379,7 +380,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -379,7 +380,7 @@ class TritonAttnBackend(AttentionBackend):
if kv_indices_buf is None: if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(max_num_tokens * self.max_context_len), (max_num_tokens * self.max_context_len),
dtype=torch.int32, dtype=torch.int64,
device=self.device, device=self.device,
) )
else: else:
...@@ -396,7 +397,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -396,7 +397,7 @@ class TritonAttnBackend(AttentionBackend):
if kv_indices_buf is None: if kv_indices_buf is None:
self.cuda_graph_window_kv_indices = torch.zeros( self.cuda_graph_window_kv_indices = torch.zeros(
(max_num_tokens * self.sliding_window_size), (max_num_tokens * self.sliding_window_size),
dtype=torch.int32, dtype=torch.int64,
device=self.device, device=self.device,
) )
else: else:
...@@ -888,7 +889,7 @@ class TritonMultiStepDraftBackend: ...@@ -888,7 +889,7 @@ class TritonMultiStepDraftBackend:
self.speculative_num_steps, self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len, forward_batch.batch_size * self.topk * self.max_context_len,
), ),
dtype=torch.int32, dtype=torch.int64,
device=self.device, device=self.device,
) )
...@@ -906,7 +907,7 @@ class TritonMultiStepDraftBackend: ...@@ -906,7 +907,7 @@ class TritonMultiStepDraftBackend:
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_num_tokens * self.max_context_len), (self.speculative_num_steps, max_num_tokens * self.max_context_len),
dtype=torch.int32, dtype=torch.int64,
device=self.device, device=self.device,
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
...@@ -1015,7 +1016,7 @@ def update_sliding_window_buffer( ...@@ -1015,7 +1016,7 @@ def update_sliding_window_buffer(
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0) window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1] window_kv_indptr = window_kv_indptr[: bs + 1]
window_kv_indices = torch.empty( window_kv_indices = torch.empty(
window_kv_indptr[-1], dtype=torch.int32, device=device window_kv_indptr[-1], dtype=torch.int64, device=device
) )
window_kv_start_idx = seq_lens - window_kv_lens window_kv_start_idx = seq_lens - window_kv_lens
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment