"vllm/vscode:/vscode.git/clone" did not exist on "76e4dcf225e4de115bdc20b00a78d49bec767c09"
Commit 78e20661 authored by 王敏's avatar 王敏
Browse files

[feat]宽松mtp支持temp,top-p等参数设置

parent e807ec39
...@@ -320,9 +320,8 @@ def shutdown(procs: list[BaseProcess]): ...@@ -320,9 +320,8 @@ def shutdown(procs: list[BaseProcess]):
kill_process_tree(pid) kill_process_tree(pid)
def copy_slice( def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int length: int, repeat_counts: Optional[torch.Tensor] = None) -> torch.Tensor:
) -> torch.Tensor:
""" """
Copy the first length elements of a tensor into another tensor in a Copy the first length elements of a tensor into another tensor in a
non-blocking manner. non-blocking manner.
...@@ -331,6 +330,11 @@ def copy_slice( ...@@ -331,6 +330,11 @@ def copy_slice(
Returns the sliced target tensor. Returns the sliced target tensor.
""" """
if repeat_counts is not None:
from_tensor_tmp = torch.repeat_interleave(from_tensor[:length], repeat_counts, dim=0)
length = torch.sum(repeat_counts).item()
from_tensor[:length].copy_(from_tensor_tmp)
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
......
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
# Datastructures defining a GPU input batch # Datastructures defining a GPU input batch
from dataclasses import dataclass from dataclasses import dataclass
from typing import cast from typing import Optional, cast
import numpy as np import numpy as np
import torch import torch
from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -96,6 +98,11 @@ class InputBatch: ...@@ -96,6 +98,11 @@ class InputBatch:
is_pooling_model: bool = False, is_pooling_model: bool = False,
cp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
): ):
ori_max_num_reqs = max_num_reqs
if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
vllm_config = get_current_vllm_config()
max_num_reqs = max_num_reqs * (1 + vllm_config.speculative_config.num_speculative_tokens)
self.is_pooling_model = is_pooling_model self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
...@@ -113,7 +120,7 @@ class InputBatch: ...@@ -113,7 +120,7 @@ class InputBatch:
# This buffer is not directly transferred to the GPU, so it does not # This buffer is not directly transferred to the GPU, so it does not
# need to be pinned. # need to be pinned.
self.token_ids_cpu_tensor = torch.zeros( self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len), (ori_max_num_reqs, max_model_len),
device="cpu", device="cpu",
dtype=torch.int32, dtype=torch.int32,
pin_memory=False, pin_memory=False,
...@@ -753,13 +760,13 @@ class InputBatch: ...@@ -753,13 +760,13 @@ class InputBatch:
del self.req_output_token_ids[num_reqs:] del self.req_output_token_ids[num_reqs:]
del self.spec_token_ids[num_reqs:] del self.spec_token_ids[num_reqs:]
def refresh_metadata(self): def refresh_metadata(self, repeat_counts: Optional[torch.Tensor] = None):
"""Apply any batch updates to sampling metadata.""" """Apply any batch updates to sampling metadata."""
if self.is_pooling_model: if self.is_pooling_model:
batch_changed = self.batch_update_builder.reset() batch_changed = self.batch_update_builder.reset()
if batch_changed: if batch_changed:
self.sampling_metadata = self._make_sampling_metadata() self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
return return
# For non-pooling models - generate and apply logitsprocs update; # For non-pooling models - generate and apply logitsprocs update;
...@@ -769,36 +776,36 @@ class InputBatch: ...@@ -769,36 +776,36 @@ class InputBatch:
for logit_proc in self.logitsprocs.all: for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update) logit_proc.update_state(batch_update)
if batch_update: if batch_update:
self.sampling_metadata = self._make_sampling_metadata() self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
def _make_sampling_metadata(self) -> SamplingMetadata: def _make_sampling_metadata(self, repeat_counts: Optional[torch.Tensor] = None) -> SamplingMetadata:
num_reqs = self.num_reqs num_reqs = self.num_reqs
if not self.all_greedy: if not self.all_greedy:
temperature = copy_slice( temperature = copy_slice(
self.temperature_cpu_tensor, self.temperature, num_reqs self.temperature_cpu_tensor, self.temperature,
num_reqs, repeat_counts
) )
else: else:
temperature = None temperature = None
if not self.no_top_p: if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) top_p = copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs, repeat_counts)
if not self.no_top_k: if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) top_k = copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs, repeat_counts)
if not self.no_penalties: if not self.no_penalties:
# Since syncing these tensors is expensive only copy them # Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require # if necessary i.e. if there are requests which require
# penalties to be applied during sampling. # penalties to be applied during sampling.
copy_slice( frequency_penalties = copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs self.frequency_penalties, num_reqs,
) repeat_counts)
copy_slice( presence_penalties = copy_slice(self.presence_penalties_cpu_tensor,
self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs self.presence_penalties, num_reqs,
) repeat_counts)
copy_slice( repetition_penalties = copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties_cpu_tensor, self.repetition_penalties, num_reqs,
self.repetition_penalties, repeat_counts)
num_reqs,
)
needs_prompt_token_ids = ( needs_prompt_token_ids = (
not self.no_penalties not self.no_penalties
...@@ -828,25 +835,22 @@ class InputBatch: ...@@ -828,25 +835,22 @@ class InputBatch:
allowed_token_ids_mask: torch.Tensor | None = None allowed_token_ids_mask: torch.Tensor | None = None
if not self.no_allowed_token_ids: if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None assert self.allowed_token_ids_mask is not None
copy_slice( allowed_token_ids_mask = copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask_cpu_tensor, self.allowed_token_ids_mask, num_reqs,
self.allowed_token_ids_mask, repeat_counts)
num_reqs,
)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
return SamplingMetadata( return SamplingMetadata(
temperature=temperature, temperature=temperature,
all_greedy=self.all_greedy, all_greedy=self.all_greedy,
all_random=self.all_random, all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs], top_p=None if self.no_top_p else top_p,
top_k=None if self.no_top_k else self.top_k[:num_reqs], top_k=None if self.no_top_k else top_k,
generators=self.generators, generators=self.generators,
max_num_logprobs=self.max_num_logprobs, max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs], frequency_penalties=None if self.no_penalties else frequency_penalties,
presence_penalties=self.presence_penalties[:num_reqs], presence_penalties=None if self.no_penalties else presence_penalties,
repetition_penalties=self.repetition_penalties[:num_reqs], repetition_penalties=None if self.no_penalties else repetition_penalties,
output_token_ids=output_token_ids, output_token_ids=output_token_ids,
spec_token_ids=self.spec_token_ids, spec_token_ids=self.spec_token_ids,
no_penalties=self.no_penalties, no_penalties=self.no_penalties,
......
...@@ -1102,7 +1102,17 @@ class GPUModelRunner( ...@@ -1102,7 +1102,17 @@ class GPUModelRunner(
# Allow attention backend to reorder the batch, potentially # Allow attention backend to reorder the batch, potentially
self._may_reorder_batch(scheduler_output) self._may_reorder_batch(scheduler_output)
# Refresh batch metadata with any pending updates. # Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata() repeat_counts = None
if envs.VLLM_REJECT_SAMPLE_OPT and \
scheduler_output.scheduled_spec_decode_tokens:
repeat_counts = [1] * self.input_batch.num_reqs
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index.get(req_id)
if req_idx is not None:
repeat_counts[req_idx] += len(draft_token_ids)
repeat_counts = torch.tensor(repeat_counts, dtype=torch.int32, device="cpu")
self.input_batch.refresh_metadata(repeat_counts)
def _update_states_after_model_execute( def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput" self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment