"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5ce4814af1de6d2dc2cc67a46d3862ce62261e2b"
Unverified Commit ef9d3b3c authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix triton kernel illegal memory issue for eagle (#4100)

parent fc91d08a
...@@ -292,11 +292,7 @@ class ForwardBatch: ...@@ -292,11 +292,7 @@ class ForwardBatch:
ret.extend_prefix_lens = torch.tensor( ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32 batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
if ( if model_runner.server_args.attention_backend != "torch_native":
model_runner.server_args.attention_backend != "torch_native"
# TODO: Fix triton kernel illegal memory access for EAGLE
and model_runner.server_args.speculative_algorithm != "EAGLE"
):
ret.extend_num_tokens = batch.extend_num_tokens ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position_triton( positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_prefix_lens,
...@@ -386,6 +382,8 @@ def compute_position_triton( ...@@ -386,6 +382,8 @@ def compute_position_triton(
): ):
"""Compute positions. It is a fused version of `compute_position_torch`.""" """Compute positions. It is a fused version of `compute_position_torch`."""
batch_size = extend_seq_lens.shape[0] batch_size = extend_seq_lens.shape[0]
has_prefix = extend_prefix_lens.shape[0] == batch_size
positions = torch.empty( positions = torch.empty(
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
) )
...@@ -399,6 +397,7 @@ def compute_position_triton( ...@@ -399,6 +397,7 @@ def compute_position_triton(
extend_start_loc, extend_start_loc,
extend_prefix_lens, extend_prefix_lens,
extend_seq_lens, extend_seq_lens,
has_prefix,
) )
return positions, extend_start_loc return positions, extend_start_loc
...@@ -410,11 +409,12 @@ def compute_position_kernel( ...@@ -410,11 +409,12 @@ def compute_position_kernel(
extend_start_loc, extend_start_loc,
extend_prefix_lens, extend_prefix_lens,
extend_seq_lens, extend_seq_lens,
has_prefix: tl.constexpr,
): ):
BLOCK_SIZE: tl.constexpr = 512 BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0).to(tl.int64) pid = tl.program_id(0).to(tl.int64)
prefix_len = tl.load(extend_prefix_lens + pid) prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
seq_len = tl.load(extend_seq_lens + pid) seq_len = tl.load(extend_seq_lens + pid)
# TODO: optimize this? # TODO: optimize this?
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment