Unverified Commit 490a1f39 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix cuda graph with flashinfer (#675)

parent 06487f12
...@@ -64,7 +64,7 @@ def main(args): ...@@ -64,7 +64,7 @@ def main(args):
@sgl.function @sgl.function
def few_shot_gsm8k(s, question): def few_shot_gsm8k(s, question):
s += few_shot_examples + question s += few_shot_examples + question
s += sgl.gen("answer", max_tokens=256, stop="Question") s += sgl.gen("answer", max_tokens=512, stop="Question")
##################################### #####################################
########## SGL Program End ########## ########## SGL Program End ##########
......
...@@ -150,8 +150,8 @@ class CudaGraphRunner: ...@@ -150,8 +150,8 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.batch_size_list, raw_bs) index = bisect.bisect_left(self.batch_size_list, raw_bs)
bs = self.batch_size_list[index] bs = self.batch_size_list[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.zero_() self.seq_lens.fill_(1)
self.position_ids_offsets.fill_(1) self.position_ids_offsets.zero_()
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
# Common inputs # Common inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment