Commit b70256d7 authored by 王敏's avatar 王敏
Browse files

[fix]解决宽松mtp引入的同步问题

parent d73be361
...@@ -321,7 +321,7 @@ def bind_kv_cache( ...@@ -321,7 +321,7 @@ def bind_kv_cache(
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> torch.Tensor: length: int, repeat_counts: Optional[torch.Tensor] = None) -> torch.Tensor:
""" """
Copy the first length elements of a tensor into another tensor in a Copy the first length elements of a tensor into another tensor in a
non-blocking manner. non-blocking manner.
...@@ -330,6 +330,11 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, ...@@ -330,6 +330,11 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
Returns the sliced target tensor. Returns the sliced target tensor.
""" """
if repeat_counts is not None:
from_tensor_tmp = torch.repeat_interleave(from_tensor[:length], repeat_counts, dim=0)
length = torch.sum(repeat_counts).item()
from_tensor[:length].copy_(from_tensor_tmp)
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
......
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
import torch import torch
from vllm import envs from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -79,6 +80,10 @@ class InputBatch: ...@@ -79,6 +80,10 @@ class InputBatch:
is_spec_decode: bool = False, is_spec_decode: bool = False,
logits_processing_needs_token_ids: bool = False, logits_processing_needs_token_ids: bool = False,
): ):
ori_max_num_reqs = max_num_reqs
if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
vllm_config = get_current_vllm_config()
max_num_reqs = max_num_reqs * (1 + vllm_config.speculative_config.num_speculative_tokens)
self.is_spec_decode = is_spec_decode self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
...@@ -97,7 +102,7 @@ class InputBatch: ...@@ -97,7 +102,7 @@ class InputBatch:
# This buffer is not directly transferred to the GPU, so it does not # This buffer is not directly transferred to the GPU, so it does not
# need to be pinned. # need to be pinned.
self.token_ids_cpu_tensor = torch.zeros( self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len), (ori_max_num_reqs, max_model_len),
device="cpu", device="cpu",
dtype=torch.int32, dtype=torch.int32,
pin_memory=False, pin_memory=False,
...@@ -651,36 +656,44 @@ class InputBatch: ...@@ -651,36 +656,44 @@ class InputBatch:
or repeat_counts is not None or repeat_counts is not None
or self._sampling_metadata_is_expanded) or self._sampling_metadata_is_expanded)
if needs_rebuild: if needs_rebuild:
if repeat_counts is None: # if repeat_counts is None:
self.sampling_metadata = self._make_sampling_metadata() # self.sampling_metadata = self._make_sampling_metadata()
else: # else:
self.sampling_metadata = self._make_sampling_metadata_expanded( # self.sampling_metadata = self._make_sampling_metadata_expanded(
repeat_counts) # repeat_counts)
self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
self._sampling_metadata_is_expanded = repeat_counts is not None self._sampling_metadata_is_expanded = repeat_counts is not None
# Expanded metadata is built on demand; do not cache a copy here. # Expanded metadata is built on demand; do not cache a copy here.
def _make_sampling_metadata(self) -> SamplingMetadata: def _make_sampling_metadata(self, repeat_counts: Optional[torch.Tensor] = None) -> SamplingMetadata:
num_reqs = self.num_reqs num_reqs = self.num_reqs
if not self.all_greedy: if not self.all_greedy:
temperature = copy_slice(self.temperature_cpu_tensor, temperature = copy_slice(self.temperature_cpu_tensor,
self.temperature, num_reqs) self.temperature, num_reqs,
repeat_counts)
else: else:
temperature = None temperature = None
if not self.no_top_p: if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) top_p = copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs, repeat_counts)
if not self.no_top_k: if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) top_k = copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs, repeat_counts)
frequency_penalties = None
presence_penalties = None
repetition_penalties = None
if not self.no_penalties: if not self.no_penalties:
# Since syncing these tensors is expensive only copy them # Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require # if necessary i.e. if there are requests which require
# penalties to be applied during sampling. # penalties to be applied during sampling.
copy_slice(self.frequency_penalties_cpu_tensor, frequency_penalties = copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs) self.frequency_penalties, num_reqs,
copy_slice(self.presence_penalties_cpu_tensor, repeat_counts)
self.presence_penalties, num_reqs) presence_penalties = copy_slice(self.presence_penalties_cpu_tensor,
copy_slice(self.repetition_penalties_cpu_tensor, self.presence_penalties, num_reqs,
self.repetition_penalties, num_reqs) repeat_counts)
repetition_penalties = copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs,
repeat_counts)
needs_prompt_token_ids = (not self.no_penalties or needs_prompt_token_ids = (not self.no_penalties or
(self.num_reqs > 0 (self.num_reqs > 0
...@@ -697,9 +710,9 @@ class InputBatch: ...@@ -697,9 +710,9 @@ class InputBatch:
allowed_token_ids_mask: Optional[torch.Tensor] = None allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids: if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor, allowed_token_ids_mask = copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs) self.allowed_token_ids_mask, num_reqs,
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] repeat_counts)
# Host-side summaries to avoid device synchronization in sampling # Host-side summaries to avoid device synchronization in sampling
# fast paths (e.g. reduced top-k/top-p sampling). # fast paths (e.g. reduced top-k/top-p sampling).
...@@ -714,14 +727,14 @@ class InputBatch: ...@@ -714,14 +727,14 @@ class InputBatch:
temperature=temperature, temperature=temperature,
all_greedy=self.all_greedy, all_greedy=self.all_greedy,
all_random=self.all_random, all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs], top_p=None if self.no_top_p else top_p,
top_k=None if self.no_top_k else self.top_k[:num_reqs], top_k=None if self.no_top_k else top_k,
generators=self.generators, generators=self.generators,
max_num_logprobs=self.max_num_logprobs, max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs], frequency_penalties=None if self.no_penalties else frequency_penalties,
presence_penalties=self.presence_penalties[:num_reqs], presence_penalties=None if self.no_penalties else presence_penalties,
repetition_penalties=self.repetition_penalties[:num_reqs], repetition_penalties=None if self.no_penalties else repetition_penalties,
output_token_ids=cast(list[list[int]], self.req_output_token_ids), output_token_ids=cast(list[list[int]], self.req_output_token_ids),
no_penalties=self.no_penalties, no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask, allowed_token_ids_mask=allowed_token_ids_mask,
......
...@@ -586,17 +586,19 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -586,17 +586,19 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Refresh batch metadata with any pending updates. If we are in spec # Refresh batch metadata with any pending updates. If we are in spec
# decode + reject mode, also expand sampling metadata to token shape # decode + reject mode, also expand sampling metadata to token shape
# using per-request repeat counts. # using per-request repeat counts.
repeat_counts: Optional[torch.Tensor] = None repeat_counts = None
if envs.VLLM_REJECT_SAMPLE_OPT and \ if envs.VLLM_REJECT_SAMPLE_OPT and \
scheduler_output.scheduled_spec_decode_tokens: scheduler_output.scheduled_spec_decode_tokens:
num_reqs = self.input_batch.num_reqs repeat_counts = [1] * self.input_batch.num_reqs
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) #num_reqs = self.input_batch.num_reqs
#num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in ( for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()): scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index.get(req_id) req_idx = self.input_batch.req_id_to_index.get(req_id)
if req_idx is not None: if req_idx is not None:
num_draft_tokens[req_idx] = len(draft_token_ids) repeat_counts[req_idx] += len(draft_token_ids)
repeat_counts = torch.from_numpy(num_draft_tokens).add_(1) repeat_counts = torch.tensor(repeat_counts, dtype=torch.int32, device="cpu")
self.input_batch.refresh_metadata(repeat_counts) self.input_batch.refresh_metadata(repeat_counts)
def _get_cumsum_and_arange( def _get_cumsum_and_arange(
...@@ -1565,8 +1567,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1565,8 +1567,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
else: else:
sampling_metadata.all_greedy = True
sampling_metadata.all_random = False
sampler_output = self.sampler( sampler_output = self.sampler(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
...@@ -3431,8 +3431,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3431,8 +3431,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
else: else:
# sampling_metadata.all_greedy = True
# sampling_metadata.all_random = False
sampler_output = self.sampler( sampler_output = self.sampler(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment