Commit ff8b5e11 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.15.1-dev-wm-mtp' into 'v0.15.1-dev'

[feat]宽松mtp支持temp,top-p等参数设置

See merge request dcutoolkit/deeplearing/vllm!420
parents be4dea75 78e20661
......@@ -320,9 +320,8 @@ def shutdown(procs: list[BaseProcess]):
kill_process_tree(pid)
def copy_slice(
from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int
) -> torch.Tensor:
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int, repeat_counts: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Copy the first length elements of a tensor into another tensor in a
non-blocking manner.
......@@ -331,6 +330,11 @@ def copy_slice(
Returns the sliced target tensor.
"""
if repeat_counts is not None:
from_tensor_tmp = torch.repeat_interleave(from_tensor[:length], repeat_counts, dim=0)
length = torch.sum(repeat_counts).item()
from_tensor[:length].copy_(from_tensor_tmp)
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
......
......@@ -3,11 +3,13 @@
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from typing import cast
from typing import Optional, cast
import numpy as np
import torch
from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
......@@ -96,6 +98,11 @@ class InputBatch:
is_pooling_model: bool = False,
cp_kv_cache_interleave_size: int = 1,
):
ori_max_num_reqs = max_num_reqs
if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
vllm_config = get_current_vllm_config()
max_num_reqs = max_num_reqs * (1 + vllm_config.speculative_config.num_speculative_tokens)
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
......@@ -113,7 +120,7 @@ class InputBatch:
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
(ori_max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
......@@ -753,13 +760,13 @@ class InputBatch:
del self.req_output_token_ids[num_reqs:]
del self.spec_token_ids[num_reqs:]
def refresh_metadata(self):
def refresh_metadata(self, repeat_counts: Optional[torch.Tensor] = None):
"""Apply any batch updates to sampling metadata."""
if self.is_pooling_model:
batch_changed = self.batch_update_builder.reset()
if batch_changed:
self.sampling_metadata = self._make_sampling_metadata()
self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
return
# For non-pooling models - generate and apply logitsprocs update;
......@@ -769,36 +776,36 @@ class InputBatch:
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
if batch_update:
self.sampling_metadata = self._make_sampling_metadata()
self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
def _make_sampling_metadata(self) -> SamplingMetadata:
def _make_sampling_metadata(self, repeat_counts: Optional[torch.Tensor] = None) -> SamplingMetadata:
num_reqs = self.num_reqs
if not self.all_greedy:
temperature = copy_slice(
self.temperature_cpu_tensor, self.temperature, num_reqs
self.temperature_cpu_tensor, self.temperature,
num_reqs, repeat_counts
)
else:
temperature = None
if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
top_p = copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs, repeat_counts)
if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
top_k = copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs, repeat_counts)
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
copy_slice(
self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs
)
copy_slice(
self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs
)
copy_slice(
self.repetition_penalties_cpu_tensor,
self.repetition_penalties,
num_reqs,
)
frequency_penalties = copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs,
repeat_counts)
presence_penalties = copy_slice(self.presence_penalties_cpu_tensor,
self.presence_penalties, num_reqs,
repeat_counts)
repetition_penalties = copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs,
repeat_counts)
needs_prompt_token_ids = (
not self.no_penalties
......@@ -828,25 +835,22 @@ class InputBatch:
allowed_token_ids_mask: torch.Tensor | None = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(
self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask,
num_reqs,
)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
allowed_token_ids_mask = copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs,
repeat_counts)
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
top_p=None if self.no_top_p else top_p,
top_k=None if self.no_top_k else top_k,
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
frequency_penalties=None if self.no_penalties else frequency_penalties,
presence_penalties=None if self.no_penalties else presence_penalties,
repetition_penalties=None if self.no_penalties else repetition_penalties,
output_token_ids=output_token_ids,
spec_token_ids=self.spec_token_ids,
no_penalties=self.no_penalties,
......
......@@ -1102,7 +1102,17 @@ class GPUModelRunner(
# Allow attention backend to reorder the batch, potentially
self._may_reorder_batch(scheduler_output)
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
repeat_counts = None
if envs.VLLM_REJECT_SAMPLE_OPT and \
scheduler_output.scheduled_spec_decode_tokens:
repeat_counts = [1] * self.input_batch.num_reqs
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index.get(req_id)
if req_idx is not None:
repeat_counts[req_idx] += len(draft_token_ids)
repeat_counts = torch.tensor(repeat_counts, dtype=torch.int32, device="cpu")
self.input_batch.refresh_metadata(repeat_counts)
def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment