Unverified Commit be0a3f75 authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Bugfix] Fix race in non-blocking num_accepted_tokens GPU->CPU copy (#36013)


Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 17dc9c7f
......@@ -727,8 +727,10 @@ class GPUModelRunner(
self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None
self.valid_sampled_token_count_cpu: torch.Tensor | None = None
self.draft_token_ids_cpu: torch.Tensor | None = None
self.num_accepted_tokens_event: torch.Event | None = None
if self.num_spec_tokens:
self.draft_token_ids_event = torch.Event()
self.num_accepted_tokens_event = torch.Event()
self.draft_token_ids_copy_stream = torch.cuda.Stream()
self.draft_token_ids_cpu = torch.empty(
(self.max_num_reqs, self.num_spec_tokens),
......@@ -1229,6 +1231,8 @@ class GPUModelRunner(
self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_(
self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True
)
assert self.num_accepted_tokens_event is not None
self.num_accepted_tokens_event.record()
def _update_streaming_request(
self, req_id: str, new_req_data: NewRequestData
......@@ -1773,6 +1777,8 @@ class GPUModelRunner(
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
if use_spec_decode:
if self.num_accepted_tokens_event is not None:
self.num_accepted_tokens_event.synchronize()
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment