Unverified Commit 97ba96fb authored by zhrrr's avatar zhrrr Committed by GitHub
Browse files

[perf][async] support non cpu sync get logprob tensors for spec (#31336)


Signed-off-by: default avatarizhuhaoran <izhuhaoran@qq.com>
Signed-off-by: default avatarzhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
parent 94578127
...@@ -69,6 +69,14 @@ class LogprobsTensors(NamedTuple): ...@@ -69,6 +69,14 @@ class LogprobsTensors(NamedTuple):
self.selected_token_ranks.to("cpu", non_blocking=True), self.selected_token_ranks.to("cpu", non_blocking=True),
) )
def filter(self, mask: torch.Tensor) -> "LogprobsTensors":
"""Filter the logprobs tensors with the given bool mask."""
return LogprobsTensors(
self.logprob_token_ids[mask],
self.logprobs[mask],
self.selected_token_ranks[mask],
)
@staticmethod @staticmethod
def empty_cpu( def empty_cpu(
num_positions: int, num_tokens_per_position: int num_positions: int, num_tokens_per_position: int
......
...@@ -9,7 +9,7 @@ import torch.nn as nn ...@@ -9,7 +9,7 @@ import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.outputs import LogprobsLists, LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.penalties import apply_all_penalties
...@@ -185,13 +185,22 @@ class RejectionSampler(nn.Module): ...@@ -185,13 +185,22 @@ class RejectionSampler(nn.Module):
final_logits[target_logits_indices] = target_logits.to(torch.float32) final_logits[target_logits_indices] = target_logits.to(torch.float32)
final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32) final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32)
# Compute accepted token indices. # NOTE: To avoid cpu-gpu synchronization, we now simply compute indices for
accepted_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID # all draft tokens, including the rejected ones. The rejected tokens will
num_accepted_tokens = accepted_mask.sum(dim=-1) # be filtered out in the `parse_output`.
accepted_logit_indices = accepted_mask.nonzero(as_tuple=True)[1] logit_start_indices = cu_num_sampled_tokens
accepted_logit_indices += cu_num_sampled_tokens.repeat_interleave( offsets = torch.arange(
num_accepted_tokens sampled_token_ids.shape[-1],
) device=logit_start_indices.device,
dtype=logit_start_indices.dtype,
)
accepted_logit_indices = (
logit_start_indices.unsqueeze(1) + offsets.unsqueeze(0)
).flatten()
accepted_logit_indices.clamp_(max=final_logits.shape[0] - 1)
accepted_tokens = sampled_token_ids.clone().flatten()
# we replace rejected token ids with 0 to avoid gather_logprobs error
accepted_tokens[accepted_tokens == PLACEHOLDER_TOKEN_ID] = 0
# Compute logprobs for accepted tokens. # Compute logprobs for accepted tokens.
accepted_logits = final_logits[accepted_logit_indices] accepted_logits = final_logits[accepted_logit_indices]
...@@ -200,7 +209,6 @@ class RejectionSampler(nn.Module): ...@@ -200,7 +209,6 @@ class RejectionSampler(nn.Module):
if self.is_logits_logprobs_mode if self.is_logits_logprobs_mode
else self.sampler.compute_logprobs(accepted_logits) else self.sampler.compute_logprobs(accepted_logits)
) )
accepted_tokens = sampled_token_ids[accepted_mask]
return self.sampler.gather_logprobs( return self.sampler.gather_logprobs(
accepted_logprobs, accepted_logprobs,
max_num_logprobs, max_num_logprobs,
...@@ -212,8 +220,8 @@ class RejectionSampler(nn.Module): ...@@ -212,8 +220,8 @@ class RejectionSampler(nn.Module):
output_token_ids: torch.Tensor, output_token_ids: torch.Tensor,
vocab_size: int, vocab_size: int,
discard_req_indices: Sequence[int] = (), discard_req_indices: Sequence[int] = (),
return_cu_num_tokens: bool = False, logprobs_tensors: LogprobsTensors | None = None,
) -> tuple[list[list[int]], list[int] | None]: ) -> tuple[list[list[int]], LogprobsLists | None]:
"""Parse the output of the rejection sampler. """Parse the output of the rejection sampler.
Args: Args:
output_token_ids: The sampled token IDs in shape output_token_ids: The sampled token IDs in shape
...@@ -222,7 +230,7 @@ class RejectionSampler(nn.Module): ...@@ -222,7 +230,7 @@ class RejectionSampler(nn.Module):
and will be filtered out in this function. and will be filtered out in this function.
vocab_size: The size of the vocabulary. vocab_size: The size of the vocabulary.
discard_req_indices: Optional row indices to discard tokens in. discard_req_indices: Optional row indices to discard tokens in.
return_cu_num_tokens: Whether to also return cumulative token counts. logprobs_tensors: Optional logprobs tensors to filter.
Returns: Returns:
A list of lists of token IDs. A list of lists of token IDs.
""" """
...@@ -231,15 +239,18 @@ class RejectionSampler(nn.Module): ...@@ -231,15 +239,18 @@ class RejectionSampler(nn.Module):
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size output_token_ids_np < vocab_size
) )
cu_num_tokens = None output_logprobs = None
if return_cu_num_tokens: if logprobs_tensors is not None:
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist() cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
filtered_tensors = logprobs_tensors.filter(valid_mask.flatten())
output_logprobs = filtered_tensors.tolists(cu_num_tokens)
if len(discard_req_indices) > 0: if len(discard_req_indices) > 0:
valid_mask[discard_req_indices] = False valid_mask[discard_req_indices] = False
outputs = [ outputs = [
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
] ]
return outputs, cu_num_tokens return outputs, output_logprobs
def apply_logits_processors( def apply_logits_processors(
self, self,
......
...@@ -237,19 +237,20 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): ...@@ -237,19 +237,20 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices: for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i].clear()
cu_num_tokens = None logprobs_lists = None
if self._logprobs_tensors_cpu is not None:
logprobs_lists = self._logprobs_tensors_cpu.tolists()
else: else:
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output( valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output(
self.sampled_token_ids_cpu, self.sampled_token_ids_cpu,
self.vocab_size, self.vocab_size,
self._invalid_req_indices, self._invalid_req_indices,
return_cu_num_tokens=self._logprobs_tensors_cpu is not None, logprobs_tensors=self._logprobs_tensors_cpu,
) )
output = self._model_runner_output output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids output.sampled_token_ids = valid_sampled_token_ids
if self._logprobs_tensors_cpu: output.logprobs = logprobs_lists
output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens)
return output return output
...@@ -395,6 +396,9 @@ class GPUModelRunner( ...@@ -395,6 +396,9 @@ class GPUModelRunner(
else: else:
self.max_encoder_len = 0 self.max_encoder_len = 0
# Async scheduling
self.use_async_scheduling = self.scheduler_config.async_scheduling
# Sampler # Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
...@@ -504,7 +508,6 @@ class GPUModelRunner( ...@@ -504,7 +508,6 @@ class GPUModelRunner(
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
) )
self.use_async_scheduling = self.scheduler_config.async_scheduling
# Separate cuda stream for overlapping transfer of sampled token ids from # Separate cuda stream for overlapping transfer of sampled token ids from
# GPU to CPU when async scheduling is enabled. # GPU to CPU when async scheduling is enabled.
self.async_output_copy_stream: torch.cuda.Stream | None = None self.async_output_copy_stream: torch.cuda.Stream | None = None
...@@ -2784,7 +2787,7 @@ class GPUModelRunner( ...@@ -2784,7 +2787,7 @@ class GPUModelRunner(
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
logprobs_tensors = sampler_output.logprobs_tensors logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = [] invalid_req_indices = []
cu_num_tokens: list[int] | None = None logprobs_lists = None
if not self.use_async_scheduling: if not self.use_async_scheduling:
# Get the valid generated tokens. # Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1] max_gen_len = sampled_token_ids.shape[-1]
...@@ -2794,13 +2797,16 @@ class GPUModelRunner( ...@@ -2794,13 +2797,16 @@ class GPUModelRunner(
# Mask out the sampled tokens that should not be sampled. # Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear() valid_sampled_token_ids[int(i)].clear()
if logprobs_tensors is not None:
logprobs_lists = logprobs_tensors.tolists()
else: else:
# Includes spec decode tokens. # Includes spec decode tokens.
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output( valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output(
sampled_token_ids, sampled_token_ids,
self.input_batch.vocab_size, self.input_batch.vocab_size,
discard_sampled_tokens_req_indices, discard_sampled_tokens_req_indices,
return_cu_num_tokens=logprobs_tensors is not None, logprobs_tensors=logprobs_tensors,
) )
else: else:
valid_sampled_token_ids = [] valid_sampled_token_ids = []
...@@ -2853,12 +2859,6 @@ class GPUModelRunner( ...@@ -2853,12 +2859,6 @@ class GPUModelRunner(
req_state = self.requests[req_id] req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids) req_state.output_token_ids.extend(sampled_ids)
logprobs_lists = (
logprobs_tensors.tolists(cu_num_tokens)
if not self.use_async_scheduling and logprobs_tensors is not None
else None
)
# Compute prompt logprobs if needed. # Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict( prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens], hidden_states[:num_scheduled_tokens],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment