"vscode:/vscode.git/clone" did not exist on "6f1355a1b74e4502e6a4e6ba9a811cc50729ee1f"
Unverified Commit f5710ef0 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Make `LayerBlockType` a `Literal` instead of `Enum` (#27658)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent a8c02fb5
...@@ -41,7 +41,6 @@ from vllm.transformers_utils.config import ( ...@@ -41,7 +41,6 @@ from vllm.transformers_utils.config import (
) )
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import LayerBlockType
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype from vllm.utils.torch_utils import common_broadcastable_dtype
...@@ -91,6 +90,7 @@ LogprobsMode = Literal[ ...@@ -91,6 +90,7 @@ LogprobsMode = Literal[
] ]
HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig] HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig]
ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"] ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"]
LayerBlockType = Literal["attention", "linear_attention", "mamba"]
_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = { _RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {
"generate": ["generate", "transcription"], "generate": ["generate", "transcription"],
...@@ -1433,11 +1433,11 @@ class ModelConfig: ...@@ -1433,11 +1433,11 @@ class ModelConfig:
def get_num_layers_by_block_type( def get_num_layers_by_block_type(
self, self,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
block_type: LayerBlockType = LayerBlockType.attention, block_type: LayerBlockType = "attention",
) -> int: ) -> int:
# This function relies on 'layers_block_type' in hf_config, # This function relies on 'layers_block_type' in hf_config,
# for w/o this attribute, we will need to have workarounds like so # for w/o this attribute, we will need to have workarounds like so
attn_block_type = block_type == LayerBlockType.attention attn_block_type = block_type == "attention"
is_transformer = ( is_transformer = (
not self.is_hybrid and not self.has_noops and not self.is_attention_free not self.is_hybrid and not self.has_noops and not self.is_attention_free
) )
...@@ -1469,9 +1469,7 @@ class ModelConfig: ...@@ -1469,9 +1469,7 @@ class ModelConfig:
) )
else: else:
return self.get_num_layers(parallel_config) return self.get_num_layers(parallel_config)
return sum( return sum(t == block_type for t in layers_block_type_value[start:end])
t == block_type.value for t in layers_block_type_value[start:end]
)
# Hybrid model Minimax # Hybrid model Minimax
attn_type_list = getattr(self.hf_config, "attn_type_list", None) attn_type_list = getattr(self.hf_config, "attn_type_list", None)
...@@ -1481,19 +1479,16 @@ class ModelConfig: ...@@ -1481,19 +1479,16 @@ class ModelConfig:
# Hybrid model Qwen3Next # Hybrid model Qwen3Next
layer_types_value = getattr(self.hf_config, "layer_types", None) layer_types_value = getattr(self.hf_config, "layer_types", None)
if layer_types_value is not None: if layer_types_value is not None:
if getattr(block_type, "value", block_type) == "attention": if block_type == "attention":
return sum( return sum(
t == "full_attention" for t in layer_types_value[start:end] t == "full_attention" for t in layer_types_value[start:end]
) )
elif getattr(block_type, "value", block_type) == "linear_attention": elif block_type == "linear_attention":
return sum( return sum(
t == "linear_attention" for t in layer_types_value[start:end] t == "linear_attention" for t in layer_types_value[start:end]
) )
else: else:
return sum( return sum(t == block_type for t in layer_types_value[start:end])
t == getattr(block_type, "value", block_type)
for t in layer_types_value[start:end]
)
if ( if (
layers_block_type_value is None layers_block_type_value is None
...@@ -1501,10 +1496,9 @@ class ModelConfig: ...@@ -1501,10 +1496,9 @@ class ModelConfig:
and layer_types_value is None and layer_types_value is None
): ):
raise ValueError( raise ValueError(
"The model is an hybrid without a" "The model is an hybrid without a layers_block_type or an "
"layers_block_type or an attn_type_list, or a layer_types " "attn_type_list, or a layer_types in the hf_config, "
"in the hf_config, cannot determine the num of " f"cannot determine the num of {block_type} layers"
f"{block_type.value} layers"
) )
def get_mamba_chunk_size(self) -> int | None: def get_mamba_chunk_size(self) -> int | None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import inspect import inspect
import uuid import uuid
import warnings import warnings
...@@ -67,11 +66,6 @@ STR_INVALID_VAL: str = "INVALID" ...@@ -67,11 +66,6 @@ STR_INVALID_VAL: str = "INVALID"
T = TypeVar("T") T = TypeVar("T")
class LayerBlockType(enum.Enum):
attention = "attention"
mamba = "mamba"
def random_uuid() -> str: def random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
......
...@@ -53,7 +53,6 @@ from vllm.multimodal.inputs import ( ...@@ -53,7 +53,6 @@ from vllm.multimodal.inputs import (
from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import LayerBlockType
from vllm.utils.math_utils import cdiv, prev_power_of_2 from vllm.utils.math_utils import cdiv, prev_power_of_2
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.pallas import ( from vllm.v1.attention.backends.pallas import (
...@@ -212,7 +211,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -212,7 +211,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Model-related. # Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type( self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention parallel_config, "attention"
) )
self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment