Unverified Commit 81f09cfd authored by Went-Liang's avatar Went-Liang Committed by GitHub
Browse files

[Model] Support math-shepherd-mistral-7b-prm model (#9697)


Signed-off-by: default avatarWent-Liang <wenteng_liang@163.com>
parent cc98f1e0
...@@ -112,38 +112,58 @@ class ModelConfig: ...@@ -112,38 +112,58 @@ class ModelConfig:
Defaults to 'auto' which defaults to 'hf'. Defaults to 'auto' which defaults to 'hf'.
mm_processor_kwargs: Arguments to be forwarded to the model's processor mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor. for multi-modal data, e.g., image processor.
pooling_type: Used to configure the pooling method in the embedding
model.
pooling_norm: Used to determine whether to normalize the pooled
data in the embedding model.
pooling_softmax: Used to determine whether to softmax the pooled
data in the embedding model.
pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates
that the score corresponding to the pooling_step_tag_id in the
generated sentence should be returned. Otherwise, it returns
the scores for all tokens.
pooling_returned_token_ids: pooling_returned_token_ids represents a
list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of good_token and bad_token in the
math-shepherd-mistral-7b-prm model.
""" """
def __init__(self, def __init__(
model: str, self,
task: Union[TaskOption, _Task], model: str,
tokenizer: str, task: Union[TaskOption, _Task],
tokenizer_mode: str, tokenizer: str,
trust_remote_code: bool, tokenizer_mode: str,
dtype: Union[str, torch.dtype], trust_remote_code: bool,
seed: int, dtype: Union[str, torch.dtype],
revision: Optional[str] = None, seed: int,
code_revision: Optional[str] = None, revision: Optional[str] = None,
rope_scaling: Optional[dict] = None, code_revision: Optional[str] = None,
rope_theta: Optional[float] = None, rope_scaling: Optional[dict] = None,
tokenizer_revision: Optional[str] = None, rope_theta: Optional[float] = None,
max_model_len: Optional[int] = None, tokenizer_revision: Optional[str] = None,
spec_target_max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None, spec_target_max_model_len: Optional[int] = None,
quantization_param_path: Optional[str] = None, quantization: Optional[str] = None,
enforce_eager: Optional[bool] = None, quantization_param_path: Optional[str] = None,
max_context_len_to_capture: Optional[int] = None, enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
max_logprobs: int = 20, max_seq_len_to_capture: Optional[int] = None,
disable_sliding_window: bool = False, max_logprobs: int = 20,
skip_tokenizer_init: bool = False, disable_sliding_window: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None, skip_tokenizer_init: bool = False,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None, served_model_name: Optional[Union[str, List[str]]] = None,
use_async_output_proc: bool = True, limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
override_neuron_config: Optional[Dict[str, Any]] = None, use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO, override_neuron_config: Optional[Dict[str, Any]] = None,
chat_template_text_format: str = "string", config_format: ConfigFormat = ConfigFormat.AUTO,
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None: chat_template_text_format: str = "string",
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
...@@ -224,6 +244,13 @@ class ModelConfig: ...@@ -224,6 +244,13 @@ class ModelConfig:
supported_tasks, task = self._resolve_task(task, self.hf_config) supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks self.supported_tasks = supported_tasks
self.task: Final = task self.task: Final = task
self.pooler_config = self._init_pooler_config(
pooling_type,
pooling_norm,
pooling_softmax,
pooling_step_tag_id,
pooling_returned_token_ids,
)
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()
...@@ -242,6 +269,23 @@ class ModelConfig: ...@@ -242,6 +269,23 @@ class ModelConfig:
return None return None
def _init_pooler_config(
self,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None
) -> Optional["PoolerConfig"]:
if self.task == "embedding":
return PoolerConfig(
pooling_type=pooling_type,
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids)
return None
def _init_attention_free(self) -> bool: def _init_attention_free(self) -> bool:
architectures = getattr(self.hf_config, "architectures", []) architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_attention_free_model(architectures) return ModelRegistry.is_attention_free_model(architectures)
...@@ -1647,6 +1691,17 @@ class MultiModalConfig: ...@@ -1647,6 +1691,17 @@ class MultiModalConfig:
# TODO: Add configs to init vision tower or not. # TODO: Add configs to init vision tower or not.
@dataclass
class PoolerConfig:
"""Controls the behavior of pooler in embedding model"""
pooling_type: Optional[str] = None
pooling_norm: Optional[bool] = None
pooling_softmax: Optional[bool] = None
pooling_step_tag_id: Optional[int] = None
pooling_returned_token_ids: Optional[List[int]] = None
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16, "half": torch.float16,
"float16": torch.float16, "float16": torch.float16,
......
...@@ -184,6 +184,13 @@ class EngineArgs: ...@@ -184,6 +184,13 @@ class EngineArgs:
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
scheduling_policy: Literal["fcfs", "priority"] = "fcfs" scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
# Pooling configuration.
pooling_type: Optional[str] = None
pooling_norm: Optional[bool] = None
pooling_softmax: Optional[bool] = None
pooling_step_tag_id: Optional[int] = None
pooling_returned_token_ids: Optional[List[int]] = None
def __post_init__(self): def __post_init__(self):
if not self.tokenizer: if not self.tokenizer:
self.tokenizer = self.model self.tokenizer = self.model
...@@ -850,6 +857,58 @@ class EngineArgs: ...@@ -850,6 +857,58 @@ class EngineArgs:
'priority (lower value means earlier handling) and time of ' 'priority (lower value means earlier handling) and time of '
'arrival deciding any ties).') 'arrival deciding any ties).')
parser.add_argument(
'--pooling-type',
choices=['LAST', 'ALL', 'CLS', 'STEP'],
default=None,
help='Used to configure the pooling method in the embedding model.'
)
parser.add_argument('--pooling-norm',
default=None,
action='store_true',
help="Used to determine whether to normalize "
"the pooled data in the embedding model.")
parser.add_argument('--no-pooling-norm',
default=None,
action='store_false',
dest='pooling_norm',
help="Used to determine whether to normalize "
"the pooled data in the embedding model.")
parser.add_argument('--pooling-softmax',
default=None,
action='store_true',
help="Used to determine whether to softmax "
"the pooled data in the embedding model.")
parser.add_argument('--no-pooling-softmax',
default=None,
action='store_false',
dest='pooling_softmax',
help="Used to determine whether to softmax "
"the pooled data in the embedding model.")
parser.add_argument(
'--pooling-step-tag-id',
type=int,
default=None,
help="When pooling-step-tag-id is not -1, it indicates "
"that the score corresponding to the step-tag-ids in the "
"generated sentence should be returned. Otherwise, it "
"returns the scores for all tokens.")
parser.add_argument(
'--pooling-returned-token-ids',
nargs='+',
type=int,
default=None,
help="pooling-returned-token-ids represents a list of "
"indices for the vocabulary dimensions to be extracted, "
"such as the token IDs of good_token and bad_token in "
"the math-shepherd-mistral-7b-prm model.")
return parser return parser
@classmethod @classmethod
...@@ -891,6 +950,11 @@ class EngineArgs: ...@@ -891,6 +950,11 @@ class EngineArgs:
override_neuron_config=self.override_neuron_config, override_neuron_config=self.override_neuron_config,
config_format=self.config_format, config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
pooling_type=self.pooling_type,
pooling_norm=self.pooling_norm,
pooling_softmax=self.pooling_softmax,
pooling_step_tag_id=self.pooling_step_tag_id,
pooling_returned_token_ids=self.pooling_returned_token_ids,
) )
def create_load_config(self) -> LoadConfig: def create_load_config(self) -> LoadConfig:
......
...@@ -257,7 +257,8 @@ class LLMEngine: ...@@ -257,7 +257,8 @@ class LLMEngine:
"num_scheduler_steps=%d, chunked_prefill_enabled=%s " "num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, "
"chat_template_text_format=%s, mm_processor_kwargs=%s)", "chat_template_text_format=%s, mm_processor_kwargs=%s, "
"pooler_config=%r)",
VLLM_VERSION, VLLM_VERSION,
model_config.model, model_config.model,
speculative_config, speculative_config,
...@@ -294,6 +295,7 @@ class LLMEngine: ...@@ -294,6 +295,7 @@ class LLMEngine:
use_cached_outputs, use_cached_outputs,
model_config.chat_template_text_format, model_config.chat_template_text_format,
model_config.mm_processor_kwargs, model_config.mm_processor_kwargs,
model_config.pooler_config,
) )
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config self.model_config = model_config
......
...@@ -159,6 +159,11 @@ class LLM: ...@@ -159,6 +159,11 @@ class LLM:
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model` # After positional args are removed, move this right below `model`
task: TaskOption = "auto", task: TaskOption = "auto",
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
''' '''
...@@ -193,6 +198,11 @@ class LLM: ...@@ -193,6 +198,11 @@ class LLM:
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
pooling_type=pooling_type,
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids,
**kwargs, **kwargs,
) )
self.llm_engine = LLMEngine.from_engine_args( self.llm_engine = LLMEngine.from_engine_args(
......
from enum import IntEnum from enum import IntEnum
from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import PoolerConfig
from vllm.model_executor.pooling_metadata import (PoolingMetadata, from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors) PoolingTensors)
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
...@@ -13,6 +15,7 @@ class PoolingType(IntEnum): ...@@ -13,6 +15,7 @@ class PoolingType(IntEnum):
LAST = 0 LAST = 0
ALL = 1 ALL = 1
CLS = 2 CLS = 2
STEP = 3
class Pooler(nn.Module): class Pooler(nn.Module):
...@@ -28,15 +31,47 @@ class Pooler(nn.Module): ...@@ -28,15 +31,47 @@ class Pooler(nn.Module):
normalize: Whether to normalize the pooled data. normalize: Whether to normalize the pooled data.
""" """
def __init__(self, def __init__(
pooling_type: PoolingType, self,
normalize: bool, pooling_type: PoolingType,
softmax: bool = False): normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
):
super().__init__() super().__init__()
self.pooling_type = pooling_type self.pooling_type = pooling_type
self.normalize = normalize self.normalize = normalize
self.softmax = softmax self.softmax = softmax
self.step_tag_id = step_tag_id
self.returned_token_ids = returned_token_ids
@classmethod
def from_config_with_defaults(
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
) -> Optional["Pooler"]:
if pooler_config is None:
return None
return cls(
pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else pooling_type,
normalize=pooler_config.pooling_norm
if pooler_config.pooling_norm is not None else normalize,
softmax=pooler_config.pooling_softmax
if pooler_config.pooling_softmax is not None else softmax,
step_tag_id=pooler_config.pooling_step_tag_id
if pooler_config.pooling_step_tag_id is not None else step_tag_id,
returned_token_ids=pooler_config.pooling_returned_token_ids
if pooler_config.pooling_returned_token_ids is not None else
returned_token_ids,
)
def forward( def forward(
self, self,
...@@ -62,6 +97,25 @@ class Pooler(nn.Module): ...@@ -62,6 +97,25 @@ class Pooler(nn.Module):
for prompt_len in prompt_lens: for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len]) pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len offset += prompt_len
elif self.pooling_type == PoolingType.STEP:
if self.returned_token_ids is not None and len(
self.returned_token_ids) > 0:
logits = hidden_states[:,
self.returned_token_ids].softmax(dim=-1)
else:
logits = hidden_states.softmax(dim=-1)
offset = 0
pooled_data = []
for prompt_len, seq_data_i in zip(
prompt_lens, pooling_metadata.seq_data.values()):
if self.step_tag_id is None:
pooled_data.append(logits[offset:offset + prompt_len])
else:
step_idxs = torch.tensor(
seq_data_i.prompt_token_ids) == self.step_tag_id
pooled_data.append(logits[offset:offset +
prompt_len][step_idxs])
offset += prompt_len
else: else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}") raise ValueError(f"Invalid pooling type: {self.pooling_type}")
......
...@@ -23,7 +23,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME ...@@ -23,7 +23,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig, LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, PoolerConfig, SchedulerConfig)
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
...@@ -122,7 +122,8 @@ def _get_model_initialization_kwargs( ...@@ -122,7 +122,8 @@ def _get_model_initialization_kwargs(
model_class: Type[nn.Module], model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]: scheduler_config: Optional[SchedulerConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> Dict[str, Any]:
"""Get extra kwargs for model initialization.""" """Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {} extra_kwargs: Dict[str, Any] = {}
...@@ -143,7 +144,8 @@ def _get_model_initialization_kwargs( ...@@ -143,7 +144,8 @@ def _get_model_initialization_kwargs(
if has_inner_state(model_class) and scheduler_config: if has_inner_state(model_class) and scheduler_config:
extra_kwargs["scheduler_config"] = scheduler_config extra_kwargs["scheduler_config"] = scheduler_config
if pooler_config:
extra_kwargs["pooler_config"] = pooler_config
return extra_kwargs return extra_kwargs
...@@ -155,10 +157,12 @@ def build_model(model_class: Type[nn.Module], ...@@ -155,10 +157,12 @@ def build_model(model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig], scheduler_config: Optional[SchedulerConfig],
prefix: Optional[str] = None) -> nn.Module: prefix: Optional[str] = None,
pooler_config: Optional[PoolerConfig] = None) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config, multimodal_config,
scheduler_config) scheduler_config,
pooler_config)
if prefix: if prefix:
extra_kwargs["prefix"] = prefix extra_kwargs["prefix"] = prefix
...@@ -185,6 +189,7 @@ def _initialize_model( ...@@ -185,6 +189,7 @@ def _initialize_model(
lora_config=lora_config, lora_config=lora_config,
multimodal_config=model_config.multimodal_config, multimodal_config=model_config.multimodal_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
pooler_config=model_config.pooler_config,
) )
......
...@@ -6,7 +6,7 @@ from transformers import BertConfig ...@@ -6,7 +6,7 @@ from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.xformers import XFormersImpl from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import CacheConfig from vllm.config import CacheConfig, PoolerConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -387,10 +387,15 @@ class BertEmbeddingModel(nn.Module): ...@@ -387,10 +387,15 @@ class BertEmbeddingModel(nn.Module):
config: BertConfig, config: BertConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = BertModel(config, cache_config, quant_config) self.model = BertModel(config, cache_config, quant_config)
self._pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True) self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
def forward( def forward(
self, self,
......
...@@ -22,7 +22,7 @@ from transformers import Gemma2Config ...@@ -22,7 +22,7 @@ from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
...@@ -473,13 +473,17 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP): ...@@ -473,13 +473,17 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
pooler_config: Optional[PoolerConfig] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = Gemma2Model(**kwargs) self.model = Gemma2Model(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -29,7 +29,7 @@ from transformers import LlamaConfig ...@@ -29,7 +29,7 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -502,6 +502,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -502,6 +502,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "", prefix: str = "",
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -543,6 +544,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -543,6 +544,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.STEP,
normalize=False,
softmax=False)
def forward( def forward(
self, self,
...@@ -565,6 +571,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -565,6 +571,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
logits = self.compute_logits(hidden_states, None)
return self._pooler(logits, pooling_metadata)
def sample(self, logits: torch.Tensor, def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
...@@ -630,12 +644,17 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP): ...@@ -630,12 +644,17 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
pooler_config: Optional[PoolerConfig] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = LlamaModel(**kwargs) self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -11,7 +11,7 @@ from transformers.models.llava_next.modeling_llava_next import ( ...@@ -11,7 +11,7 @@ from transformers.models.llava_next.modeling_llava_next import (
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -285,7 +285,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -285,7 +285,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
config: LlavaNextConfig, config: LlavaNextConfig,
multimodal_config: MultiModalConfig, multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -312,8 +313,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -312,8 +313,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
# The same model class supports both language generation and embedding # The same model class supports both language generation and embedding
# because the architecture name is the same # because the architecture name is the same
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
......
...@@ -26,7 +26,8 @@ from PIL import Image ...@@ -26,7 +26,8 @@ from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig,
PoolerConfig)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs) token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -530,7 +531,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -530,7 +531,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config: PretrainedConfig, config: PretrainedConfig,
multimodal_config: MultiModalConfig, multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -556,8 +558,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -556,8 +558,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# The same model class supports both language generation and embedding # The same model class supports both language generation and embedding
# because the architecture name is the same # because the architecture name is the same
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
......
...@@ -12,7 +12,7 @@ from torch import nn ...@@ -12,7 +12,7 @@ from torch import nn
from transformers import Qwen2Config from transformers import Qwen2Config
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -53,6 +53,7 @@ class Qwen2ForSequenceClassification(nn.Module): ...@@ -53,6 +53,7 @@ class Qwen2ForSequenceClassification(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
# TODO (@robertgshaw2): see if this can be moved out # TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None if (cache_config.sliding_window is not None
...@@ -77,9 +78,11 @@ class Qwen2ForSequenceClassification(nn.Module): ...@@ -77,9 +78,11 @@ class Qwen2ForSequenceClassification(nn.Module):
self.score = RowParallelLinear(config.hidden_size, self.score = RowParallelLinear(config.hidden_size,
config.num_labels, config.num_labels,
quant_config=quant_config) quant_config=quant_config)
self._pooler = Pooler(pooling_type=PoolingType.LAST, self._pooler = Pooler.from_config_with_defaults(
normalize=False, pooler_config,
softmax=True) pooling_type=PoolingType.LAST,
normalize=False,
softmax=True)
def forward( def forward(
self, self,
......
...@@ -11,7 +11,7 @@ from torch import nn ...@@ -11,7 +11,7 @@ from torch import nn
from transformers import Qwen2Config from transformers import Qwen2Config
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
...@@ -64,6 +64,7 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP): ...@@ -64,6 +64,7 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
# TODO (@robertgshaw2): see if this can be moved out # TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None if (cache_config.sliding_window is not None
...@@ -93,8 +94,11 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP): ...@@ -93,8 +94,11 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
RowParallelLinear(config.hidden_size, 1, RowParallelLinear(config.hidden_size, 1,
quant_config=quant_config), quant_config=quant_config),
) )
self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False) self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -100,11 +100,27 @@ _EMBEDDING_MODELS = { ...@@ -100,11 +100,27 @@ _EMBEDDING_MODELS = {
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": ( "Qwen2ForSequenceClassification": (
"qwen2_cls", "Qwen2ForSequenceClassification"), "qwen2_cls", "Qwen2ForSequenceClassification"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
# [Multimodal] # [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
} }
def add_embedding_models(base_models, embedding_models):
with_pooler_method_models = {}
embedding_models_name = embedding_models.keys()
for name, (path, arch) in base_models.items():
if arch in embedding_models_name:
with_pooler_method_models[name] = (path, arch)
return with_pooler_method_models
_EMBEDDING_MODELS = {
**add_embedding_models(_TEXT_GENERATION_MODELS, _EMBEDDING_MODELS),
**_EMBEDDING_MODELS,
}
_MULTIMODAL_MODELS = { _MULTIMODAL_MODELS = {
# [Decoder-only] # [Decoder-only]
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment