Unverified Commit 7b6a5332 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix triton args init (#1034)

parent 4080e822
......@@ -148,9 +148,6 @@ class InputMetadata:
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu)
def init_total_num_tokens(self, batch: ScheduleBatch):
self.total_num_tokens = sum(len(req.fill_ids) for req in batch.reqs)
@classmethod
def from_schedule_batch(
cls,
......@@ -174,7 +171,11 @@ class InputMetadata:
ret.compute_extend_infos(batch)
ret.init_total_num_tokens(batch)
if (
forward_mode != ForwardMode.DECODE
or model_runner.server_args.disable_flashinfer
):
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
if forward_mode != ForwardMode.DECODE:
ret.init_multimuldal_info(batch)
......@@ -203,7 +204,7 @@ class InputMetadata:
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
"""Init auxiliary variables for triton attention backend."""
self.triton_max_seq_len = max(len(r.fill_ids) for r in batch.reqs)
self.triton_max_seq_len = int(torch.max(self.seq_lens))
self.triton_prefix_lens = prefix_lens
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment