Commit 35017fdf authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm' into 'v0.9.2-dev'

[fix]解决宽松mtp引入的同步问题

See merge request dcutoolkit/deeplearing/vllm!417
parents d73be361 b70256d7
......@@ -321,7 +321,7 @@ def bind_kv_cache(
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> torch.Tensor:
length: int, repeat_counts: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Copy the first length elements of a tensor into another tensor in a
non-blocking manner.
......@@ -330,6 +330,11 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
Returns the sliced target tensor.
"""
if repeat_counts is not None:
from_tensor_tmp = torch.repeat_interleave(from_tensor[:length], repeat_counts, dim=0)
length = torch.sum(repeat_counts).item()
from_tensor[:length].copy_(from_tensor_tmp)
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
......
......@@ -9,6 +9,7 @@ import numpy as np
import torch
from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
......@@ -79,6 +80,10 @@ class InputBatch:
is_spec_decode: bool = False,
logits_processing_needs_token_ids: bool = False,
):
ori_max_num_reqs = max_num_reqs
if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
vllm_config = get_current_vllm_config()
max_num_reqs = max_num_reqs * (1 + vllm_config.speculative_config.num_speculative_tokens)
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
......@@ -97,7 +102,7 @@ class InputBatch:
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
(ori_max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
......@@ -651,36 +656,44 @@ class InputBatch:
or repeat_counts is not None
or self._sampling_metadata_is_expanded)
if needs_rebuild:
if repeat_counts is None:
self.sampling_metadata = self._make_sampling_metadata()
else:
self.sampling_metadata = self._make_sampling_metadata_expanded(
repeat_counts)
# if repeat_counts is None:
# self.sampling_metadata = self._make_sampling_metadata()
# else:
# self.sampling_metadata = self._make_sampling_metadata_expanded(
# repeat_counts)
self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
self._sampling_metadata_is_expanded = repeat_counts is not None
# Expanded metadata is built on demand; do not cache a copy here.
def _make_sampling_metadata(self) -> SamplingMetadata:
def _make_sampling_metadata(self, repeat_counts: Optional[torch.Tensor] = None) -> SamplingMetadata:
num_reqs = self.num_reqs
if not self.all_greedy:
temperature = copy_slice(self.temperature_cpu_tensor,
self.temperature, num_reqs)
self.temperature, num_reqs,
repeat_counts)
else:
temperature = None
if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
top_p = copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs, repeat_counts)
if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
top_k = copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs, repeat_counts)
frequency_penalties = None
presence_penalties = None
repetition_penalties = None
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs)
copy_slice(self.presence_penalties_cpu_tensor,
self.presence_penalties, num_reqs)
copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs)
frequency_penalties = copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs,
repeat_counts)
presence_penalties = copy_slice(self.presence_penalties_cpu_tensor,
self.presence_penalties, num_reqs,
repeat_counts)
repetition_penalties = copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs,
repeat_counts)
needs_prompt_token_ids = (not self.no_penalties or
(self.num_reqs > 0
......@@ -697,9 +710,9 @@ class InputBatch:
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
allowed_token_ids_mask = copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs,
repeat_counts)
# Host-side summaries to avoid device synchronization in sampling
# fast paths (e.g. reduced top-k/top-p sampling).
......@@ -714,14 +727,14 @@ class InputBatch:
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
top_p=None if self.no_top_p else top_p,
top_k=None if self.no_top_k else top_k,
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
frequency_penalties=None if self.no_penalties else frequency_penalties,
presence_penalties=None if self.no_penalties else presence_penalties,
repetition_penalties=None if self.no_penalties else repetition_penalties,
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
......
......@@ -586,17 +586,19 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Refresh batch metadata with any pending updates. If we are in spec
# decode + reject mode, also expand sampling metadata to token shape
# using per-request repeat counts.
repeat_counts: Optional[torch.Tensor] = None
repeat_counts = None
if envs.VLLM_REJECT_SAMPLE_OPT and \
scheduler_output.scheduled_spec_decode_tokens:
num_reqs = self.input_batch.num_reqs
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
repeat_counts = [1] * self.input_batch.num_reqs
#num_reqs = self.input_batch.num_reqs
#num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index.get(req_id)
if req_idx is not None:
num_draft_tokens[req_idx] = len(draft_token_ids)
repeat_counts = torch.from_numpy(num_draft_tokens).add_(1)
repeat_counts[req_idx] += len(draft_token_ids)
repeat_counts = torch.tensor(repeat_counts, dtype=torch.int32, device="cpu")
self.input_batch.refresh_metadata(repeat_counts)
def _get_cumsum_and_arange(
......@@ -1565,8 +1567,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
)
sampler_output.sampled_token_ids = output_token_ids
else:
sampling_metadata.all_greedy = True
sampling_metadata.all_random = False
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
......@@ -3431,8 +3431,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
)
sampler_output.sampled_token_ids = output_token_ids
else:
# sampling_metadata.all_greedy = True
# sampling_metadata.all_random = False
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment