Unverified Commit e0ce171d authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix triton backend eagle illegal memory access (#9344)

parent fe43e889
......@@ -172,7 +172,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
forward_batch.seq_lens_sum, dtype=torch.int64, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
......@@ -238,7 +238,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
kv_indptr[-1], dtype=torch.int32, device=self.device
kv_indptr[-1], dtype=torch.int64, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
......@@ -289,6 +289,7 @@ class TritonAttnBackend(AttentionBackend):
self.req_to_token,
)
)
kv_indices = kv_indices.to(torch.int64)
mask_indptr = None
# TODO(FIXME): This will trigger an invalid Eagle tree when using
# `max(spec_info.accept_length_cpu)`.
......@@ -304,7 +305,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
forward_batch.extend_prefix_lens.sum().item(),
dtype=torch.int32,
dtype=torch.int64,
device=self.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
......@@ -379,7 +380,7 @@ class TritonAttnBackend(AttentionBackend):
if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros(
(max_num_tokens * self.max_context_len),
dtype=torch.int32,
dtype=torch.int64,
device=self.device,
)
else:
......@@ -396,7 +397,7 @@ class TritonAttnBackend(AttentionBackend):
if kv_indices_buf is None:
self.cuda_graph_window_kv_indices = torch.zeros(
(max_num_tokens * self.sliding_window_size),
dtype=torch.int32,
dtype=torch.int64,
device=self.device,
)
else:
......@@ -888,7 +889,7 @@ class TritonMultiStepDraftBackend:
self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len,
),
dtype=torch.int32,
dtype=torch.int64,
device=self.device,
)
......@@ -906,7 +907,7 @@ class TritonMultiStepDraftBackend:
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
dtype=torch.int32,
dtype=torch.int64,
device=self.device,
)
for i in range(self.speculative_num_steps):
......@@ -1015,7 +1016,7 @@ def update_sliding_window_buffer(
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
window_kv_indices = torch.empty(
window_kv_indptr[-1], dtype=torch.int32, device=device
window_kv_indptr[-1], dtype=torch.int64, device=device
)
window_kv_start_idx = seq_lens - window_kv_lens
create_flashinfer_kv_indices_triton[(bs,)](
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment