".github/vscode:/vscode.git/clone" did not exist on "2255a0fc9f6b204f152da7920f116a0c22a1da35"
Unverified Commit 8f790ac1 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix a bug in cuda graph runner (#1094)

parent 616b59f3
......@@ -98,8 +98,8 @@ class CudaGraphRunner:
self.req_pool_indices = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
self.position_ids_offsets = torch.zeros(
self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.position_ids_offsets = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
self.out_cache_loc = torch.zeros(
......@@ -201,7 +201,7 @@ class CudaGraphRunner:
out_cache_loc=out_cache_loc,
return_logprob=False,
top_logprobs_nums=0,
positions=(seq_lens - 1).to(torch.int64),
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
)
......@@ -225,8 +225,8 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.batch_size_list, raw_bs)
bs = self.batch_size_list[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.position_ids_offsets.zero_()
self.seq_lens.zero_()
self.position_ids_offsets.fill_(1)
self.out_cache_loc.zero_()
# Common inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment