Unverified Commit ef9d3b3c authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix triton kernel illegal memory issue for eagle (#4100)

parent fc91d08a
......@@ -292,11 +292,7 @@ class ForwardBatch:
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
if (
model_runner.server_args.attention_backend != "torch_native"
# TODO: Fix triton kernel illegal memory access for EAGLE
and model_runner.server_args.speculative_algorithm != "EAGLE"
):
if model_runner.server_args.attention_backend != "torch_native":
ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens,
......@@ -386,6 +382,8 @@ def compute_position_triton(
):
"""Compute positions. It is a fused version of `compute_position_torch`."""
batch_size = extend_seq_lens.shape[0]
has_prefix = extend_prefix_lens.shape[0] == batch_size
positions = torch.empty(
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
)
......@@ -399,6 +397,7 @@ def compute_position_triton(
extend_start_loc,
extend_prefix_lens,
extend_seq_lens,
has_prefix,
)
return positions, extend_start_loc
......@@ -410,11 +409,12 @@ def compute_position_kernel(
extend_start_loc,
extend_prefix_lens,
extend_seq_lens,
has_prefix: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0).to(tl.int64)
prefix_len = tl.load(extend_prefix_lens + pid)
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
seq_len = tl.load(extend_seq_lens + pid)
# TODO: optimize this?
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment