Unverified Commit ad9d09e2 authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Perf] [Hybrid] Copy num_accepted_tokens in non-blocking way when not using prefix caching (#35442)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 4beebfd1
...@@ -1191,13 +1191,14 @@ class GPUModelRunner( ...@@ -1191,13 +1191,14 @@ class GPUModelRunner(
return return
# Find the number of accepted tokens for each sequence. # Find the number of accepted tokens for each sequence.
num_accepted_tokens = ( num_reqs = output_token_ids.size(0)
self.num_accepted_tokens.gpu[:num_reqs] = (
( (
torch.cat( torch.cat(
[ [
output_token_ids, output_token_ids,
torch.full( torch.full(
(output_token_ids.size(0), 1), (num_reqs, 1),
-1, -1,
device=output_token_ids.device, device=output_token_ids.device,
), ),
...@@ -1208,12 +1209,13 @@ class GPUModelRunner( ...@@ -1208,12 +1209,13 @@ class GPUModelRunner(
) )
.int() .int()
.argmax(-1) .argmax(-1)
.cpu()
.numpy()
) )
for i, num_tokens in enumerate(num_accepted_tokens):
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
if self.cache_config.mamba_cache_mode == "align": if self.cache_config.mamba_cache_mode == "align":
for i, num_tokens in enumerate(
self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy()
):
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
mamba_utils.postprocess_mamba( mamba_utils.postprocess_mamba(
scheduler_output, scheduler_output,
self.kv_cache_config, self.kv_cache_config,
...@@ -1224,6 +1226,10 @@ class GPUModelRunner( ...@@ -1224,6 +1226,10 @@ class GPUModelRunner(
self.model.get_mamba_state_copy_func(), self.model.get_mamba_state_copy_func(),
self._get_mamba_copy_bufs(), self._get_mamba_copy_bufs(),
) )
else:
self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_(
self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True
)
def _update_streaming_request( def _update_streaming_request(
self, req_id: str, new_req_data: NewRequestData self, req_id: str, new_req_data: NewRequestData
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment