Unverified Commit a134ef6f authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

Support eos_token_id from generation_config.json (#4182)

parent 8a7a3e44
import time
from typing import Iterable, List, Optional, Type, Union
from transformers import PreTrainedTokenizer
from transformers import GenerationConfig, PreTrainedTokenizer
import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
......@@ -34,6 +34,17 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig):
try:
return GenerationConfig.from_pretrained(
model_config.model,
revision=model_config.revision,
).to_diff_dict()
except OSError:
# Not found.
return {}
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
......@@ -124,6 +135,8 @@ class LLMEngine:
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.model_executor = executor_class(
model_config=model_config,
......@@ -391,6 +404,8 @@ class LLMEngine:
# inject the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id
sampling_params.update_from_generation_config(
self.generation_config_fields)
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
......@@ -435,7 +450,7 @@ class LLMEngine:
scheduled_seq_groups: List[SequenceGroup],
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
"""
......
......@@ -2,7 +2,7 @@
import copy
from enum import IntEnum
from functools import cached_property
from typing import Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from pydantic import Field
......@@ -271,6 +271,18 @@ class SamplingParams:
raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.")
def update_from_generation_config(
self, generation_config: Dict[str, Any]) -> None:
"""Update if there are non-default values from generation_config"""
# Update eos_token_id for generation
if eos_ids := generation_config.get("eos_token_id"):
# it can be either int or list of int
if isinstance(eos_ids, int):
eos_ids = [eos_ids]
original_stop_token_ids = set(self.stop_token_ids)
original_stop_token_ids.update(eos_ids)
self.stop_token_ids = list(original_stop_token_ids)
@cached_property
def sampling_type(self) -> SamplingType:
if self.use_beam_search:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment