Commit 4b4eeb26 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents 2216a4e5 4fdc581f
...@@ -254,7 +254,7 @@ class LLMEngine: ...@@ -254,7 +254,7 @@ class LLMEngine:
"num_scheduler_steps=%d, chunked_prefill_enabled=%s " "num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s)", "chat_template_text_format=%s, mm_processor_kwargs=%s)",
VLLM_VERSION, VLLM_VERSION,
model_config.model, model_config.model,
speculative_config, speculative_config,
...@@ -289,6 +289,7 @@ class LLMEngine: ...@@ -289,6 +289,7 @@ class LLMEngine:
cache_config.enable_prefix_caching, cache_config.enable_prefix_caching,
model_config.use_async_output_proc, model_config.use_async_output_proc,
use_cached_outputs, use_cached_outputs,
model_config.chat_template_text_format,
model_config.mm_processor_kwargs, model_config.mm_processor_kwargs,
) )
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
...@@ -646,10 +647,24 @@ class LLMEngine: ...@@ -646,10 +647,24 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> SequenceGroup: ) -> Optional[SequenceGroup]:
"""Add a processed request to the engine's request pool. """Add a processed request to the engine's request pool.
return the created sequence group. return the created sequence group.
""" """
if isinstance(params, SamplingParams) and params.n > 1:
ParallelSampleSequenceGroup.add_request(
request_id,
self,
params,
processed_inputs=processed_inputs,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
return None
self._validate_model_inputs(processed_inputs) self._validate_model_inputs(processed_inputs)
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
...@@ -720,7 +735,7 @@ class LLMEngine: ...@@ -720,7 +735,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> Optional[SequenceGroup]: ) -> None:
... ...
@overload @overload
...@@ -734,7 +749,7 @@ class LLMEngine: ...@@ -734,7 +749,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> Optional[SequenceGroup]: ) -> None:
... ...
@deprecate_kwargs( @deprecate_kwargs(
...@@ -753,7 +768,7 @@ class LLMEngine: ...@@ -753,7 +768,7 @@ class LLMEngine:
priority: int = 0, priority: int = 0,
*, *,
inputs: Optional[PromptType] = None, # DEPRECATED inputs: Optional[PromptType] = None, # DEPRECATED
) -> Optional[SequenceGroup]: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the The request is added to the request pool and will be processed by the
...@@ -797,22 +812,6 @@ class LLMEngine: ...@@ -797,22 +812,6 @@ class LLMEngine:
>>> # continue the request processing >>> # continue the request processing
>>> ... >>> ...
""" """
if isinstance(params, SamplingParams) and params.n > 1:
ParallelSampleSequenceGroup.add_request(
request_id,
self,
params,
prompt=prompt,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
inputs=inputs,
)
return None
if inputs is not None: if inputs is not None:
prompt = inputs prompt = inputs
assert prompt is not None and params is not None assert prompt is not None and params is not None
...@@ -843,7 +842,7 @@ class LLMEngine: ...@@ -843,7 +842,7 @@ class LLMEngine:
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get( processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs") "mm_processor_kwargs")
return self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
processed_inputs=processed_inputs, processed_inputs=processed_inputs,
params=params, params=params,
...@@ -1612,7 +1611,7 @@ class LLMEngine: ...@@ -1612,7 +1611,7 @@ class LLMEngine:
# KV Cache Usage in % # KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks num_total_gpu = self.cache_config.num_gpu_blocks
gpu_cache_usage_sys = 0. gpu_cache_usage_sys = 0.
if num_total_gpu is not None: if num_total_gpu: # Guard against both None and 0
num_free_gpu = sum( num_free_gpu = sum(
scheduler.block_manager.get_num_free_gpu_blocks() scheduler.block_manager.get_num_free_gpu_blocks()
for scheduler in self.scheduler) for scheduler in self.scheduler)
...@@ -1620,7 +1619,7 @@ class LLMEngine: ...@@ -1620,7 +1619,7 @@ class LLMEngine:
num_total_cpu = self.cache_config.num_cpu_blocks num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage_sys = 0. cpu_cache_usage_sys = 0.
if num_total_cpu is not None and num_total_cpu > 0: if num_total_cpu: # Guard against both None and 0
num_free_cpu = sum( num_free_cpu = sum(
scheduler.block_manager.get_num_free_cpu_blocks() scheduler.block_manager.get_num_free_cpu_blocks()
for scheduler in self.scheduler) for scheduler in self.scheduler)
......
from typing import Dict, List, Tuple from typing import List
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
...@@ -6,9 +6,8 @@ from vllm.engine.output_processor.interfaces import ( ...@@ -6,9 +6,8 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import (CompletionSequenceGroupOutput, Sequence, from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup,
SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceGroupOutput)
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter from vllm.utils import Counter
...@@ -114,104 +113,22 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -114,104 +113,22 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
outputs: SequenceGroupOutput, outputs: SequenceGroupOutput,
is_async: bool) -> None: is_async: bool) -> None:
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
if sampling_params.n == 1:
# only have one output sample sample = outputs.samples[0]
sample = outputs.samples[0] seq = seq_group.first_seq
# only have one sequence if not is_async:
seq = seq_group.seqs[0] seq.append_token_id(sample.output_token, sample.logprobs)
if not is_async: if sampling_params.detokenize and self.detokenizer:
seq.append_token_id(sample.output_token, sample.logprobs) new_char_count = self.detokenizer.decode_sequence_inplace(
if sampling_params.detokenize and self.detokenizer: seq, sampling_params)
new_char_count = self.detokenizer.decode_sequence_inplace( else:
seq, sampling_params) new_char_count = 0
else: self.stop_checker.maybe_stop_sequence(
new_char_count = 0 seq,
self.stop_checker.maybe_stop_sequence( new_char_count,
seq, sampling_params,
new_char_count, lora_req=seq_group.lora_request,
sampling_params, )
lora_req=seq_group.lora_request, if seq.is_finished():
) for scheduler in self.scheduler:
if seq.is_finished(): scheduler.free_seq(seq)
for scheduler in self.scheduler:
scheduler.free_seq(seq)
return
# TODO: Add support for async for beam search
assert not is_async
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
for sample in samples:
# Guard against a KeyError which can occur if the request was
# aborted while the output was generated
if (child_list :=
parent_child_dict.get(sample.parent_seq_id)) is not None:
child_list.append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutput] = parent_child_dict[
parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
for scheduler in self.scheduler:
scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count,
sampling_params,
lora_req=seq_group.lora_request,
)
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
return
...@@ -121,7 +121,7 @@ class ConversationMessage(TypedDict, total=False): ...@@ -121,7 +121,7 @@ class ConversationMessage(TypedDict, total=False):
role: Required[str] role: Required[str]
"""The role of the message's author.""" """The role of the message's author."""
content: Optional[str] content: Union[Optional[str], List[Dict[str, str]]]
"""The contents of the message""" """The contents of the message"""
tool_call_id: Optional[str] tool_call_id: Optional[str]
...@@ -196,7 +196,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -196,7 +196,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
elif modality == "audio": elif modality == "audio":
if model_type == "ultravox": if model_type == "ultravox":
return "<|reserved_special_token_0|>" return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown {modality} model type: {model_type}") if model_type == "qwen2_audio":
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video": elif modality == "video":
if model_type == "qwen2_vl": if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>" return "<|vision_start|><|video_pad|><|vision_end|>"
...@@ -428,7 +431,7 @@ MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = { ...@@ -428,7 +431,7 @@ MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
def _parse_chat_message_content_mm_part( def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, str]: part: ChatCompletionContentPartParam) -> Tuple[str, str]:
""" """
Parses a given multi modal content part based on its type. Parses a given multi-modal content part based on its type.
Args: Args:
part: A dict containing the content part, with a potential 'type' field. part: A dict containing the content part, with a potential 'type' field.
...@@ -482,54 +485,76 @@ def _parse_chat_message_content_parts( ...@@ -482,54 +485,76 @@ def _parse_chat_message_content_parts(
role: str, role: str,
parts: Iterable[ChatCompletionContentPartParam], parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: BaseMultiModalItemTracker, mm_tracker: BaseMultiModalItemTracker,
chat_template_text_format: str,
) -> List[ConversationMessage]: ) -> List[ConversationMessage]:
texts: List[str] = [] content: List[Union[str, Dict[str, str]]] = []
mm_parser = mm_tracker.create_parser() mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \ wrap_dicts = \
mm_tracker._model_config.hf_config.model_type in \ mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT MODEL_KEEP_MULTI_MODAL_CONTENT or \
(chat_template_text_format == "openai")
has_image = False
for part in parts: for part in parts:
if isinstance(part, str): # Handle plain text parts parse_res = _parse_chat_message_content_part(
text = _TextParser(part) part,
texts.append(text) mm_parser,
else: # Handle structured dictionary parts wrap_dicts=wrap_dicts,
part_type, content = _parse_chat_message_content_mm_part(part) )
if parse_res:
# if part_type is text/refusal/image_url/audio_url but content.append(parse_res)
# content is empty, logg a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content: if wrap_dicts:
logger.warning("Skipping multimodal part " # Parsing wraps images and texts as interleaved dictionaries
"with empty / unparsable content.") return [ConversationMessage(role=role,
continue content=content)] # type: ignore
texts = cast(List[str], content)
if part_type in ("text", "refusal"):
texts.append(content)
elif part_type == "image_url":
mm_parser.parse_image(content)
has_image = True
elif part_type == "audio_url":
mm_parser.parse_audio(content)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts) text_prompt = "\n".join(texts)
if keep_multimodal_content: mm_placeholder_counts = mm_parser.mm_placeholder_counts()
text_prompt = "\n".join(texts) if mm_placeholder_counts:
role_content = [{'type': 'text', 'text': text_prompt}] text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]
def _parse_chat_message_content_part(
part: ChatCompletionContentPartParam,
mm_parser: BaseMultiModalContentParser,
wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
{"type": "image"}, respectively. Otherwise multimodal data will be
handled by mm_parser, and texts will be returned as strings to be joined
with multimodal placeholders.
"""
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
return text
# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/refusal/image_url/audio_url but
# content is empty, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning(
"Skipping multimodal part (type: '%s')"
"with empty / unparsable content.", part_type)
return None
if has_image: if part_type in ("text", "refusal"):
role_content = [{'type': 'image'}] + role_content return {'type': 'text', 'text': content} if wrap_dicts else content
return [ConversationMessage(role=role,
content=role_content)] # type: ignore if part_type == "image_url":
else: mm_parser.parse_image(content)
mm_placeholder_counts = mm_parser.mm_placeholder_counts() return {'type': 'image'} if wrap_dicts else None
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt( if part_type == "audio_url":
mm_placeholder_counts, text_prompt) mm_parser.parse_audio(content)
return [ConversationMessage(role=role, content=text_prompt)] return {'type': 'audio'} if wrap_dicts else None
raise NotImplementedError(f"Unknown part type: {part_type}")
# No need to validate using Pydantic again # No need to validate using Pydantic again
...@@ -540,6 +565,7 @@ _ToolParser = partial(cast, ChatCompletionToolMessageParam) ...@@ -540,6 +565,7 @@ _ToolParser = partial(cast, ChatCompletionToolMessageParam)
def _parse_chat_message_content( def _parse_chat_message_content(
message: ChatCompletionMessageParam, message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker, mm_tracker: BaseMultiModalItemTracker,
chat_template_text_format: str,
) -> List[ConversationMessage]: ) -> List[ConversationMessage]:
role = message["role"] role = message["role"]
content = message.get("content") content = message.get("content")
...@@ -555,6 +581,7 @@ def _parse_chat_message_content( ...@@ -555,6 +581,7 @@ def _parse_chat_message_content(
role, role,
content, # type: ignore content, # type: ignore
mm_tracker, mm_tracker,
chat_template_text_format,
) )
for result_msg in result: for result_msg in result:
...@@ -598,7 +625,11 @@ def parse_chat_messages( ...@@ -598,7 +625,11 @@ def parse_chat_messages(
mm_tracker = MultiModalItemTracker(model_config, tokenizer) mm_tracker = MultiModalItemTracker(model_config, tokenizer)
for msg in messages: for msg in messages:
sub_messages = _parse_chat_message_content(msg, mm_tracker) sub_messages = _parse_chat_message_content(
msg,
mm_tracker,
model_config.chat_template_text_format,
)
conversation.extend(sub_messages) conversation.extend(sub_messages)
...@@ -616,7 +647,11 @@ def parse_chat_messages_futures( ...@@ -616,7 +647,11 @@ def parse_chat_messages_futures(
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
for msg in messages: for msg in messages:
sub_messages = _parse_chat_message_content(msg, mm_tracker) sub_messages = _parse_chat_message_content(
msg,
mm_tracker,
model_config.chat_template_text_format,
)
conversation.extend(sub_messages) conversation.extend(sub_messages)
......
...@@ -384,7 +384,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -384,7 +384,7 @@ class OpenAIServingChat(OpenAIServing):
# Send response to echo the input portion of the # Send response to echo the input portion of the
# last message # last message
if request.echo or request.continue_final_message: if request.echo or request.continue_final_message:
last_msg_content: str = "" last_msg_content: Union[str, List[Dict[str, str]]] = ""
if conversation and "content" in conversation[ if conversation and "content" in conversation[
-1] and conversation[-1].get("role") == role: -1] and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"] or "" last_msg_content = conversation[-1]["content"] or ""
...@@ -724,10 +724,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -724,10 +724,13 @@ class OpenAIServingChat(OpenAIServing):
choices.append(choice_data) choices.append(choice_data)
if request.echo or request.continue_final_message: if request.echo or request.continue_final_message:
last_msg_content = "" last_msg_content: Union[str, List[Dict[str, str]]] = ""
if conversation and "content" in conversation[-1] and conversation[ if conversation and "content" in conversation[-1] and conversation[
-1].get("role") == role: -1].get("role") == role:
last_msg_content = conversation[-1]["content"] or "" last_msg_content = conversation[-1]["content"] or ""
if isinstance(last_msg_content, list):
last_msg_content = "\n".join(msg['text']
for msg in last_msg_content)
for choice in choices: for choice in choices:
full_message = last_msg_content + (choice.message.content full_message = last_msg_content + (choice.message.content
......
...@@ -10,7 +10,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook ...@@ -10,7 +10,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip, is_hip, is_xpu from vllm.utils import get_ip, is_hip
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -231,7 +231,7 @@ def initialize_ray_cluster( ...@@ -231,7 +231,7 @@ def initialize_ray_cluster(
assert_ray_available() assert_ray_available()
# Connect to a ray cluster. # Connect to a ray cluster.
if is_hip() or is_xpu(): if is_hip() or current_platform.is_xpu():
ray.init(address=ray_address, ray.init(address=ray_address,
ignore_reinit_error=True, ignore_reinit_error=True,
num_gpus=parallel_config.world_size) num_gpus=parallel_config.world_size)
......
...@@ -7,7 +7,7 @@ import vllm.envs as envs ...@@ -7,7 +7,7 @@ import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip, is_xpu, print_warning_once from vllm.utils import is_hip, print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -78,7 +78,7 @@ class CustomOp(nn.Module): ...@@ -78,7 +78,7 @@ class CustomOp(nn.Module):
return self.forward_cpu return self.forward_cpu
elif current_platform.is_tpu(): elif current_platform.is_tpu():
return self.forward_tpu return self.forward_tpu
elif is_xpu(): elif current_platform.is_xpu():
return self.forward_xpu return self.forward_xpu
else: else:
return self.forward_cuda return self.forward_cuda
......
...@@ -5,7 +5,8 @@ import os ...@@ -5,7 +5,8 @@ import os
import torch.nn.functional as F import torch.nn.functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
...@@ -37,10 +38,12 @@ class AWQConfig(QuantizationConfig): ...@@ -37,10 +38,12 @@ class AWQConfig(QuantizationConfig):
weight_bits: int, weight_bits: int,
group_size: int, group_size: int,
zero_point: bool, zero_point: bool,
modules_to_not_convert: Optional[List[str]] = None,
) -> None: ) -> None:
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
self.zero_point = zero_point self.zero_point = zero_point
self.modules_to_not_convert = modules_to_not_convert or []
if self.weight_bits != 4: if self.weight_bits != 4:
raise ValueError( raise ValueError(
...@@ -51,7 +54,8 @@ class AWQConfig(QuantizationConfig): ...@@ -51,7 +54,8 @@ class AWQConfig(QuantizationConfig):
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, " return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, " f"group_size={self.group_size}, "
f"zero_point={self.zero_point})") f"zero_point={self.zero_point}, "
f"modules_to_not_convert={self.modules_to_not_convert})")
def get_name(self) -> str: def get_name(self) -> str:
return "awq" return "awq"
...@@ -77,11 +81,15 @@ class AWQConfig(QuantizationConfig): ...@@ -77,11 +81,15 @@ class AWQConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"]) zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point) modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQLinearMethod"]: prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQLinearMethod(self) return AWQLinearMethod(self)
return None return None
...@@ -89,6 +97,10 @@ class AWQConfig(QuantizationConfig): ...@@ -89,6 +97,10 @@ class AWQConfig(QuantizationConfig):
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
class AWQLinearMethod(LinearMethodBase): class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ. """Linear method for AWQ.
......
...@@ -28,6 +28,7 @@ import os ...@@ -28,6 +28,7 @@ import os
import re import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
...@@ -264,6 +265,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -264,6 +265,7 @@ class BaiChuanDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_torch_compile
class BaiChuanModel(nn.Module): class BaiChuanModel(nn.Module):
def __init__(self, def __init__(self,
...@@ -527,7 +529,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -527,7 +529,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous() qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 13B and Baichuan2 7B/13B.""" """Baichuan 13B and Baichuan2 7B/13B.
NOTE: the class name has a lower case 'c'.
"""
def __init__( def __init__(
self, self,
...@@ -545,7 +549,9 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -545,7 +549,9 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 7B.""" """Baichuan 7B.
NOTE: the class name has an upper case 'C'.
"""
def __init__( def __init__(
self, self,
......
...@@ -122,7 +122,7 @@ def input_processor_for_blip( ...@@ -122,7 +122,7 @@ def input_processor_for_blip(
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module): class BlipVisionEmbeddings(nn.Module):
def __init__(self, config: BlipVisionConfig): def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -167,9 +167,10 @@ class BlipParallelAttention(nn.Module): ...@@ -167,9 +167,10 @@ class BlipParallelAttention(nn.Module):
def __init__( def __init__(
self, self,
config: BlipVisionConfig, config: Union[BlipVisionConfig, Blip2VisionConfig],
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -189,11 +190,13 @@ class BlipParallelAttention(nn.Module): ...@@ -189,11 +190,13 @@ class BlipParallelAttention(nn.Module):
self.num_heads, self.num_heads,
bias=config.qkv_bias, bias=config.qkv_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv",
) )
self.projection = RowParallelLinear( self.projection = RowParallelLinear(
self.embed_dim, self.embed_dim,
self.embed_dim, self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.projection",
) )
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
...@@ -235,9 +238,12 @@ class BlipParallelAttention(nn.Module): ...@@ -235,9 +238,12 @@ class BlipParallelAttention(nn.Module):
class BlipMLP(nn.Module): class BlipMLP(nn.Module):
def __init__(self, def __init__(
config: BlipVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None): config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -246,11 +252,13 @@ class BlipMLP(nn.Module): ...@@ -246,11 +252,13 @@ class BlipMLP(nn.Module):
self.fc1 = ColumnParallelLinear(config.hidden_size, self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size, config.intermediate_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size, self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)
...@@ -262,24 +270,32 @@ class BlipMLP(nn.Module): ...@@ -262,24 +270,32 @@ class BlipMLP(nn.Module):
class BlipEncoderLayer(nn.Module): class BlipEncoderLayer(nn.Module):
def __init__(self, def __init__(
config: BlipVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None): config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
# fallback to sdpa attention if tp unavailable # fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0: if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(config, self.self_attn = BlipParallelAttention(
quant_config=quant_config) config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else: else:
# Blip doesn't have SDPA attention implemented in transformers # Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend # use eager attention instead for cpu backend
self.self_attn = BlipAttention(config) self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config) self.mlp = BlipMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(config.hidden_size, self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
...@@ -307,10 +323,13 @@ class BlipEncoder(nn.Module): ...@@ -307,10 +323,13 @@ class BlipEncoder(nn.Module):
config: BlipConfig config: BlipConfig
""" """
def __init__(self, def __init__(
config: BlipVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None, config: BlipVisionConfig,
num_hidden_layers_override: Optional[int] = None): quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -321,8 +340,10 @@ class BlipEncoder(nn.Module): ...@@ -321,8 +340,10 @@ class BlipEncoder(nn.Module):
num_hidden_layers = num_hidden_layers_override num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
BlipEncoderLayer(config=config, quant_config=quant_config) BlipEncoderLayer(config=config,
for _ in range(num_hidden_layers) quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
]) ])
def forward(self, inputs_embeds: torch.Tensor): def forward(self, inputs_embeds: torch.Tensor):
...@@ -337,10 +358,15 @@ class BlipVisionModel(nn.Module): ...@@ -337,10 +358,15 @@ class BlipVisionModel(nn.Module):
config_class = BlipVisionConfig config_class = BlipVisionConfig
main_input_name = "pixel_values" main_input_name = "pixel_values"
def __init__(self, def __init__(
config: BlipVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None, config: BlipVisionConfig,
num_hidden_layers_override: Optional[int] = None): quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -354,19 +380,24 @@ class BlipVisionModel(nn.Module): ...@@ -354,19 +380,24 @@ class BlipVisionModel(nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
) )
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers: if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError( raise ValueError(
f"The original encoder only has {config.num_hidden_layers} " f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers." f"layers, but you requested {len(self.encoder.layers)} layers."
) )
elif len(self.encoder.layers) == config.num_hidden_layers:
# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(config.hidden_size, self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
else: else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None self.post_layernorm = None
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
......
...@@ -490,7 +490,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -490,7 +490,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_model = BlipVisionModel(config.vision_config) self.vision_model = BlipVisionModel(config.vision_config, quant_config)
self.query_tokens = nn.Parameter( self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens, torch.zeros(1, config.num_query_tokens,
......
...@@ -26,6 +26,7 @@ import os ...@@ -26,6 +26,7 @@ import os
import re import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
...@@ -226,6 +227,7 @@ class BloomBlock(nn.Module): ...@@ -226,6 +227,7 @@ class BloomBlock(nn.Module):
return output return output
@support_torch_compile
class BloomModel(nn.Module): class BloomModel(nn.Module):
def __init__( def __init__(
......
...@@ -15,8 +15,9 @@ import re ...@@ -15,8 +15,9 @@ import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -24,8 +25,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -24,8 +25,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -41,11 +41,13 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, ...@@ -41,11 +41,13 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from .interfaces import SupportsLoRA, SupportsMultiModal
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -155,6 +157,10 @@ def find_all_positions(input_ids: List[int], target: int) -> List[int]: ...@@ -155,6 +157,10 @@ def find_all_positions(input_ids: List[int], target: int) -> List[int]:
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
hf_config = ctx.get_hf_config(ChatGLMConfig) hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None) vision_config = getattr(hf_config, 'vision_config', None)
...@@ -166,8 +172,8 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -166,8 +172,8 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
input_ids = inputs.get("prompt_token_ids") input_ids = inputs["prompt_token_ids"]
position_ids = inputs.get("position_ids")
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
ctx.model_config.model, ctx.model_config.model,
trust_remote_code=ctx.model_config.trust_remote_code) trust_remote_code=ctx.model_config.trust_remote_code)
...@@ -176,20 +182,19 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -176,20 +182,19 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
raw_batch_data = tokenizer.apply_chat_template( raw_batch_data = tokenizer.apply_chat_template(
conversation=[{ conversation=[{
"role": "user", "role": "user",
"image": inputs['multi_modal_data']["image"], "image": multi_modal_data["image"],
"content": inputs['prompt'] "content": inputs['prompt'],
}], }],
add_generation_prompt=True, add_generation_prompt=True,
tokenize=True, tokenize=True,
return_tensors="pt", return_tensors="pt",
return_dict=True).data return_dict=True,
).data
except Exception: except Exception:
logger.error("Failed to process content (%s)", inputs['prompt']) logger.error("Failed to process content (%s)", inputs['prompt'])
raise raise
input_ids = raw_batch_data['input_ids'][0].tolist() input_ids = raw_batch_data['input_ids'][0].tolist()
if position_ids is None:
position_ids = list(range(len(input_ids)))
boi_token_id = hf_config.boi_token_id boi_token_id = hf_config.boi_token_id
eoi_token_id = hf_config.eoi_token_id eoi_token_id = hf_config.eoi_token_id
boi_positions = find_all_positions(input_ids, boi_token_id) boi_positions = find_all_positions(input_ids, boi_token_id)
...@@ -198,7 +203,6 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -198,7 +203,6 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
assert len(boi_positions) == len(eoi_positions) assert len(boi_positions) == len(eoi_positions)
new_input_ids = [] new_input_ids = []
new_position_ids = []
final_processed_position = 0 final_processed_position = 0
final_processed_position = 0 final_processed_position = 0
...@@ -206,29 +210,28 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -206,29 +210,28 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
assert boi_position < eoi_position assert boi_position < eoi_position
new_input_ids.extend(input_ids[final_processed_position:boi_position + new_input_ids.extend(input_ids[final_processed_position:boi_position +
1]) 1])
new_position_ids.extend(
list(range(final_processed_position, boi_position + 1)))
new_input_ids.extend([input_ids[boi_position + 1]] * new_input_ids.extend([input_ids[boi_position + 1]] *
image_placeholder_length) image_placeholder_length)
new_position_ids.extend([boi_position + 1] * image_placeholder_length)
final_processed_position = eoi_position final_processed_position = eoi_position
new_input_ids.extend(input_ids[final_processed_position:]) new_input_ids.extend(input_ids[final_processed_position:])
new_position_ids.extend(
list(range(final_processed_position, len(input_ids))))
assert len(new_input_ids) == len(new_position_ids) prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(new_input_ids)
inputs["prompt_token_ids"] = new_input_ids return token_inputs(
inputs["position_ids"] = new_position_ids prompt_token_ids=new_input_ids,
return inputs prompt=prompt,
multi_modal_data=multi_modal_data,
)
class GLMAttention(nn.Module): class GLMAttention(nn.Module):
def __init__( def __init__(
self, self,
config, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
...@@ -326,7 +329,7 @@ class GLMMLP(nn.Module): ...@@ -326,7 +329,7 @@ class GLMMLP(nn.Module):
def __init__( def __init__(
self, self,
config, config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -369,7 +372,7 @@ class GLMBlock(nn.Module): ...@@ -369,7 +372,7 @@ class GLMBlock(nn.Module):
def __init__( def __init__(
self, self,
config, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
...@@ -440,9 +443,10 @@ class GLMTransformer(nn.Module): ...@@ -440,9 +443,10 @@ class GLMTransformer(nn.Module):
def __init__( def __init__(
self, self,
config, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.post_layer_norm = config.post_layer_norm self.post_layer_norm = config.post_layer_norm
...@@ -451,10 +455,11 @@ class GLMTransformer(nn.Module): ...@@ -451,10 +455,11 @@ class GLMTransformer(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
# Transformer layers. # Transformer layers.
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
GLMBlock(config, cache_config, quant_config) self.num_layers,
for i in range(self.num_layers) lambda prefix: GLMBlock(config, cache_config, quant_config),
]) prefix=f"{prefix}.layers",
)
if self.post_layer_norm: if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
...@@ -462,6 +467,10 @@ class GLMTransformer(nn.Module): ...@@ -462,6 +467,10 @@ class GLMTransformer(nn.Module):
self.final_layernorm = layer_norm_func( self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -469,16 +478,16 @@ class GLMTransformer(nn.Module): ...@@ -469,16 +478,16 @@ class GLMTransformer(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
for i in range(self.num_layers): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
hidden_states=hidden_states, hidden_states=hidden_states,
position_ids=position_ids, position_ids=position_ids,
kv_cache=kv_caches[i], kv_cache=kv_caches[i - self.start_layer],
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
# Final layer norm. # Final layer norm.
if self.post_layer_norm: if get_pp_group().is_last_rank and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
return hidden_states return hidden_states
...@@ -488,7 +497,7 @@ class ChatGLMModel(nn.Module): ...@@ -488,7 +497,7 @@ class ChatGLMModel(nn.Module):
def __init__( def __init__(
self, self,
config, config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
...@@ -516,6 +525,9 @@ class ChatGLMModel(nn.Module): ...@@ -516,6 +525,9 @@ class ChatGLMModel(nn.Module):
else: else:
self.vision = None self.vision = None
self.make_empty_intermediate_tensors = (
self.encoder.make_empty_intermediate_tensors)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> GLMImagePixelInputs: self, **kwargs: object) -> GLMImagePixelInputs:
...@@ -541,24 +553,26 @@ class ChatGLMModel(nn.Module): ...@@ -541,24 +553,26 @@ class ChatGLMModel(nn.Module):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> torch.Tensor: ) -> torch.Tensor:
if intermediate_tensors is None:
inputs_embeds = self.embedding(input_ids) inputs_embeds = self.embedding(input_ids)
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input["pixel_values"] is not None: if image_input["pixel_values"] is not None:
pixel_values = image_input["pixel_values"].to( pixel_values = image_input["pixel_values"].to(
dtype=inputs_embeds.dtype) dtype=inputs_embeds.dtype)
image_embeds = self.vision(pixel_values) image_embeds = self.vision(pixel_values)
boi_token_id = self.config.boi_token_id boi_token_id = self.config.boi_token_id
eoi_token_id = self.config.eoi_token_id eoi_token_id = self.config.eoi_token_id
inputs_embeds = merge_glm_vision_embeddings( inputs_embeds = merge_glm_vision_embeddings(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
vision_embeddings=image_embeds, vision_embeddings=image_embeds,
boi_token_id=boi_token_id, boi_token_id=boi_token_id,
eoi_token_id=eoi_token_id) eoi_token_id=eoi_token_id)
else:
inputs_embeds = intermediate_tensors["hidden_states"]
# Run encoder. # Run encoder.
hidden_states = self.encoder( hidden_states = self.encoder(
...@@ -567,6 +581,9 @@ class ChatGLMModel(nn.Module): ...@@ -567,6 +581,9 @@ class ChatGLMModel(nn.Module):
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states return hidden_states
...@@ -574,7 +591,8 @@ class ChatGLMModel(nn.Module): ...@@ -574,7 +591,8 @@ class ChatGLMModel(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) @INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
SupportsMultiModal):
packed_modules_mapping = { packed_modules_mapping = {
"query_key_value": ["query_key_value"], "query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"] "dense_h_to_4h": ["dense_h_to_4h"]
...@@ -631,7 +649,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -631,7 +649,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs) -> torch.Tensor: **kwargs) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, **kwargs) attn_metadata, intermediate_tensors,
**kwargs)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -677,6 +696,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -677,6 +696,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -192,6 +192,7 @@ class CLIPParallelAttention(nn.Module): ...@@ -192,6 +192,7 @@ class CLIPParallelAttention(nn.Module):
self, self,
config: CLIPVisionConfig, config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -211,12 +212,14 @@ class CLIPParallelAttention(nn.Module): ...@@ -211,12 +212,14 @@ class CLIPParallelAttention(nn.Module):
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.num_heads, total_num_heads=self.num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
input_size=self.embed_dim, input_size=self.embed_dim,
output_size=self.embed_dim, output_size=self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj",
) )
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
...@@ -259,20 +262,25 @@ class CLIPParallelAttention(nn.Module): ...@@ -259,20 +262,25 @@ class CLIPParallelAttention(nn.Module):
class CLIPMLP(nn.Module): class CLIPMLP(nn.Module):
def __init__(self, def __init__(
config: CLIPVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None): config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size, self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size, config.intermediate_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size, self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)
...@@ -284,21 +292,29 @@ class CLIPMLP(nn.Module): ...@@ -284,21 +292,29 @@ class CLIPMLP(nn.Module):
class CLIPEncoderLayer(nn.Module): class CLIPEncoderLayer(nn.Module):
def __init__(self, def __init__(
config: CLIPVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None): config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
num_heads = config.num_attention_heads num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0: if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(config, self.self_attn = CLIPParallelAttention(
quant_config=quant_config) config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else: else:
self.self_attn = CLIPSdpaAttention(config) self.self_attn = CLIPSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config) self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(config.hidden_size, self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
...@@ -327,11 +343,15 @@ class CLIPEncoder(nn.Module): ...@@ -327,11 +343,15 @@ class CLIPEncoder(nn.Module):
config: CLIPConfig config: CLIPConfig
""" """
def __init__(self, def __init__(
config: CLIPVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None, config: CLIPVisionConfig,
num_hidden_layers_override: Optional[int] = None): quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
if num_hidden_layers_override is None: if num_hidden_layers_override is None:
...@@ -339,8 +359,10 @@ class CLIPEncoder(nn.Module): ...@@ -339,8 +359,10 @@ class CLIPEncoder(nn.Module):
else: else:
num_hidden_layers = num_hidden_layers_override num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
CLIPEncoderLayer(config=config, quant_config=quant_config) CLIPEncoderLayer(config=config,
for _ in range(num_hidden_layers) quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
]) ])
def forward(self, inputs_embeds: torch.Tensor): def forward(self, inputs_embeds: torch.Tensor):
...@@ -354,11 +376,17 @@ class CLIPEncoder(nn.Module): ...@@ -354,11 +376,17 @@ class CLIPEncoder(nn.Module):
class CLIPVisionTransformer(nn.Module): class CLIPVisionTransformer(nn.Module):
def __init__(self, def __init__(
config: CLIPVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None, config: CLIPVisionConfig,
num_hidden_layers_override: Optional[int] = None): quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
...@@ -370,19 +398,25 @@ class CLIPVisionTransformer(nn.Module): ...@@ -370,19 +398,25 @@ class CLIPVisionTransformer(nn.Module):
self.encoder = CLIPEncoder( self.encoder = CLIPEncoder(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override) num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers: if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError( raise ValueError(
f"The original encoder only has {config.num_hidden_layers} " f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers." f"layers, but you requested {len(self.encoder.layers)} layers."
) )
elif len(self.encoder.layers) == config.num_hidden_layers:
# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(embed_dim, self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
else: else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None self.post_layernorm = None
def forward( def forward(
...@@ -405,10 +439,15 @@ class CLIPVisionModel(nn.Module): ...@@ -405,10 +439,15 @@ class CLIPVisionModel(nn.Module):
config_class = CLIPVisionConfig config_class = CLIPVisionConfig
main_input_name = "pixel_values" main_input_name = "pixel_values"
def __init__(self, def __init__(
config: CLIPVisionConfig, self,
quant_config: Optional[QuantizationConfig] = None, config: CLIPVisionConfig,
num_hidden_layers_override: Optional[int] = None): quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -418,7 +457,10 @@ class CLIPVisionModel(nn.Module): ...@@ -418,7 +457,10 @@ class CLIPVisionModel(nn.Module):
self.vision_model = CLIPVisionTransformer( self.vision_model = CLIPVisionTransformer(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override) num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values) return self.vision_model(pixel_values)
......
...@@ -28,6 +28,7 @@ from torch import nn ...@@ -28,6 +28,7 @@ from torch import nn
from transformers import CohereConfig from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -250,6 +251,7 @@ class CohereDecoderLayer(nn.Module): ...@@ -250,6 +251,7 @@ class CohereDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_torch_compile
class CohereModel(nn.Module): class CohereModel(nn.Module):
def __init__( def __init__(
......
...@@ -29,6 +29,7 @@ import torch ...@@ -29,6 +29,7 @@ import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
...@@ -311,6 +312,7 @@ class ExaoneDecoderLayer(nn.Module): ...@@ -311,6 +312,7 @@ class ExaoneDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_torch_compile
class ExaoneModel(nn.Module): class ExaoneModel(nn.Module):
def __init__( def __init__(
......
import math
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
BartParallelLMHead,
BartScaledWordEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import AutoWeightsLoader
class Florence2LanguageModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
self.encoder = BartEncoder(config,
cache_config=cache_config,
quant_config=quant_config)
self.decoder = BartDecoder(config,
cache_config=cache_config,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.encoder.embed_tokens.weight = self.shared.weight
self.decoder.embed_tokens.weight = self.shared.weight
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Model output torch.Tensor
"""
encoder_hidden_states = None
if encoder_input_ids.numel() > 0:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids=input_ids,
decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states,
kv_caches=kv_caches,
attn_metadata=attn_metadata)
return decoder_outputs
class Florence2LanguageForConditionalGeneration(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.model = Florence2LanguageModel(config,
cache_config=cache_config,
quant_config=quant_config)
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.vocab_size = config.vocab_size
self.lm_head = BartParallelLMHead(self.vocab_size,
config.d_model,
embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Output torch.Tensor
"""
return self.model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "final_logits_bias" in name:
continue
if self.config.tie_word_embeddings and "embed_tokens" in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class Florence2ForConditionalGeneration(nn.Module):
def __init__(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
# TODO(Isotr0py): Add vision backbone
self.language_model = Florence2LanguageForConditionalGeneration(
config=config.text_config,
cache_config=cache_config,
quant_config=quant_config)
@property
def sampler(self):
return self.language_model.sampler
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
*,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Output torch.Tensor
"""
return self.language_model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
skip_prefixes = [
'image_projection', "vision_tower", "image_proj_norm",
"image_pos_embed", "visual_temporal_embed"
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader.load_weights(weights)
...@@ -22,6 +22,7 @@ from torch import nn ...@@ -22,6 +22,7 @@ from torch import nn
from transformers import GemmaConfig from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -239,6 +240,7 @@ class GemmaDecoderLayer(nn.Module): ...@@ -239,6 +240,7 @@ class GemmaDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_torch_compile
class GemmaModel(nn.Module): class GemmaModel(nn.Module):
def __init__( def __init__(
......
...@@ -24,6 +24,7 @@ from torch import nn ...@@ -24,6 +24,7 @@ from torch import nn
from transformers import GPT2Config from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size) get_pp_group, get_tensor_model_parallel_world_size)
...@@ -182,6 +183,7 @@ class GPT2Block(nn.Module): ...@@ -182,6 +183,7 @@ class GPT2Block(nn.Module):
return hidden_states return hidden_states
@support_torch_compile
class GPT2Model(nn.Module): class GPT2Model(nn.Module):
def __init__( def __init__(
......
...@@ -113,7 +113,8 @@ class Idefics2VisionAttention(nn.Module): ...@@ -113,7 +113,8 @@ class Idefics2VisionAttention(nn.Module):
self, self,
config: Idefics2Config, config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -130,12 +131,14 @@ class Idefics2VisionAttention(nn.Module): ...@@ -130,12 +131,14 @@ class Idefics2VisionAttention(nn.Module):
self.head_dim, self.head_dim,
self.num_heads, self.num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
self.embed_dim, self.embed_dim,
self.embed_dim, self.embed_dim,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj",
) )
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
...@@ -178,7 +181,8 @@ class Idefics2VisionMLP(nn.Module): ...@@ -178,7 +181,8 @@ class Idefics2VisionMLP(nn.Module):
self, self,
config: Idefics2Config, config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
...@@ -187,12 +191,14 @@ class Idefics2VisionMLP(nn.Module): ...@@ -187,12 +191,14 @@ class Idefics2VisionMLP(nn.Module):
config.intermediate_size, config.intermediate_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1",
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc2",
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -204,13 +210,22 @@ class Idefics2VisionMLP(nn.Module): ...@@ -204,13 +210,22 @@ class Idefics2VisionMLP(nn.Module):
class Idefics2EncoderLayer(nn.Module): class Idefics2EncoderLayer(nn.Module):
def __init__(self, config: Idefics2Config): def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(config) self.self_attn = Idefics2VisionAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.layer_norm1 = nn.LayerNorm(self.embed_dim, self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config) self.mlp = Idefics2VisionMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(self.embed_dim, self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
...@@ -245,12 +260,20 @@ class Idefics2Encoder(nn.Module): ...@@ -245,12 +260,20 @@ class Idefics2Encoder(nn.Module):
config: Idefics2Config config: Idefics2Config
""" """
def __init__(self, config: Idefics2Config): def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Idefics2EncoderLayer(config) Idefics2EncoderLayer(config,
for _ in range(config.num_hidden_layers) quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
]) ])
def forward( def forward(
...@@ -275,12 +298,20 @@ class Idefics2Encoder(nn.Module): ...@@ -275,12 +298,20 @@ class Idefics2Encoder(nn.Module):
class Idefics2VisionTransformer(nn.Module): class Idefics2VisionTransformer(nn.Module):
def __init__(self, config: Idefics2VisionConfig): def __init__(
self,
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.config = config self.config = config
self.embeddings = Idefics2VisionEmbeddings(config) self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder(config) self.encoder = Idefics2Encoder(config,
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.post_layernorm = nn.LayerNorm(embed_dim, self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
......
...@@ -137,6 +137,7 @@ class InternParallelAttention(nn.Module): ...@@ -137,6 +137,7 @@ class InternParallelAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
*, *,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -165,6 +166,7 @@ class InternParallelAttention(nn.Module): ...@@ -165,6 +166,7 @@ class InternParallelAttention(nn.Module):
num_dummy_heads + self.num_heads, num_dummy_heads + self.num_heads,
bias=config.qkv_bias, bias=config.qkv_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv",
) )
self.qk_normalization = config.qk_normalization self.qk_normalization = config.qk_normalization
...@@ -181,6 +183,7 @@ class InternParallelAttention(nn.Module): ...@@ -181,6 +183,7 @@ class InternParallelAttention(nn.Module):
self.dummy_dim, self.dummy_dim,
self.embed_dim, self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.proj",
) )
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
...@@ -284,20 +287,26 @@ class InternSdpaAttention(nn.Module): ...@@ -284,20 +287,26 @@ class InternSdpaAttention(nn.Module):
class InternMLP(nn.Module): class InternMLP(nn.Module):
def __init__(self, def __init__(
config: PretrainedConfig, self,
quant_config: Optional[QuantizationConfig] = None): config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size, self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size, config.intermediate_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size, self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)
...@@ -315,6 +324,7 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -315,6 +324,7 @@ class InternVisionEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
*, *,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -324,9 +334,12 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -324,9 +334,12 @@ class InternVisionEncoderLayer(nn.Module):
self.attn = self._init_attn(config, self.attn = self._init_attn(config,
quant_config, quant_config,
num_dummy_heads=num_dummy_heads) num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.attn")
self.mlp = InternMLP(config, quant_config=quant_config) self.mlp = InternMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
...@@ -343,6 +356,7 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -343,6 +356,7 @@ class InternVisionEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig], quant_config: Optional[QuantizationConfig],
*, *,
num_dummy_heads: int, num_dummy_heads: int,
prefix: str = "",
): ):
# fallback to sdpa attention if tp unavailable # fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -351,7 +365,8 @@ class InternVisionEncoderLayer(nn.Module): ...@@ -351,7 +365,8 @@ class InternVisionEncoderLayer(nn.Module):
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0: if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config, return InternParallelAttention(config,
quant_config=quant_config, quant_config=quant_config,
num_dummy_heads=num_dummy_heads) num_dummy_heads=num_dummy_heads,
prefix=prefix)
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads) return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
...@@ -377,6 +392,7 @@ class InternVisionEncoder(nn.Module): ...@@ -377,6 +392,7 @@ class InternVisionEncoder(nn.Module):
*, *,
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -390,8 +406,9 @@ class InternVisionEncoder(nn.Module): ...@@ -390,8 +406,9 @@ class InternVisionEncoder(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
InternVisionEncoderLayer(config, InternVisionEncoderLayer(config,
quant_config, quant_config,
num_dummy_heads=num_dummy_heads) num_dummy_heads=num_dummy_heads,
for _ in range(num_hidden_layers) prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
]) ])
def forward(self, inputs_embeds: torch.Tensor): def forward(self, inputs_embeds: torch.Tensor):
...@@ -412,7 +429,8 @@ class InternVisionModel(nn.Module): ...@@ -412,7 +429,8 @@ class InternVisionModel(nn.Module):
*, *,
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0, num_dummy_heads: int = 0,
): prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -423,6 +441,7 @@ class InternVisionModel(nn.Module): ...@@ -423,6 +441,7 @@ class InternVisionModel(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.encoder",
) )
def get_input_embeddings(self): def get_input_embeddings(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment