Unverified Commit 51e971d3 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Support `eos_token_id` from `config.json` (#5954)

parent 329df38f
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
:meth:`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
def test_get_llama3_eos_token():
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009
generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128009]
def test_get_blip2_eos_token():
model_name = "Salesforce/blip2-opt-2.7b"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2
generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union from typing import Set, Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ObservabilityConfig, LoRAConfig, ModelConfig, ObservabilityConfig,
...@@ -34,6 +34,7 @@ from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, ...@@ -34,6 +34,7 @@ from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
SequenceStatus) SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group) get_tokenizer_group)
...@@ -46,16 +47,18 @@ logger = init_logger(__name__) ...@@ -46,16 +47,18 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig): def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
try: config = try_get_generation_config(
return GenerationConfig.from_pretrained( model_config.model,
model_config.model, trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision, revision=model_config.revision,
).to_diff_dict() )
except OSError:
# Not found. if config is None:
return {} return {}
return config.to_diff_dict()
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
......
import contextlib import contextlib
from typing import Dict, Optional, Type from typing import Dict, Optional, Type
from transformers import PretrainedConfig from transformers import GenerationConfig, PretrainedConfig
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -80,3 +80,25 @@ def get_hf_text_config(config: PretrainedConfig): ...@@ -80,3 +80,25 @@ def get_hf_text_config(config: PretrainedConfig):
return config.text_config return config.text_config
else: else:
return config return config
def try_get_generation_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
) -> Optional[GenerationConfig]:
try:
return GenerationConfig.from_pretrained(
model,
revision=revision,
)
except OSError: # Not found
try:
config = get_config(
model,
trust_remote_code=trust_remote_code,
revision=revision,
)
return GenerationConfig.from_model_config(config)
except OSError: # Not found
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment