Unverified Commit aed16879 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Move `ModelConfig` from `config/__init__.py` to `config/model.py` (#25252)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent cf278ff3
...@@ -39,7 +39,8 @@ from vllm import LLM, SamplingParams ...@@ -39,7 +39,8 @@ from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.config import ConvertOption, RunnerOption, _get_and_verify_dtype from vllm.config.model import (ConvertOption, RunnerOption,
_get_and_verify_dtype)
from vllm.connections import global_http_connection from vllm.connections import global_http_connection
from vllm.distributed import (cleanup_dist_env_and_memory, from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment, init_distributed_environment,
......
...@@ -14,7 +14,7 @@ from typing import Literal, NamedTuple, Optional ...@@ -14,7 +14,7 @@ from typing import Literal, NamedTuple, Optional
import pytest import pytest
from vllm.config import _FLOAT16_NOT_SUPPORTED_MODELS, RunnerOption from vllm.config.model import _FLOAT16_NOT_SUPPORTED_MODELS, RunnerOption
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
......
...@@ -7,7 +7,6 @@ from unittest.mock import patch ...@@ -7,7 +7,6 @@ from unittest.mock import patch
import pytest import pytest
from vllm import LLM from vllm import LLM
from vllm.config import ModelImpl
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
from vllm.utils import GiB_bytes from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
...@@ -111,8 +110,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, ...@@ -111,8 +110,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
# these tests seem to produce leftover memory # these tests seem to produce leftover memory
gpu_memory_utilization=0.80, gpu_memory_utilization=0.80,
load_format="dummy", load_format="dummy",
model_impl=ModelImpl.TRANSFORMERS model_impl="transformers"
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM, if model_arch in _TRANSFORMERS_BACKEND_MODELS else "vllm",
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
max_num_seqs=model_info.max_num_seqs) max_num_seqs=model_info.max_num_seqs)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import itertools import itertools
from collections.abc import Generator from collections.abc import Generator
from typing import get_args
import pytest import pytest
import torch import torch
...@@ -464,7 +465,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): ...@@ -464,7 +465,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
assert len(prompt_logprob) == vocab_size assert len(prompt_logprob) == vocab_size
@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode)) @pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
def test_logprobs_mode(logprobs_mode: LogprobsMode, def test_logprobs_mode(logprobs_mode: LogprobsMode,
monkeypatch: pytest.MonkeyPatch): monkeypatch: pytest.MonkeyPatch):
"""Test with LLM engine with different logprobs_mode. """Test with LLM engine with different logprobs_mode.
...@@ -493,14 +494,12 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode, ...@@ -493,14 +494,12 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode,
for logprobs in output.logprobs: for logprobs in output.logprobs:
for token_id in logprobs: for token_id in logprobs:
logprob = logprobs[token_id] logprob = logprobs[token_id]
if logprobs_mode in (LogprobsMode.RAW_LOGPROBS, if logprobs_mode in ("raw_logprobs", "processed_logprobs"):
LogprobsMode.PROCESSED_LOGPROBS):
assert logprob.logprob <= 0 assert logprob.logprob <= 0
if logprob.logprob > 0: if logprob.logprob > 0:
positive_values = positive_values + 1 positive_values = positive_values + 1
total_token_with_logprobs = total_token_with_logprobs + 1 total_token_with_logprobs = total_token_with_logprobs + 1
assert total_token_with_logprobs >= len(results[0].outputs) assert total_token_with_logprobs >= len(results[0].outputs)
if logprobs_mode in (LogprobsMode.RAW_LOGITS, if logprobs_mode in ("raw_logits", "processed_logits"):
LogprobsMode.PROCESSED_LOGITS):
assert positive_values > 0 assert positive_values > 0
del llm del llm
This diff is collapsed.
This diff is collapsed.
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import hashlib import hashlib
from dataclasses import field from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
from pydantic import SkipValidation, model_validator from pydantic import SkipValidation, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
...@@ -15,13 +15,9 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, ...@@ -15,13 +15,9 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS) POOLING_MODEL_MAX_NUM_BATCHED_TOKENS)
if TYPE_CHECKING:
from vllm.config import RunnerType
else:
RunnerType = Any
logger = init_logger(__name__) logger = init_logger(__name__)
RunnerType = Literal["generate", "pooling", "draft"]
PreemptionMode = Literal["swap", "recompute"] PreemptionMode = Literal["swap", "recompute"]
SchedulerPolicy = Literal["fcfs", "priority"] SchedulerPolicy = Literal["fcfs", "priority"]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import inspect
import textwrap
from dataclasses import MISSING, Field, field, fields, is_dataclass from dataclasses import MISSING, Field, field, fields, is_dataclass
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, Any, TypeVar
import regex as re
if TYPE_CHECKING: if TYPE_CHECKING:
from _typeshed import DataclassInstance from _typeshed import DataclassInstance
...@@ -45,3 +50,96 @@ def get_field(cls: ConfigType, name: str) -> Field: ...@@ -45,3 +50,96 @@ def get_field(cls: ConfigType, name: str) -> Field:
return field(default=default) return field(default=default)
raise ValueError( raise ValueError(
f"{cls.__name__}.{name} must have a default value or default factory.") f"{cls.__name__}.{name} must have a default value or default factory.")
def contains_object_print(text: str) -> bool:
"""
Check if the text looks like a printed Python object, e.g.
contains any substring matching the pattern: "at 0xFFFFFFF>"
We match against 0x followed by 2-16 hex chars (there's
a max of 16 on a 64-bit system).
Args:
text (str): The text to check
Returns:
result (bool): `True` if a match is found, `False` otherwise.
"""
pattern = r'at 0x[a-fA-F0-9]{2,16}>'
match = re.search(pattern, text)
return match is not None
def assert_hashable(text: str) -> bool:
if not contains_object_print(text):
return True
raise AssertionError(
f"vLLM tried to hash some configs that may have Python objects ids "
f"in them. This is a bug, please file an issue. "
f"Text being hashed: {text}")
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
"""
Get any docstrings placed after attribute assignments in a class body.
https://davidism.com/mit-license/
"""
def pairwise(iterable):
"""
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
Can be removed when Python 3.9 support is dropped.
"""
iterator = iter(iterable)
a = next(iterator, None)
for b in iterator:
yield a, b
a = b
try:
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
except (OSError, KeyError, TypeError):
# HACK: Python 3.13+ workaround - set missing __firstlineno__
# Workaround can be removed after we upgrade to pydantic==2.12.0
with open(inspect.getfile(cls)) as f:
for i, line in enumerate(f):
if f"class {cls.__name__}" in line and ":" in line:
cls.__firstlineno__ = i + 1
break
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
if not isinstance(cls_node, ast.ClassDef):
raise TypeError("Given object was not a class.")
out = {}
# Consider each pair of nodes.
for a, b in pairwise(cls_node.body):
# Must be an assignment then a constant string.
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)):
continue
doc = inspect.cleandoc(b.value.value)
# An assignment can have multiple targets (a = b = v), but an
# annotated assignment only has one target.
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
for target in targets:
# Must be assigning to a plain name.
if not isinstance(target, ast.Name):
continue
out[target.id] = doc
return out
def is_init_field(cls: ConfigType, name: str) -> bool:
return next(f for f in fields(cls) if f.name == name).init
...@@ -27,11 +27,11 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, ...@@ -27,11 +27,11 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
EPLBConfig, HfOverrides, KVEventsConfig, EPLBConfig, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode, KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
ModelDType, ModelImpl, ObservabilityConfig, ModelDType, ObservabilityConfig, ParallelConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
RunnerOption, SchedulerConfig, SchedulerPolicy, SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
SpeculativeConfig, StructuredOutputsConfig, StructuredOutputsConfig, TaskOption, TokenizerMode,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs) VllmConfig, get_attr_docs)
from vllm.config.multimodal import MMCacheType, MultiModalConfig from vllm.config.multimodal import MMCacheType, MultiModalConfig
from vllm.config.parallel import ExpertPlacementStrategy from vllm.config.parallel import ExpertPlacementStrategy
from vllm.config.utils import get_field from vllm.config.utils import get_field
...@@ -548,7 +548,6 @@ class EngineArgs: ...@@ -548,7 +548,6 @@ class EngineArgs:
model_group.add_argument("--max-logprobs", model_group.add_argument("--max-logprobs",
**model_kwargs["max_logprobs"]) **model_kwargs["max_logprobs"])
model_group.add_argument("--logprobs-mode", model_group.add_argument("--logprobs-mode",
choices=[f.value for f in LogprobsMode],
**model_kwargs["logprobs_mode"]) **model_kwargs["logprobs_mode"])
model_group.add_argument("--disable-sliding-window", model_group.add_argument("--disable-sliding-window",
**model_kwargs["disable_sliding_window"]) **model_kwargs["disable_sliding_window"])
...@@ -593,9 +592,7 @@ class EngineArgs: ...@@ -593,9 +592,7 @@ class EngineArgs:
**model_kwargs["override_generation_config"]) **model_kwargs["override_generation_config"])
model_group.add_argument("--enable-sleep-mode", model_group.add_argument("--enable-sleep-mode",
**model_kwargs["enable_sleep_mode"]) **model_kwargs["enable_sleep_mode"])
model_group.add_argument("--model-impl", model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
choices=[f.value for f in ModelImpl],
**model_kwargs["model_impl"])
model_group.add_argument("--override-attention-dtype", model_group.add_argument("--override-attention-dtype",
**model_kwargs["override_attention_dtype"]) **model_kwargs["override_attention_dtype"])
model_group.add_argument("--logits-processors", model_group.add_argument("--logits-processors",
......
...@@ -13,8 +13,7 @@ from torch import nn ...@@ -13,8 +13,7 @@ from torch import nn
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.attention import Attention from vllm.attention import Attention
from vllm.config import (ModelConfig, ModelImpl, VllmConfig, from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
set_current_vllm_config)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.linear import QKVCrossParallelLinear
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -176,8 +175,8 @@ def get_model_architecture( ...@@ -176,8 +175,8 @@ def get_model_architecture(
) )
if arch == model_config._get_transformers_backend_cls(): if arch == model_config._get_transformers_backend_cls():
assert model_config.model_impl != ModelImpl.VLLM assert model_config.model_impl != "vllm"
if model_config.model_impl == ModelImpl.AUTO: if model_config.model_impl == "auto":
logger.warning_once( logger.warning_once(
"%s has no vLLM implementation, falling back to Transformers " "%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and " "implementation. Some features may not be supported and "
......
...@@ -19,7 +19,7 @@ from typing import Callable, Optional, TypeVar, Union ...@@ -19,7 +19,7 @@ from typing import Callable, Optional, TypeVar, Union
import torch.nn as nn import torch.nn as nn
import transformers import transformers
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults, from vllm.config import (ModelConfig, iter_architecture_defaults,
try_match_architecture_defaults) try_match_architecture_defaults)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.dynamic_module import ( from vllm.transformers_utils.dynamic_module import (
...@@ -587,7 +587,7 @@ class _ModelRegistry: ...@@ -587,7 +587,7 @@ class _ModelRegistry:
if model_module is not None: if model_module is not None:
break break
else: else:
if model_config.model_impl != ModelImpl.TRANSFORMERS: if model_config.model_impl != "transformers":
return None return None
raise ValueError( raise ValueError(
...@@ -598,7 +598,7 @@ class _ModelRegistry: ...@@ -598,7 +598,7 @@ class _ModelRegistry:
"'auto_map' (relevant if the model is custom).") "'auto_map' (relevant if the model is custom).")
if not model_module.is_backend_compatible(): if not model_module.is_backend_compatible():
if model_config.model_impl != ModelImpl.TRANSFORMERS: if model_config.model_impl != "transformers":
return None return None
raise ValueError( raise ValueError(
...@@ -644,20 +644,20 @@ class _ModelRegistry: ...@@ -644,20 +644,20 @@ class _ModelRegistry:
raise ValueError("No model architectures are specified") raise ValueError("No model architectures are specified")
# Require transformers impl # Require transformers impl
if model_config.model_impl == ModelImpl.TRANSFORMERS: if model_config.model_impl == "transformers":
arch = self._try_resolve_transformers(architectures[0], arch = self._try_resolve_transformers(architectures[0],
model_config) model_config)
if arch is not None: if arch is not None:
model_info = self._try_inspect_model_cls(arch) model_info = self._try_inspect_model_cls(arch)
if model_info is not None: if model_info is not None:
return (model_info, arch) return (model_info, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH: elif model_config.model_impl == "terratorch":
model_info = self._try_inspect_model_cls("Terratorch") model_info = self._try_inspect_model_cls("Terratorch")
return (model_info, "Terratorch") return (model_info, "Terratorch")
# Fallback to transformers impl (after resolving convert_type) # Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures) if (all(arch not in self.models for arch in architectures)
and model_config.model_impl == ModelImpl.AUTO and model_config.model_impl == "auto"
and getattr(model_config, "convert_type", "none") == "none"): and getattr(model_config, "convert_type", "none") == "none"):
arch = self._try_resolve_transformers(architectures[0], arch = self._try_resolve_transformers(architectures[0],
model_config) model_config)
...@@ -674,7 +674,7 @@ class _ModelRegistry: ...@@ -674,7 +674,7 @@ class _ModelRegistry:
# Fallback to transformers impl (before resolving runner_type) # Fallback to transformers impl (before resolving runner_type)
if (all(arch not in self.models for arch in architectures) if (all(arch not in self.models for arch in architectures)
and model_config.model_impl == ModelImpl.AUTO): and model_config.model_impl == "auto"):
arch = self._try_resolve_transformers(architectures[0], arch = self._try_resolve_transformers(architectures[0],
model_config) model_config)
if arch is not None: if arch is not None:
...@@ -695,14 +695,14 @@ class _ModelRegistry: ...@@ -695,14 +695,14 @@ class _ModelRegistry:
raise ValueError("No model architectures are specified") raise ValueError("No model architectures are specified")
# Require transformers impl # Require transformers impl
if model_config.model_impl == ModelImpl.TRANSFORMERS: if model_config.model_impl == "transformers":
arch = self._try_resolve_transformers(architectures[0], arch = self._try_resolve_transformers(architectures[0],
model_config) model_config)
if arch is not None: if arch is not None:
model_cls = self._try_load_model_cls(arch) model_cls = self._try_load_model_cls(arch)
if model_cls is not None: if model_cls is not None:
return (model_cls, arch) return (model_cls, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH: elif model_config.model_impl == "terratorch":
arch = "Terratorch" arch = "Terratorch"
model_cls = self._try_load_model_cls(arch) model_cls = self._try_load_model_cls(arch)
if model_cls is not None: if model_cls is not None:
...@@ -710,7 +710,7 @@ class _ModelRegistry: ...@@ -710,7 +710,7 @@ class _ModelRegistry:
# Fallback to transformers impl (after resolving convert_type) # Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures) if (all(arch not in self.models for arch in architectures)
and model_config.model_impl == ModelImpl.AUTO and model_config.model_impl == "auto"
and getattr(model_config, "convert_type", "none") == "none"): and getattr(model_config, "convert_type", "none") == "none"):
arch = self._try_resolve_transformers(architectures[0], arch = self._try_resolve_transformers(architectures[0],
model_config) model_config)
...@@ -727,7 +727,7 @@ class _ModelRegistry: ...@@ -727,7 +727,7 @@ class _ModelRegistry:
# Fallback to transformers impl (before resolving runner_type) # Fallback to transformers impl (before resolving runner_type)
if (all(arch not in self.models for arch in architectures) if (all(arch not in self.models for arch in architectures)
and model_config.model_impl == ModelImpl.AUTO): and model_config.model_impl == "auto"):
arch = self._try_resolve_transformers(architectures[0], arch = self._try_resolve_transformers(architectures[0],
model_config) model_config)
if arch is not None: if arch is not None:
......
...@@ -29,15 +29,12 @@ class TopKTopPSampler(nn.Module): ...@@ -29,15 +29,12 @@ class TopKTopPSampler(nn.Module):
Implementations may update the logits tensor in-place. Implementations may update the logits tensor in-place.
""" """
def __init__( def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
self,
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None:
super().__init__() super().__init__()
self.logprobs_mode = logprobs_mode self.logprobs_mode = logprobs_mode
# flashinfer optimization does not apply if intermediate # flashinfer optimization does not apply if intermediate
# logprobs/logits after top_k/top_p need to be returned # logprobs/logits after top_k/top_p need to be returned
if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS, if logprobs_mode not in ("processed_logits", "processed_logprobs"
LogprobsMode.PROCESSED_LOGPROBS
) and current_platform.is_cuda(): ) and current_platform.is_cuda():
if is_flashinfer_available: if is_flashinfer_available:
flashinfer_version = flashinfer.__version__ flashinfer_version = flashinfer.__version__
...@@ -90,9 +87,9 @@ class TopKTopPSampler(nn.Module): ...@@ -90,9 +87,9 @@ class TopKTopPSampler(nn.Module):
""" """
logits = self.apply_top_k_top_p(logits, k, p) logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None logits_to_return = None
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: if self.logprobs_mode == "processed_logits":
logits_to_return = logits logits_to_return = logits
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), logits_to_return return random_sample(probs, generators), logits_to_return
...@@ -115,7 +112,7 @@ class TopKTopPSampler(nn.Module): ...@@ -115,7 +112,7 @@ class TopKTopPSampler(nn.Module):
"PyTorch-native implementation.") "PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p) return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in ( assert self.logprobs_mode not in (
LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS "processed_logits", "processed_logprobs"
), "FlashInfer does not support returning logits/logprobs" ), "FlashInfer does not support returning logits/logprobs"
# flashinfer sampling functions expect contiguous logits. # flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
......
...@@ -60,8 +60,7 @@ class Sampler(nn.Module): ...@@ -60,8 +60,7 @@ class Sampler(nn.Module):
9. Return the final `SamplerOutput`. 9. Return the final `SamplerOutput`.
""" """
def __init__(self, def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS):
super().__init__() super().__init__()
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode) self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
...@@ -78,9 +77,9 @@ class Sampler(nn.Module): ...@@ -78,9 +77,9 @@ class Sampler(nn.Module):
# is used for sampling (after penalties and temperature scaling). # is used for sampling (after penalties and temperature scaling).
num_logprobs = sampling_metadata.max_num_logprobs num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None: if num_logprobs is not None:
if self.logprobs_mode == LogprobsMode.RAW_LOGPROBS: if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits) raw_logprobs = self.compute_logprobs(logits)
elif self.logprobs_mode == LogprobsMode.RAW_LOGITS: elif self.logprobs_mode == "raw_logits":
raw_logprobs = logits.clone() raw_logprobs = logits.clone()
# Use float32 for the logits. # Use float32 for the logits.
...@@ -156,9 +155,9 @@ class Sampler(nn.Module): ...@@ -156,9 +155,9 @@ class Sampler(nn.Module):
if sampling_metadata.all_greedy: if sampling_metadata.all_greedy:
processed_logprobs = None processed_logprobs = None
if sampling_metadata.max_num_logprobs is not None: if sampling_metadata.max_num_logprobs is not None:
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: if self.logprobs_mode == "processed_logits":
processed_logprobs = logits processed_logprobs = logits
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: elif self.logprobs_mode == "processed_logprobs":
processed_logprobs = self.compute_logprobs(logits) processed_logprobs = self.compute_logprobs(logits)
return greedy_sampled, processed_logprobs return greedy_sampled, processed_logprobs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment