"...text-generation-inference.git" did not exist on "9e2fdf57c04bae65827b2b03ad2b696eb6e8dec7"
Unverified Commit 8f790ac1 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix a bug in cuda graph runner (#1094)

parent 616b59f3
...@@ -98,8 +98,8 @@ class CudaGraphRunner: ...@@ -98,8 +98,8 @@ class CudaGraphRunner:
self.req_pool_indices = torch.zeros( self.req_pool_indices = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda" (self.max_bs,), dtype=torch.int32, device="cuda"
) )
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda") self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.position_ids_offsets = torch.zeros( self.position_ids_offsets = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda" (self.max_bs,), dtype=torch.int32, device="cuda"
) )
self.out_cache_loc = torch.zeros( self.out_cache_loc = torch.zeros(
...@@ -201,7 +201,7 @@ class CudaGraphRunner: ...@@ -201,7 +201,7 @@ class CudaGraphRunner:
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
return_logprob=False, return_logprob=False,
top_logprobs_nums=0, top_logprobs_nums=0,
positions=(seq_lens - 1).to(torch.int64), positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
flashinfer_decode_wrapper=flashinfer_decode_wrapper, flashinfer_decode_wrapper=flashinfer_decode_wrapper,
) )
...@@ -225,8 +225,8 @@ class CudaGraphRunner: ...@@ -225,8 +225,8 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.batch_size_list, raw_bs) index = bisect.bisect_left(self.batch_size_list, raw_bs)
bs = self.batch_size_list[index] bs = self.batch_size_list[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(1) self.seq_lens.zero_()
self.position_ids_offsets.zero_() self.position_ids_offsets.fill_(1)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
# Common inputs # Common inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment