Unverified Commit 6ae3f05b authored by ur4t's avatar ur4t Committed by GitHub
Browse files

Fix CUDA illegal memory access issues in speculative decoding (#10892)

parent fdc4e1e5
...@@ -302,6 +302,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -302,6 +302,7 @@ class EAGLEDraftCudaGraphRunner:
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(self.seq_len_fill_value) self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
self.positions.zero_()
num_tokens = bs * self.num_tokens_per_bs num_tokens = bs * self.num_tokens_per_bs
......
...@@ -332,6 +332,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -332,6 +332,7 @@ class EAGLEDraftExtendCudaGraphRunner:
if bs * self.num_tokens_per_bs != num_tokens: if bs * self.num_tokens_per_bs != num_tokens:
self.seq_lens.fill_(self.seq_len_fill_value) self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
self.positions.zero_()
self.accept_length.fill_(1) self.accept_length.fill_(1)
self.extend_seq_lens.fill_(1) self.extend_seq_lens.fill_(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment