Unverified Commit a134ef6f authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

Support eos_token_id from generation_config.json (#4182)

parent 8a7a3e44
import time import time
from typing import Iterable, List, Optional, Type, Union from typing import Iterable, List, Optional, Type, Union
from transformers import PreTrainedTokenizer from transformers import GenerationConfig, PreTrainedTokenizer
import vllm import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
...@@ -34,6 +34,17 @@ logger = init_logger(__name__) ...@@ -34,6 +34,17 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig):
try:
return GenerationConfig.from_pretrained(
model_config.model,
revision=model_config.revision,
).to_diff_dict()
except OSError:
# Not found.
return {}
class LLMEngine: class LLMEngine:
"""An LLM engine that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
...@@ -124,6 +135,8 @@ class LLMEngine: ...@@ -124,6 +135,8 @@ class LLMEngine:
self._init_tokenizer() self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer) self.detokenizer = Detokenizer(self.tokenizer)
self.seq_counter = Counter() self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.model_executor = executor_class( self.model_executor = executor_class(
model_config=model_config, model_config=model_config,
...@@ -391,6 +404,8 @@ class LLMEngine: ...@@ -391,6 +404,8 @@ class LLMEngine:
# inject the eos token id into the sampling_params to support min_tokens # inject the eos token id into the sampling_params to support min_tokens
# processing # processing
sampling_params.eos_token_id = seq.eos_token_id sampling_params.eos_token_id = seq.eos_token_id
sampling_params.update_from_generation_config(
self.generation_config_fields)
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, seq_group = SequenceGroup(request_id, [seq], sampling_params,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import copy import copy
from enum import IntEnum from enum import IntEnum
from functools import cached_property from functools import cached_property
from typing import Callable, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from pydantic import Field from pydantic import Field
...@@ -271,6 +271,18 @@ class SamplingParams: ...@@ -271,6 +271,18 @@ class SamplingParams:
raise ValueError("best_of must be 1 when using greedy sampling." raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.") f"Got {self.best_of}.")
def update_from_generation_config(
self, generation_config: Dict[str, Any]) -> None:
"""Update if there are non-default values from generation_config"""
# Update eos_token_id for generation
if eos_ids := generation_config.get("eos_token_id"):
# it can be either int or list of int
if isinstance(eos_ids, int):
eos_ids = [eos_ids]
original_stop_token_ids = set(self.stop_token_ids)
original_stop_token_ids.update(eos_ids)
self.stop_token_ids = list(original_stop_token_ids)
@cached_property @cached_property
def sampling_type(self) -> SamplingType: def sampling_type(self) -> SamplingType:
if self.use_beam_search: if self.use_beam_search:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment