Unverified Commit be0a3f75 authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Bugfix] Fix race in non-blocking num_accepted_tokens GPU->CPU copy (#36013)


Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 17dc9c7f
...@@ -727,8 +727,10 @@ class GPUModelRunner( ...@@ -727,8 +727,10 @@ class GPUModelRunner(
self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None
self.valid_sampled_token_count_cpu: torch.Tensor | None = None self.valid_sampled_token_count_cpu: torch.Tensor | None = None
self.draft_token_ids_cpu: torch.Tensor | None = None self.draft_token_ids_cpu: torch.Tensor | None = None
self.num_accepted_tokens_event: torch.Event | None = None
if self.num_spec_tokens: if self.num_spec_tokens:
self.draft_token_ids_event = torch.Event() self.draft_token_ids_event = torch.Event()
self.num_accepted_tokens_event = torch.Event()
self.draft_token_ids_copy_stream = torch.cuda.Stream() self.draft_token_ids_copy_stream = torch.cuda.Stream()
self.draft_token_ids_cpu = torch.empty( self.draft_token_ids_cpu = torch.empty(
(self.max_num_reqs, self.num_spec_tokens), (self.max_num_reqs, self.num_spec_tokens),
...@@ -1229,6 +1231,8 @@ class GPUModelRunner( ...@@ -1229,6 +1231,8 @@ class GPUModelRunner(
self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_( self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_(
self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True
) )
assert self.num_accepted_tokens_event is not None
self.num_accepted_tokens_event.record()
def _update_streaming_request( def _update_streaming_request(
self, req_id: str, new_req_data: NewRequestData self, req_id: str, new_req_data: NewRequestData
...@@ -1773,6 +1777,8 @@ class GPUModelRunner( ...@@ -1773,6 +1777,8 @@ class GPUModelRunner(
max_seq_len = self.seq_lens.np[:num_reqs].max().item() max_seq_len = self.seq_lens.np[:num_reqs].max().item()
if use_spec_decode: if use_spec_decode:
if self.num_accepted_tokens_event is not None:
self.num_accepted_tokens_event.synchronize()
self.num_accepted_tokens.np[:num_reqs] = ( self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs] self.input_batch.num_accepted_tokens_cpu[:num_reqs]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment