Unverified Commit 0089c4bc authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Fix] Fix NaN issues by fixing the cuda graph padding values for flashinfer (#1779)

parent 72e7b57a
...@@ -290,7 +290,7 @@ class CudaGraphRunner: ...@@ -290,7 +290,7 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(self.seq_len_fill_value) self.seq_lens.fill_(1)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
# Common inputs # Common inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment