Commit 112bf76b authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1826 canceled with stages
from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalInputs, MultiModalPlugin,
NestedTensors)
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""
The global :class:`~MultiModalRegistry` is used by model runners to
dispatch data processing according to its modality and the target model.
See also:
:ref:`input_processing_pipeline`
"""
__all__ = [
"BatchedTensorInputs",
"BatchedTensors",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalInputs",
"MultiModalPlugin",
"NestedTensors",
"MULTIMODAL_REGISTRY",
"MultiModalRegistry",
]
import sys
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
TypedDict, TypeVar, Union, cast, final)
import numpy as np
import torch
import torch.types
from PIL import Image
from torch import nn
from typing_extensions import TypeAlias
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
json_map_leaves)
logger = init_logger(__name__)
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
BatchedTensors: TypeAlias = JSONTree[torch.Tensor]
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
if sys.version_info < (3, 9):
# UserDict cannot be subscripted
class _MultiModalInputsBase(UserDict):
pass
else:
class _MultiModalInputsBase(UserDict[str, NestedTensors]):
pass
class MultiModalInputs(_MultiModalInputsBase):
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.
"""
@staticmethod
def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors:
"""
If each input tensor in the batch has the same shape, return a single
batched tensor; otherwise, return a list of :class:`NestedTensors` with
one element per item in the batch.
"""
# may be list rather than tensors
if isinstance(tensors[0], list):
return [[t for t in tensor[0]]
for tensor in cast(List[List[torch.Tensor]], tensors)]
tensors_ = cast(List[torch.Tensor], tensors)
unbatched_shape = tensors_[0].shape[1:]
for tensor in tensors_:
if tensor.shape[1:] != unbatched_shape:
return [tensor.squeeze(0) for tensor in tensors_]
return torch.cat(tensors_, dim=0)
@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
"""
Recursively stacks lists of tensors when they all have the same shape.
"""
if isinstance(nested_tensors, torch.Tensor):
return nested_tensors
if isinstance(nested_tensors, np.ndarray):
return torch.from_numpy(nested_tensors)
if isinstance(nested_tensors, (int, float)):
return torch.tensor(nested_tensors)
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.
return stacked
tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
return torch.stack(tensors_)
@staticmethod
def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
"""
Batch multiple inputs together into a dictionary.
The resulting dictionary has the same keys as the inputs.
If the corresponding value from each input is a tensor and they all
share the same shape, the output value is a single batched tensor;
otherwise, the output value is a list containing the original value
from each input.
"""
if len(inputs_list) == 0:
return {}
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
for inputs in inputs_list:
# For models that supports multiple modalities (e.g. Qwen2-VL),
# different modalities will return different data keys,
# so batch() should skip the same key check.
for k, v in inputs.items():
item_lists[k].append(v)
return {
k: MultiModalInputs._try_stack(item_list)
for k, item_list in item_lists.items()
}
@staticmethod
def as_kwargs(
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
json_inputs,
)
return cast(BatchedTensorInputs, json_mapped)
_T = TypeVar("_T")
MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data instance, or a list of data instances.
The number of data instances allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""
@final
class MultiModalDataBuiltins(TypedDict, total=False):
"""Modality types that are predefined by vLLM."""
image: MultiModalData[Image.Image]
"""The input image(s)."""
audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
"""The input audio item(s) and corresponding sampling rate(s)."""
MultiModalDataDict = Union[MultiModalDataBuiltins,
Mapping[str, MultiModalData[object]]]
"""
A dictionary containing an item for each modality type to input.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalDataBuiltins` as long as a customized plugin is registered
through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
MultiModalInputs]
"""
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers.
If the data is not supported, throw :exc:`TypeError`.
"""
MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
N = TypeVar("N", bound=Type[nn.Module])
class MultiModalPlugin(ABC):
"""
Base class that defines data processing logic for a specific modality.
In particular, we adopt a registry pattern to dispatch data processing
according to the model being used (considering that different models may
process the same data differently). This registry is in turn used by
:class:`~MultiModalRegistry` which acts at a higher level
(i.e., the modality of the data).
See also:
:ref:`adding_multimodal_plugin`
"""
def __init__(self) -> None:
self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
@abstractmethod
def get_data_key(self) -> str:
"""
Get the data key corresponding to the modality.
"""
raise NotImplementedError
@abstractmethod
def _default_input_mapper(
self,
ctx: InputContext,
data: MultiModalData[object],
) -> MultiModalInputs:
"""
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers.
If the data is not supported, throw :exc:`TypeError`.
"""
raise NotImplementedError
def register_input_mapper(
self,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper to a model class.
When the model receives input data that matches the modality served by
this plugin (see :meth:`get_data_key`), the provided function is
invoked to transform the data into a dictionary of model inputs.
If `None` is provided, then the default input mapper is used instead.
See also:
- :ref:`input_processing_pipeline`
- :ref:`enabling_multimodal_inputs`
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._input_mappers:
logger.warning(
"Model class %s already has an input mapper "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._input_mappers[model_cls] = mapper \
or self._default_input_mapper
return model_cls
return wrapper
def map_input(self, model_config: ModelConfig,
data: MultiModalData[object]) -> MultiModalInputs:
"""
Transform the data into a dictionary of model inputs using the
input mapper registered for that model.
The model is identified by ``model_config``.
Raises:
TypeError: If the data type is not supported.
See also:
- :ref:`input_processing_pipeline`
- :ref:`enabling_multimodal_inputs`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
mapper = self._input_mappers.get(model_cls)
# Only get processor kwargs at mapping time if we are not using the
# input mapper; no overrides are used on the default here because they
# should be passed to the huggingface resource at initialization time.
if mapper is not None and mapper != self._default_input_mapper:
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
mapper, overrides=model_config.mm_processor_kwargs)
else:
mm_processor_kwargs = {}
if mapper is None:
raise KeyError(f"No input mapper in {self} is registered for "
f"model class {model_cls.__name__}.")
return mapper(InputContext(model_config), data, **mm_processor_kwargs)
@abstractmethod
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
"""
Calculate the maximum number of tokens, corresponding to a single
instance of multimodal data, that are passed to the language model.
"""
raise NotImplementedError
def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
if max_mm_tokens < 1:
raise ValueError("You should set the number of tokens to a "
f"positive integer. Found: {max_mm_tokens}")
def register_max_multimodal_tokens(
self,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of tokens, corresponding to a single
instance of multimodal data, that are passed to the language model
for a model class.
If `None` is provided, then the default calculation is used instead.
See also:
:ref:`enabling_multimodal_inputs`
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._max_mm_tokens:
logger.warning(
"Model class %s already calculates maximum number of "
"tokens in %s. It is overwritten by the new one.",
model_cls, self)
if isinstance(max_mm_tokens, int):
self._validate_max_multimodal_tokens(max_mm_tokens)
self._max_mm_tokens[model_cls] = max_mm_tokens \
or self._default_max_multimodal_tokens
return model_cls
return wrapper
def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
If this registry is not applicable to the model, `0` is returned.
The model is identified by ``model_config``.
See also:
:ref:`enabling_multimodal_inputs`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
if model_cls not in self._input_mappers:
return 0
max_mm_tokens = self._max_mm_tokens.get(model_cls)
if max_mm_tokens is None:
raise KeyError(f"No maximum number of multi-modal tokens is given "
f"for model class {model_cls.__name__} in {self}.")
if callable(max_mm_tokens):
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
max_mm_tokens, overrides=model_config.mm_processor_kwargs)
max_mm_tokens = max_mm_tokens(InputContext(model_config),
**mm_processor_kwargs)
self._validate_max_multimodal_tokens(max_mm_tokens)
return max_mm_tokens
import functools
import importlib
from typing import Dict, List, Optional, Tuple, Type
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__)
_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
# ChatGLMModel supports multimodal
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
# QWenLMHeadModel supports multimodal
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"Qwen2VLForConditionalGeneration":
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"TeleChat12BForCausalLM": ("telechat_12B", "TeleChat12BForCausalLM"), # telechat12b
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MedusaModel": ("medusa", "Medusa"),
"EAGLEModel": ("eagle", "EAGLE"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"GraniteForCausalLM": ("granite", "GraniteForCausalLM")
}
_EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
}
_MULTIMODAL_MODELS = {
"Blip2ForConditionalGeneration":
("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration": ("llava",
"LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next",
"LlavaNextForConditionalGeneration"),
"LlavaNextVideoForConditionalGeneration":
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
"LlavaOnevisionForConditionalGeneration":
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
"MixtralForConditionalGeneration":
("mixtral", "MixtralForConditionalGeneration"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"PixtralForConditionalGeneration": ("pixtral",
"PixtralForConditionalGeneration"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"MllamaForConditionalGeneration": ("mllama",
"MllamaForConditionalGeneration"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
}
_MODELS = {
**_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_MULTIMODAL_MODELS,
**_CONDITIONAL_GENERATION_MODELS,
}
# Architecture -> type.
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
# "Qwen2ForCausalLM":
# _ROCM_SWA_REASON,
# "MistralForCausalLM":
# _ROCM_SWA_REASON,
# "MixtralForCausalLM":
# _ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}
class ModelRegistry:
@staticmethod
@functools.lru_cache(maxsize=128)
def _get_model(model_arch: str):
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
@staticmethod
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
if model_arch not in _MODELS:
return None
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
"Model architecture %s is partially supported by ROCm: %s",
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
return ModelRegistry._get_model(model_arch)
@staticmethod
def resolve_model_cls(
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):
if model_arch in _MODELS:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls.__name__)
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls
@staticmethod
def is_embedding_model(model_arch: str) -> bool:
return model_arch in _EMBEDDING_MODELS
@staticmethod
def is_multimodal_model(model_arch: str) -> bool:
# TODO: find a way to avoid initializing CUDA prematurely to
# use `supports_multimodal` to determine if a model is multimodal
# model_cls = ModelRegistry._try_load_model_cls(model_arch)
# from vllm.model_executor.models.interfaces import supports_multimodal
return model_arch in _MULTIMODAL_MODELS
__all__ = [
"ModelRegistry",
]
\ No newline at end of file
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union, Mapping, TypeVar
from typing_extensions import NotRequired
import torch
from torch import nn
from transformers import MixtralConfig
from transformers.activations import ACT2FN
from transformers import PreTrainedTokenizerBase
from PIL import Image
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig, ModelConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.multimodal.image import cached_get_image_processor
from vllm.logger import init_logger
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,)
from .internvl import get_max_internvl_image_tokens, dynamic_preprocess, get_internvl_num_patches, calculate_num_blocks
from .interfaces import SupportsLoRA, SupportsMultiModal
from .utils import is_pp_missing_parameter, make_layers
from .intern_vit import InternVisionModel
from .whale import WhaleAudioModel
logger = init_logger(__name__)
class MixtralImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: BatchedTensors
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
class MixtralAudioFeaturesInputs(TypedDict):
type: Literal["audio_input"]
data: BatchedTensors
"""
Shape: `(batch_size, num_channels, time)`
"""
mask: NotRequired[torch.Tensor]
MixtralImageInputs = MixtralImagePixelInputs
MixtralAudioInputs = MixtralAudioFeaturesInputs
# Utilities for input processors
_T = TypeVar("_T", str, int)
def repeat_and_pad_image_tokens(
tokenizer: PreTrainedTokenizerBase,
prompt: Optional[str],
prompt_token_ids: List[int],
*,
image_token_id: Union[int, List[int]],
repeat_count: Union[int, List[int]] = 1,
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]:
def repeat_and_pad_token(
token: _T,
*,
repeat_count: int = 1,
pad_token_left: Optional[_T] = None,
pad_token_right: Optional[_T] = None,
) -> List[_T]:
replacement = [token] * repeat_count
if pad_token_left is not None:
replacement = [pad_token_left] + replacement
if pad_token_right is not None:
replacement = replacement + [pad_token_right]
return replacement
if prompt is not None:
image_token_str = tokenizer.decode(image_token_id)
image_token_count = prompt.count(image_token_str)
elif prompt_token_ids is not None:
image_token_count = prompt_token_ids.count(image_token_id)
else:
raise ValueError("Either prompt or prompt_token_ids must be provided.")
if isinstance(repeat_count, int):
repeat_count = [repeat_count] * image_token_count
assert len(repeat_count) == image_token_count, (
f"Length of repeat_count ({len(repeat_count)}) does not match "
f"the number of image tokens ({image_token_count})."
)
if prompt is None:
new_prompt = None
else:
pad_token_str_left = (None if pad_token_left is None else
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))
replacement_strs = []
for i, rp_count in enumerate(repeat_count):
replacement_str = "".join(
repeat_and_pad_token(
image_token_str,
repeat_count=rp_count,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
)
)
replacement_strs.append(replacement_str)
prompt_split = prompt.split(image_token_str)
assert len(prompt_split) == len(replacement_strs) + 1, (
f"Length of new_prompt ({len(prompt_split)}) does not match "
f"the number of replacement strings ({len(replacement_strs)})."
)
new_prompt = []
for a, b in zip(prompt_split, replacement_strs + [None]):
new_prompt.append(a)
if b is not None:
new_prompt.append(b)
new_prompt = "".join(new_prompt)
new_token_ids: List[int] = []
for i, token in enumerate(prompt_token_ids):
if token == image_token_id:
rp_count = repeat_count.pop(0)
replacement_ids = repeat_and_pad_token(
image_token_id,
repeat_count=rp_count,
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
new_token_ids.extend(replacement_ids)
else:
new_token_ids.append(token)
return new_prompt, new_token_ids
def input_processor_for_mixtral_multimodal_base(
model_config: ModelConfig,
hf_config,
llm_inputs: LLMInputs,
*,
image_token_id: int,
audio_token_id: int,
image_feature_size_override: Optional[int] = None,
audio_feature_size_override: Optional[int] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None:
return llm_inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
image_modal_exists = "image" in multi_modal_data
audio_modal_exists = "audio" in multi_modal_data
if image_modal_exists:
images = multi_modal_data["image"]
vision_config = hf_config.vision_config
if not isinstance(images, list):
images = [images]
image_feature_sizes = []
patched_images = []
for image in images:
patched_image = dynamic_preprocess(
image,
min_num=hf_config.min_dynamic_patch,
max_num=hf_config.max_dynamic_patch,
image_size=hf_config.vision_config.image_size,
use_thumbnail=hf_config.use_thumbnail
)
patched_images += patched_image
width, height = image.size
num_blocks, _, _ = calculate_num_blocks(
width, height,
hf_config.min_dynamic_patch,
hf_config.max_dynamic_patch,
vision_config.image_size,
)
if hf_config.use_thumbnail and num_blocks > 1:
num_blocks += 1
assert num_blocks == len(patched_image), (
f"Number of patches ({len(patched_image)}) does not match "
f"the number of blocks ({num_blocks})."
)
image_feature_size_per_patch = get_internvl_num_patches(
vision_config.image_size, vision_config.patch_size,
hf_config.downsample_ratio
)
image_feature_size = image_feature_size_per_patch * num_blocks
image_feature_sizes.append(image_feature_size)
if image_feature_size_override is None:
image_feature_size = image_feature_sizes
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
repeat_count=image_feature_size,
)
multi_modal_data["image"] = patched_images
if audio_modal_exists:
def get_audio_feature_size(audio: torch.Tensor) -> int:
input_size = int(audio.shape[0])
downsample_size = ((input_size - 1) // 2 - 1) // 2
projected_size = (downsample_size - 1) // 2 + 1
return projected_size
if audio_feature_size_override is None:
audio_feature_size = [get_audio_feature_size(x) for x in multi_modal_data["audio"]]
else:
audio_feature_size = audio_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
tokenizer,
new_prompt if image_modal_exists else llm_inputs.get("prompt"),
new_token_ids if image_modal_exists else llm_inputs["prompt_token_ids"],
image_token_id=audio_token_id,
repeat_count=audio_feature_size,
)
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
def input_processor_for_mixtral_multimodal(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or ("image" not in multi_modal_data and "audio" not in multi_modal_data):
return llm_inputs
model_config = ctx.model_config
hf_config = ctx.model_config.hf_config
return input_processor_for_mixtral_multimodal_base(
model_config,
hf_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
audio_token_id=hf_config.audio_token_index,
)
def vision_input_mapper_for_mixtral(ctx: InputContext, data: object):
model_config = ctx.model_config
if not isinstance(data, List):
data = [data]
if all(isinstance(x, Image.Image) for x in data):
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
try:
batch_data = image_processor \
.preprocess(data, return_tensors="pt").to(model_config.dtype) \
.data
except Exception:
logger.error("Failed to process image (%s)", data)
raise
return MultiModalInputs(batch_data)
elif all(isinstance(x, torch.Tensor) for x in data):
raise NotImplementedError("Embeddings input is not supported yet")
raise TypeError(f"Invalid image type: {type(data)}")
def dummy_data_for_mixtral_multimodal(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
image_feature_size = get_internvl_num_patches(
vision_config.image_size,
vision_config.patch_size,
hf_config.downsample_ratio,
)
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images=num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
vision_config,
num_images,
image_width_override=vision_config.image_size,
image_height_override=vision_config.image_size,
)
return seq_data, mm_data
class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = hidden_size
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size,
num_experts,
bias=False,
params_dtype=params_dtype,
quant_config=None,
prefix=f"{prefix}.gate")
self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)
class MixtralAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states)
return hidden_states, residual
class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
input_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) if input_ids is not None else input_embeds
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class MixtralForCausalLM(nn.Module, SupportsLoRA):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.lora_config = lora_config
self.model = MixtralModel(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class MixtralMultiModalVisionProjector(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
# self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_1 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class MixtralMultiModalAudioProjector(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.audio_hidden_size = config.audio_config.hidden_size
self.text_hidden_size = config.text_config.hidden_size
self.kernel_size = config.audio_projector_kernel_size
self.left_padding = nn.ConstantPad1d(padding=(0, self.kernel_size - 1), value=0.0)
self.conv1d = nn.Conv1d(
in_channels=self.audio_hidden_size,
out_channels=2 * self.audio_hidden_size,
kernel_size=self.kernel_size,
stride=2,
padding=0
)
self.norm = nn.LayerNorm(2 * self.audio_hidden_size, eps=1e-3)
self.act = ACT2FN[config.audio_projector_hidden_act]
self.linear = nn.Linear(2 * self.audio_hidden_size, self.text_hidden_size)
def forward(self, audio_features, mask_pad=None):
"""
x: B, T, enc_out_dim
mask: (B, T)
"""
if mask_pad is not None:
audio_features.masked_fill_(~mask_pad.bool().unsqueeze(-1), 0.0)
audio_features = audio_features.transpose(1, 2) # B, channels, T
hidden_states = self.left_padding(audio_features)
hidden_states = self.conv1d(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.norm(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear(hidden_states)
return hidden_states, mask_pad[:, 0::2]
def audio_input_mapper_for_mixtral(ctx: InputContext, data: object):
model_config = ctx.model_config
if not isinstance(data, List):
data = [data]
from torch.nn.utils.rnn import pad_sequence
if all(isinstance(x, torch.Tensor) for x in data):
# x of shape (length, hidden_size)
lengths = [x.shape[0] for x in data]
data = pad_sequence(data, batch_first=True, padding_value=0)
num_samples, max_length = data.shape[:2]
# Create mask
mask = torch.zeros((num_samples, max_length), dtype=torch.float)
for i, length in enumerate(lengths):
mask[i, :length] = 1
batch_data = {
"audio_input": data.to(model_config.dtype),
"audio_mask": mask.to(model_config.dtype),
}
return MultiModalInputs(batch_data)
raise TypeError(f"Invalid image type: {type(data)}")
@MULTIMODAL_REGISTRY.register_input_mapper("image", vision_input_mapper_for_mixtral)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", audio_input_mapper_for_mixtral)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("image", get_max_internvl_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", 1024)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_mixtral_multimodal)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mixtral_multimodal)
class MixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: MixtralConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
lora_config: Optional[LoRAConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
self.lora_config = lora_config
vision_feature_layer = config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
# TODO: Optionally initializes this for supporting embeddings.
if hasattr(config, "vision_config"):
self.vision_tower = InternVisionModel(
config=config.vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
)
self.vision_projector = MixtralMultiModalVisionProjector(config)
if hasattr(config, "audio_config"):
self.audio_tower = WhaleAudioModel(config=config.audio_config, quant_config=quant_config)
self.audio_projector = MixtralMultiModalAudioProjector(config)
self.quant_config = quant_config
self.language_model = MixtralModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.text_config.vocab_size, logit_scale)
self.sampler = Sampler()
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
x = x.permute(0, 2, 1, 3).contiguous()
return x
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[MixtralImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return MixtralImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(self, vision_tower: InternVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
image_features = vision_tower(pixel_values)
image_features = self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)
h = w = int(image_features.shape[1] ** 0.5)
assert image_features.shape[1] == h * w
image_features = image_features.reshape(image_features.shape[0], h, w, -1)
image_features = self.pixel_shuffle(image_features * 0.5)
image_features = image_features.reshape(image_features.shape[0], -1, image_features.shape[-1])
return image_features
def _process_image_input(self,
image_input: MixtralImageInputs) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = image_input["data"]
image_features = self._image_pixels_to_features(self.vision_tower, pixel_values)
return self.vision_projector(image_features)
def _process_audio_input(self,
inputs: MixtralAudioInputs) -> torch.Tensor:
assert self.audio_tower is not None
audio_input = inputs["data"]
audio_masks = inputs["mask"]
audio_features = self.audio_tower(audio_input, audio_masks)["last_hidden_state"]
audio_masks = audio_masks[:, 2::2][:, 2::2]
return self.audio_projector(audio_features, audio_masks)
def _validate_audio_input(self, data: torch.Tensor) -> torch.Tensor:
c, t = self.config.audio_config.num_channels, self.config.audio_config.input_dim
expected_dims = (c, t)
assert t == data.shape[-1]
return data
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[MixtralAudioInputs]:
audio_input = kwargs.pop("audio_input", None)
audio_mask = kwargs.pop("audio_mask", None)
if audio_input is None:
return None
if not isinstance(audio_input, torch.Tensor):
raise ValueError("Incorrect type of audio input. "
f"Got type: {type(audio_input)}")
if audio_mask is not None and not isinstance(audio_mask, torch.Tensor):
raise ValueError("Incorrect type of audio mask. "
f"Got type: {type(audio_mask)}")
return MixtralAudioInputs(
type="audio_input",
data=self._validate_audio_input(audio_input),
mask=audio_mask,
)
def merge_multimodal_embeddings(self,
input_ids: torch.Tensor,
input_embeds: torch.Tensor,
vision_embeddings: BatchedTensors,
vision_masks: Optional[torch.Tensor],
image_token_id: int) -> torch.Tensor:
"""
Merge `vision_embeddings` into `input_embeds` by overwriting the positions
in `input_embeds` corresponding to placeholder image tokens in `input_ids`.
Note:
This updates `input_embeds` in place.
"""
mask = (input_ids == image_token_id)
num_expected_tokens = mask.sum()
if isinstance(vision_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
vision_embeddings = vision_embeddings.view(batch_size * batch_tokens, embed_dim)
if vision_masks is not None:
vision_masks = vision_masks.reshape(batch_size * batch_tokens).bool()
vision_embeddings = vision_embeddings[vision_masks]
total_tokens = vision_embeddings.shape[0]
if num_expected_tokens != total_tokens:
expr = f"{batch_size} x {batch_tokens}"
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"image tokens to {num_expected_tokens} placeholders")
input_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
else:
size_per_batch = [t.shape[0] for t in vision_embeddings]
total_tokens = sum(size_per_batch)
if num_expected_tokens != total_tokens:
expr = ' + '.join(map(str, size_per_batch))
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"image tokens to {num_expected_tokens} placeholders")
input_embeds[mask] = torch.cat(vision_embeddings)
return input_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
input_embeds = self.language_model.embed_tokens(input_ids)
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
input_embeds = self.merge_multimodal_embeddings(
input_ids, input_embeds, vision_embeddings, None,
self.config.image_token_index,
)
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is not None:
audio_embeddings, audio_masks = self._process_audio_input(audio_input)
input_embeds = self.merge_multimodal_embeddings(
input_ids, input_embeds, audio_embeddings, audio_masks,
self.config.audio_token_index,
)
if image_input is not None or audio_input is not None:
input_ids = None
hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=None,
input_embeds=input_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
_KEYS_TO_MODIFY_MAPPING_LLM = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
"model.layer": "language_model.layer",
"model.embed_tokens": "language_model.embed_tokens",
"model.norm": "language_model.norm",
}
_KEYS_TO_MODIFY_MAPPING_VISION = {
"model.vision_tower.vision_tower": "vision_tower",
"model.mm_projector.0": "vision_projector.linear_1",
"model.mm_projector.2": "vision_projector.linear_2",
}
_KEYS_TO_MODIFY_MAPPING_AUDIO = {
"self_attn": "attn",
"model.audio_encoder.encoder.enc.0.core.conv": "audio_tower.subsampling.conv_in",
"model.audio_encoder.encoder.enc.0.core.out.0": "audio_tower.subsampling.out",
"model.audio_encoder.encoder.enc.1.embed": "audio_tower.embeddings.embedding",
"model.audio_encoder.encoder.enc.1.encoders": "audio_tower.encoder.layers",
"model.audio_encoder.encoder.enc.1.after_norm": "audio_tower.encoder.layer_norm",
"model.audio_encoder.adpter.bn2": "audio_projector.norm",
"model.audio_encoder.adpter.conv1d2": "audio_projector.conv1d",
"model.audio_encoder.adpter.project": "audio_projector.linear",
}
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.text_config.num_local_experts)
params_dict = dict(self.named_parameters())
intialized_dict = {}
for name, loaded_weight in weights:
orig_name = name
intialized_dict[orig_name] = False
if not hasattr(self, "vision_tower") and ("vision_tower" in orig_name or "mm_projector" in orig_name):
continue
if not hasattr(self, "audio_tower") and ("audio_encoder" in orig_name or "audio_projector" in orig_name):
continue
if "rotary_emb.inv_freq" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING_LLM.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
if "vision_tower" in orig_name or "mm_projector" in orig_name:
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING_VISION.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
if "audio_encoder" in orig_name:
if "global_cmvn" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING_AUDIO.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
intialized_dict[orig_name] = True
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
shard_id=shard_id,
expert_id=expert_id)
intialized_dict[orig_name] = True
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
intialized_dict[orig_name] = True
uninitalized_names = [k for k, v in intialized_dict.items() if not v]
if uninitalized_names:
print(f"Uninitialized parameters: {uninitalized_names}")
# --------------------------------------------------------
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import os
from typing import Union
from dataclasses import dataclass
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.logger import init_logger
logger = init_logger(__name__)
_CONFIG_FOR_DOC = "WhaleConfig"
from transformers.configuration_utils import PretrainedConfig
class WhaleConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`WhaleModel`]. It is used to
instantiate a vision encoder according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
input_dim (`int`, *optional*, defaults to 80):
Dimensionality of the input features.
num_channels (`int`, *optional*, defaults to 1):
Number of color channels in the input images (e.g., 1 for grayscale).
qkv_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the queries and values in the self-attention layers.
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers and the pooler layer.
num_attention_heads (`int`, *optional*, defaults to 25):
Number of attention heads for each attention layer in the Transformer encoder.
max_position_embeddings (`int`, *optional*, defaults to 5000):
The maximum number of position embeddings.
intermediate_size (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
qk_normalization (`bool`, *optional*, defaults to `True`):
Whether to normalize the queries and keys in the self-attention layers.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
use_flash_attn (`bool`, *optional*, defaults to `True`):
Whether to use flash attention.
hidden_act (`str` or `function`, *optional*, defaults to `"relu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the layer normalization layers.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
positional_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the positional encodings.
normalize_before (`bool`, *optional*, defaults to `True`):
Whether to apply layer normalization before the attention and feed-forward operations.
concat_after (`bool`, *optional*, defaults to `True`):
Whether to concatenate the attention output with the input before the feed-forward layer.
use_relative_pe (`bool`, *optional*, defaults to `True`):
Whether to use relative position encodings.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 0.1):
A factor for layer scale.
"""
model_type = 'whale'
def __init__(
self,
input_dim=80,
num_channels=1,
qkv_bias=False,
hidden_size=1024,
num_attention_heads=25,
max_position_embeddings=5000,
intermediate_size=4096,
qk_normalization=True,
num_hidden_layers=48,
use_flash_attn=True,
hidden_act='relu',
layer_norm_eps=1e-6,
dropout=0.0,
attention_dropout=0.0,
positional_dropout=0.0,
normalize_before=True,
concat_after=True,
use_relative_pe=True,
initializer_range=0.02,
initializer_factor=0.1,
**kwargs,
):
super().__init__(**kwargs)
self.input_dim = input_dim
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.dropout = dropout
self.num_channels = num_channels
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.positional_dropout = positional_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.qkv_bias = qkv_bias
self.qk_normalization = qk_normalization
self.use_flash_attn = use_flash_attn
self.normalize_before = normalize_before
self.concat_after = concat_after
self.max_position_embeddings = max_position_embeddings
self.use_relative_pe = use_relative_pe
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if 'audio_config' in config_dict:
config_dict = config_dict['audio_config']
if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
)
return cls.from_dict(config_dict, **kwargs)
has_flash_attn = False
class WhaleConv2dSubsampling4(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, config: WhaleConfig):
"""Construct an Conv2dSubsampling4 object."""
super().__init__()
self.config = config
self.in_channels = config.num_channels
self.hidden_size = config.hidden_size
self.input_dim = config.input_dim
self.conv_in = nn.Sequential(
nn.Conv2d(
in_channels=self.in_channels, out_channels=self.hidden_size, kernel_size=3, stride=2
),
nn.ReLU(),
nn.Conv2d(
in_channels=self.hidden_size, out_channels=self.hidden_size, kernel_size=3, stride=2
),
nn.ReLU(),
)
self.intermediate_size = self.hidden_size * (((self.input_dim - 1) // 2 - 1) // 2)
self.out = nn.Linear(self.intermediate_size, self.hidden_size)
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
self.subsampling_rate = 4
# 6 = (3 - 1) * 1 + (3 - 1) * 2
self.right_context = 6
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv_in(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x, x_mask[:, 2::2][:, 2::2]
class WhalePositionalEncoding(torch.nn.Module):
"""Positional encoding.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
"""
def __init__(self, config: WhaleConfig):
"""Construct an PositionalEncoding object."""
super().__init__()
self.d_model = config.hidden_size
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=config.dropout)
self.max_len = config.max_position_embeddings
self.pe = torch.zeros(self.max_len, self.d_model)
position = torch.arange(0, self.max_len,
dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) *
-(math.log(10000.0) / self.d_model))
self.pe[:, 0::2] = torch.sin(position * div_term)
self.pe[:, 1::2] = torch.cos(position * div_term)
self.pe = self.pe.unsqueeze(0)
def forward(self,
x: torch.Tensor,
offset: int = 0):
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
offset (int): position offset
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
torch.Tensor: for compatibility to RelPositionalEncoding
"""
assert offset + x.size(1) < self.max_len
self.pe = self.pe.to(x.device)
pos_emb = self.pe[:, offset:offset + x.size(1)]
x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self, offset: int, size: int):
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int): start offset
size (int): requried size of position encoding
Returns:
torch.Tensor: Corresponding encoding
"""
assert offset + size < self.max_len
return self.dropout(self.pe[:, offset:offset + size])
class RelPositionalEncoding(WhalePositionalEncoding):
"""Relative positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, config: WhaleConfig):
"""Initialize class."""
super().__init__(config)
self.hidden_size = config.hidden_size
# self.chunk_size = chunk_size
# self.left_chunks = left_chunks
# self.full_chunk_size = (self.left_chunks + 1) * self.chunk_size
self.div_term = torch.exp(
torch.arange(0, self.hidden_size, 2, dtype=torch.float32) *
-(math.log(10000.0) / self.hidden_size))
self.max_length = config.max_position_embeddings
# self.max_len = self.chunk_size * (max_len // self.chunk_size) - self.full_chunk_size
@torch.jit.export
def forward(self,
x: torch.Tensor,
offset: int = 0):
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.pe = self.pe.to(x.device)
x = x * self.xscale
pos_emb = self.pe[:, offset:offset + x.size(1)]
return self.dropout(x), self.dropout(pos_emb)
class WhaleAudioEmbeddings(nn.Module):
def __init__(self, config: WhaleConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.embed_dim = config.hidden_size
self.dropout_rate = config.dropout
self.input_dim = config.input_dim
self.embedding = nn.Sequential(
nn.Linear(config.hidden_size, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.Dropout(self.dropout_rate),
nn.ReLU()
)
self.positional_embedding = RelPositionalEncoding(config)
def forward(self, input_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.embedding(input_features)
hidden_states, pos_embeds = self.positional_embedding(hidden_states)
return hidden_states, pos_embeds
class WhaleAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: WhaleConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.use_flash_attn = config.use_flash_attn and has_flash_attn
if config.use_flash_attn and not has_flash_attn:
print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).'
)
self.scale = self.head_dim ** -0.5
self.linear_q = nn.Linear(self.embed_dim, self.embed_dim)
self.linear_k = nn.Linear(self.embed_dim, self.embed_dim)
self.linear_v = nn.Linear(self.embed_dim, self.embed_dim)
self.linear_out = nn.Linear(self.embed_dim, self.embed_dim)
self.attn_drop = nn.Dropout(config.attention_dropout)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.linear_out = nn.Linear(self.embed_dim, self.embed_dim)
self.use_relative_pe = config.use_relative_pe
if self.use_relative_pe:
self.linear_pos = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_dim))
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
def _naive_attn(self, x, attention_mask=None, pos_embeds=None):
B, N, C = x.shape
q = self.linear_q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k = self.linear_k(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v = self.linear_v(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
if self.qk_normalization:
B_, H_, N_, D_ = q.shape
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
if self.use_relative_pe:
q = q.transpose(1, 2)
batch_size = pos_embeds.size(0)
p = self.linear_pos(pos_embeds.to(q.dtype)).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
query_with_bias_u = (q + self.pos_bias_u.to(q.device)).transpose(1, 2)
query_with_bias_v = (q + self.pos_bias_v.to(q.device)).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
matrix_ac = torch.matmul(query_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
matrix_bd = torch.matmul(query_with_bias_v, p.transpose(-2, -1))
attn = (matrix_ac + matrix_bd) * self.scale
else:
attn = ((q * self.scale) @ k.transpose(-2, -1))
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attn = attn.masked_fill(~attention_mask.bool(), float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.linear_out(x)
return x
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor = None,
pos_embeds: torch.Tensor = None
) -> torch.Tensor:
x = self._naive_attn(hidden_states, attention_mask, pos_embeds)
return x
class WhaleMLP(nn.Module):
def __init__(self, config: WhaleConfig, quant_config=None):
super().__init__()
self.config = config
self.act = get_act_fn(config.hidden_act)
self.w_1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config)
self.w_2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config)
self.dropout = nn.Dropout(config.dropout)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.w_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states, _ = self.w_2(hidden_states)
return hidden_states
class WhaleAudioEncoderLayer(nn.Module):
def __init__(self, config: WhaleConfig, quant_config=None):
super().__init__()
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.dropout_rate = config.dropout
self.normalize_before = config.normalize_before
self.concat_after = config.concat_after
self.attn = WhaleAttention(config)
self.feed_forward = WhaleMLP(config, quant_config=quant_config)
self.norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.dropout)
if self.concat_after:
self.concat_linear = nn.Linear(self.embed_dim * 2, self.embed_dim)
else:
self.concat_linear = nn.Identity()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
pos_emb: torch.Tensor,
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
"""
Args:
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
"""
residual = hidden_states
if self.normalize_before:
hidden_states = self.norm1(hidden_states)
if self.concat_after:
hidden_states = torch.cat(
[hidden_states, self.attn(hidden_states, attention_mask, pos_emb)],
dim=-1
)
hidden_states = self.concat_linear(hidden_states) + residual
else:
hidden_states = self.dropout(self.attn(hidden_states, attention_mask, pos_emb)) + residual
if not self.normalize_before:
hidden_states = self.norm1(hidden_states)
residual = hidden_states
if self.normalize_before:
hidden_states = self.norm2(hidden_states)
hidden_states = self.dropout(self.feed_forward(hidden_states)) + residual
if not self.normalize_before:
hidden_states = self.norm2(hidden_states)
return hidden_states
class WhaleAudioEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`InternEncoderLayer`].
Args:
config (`InternConfig`):
The corresponding vision configuration for the `InternEncoder`.
"""
def __init__(self, config: WhaleConfig, quant_config=None):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
WhaleAudioEncoderLayer(config, quant_config=quant_config) for idx in range(config.num_hidden_layers)])
self.gradient_checkpointing = True
self.normalize_before = config.normalize_before
if self.normalize_before:
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.FloatTensor] = None,
pos_embeds: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint(
encoder_layer,
hidden_states,
attention_mask,
pos_embeds,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
pos_embeds,
)
hidden_states = layer_outputs
if self.normalize_before:
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states
)
class WhaleAudioModel(PreTrainedModel):
main_input_name = 'input_features'
config_class = WhaleConfig
_no_split_modules = ['WhaleAudioEncoderLayer']
def __init__(self, config: WhaleConfig, quant_config=None):
super().__init__(config)
self.config = config
self.subsampling = WhaleConv2dSubsampling4(config)
self.embeddings = WhaleAudioEmbeddings(config)
self.encoder = WhaleAudioEncoder(config, quant_config=quant_config)
def get_input_embeddings(self):
return self.embeddings
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_embeds: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_features is None and pixel_embeds is None:
raise ValueError('You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
else:
if len(input_features.shape) == 3:
input_features, attention_mask = self.subsampling(input_features, attention_mask)
hidden_states, pos_embeds = self.embeddings(input_features)
else:
raise ValueError(f'wrong pixel_values size: {input_features.shape}')
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
pos_embeds=pos_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs.last_hidden_state
pooled_output = last_hidden_state[:, 0, :]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
\ No newline at end of file
import onnxruntime
import torch
import numpy as np
import math
import sys
import os
import torchaudio
import pyaudio
import torchaudio.compliance.kaldi as k
class VADIterator:
def __init__(self,
model,
threshold: float = 0.7,
sampling_rate: int = 16000,
min_silence_duration_ms: int = 500,
speech_pad_ms: int = 30
):
"""
Class for stream imitation
Parameters
----------
model: preloaded .jit silero VAD model
threshold: float (default - 0.5)
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
sampling_rate: int (default - 16000)
Currently silero VAD models support 8000 and 16000 sample rates
min_silence_duration_ms: int (default - 100 milliseconds)
In the end of each speech chunk wait for min_silence_duration_ms before separating it
speech_pad_ms: int (default - 30 milliseconds)
Final speech chunks are padded by speech_pad_ms each side
"""
self.model = model
self.threshold = threshold
self.sampling_rate = sampling_rate
if sampling_rate not in [8000, 16000]:
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
self.reset_states()
def reset_states(self):
self.model.reset_states()
self.triggered = False
self.temp_end = 0
self.current_sample = 0
@torch.no_grad()
def __call__(self, x, return_seconds=False):
"""
x: torch.Tensor
audio chunk (see examples in repo)
return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples)
"""
if not torch.is_tensor(x):
try:
x = torch.Tensor(x)
except:
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
self.current_sample += window_size_samples
speech_prob = self.model(x, self.sampling_rate).item()
if (speech_prob >= self.threshold) and self.temp_end:
self.temp_end = 0
if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True
speech_start = self.current_sample - self.speech_pad_samples - window_size_samples
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end:
self.temp_end = self.current_sample
if self.current_sample - self.temp_end < self.min_silence_samples:
return None
else:
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
self.temp_end = 0
self.triggered = False
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
return None
class WakeupAndVAD:
def __init__(self, model_dir, keyword=None, cache_history=10, threshold=0.1):
self.model_dir = model_dir
self.chunk_size = 16
self.chunk_overlap = 0
self.feat_dim = 80
self.frame_shift = 256
self.CHUNK = self.frame_shift * self.chunk_size
self.cache_history = cache_history
self.keyword = keyword
self.threshold = threshold
self.is_wakeup = False
self.in_dialog = False
self.dialog_time = 0
with torch.no_grad():
self.load_vad()
self.reset_dialog()
self.history = torch.zeros(self.cache_history * 16000)
def get_chunk_size(self):
return self.CHUNK
def load_model(self):
self.sess_opt = onnxruntime.SessionOptions()
self.sess_opt.intra_op_num_threads = 4
self.sess_opt.inter_op_num_threads = 4
sys.path.append(os.path.abspath(self.model_dir))
self.input_chunk = torch.zeros([1, self.chunk_size + self.chunk_overlap, self.feat_dim])
self.input_sample = torch.zeros([1, self.CHUNK + self.frame_shift , 1])
def load_cmvn(self):
cmvn_info = torch.load(f"{self.model_dir}/cmvn.dict")
means = cmvn_info['mean_stat']
variance = cmvn_info['var_stat']
count = cmvn_info['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
self.cmvn = np.array([means, variance]).astype(np.float32)
def load_vad(self):
self.vad_model = torch.jit.load(f"{self.model_dir}/silero_vad.jit")
self.vad_iterator = VADIterator(self.vad_model)
self.vad_model_post = torch.jit.load(f"{self.model_dir}/silero_vad.jit")
self.vad_iterator_post = VADIterator(self.vad_model_post, min_silence_duration_ms=50)
def reset_dialog(self):
self.vad_iterator.reset_states()
self.in_dialog = False
self.dialog_time = 0
self.dialog_part = torch.zeros([0,])
def post_process_history(self, history):
self.vad_iterator_post.reset_states()
self.time_stamps = []
for i in range(0, len(history) // 1024 * 1024, 1024):
speech_dict = self.vad_iterator_post(history[i: i+ 1024], return_seconds=True)
if speech_dict is not None and 'start' in speech_dict:
self.time_stamps.append(speech_dict['start'])
if self.cache_history - self.time_stamps[-1] < 1.5:
history = history[:int(self.time_stamps[-1] * 16000)]
return history
def predict(self,
audio: torch.Tensor):
with torch.no_grad():
audio = audio.clone().detach()
speech_dict = self.vad_iterator(audio.reshape(-1), return_seconds=True)
# print(speech_dict)
if self.in_dialog:
self.dialog_part = torch.cat([self.dialog_part, audio.reshape(-1)])
if speech_dict is not None:
if 'start' in speech_dict:
self.in_dialog = True
self.dialog_part = torch.cat([self.last_audio.reshape(-1), audio.reshape(-1)])
return speech_dict
if self.in_dialog and 'end' in speech_dict:
output = {"cache_dialog": self.dialog_part.clone()}
self.reset_dialog()
self.is_wakeup = False
return output
self.last_audio = audio.clone()
return None
import torch
import os
import argparse
import numpy as np
import copy
import gradio as gr
import re
import torchaudio
import io
import ffmpeg
from vita.constants import DEFAULT_AUDIO_TOKEN, DEFAULT_IMAGE_TOKEN, MAX_IMAGE_LENGTH, MIN_IMAGE_LENGTH
from vita.conversation import conv_templates, SeparatorStyle
from vita.util.mm_utils import tokenizer_image_token, tokenizer_image_audio_token
from PIL import Image
from decord import VideoReader, cpu
from vllm import LLM, SamplingParams
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoFeatureExtractor
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
def remove_special_characters(input_str):
return input_str.replace('<2>', '').replace('<1>', '').replace('<3>', '')
def is_video(file_path):
video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'}
_, ext = os.path.splitext(file_path)
return ext.lower() in video_extensions
def is_image(file_path):
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'}
_, ext = os.path.splitext(file_path)
return ext.lower() in image_extensions
def is_wav(file_path):
wav_extensions = {'.wav'}
_, ext = os.path.splitext(file_path)
return ext.lower() in wav_extensions
def convert_webm_to_mp4(input_file, output_file):
try:
(
ffmpeg
.input(input_file)
.output(output_file, vcodec='libx264', acodec='aac')
.run()
)
print(f"Conversion successful: {output_file}")
except ffmpeg.Error as e:
print(f"Error: {e.stderr.decode()}")
raise
def _get_rawvideo_dec(video_path, max_frames=MAX_IMAGE_LENGTH, min_frames=MIN_IMAGE_LENGTH, video_framerate=1, s=None, e=None):
if s is None or e is None:
start_time, end_time = None, None
else:
start_time = int(s)
end_time = int(e)
start_time = max(start_time, 0)
end_time = max(end_time, 0)
if start_time > end_time:
start_time, end_time = end_time, start_time
elif start_time == end_time:
end_time = start_time + 1
if os.path.exists(video_path):
vreader = VideoReader(video_path, ctx=cpu(0))
else:
raise FileNotFoundError
fps = vreader.get_avg_fps()
f_start = 0 if start_time is None else int(start_time * fps)
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
num_frames = f_end - f_start + 1
if num_frames > 0:
sample_fps = int(video_framerate)
t_stride = int(round(float(fps) / sample_fps))
all_pos = list(range(f_start, f_end + 1, t_stride))
if len(all_pos) > max_frames:
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
elif len(all_pos) < min_frames:
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int)]
else:
sample_pos = all_pos
patch_images = [Image.fromarray(f).convert("RGB") for f in vreader.get_batch(sample_pos).asnumpy()]
return patch_images, len(patch_images)
else:
print(f"video path: {video_path} error.")
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = "<br></code></pre>"
else:
if i > 0 and count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
return "".join(lines)
def _launch_demo(llm, model_config, sampling_params, tokenizer, feature_extractor):
def predict(_chatbot, task_history):
chat_query = task_history[-1][0]
print(task_history)
conv_mode = "mixtral_two"
conv = conv_templates[conv_mode].copy()
all_audio_path = []
all_visual_tensor = []
qs = ''
input_mode = 'lang'
for i, (q, a) in enumerate(task_history):
if isinstance(q, (tuple, list)):
if is_image(q[0]):
images = [Image.open(q[0]).convert("RGB")]
all_visual_tensor.extend(images)
input_mode = 'image'
qs += DEFAULT_IMAGE_TOKEN * len(images) + '\n'
elif is_video(q[0]):
video_frames, slice_len = _get_rawvideo_dec(q[0])
all_visual_tensor.extend(video_frames)
input_mode = 'video'
qs += DEFAULT_IMAGE_TOKEN * slice_len + '\n'
elif is_wav(q[0]):
if a is not None and a.startswith('<2>'):
continue
else:
all_audio_path.append(q[0])
new_q = qs + DEFAULT_AUDIO_TOKEN
qs = ''
conv.append_message(conv.roles[0], new_q)
conv.append_message(conv.roles[1], a)
else:
new_q = qs + q
qs = ''
conv.append_message(conv.roles[0], new_q)
conv.append_message(conv.roles[1], a)
print(conv)
prompt = conv.get_prompt(input_mode)
if all_audio_path != []:
input_ids = tokenizer_image_audio_token(
prompt, tokenizer,
image_token_index=model_config.image_token_index,
audio_token_index=model_config.audio_token_index
)
audio_list = []
for single_audio_path in all_audio_path:
try:
audio, original_sr = torchaudio.load(single_audio_path)
# The FeatureExtractor was trained using a sampling rate of 16000 Hz
target_sr = 16000
# Resample
if original_sr != target_sr:
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
audio = resampler(audio)
audio_features = feature_extractor(audio, sampling_rate=target_sr, return_tensors="pt")["input_features"]
audio_list.append(audio_features.squeeze(0))
except Exception as e:
print(f"Error processing {single_audio_path}: {e}")
else:
input_ids = tokenizer_image_token(
prompt, tokenizer,
image_token_index=model_config.image_token_index
)
if all_visual_tensor == [] and all_audio_path == []:
datapromt={
"prompt_token_ids": input_ids,
}
elif all_visual_tensor != [] and all_audio_path == []:
datapromt={
"prompt_token_ids": input_ids,
"multi_modal_data": {
"image": all_visual_tensor
},
}
elif all_visual_tensor == [] and all_audio_path != []:
datapromt={
"prompt_token_ids": input_ids,
"multi_modal_data": {
"audio": audio_list
},
}
else:
datapromt={
"prompt_token_ids": input_ids,
"multi_modal_data": {
"image": all_visual_tensor,
"audio": audio_list
},
}
output = llm.generate(datapromt, sampling_params=sampling_params)
outputs = output[0].outputs[0].text
task_history[-1] = (chat_query, outputs)
remove_special_characters_output = remove_special_characters(outputs)
_chatbot[-1] = (chat_query, _parse_text(remove_special_characters_output))
print("query",chat_query)
print("task_history",task_history)
print(_chatbot)
print("answer: ",outputs)
yield _chatbot
def add_text(history, task_history, text):
task_text = text
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
task_text = text[:-1]
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ""
def add_file(history, task_history, file):
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def add_audio(history, task_history, file):
print(file)
if file is None:
return history, task_history
history = history + [((file,), None)]
task_history = task_history + [((file,), None)]
return history, task_history
def add_video(history, task_history, file):
print(file)
if file is None:
return history, task_history
new_file_name = file.replace(".webm",".mp4")
if file.endswith(".webm"):
convert_webm_to_mp4(file, new_file_name)
task_history = task_history + [((new_file_name,), None)]
return history, task_history
def reset_user_input():
return gr.update(value="")
def reset_state(task_history):
task_history.clear()
return []
with gr.Blocks(title="VideoMLLM") as demo:
gr.Markdown("""<center><font size=8>VITA</center>""")
chatbot = gr.Chatbot(label='VITA', elem_classes="control-height", height=500)
query = gr.Textbox(lines=2, label='Text Input')
task_history = gr.State([])
with gr.Row():
add_text_button = gr.Button("Submit Text (提交文本)")
add_audio_button = gr.Button("Submit Audio (提交音频)")
with gr.Row():
with gr.Column(scale=2):
addfile_btn = gr.UploadButton("📁 Upload (上传文件[视频,图片])", file_types=["video", "image"])
video_input = gr.Video(sources=[ "webcam"], height=400, width=700, container=True, interactive=True, show_download_button=True, label="📹 Video Recording (视频录制)")
with gr.Column(scale=1):
empty_bin = gr.Button("🧹 Clear History (清除历史)")
record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="🎤 Record or Upload Audio (录音或上传音频)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then(
reset_user_input, [], [query]
).then(
predict, [chatbot, task_history], [chatbot], show_progress=True
)
video_input.stop_recording(add_video, [chatbot, task_history, video_input], [chatbot, task_history], show_progress=True)
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then(
predict, [chatbot, task_history], [chatbot], show_progress=True
)
server_port = 18806
demo.launch(
share=False,
debug=True,
server_name="0.0.0.0",
server_port=server_port,
show_api=False,
show_error=False,
auth=('123','123'),
)
def main(model_path):
llm = LLM(
model=model_path,
dtype="float16",
tensor_parallel_size=2,
trust_remote_code=True,
gpu_memory_utilization=0.85,
disable_custom_all_reduce=True,
limit_mm_per_prompt={'image':256,'audio':50}
)
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
sampling_params = SamplingParams(temperature=0.01, max_tokens=512, best_of=1, skip_special_tokens=False)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path, subfolder="feature_extractor", trust_remote_code=True)
_launch_demo(llm, model_config, sampling_params, tokenizer, feature_extractor)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run the web demo with your model path.')
parser.add_argument('model_path', type=str, help='Path to the model')
args = parser.parse_args()
main(args.model_path)
decord==0.6.0
fastapi==0.115.2
ffmpeg-python==0.2.0
ffmpy==0.4.0
gradio==5.4.0
numpy==1.26.4
onnxruntime==1.16.3
opencv-python==4.10.0.84
pyaudio==0.2.14
pydantic==2.8.2
shortuuid==1.0.13
tencentcloud-sdk-python-tts
#torch==2.4.0
#torchaudio==2.4.0
#torchvision==0.19.0
transformers==4.44.2
#vllm==0.5.5
#vllm-flash-attn==2.6.1
#xformers==0.0.27.post2
import base64
import datetime
import io
import multiprocessing
import re
from typing import AsyncGenerator
from transformers import AutoTokenizer, AutoFeatureExtractor
from PIL import Image
from vllm import LLM, SamplingParams
import time
import torchaudio
import numpy as np
import os
from decord import VideoReader, cpu
import torch
import asyncio
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
import shortuuid
from vllm.utils import random_uuid
import gradio as gr
from collections import deque
from queue import Empty
import cv2
import json
from web_demo.wakeup_and_vad.wakeup_and_vad import WakeupAndVAD
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.tts.v20190823 import tts_client, models
IMAGE_TOKEN_INDEX = 51000
AUDIO_TOKEN_INDEX = 51001
IMAGE_TOKEN = "<image>"
AUDIO_TOKEN = "<audio>"
VIDEO_TOKEN = "<video>"
httpProfile = HttpProfile()
httpProfile.endpoint = "tts.tencentcloudapi.com"
cred = credential.Credential("", "")
clientProfile = ClientProfile()
clientProfile.httpProfile = httpProfile
client = tts_client.TtsClient(cred, "ap-shanghai", clientProfile)
req = models.TextToVoiceRequest()
def clear_queue(queue):
while not queue.empty():
try:
queue.get_nowait()
except Empty:
break
# The following code is used to run an async task in a synchronous way
def run_async_task(task):
loop = asyncio.get_event_loop()
if loop.is_running():
# If the event loop is already running, run the task in the current loop
return loop.run_until_complete(task)
else:
# Else, create a new loop and run the task in it
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(task)
finally:
loop.close()
# This is a function to tokenize the prompt with image and audio tokens
def tokenizer_image_audio_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, audio_token_index=AUDIO_TOKEN_INDEX, return_tensors=None):
prompt_chunks = []
for chunk in re.split(r'(<audio>|<image>)', prompt):
if chunk == '<audio>':
prompt_chunks.append([audio_token_index])
elif chunk == '<image>':
prompt_chunks.append([image_token_index])
else:
prompt_chunks.append(tokenizer(chunk).input_ids)
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in prompt_chunks:
if x != [image_token_index] and x != [audio_token_index]:
input_ids.extend(x[offset:])
else:
input_ids.extend(x[:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.LongTensor(input_ids)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def load_model(
llm_id,
engine_args,
cuda_devices,
inputs_queue,
outputs_queue,
tts_outputs_queue,
stop_event,
other_stop_event,
worker_ready,
wait_workers_ready,
start_event,
other_start_event,
start_event_lock,
interrupt_signal,
global_history,
global_history_limit=0,
):
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_devices
multiprocessing.set_start_method('spawn', force=True)
llm = AsyncLLMEngine.from_engine_args(engine_args)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path, subfolder="feature_extractor", trust_remote_code=True)
sampling_params = SamplingParams(temperature=0.001, max_tokens=512, best_of=1, skip_special_tokens=False)
def _process_inputs(inputs):
def _process_image(image_path):
if isinstance(image_path, str):
assert os.path.exists(image_path), f"Image file {image_path} does not exist."
return Image.open(image_path).convert("RGB").transpose(Image.FLIP_LEFT_RIGHT)
else:
assert isinstance(image_path, np.ndarray), "Image must be either a file path or a numpy array."
return Image.fromarray(image_path).convert("RGB").transpose(Image.FLIP_LEFT_RIGHT)
def _process_audio(audio_path):
assert os.path.exists(audio_path), f"Audio file {audio_path} does not exist."
audio, sr = torchaudio.load(audio_path)
audio_features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt")["input_features"]
audio_features = audio_features.squeeze(0)
return audio_features
def _process_video(video_path, max_frames=4, min_frames=4, s=None, e=None):
# speed up video decode via decord.
if s is None or e is None:
start_time, end_time = None, None
else:
start_time = int(s)
end_time = int(e)
start_time = max(start_time, 0)
end_time = max(end_time, 0)
if start_time > end_time:
start_time, end_time = end_time, start_time
elif start_time == end_time:
end_time = start_time + 1
if os.path.exists(video_path):
vreader = VideoReader(video_path, ctx=cpu(0))
else:
raise FileNotFoundError(f"Video file {video_path} does not exist.")
fps = vreader.get_avg_fps()
f_start = 0 if start_time is None else int(start_time * fps)
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
num_frames = f_end - f_start + 1
if num_frames > 0:
# T x 3 x H x W
all_pos = list(range(f_start, f_end + 1))
if len(all_pos) > max_frames:
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
elif len(all_pos) < min_frames:
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int)]
else:
sample_pos = all_pos
patch_images = [Image.fromarray(f).transpose(Image.FLIP_LEFT_RIGHT) for f in vreader.get_batch(sample_pos).asnumpy()]
return patch_images
else:
print("video path: {} error.".format(video_path))
if "multi_modal_data" in inputs:
if "image" in inputs["multi_modal_data"]:
image_inputs = inputs["multi_modal_data"]["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
inputs["multi_modal_data"]["image"] = [_process_image(f) for f in image_inputs]
if "prompt" in inputs:
assert inputs["prompt"].count(IMAGE_TOKEN) == len(image_inputs), \
f"Number of image token {IMAGE_TOKEN} in prompt must match the number of image inputs."
elif "prompt_token_ids" in inputs:
assert inputs["prompt_token_ids"].count(IMAGE_TOKEN_INDEX) == len(image_inputs), \
f"Number of image token ids {IMAGE_TOKEN_INDEX} in prompt_token_ids must match the number of image inputs."
else:
raise ValueError("Either 'prompt' or 'prompt_token_ids' must be provided.")
if "audio" in inputs["multi_modal_data"]:
audio_inputs = inputs["multi_modal_data"]["audio"]
if not isinstance(audio_inputs, list):
audio_inputs = [audio_inputs]
inputs["multi_modal_data"]["audio"] = [_process_audio(f) for f in audio_inputs]
if "prompt" in inputs:
assert inputs["prompt"].count(AUDIO_TOKEN) == len(inputs["multi_modal_data"]["audio"]), \
f"Number of audio token {AUDIO_TOKEN} in prompt must match the number of audio inputs."
elif "prompt_token_ids" in inputs:
assert inputs["prompt_token_ids"].count(AUDIO_TOKEN_INDEX) == len(inputs["multi_modal_data"]["audio"]), \
f"Number of audio token ids {AUDIO_TOKEN_INDEX} in prompt_token_ids must match the number of audio inputs."
else:
raise ValueError("Either 'prompt' or 'prompt_token_ids' must be provided.")
if "video" in inputs["multi_modal_data"]:
video_inputs = inputs["multi_modal_data"]["video"]
if not isinstance(video_inputs, list):
video_inputs = [video_inputs]
assert "prompt" in inputs, "Prompt must be provided when video inputs are provided."
assert "image" not in inputs["multi_modal_data"], "Image inputs are not supported when video inputs are provided."
assert inputs["prompt"].count(VIDEO_TOKEN) == 1, "Currently only one video token is supported in prompt."
assert inputs["prompt"].count(VIDEO_TOKEN) == len(inputs["multi_modal_data"]["video"]), \
f"Number of video token {VIDEO_TOKEN} in prompt must match the number of video inputs."
video_frames_inputs = []
for video_input in video_inputs:
video_frames_inputs.extend(_process_video(video_input, max_frames=4, min_frames=4))
inputs["prompt"] = inputs["prompt"].replace(VIDEO_TOKEN, IMAGE_TOKEN * len(video_frames_inputs))
if "image" not in inputs["multi_modal_data"]:
inputs["multi_modal_data"]["image"] = []
inputs["multi_modal_data"]["image"].extend(video_frames_inputs)
inputs["multi_modal_data"].pop("video", None)
return inputs
def judge_negative(text):
is_negative = text.startswith('<2>')
return is_negative
async def stream_results(results_generator) -> AsyncGenerator[bytes, None]:
previous_text = ""
async for request_output in results_generator:
text = request_output.outputs[0].text
newly_generated_text = text[len(previous_text):]
previous_text = text
yield newly_generated_text
async def collect_results_demo(results_generator):
async for newly_generated_text in stream_results(results_generator):
continue
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
worker_ready.set()
if not isinstance(wait_workers_ready, list):
wait_workers_ready = [wait_workers_ready]
while True:
# Wait for all workers to be ready
if not all([worker.is_set() for worker in wait_workers_ready]):
time.sleep(0.1)
continue
if not inputs_queue.empty():
with start_event_lock:
if start_event.is_set():
inputs = inputs_queue.get()
other_start_event.set()
start_event.clear()
else:
continue
inputs = _process_inputs(inputs)
current_inputs = inputs.copy()
inputs = merge_current_and_history(
global_history[-global_history_limit:],
inputs,
skip_history_vision=True,
move_image_token_to_start=True
)
print(f"Process {cuda_devices} is processing inputs: {inputs}")
if "prompt" in inputs:
# Process multimodal tokens
inputs["prompt_token_ids"] = tokenizer_image_audio_token(inputs["prompt"], tokenizer, image_token_index=IMAGE_TOKEN_INDEX, audio_token_index=AUDIO_TOKEN_INDEX)
else:
assert "prompt_token_ids" in inputs, "Either 'prompt' or 'prompt_token_ids' must be provided."
inputs.pop("prompt", None)
# print(f"Process {cuda_devices} is about to generate results, prompt: {current_inputs['prompt']}, prompt_token_ids: {inputs['prompt_token_ids']}")
results_generator = llm.generate(
inputs,
sampling_params=sampling_params,
request_id=random_uuid(),
)
async def stream_results(results_generator) -> AsyncGenerator[bytes, None]:
previous_text = ""
async for request_output in results_generator:
text = request_output.outputs[0].text
newly_generated_text = text[len(previous_text):]
previous_text = text
yield newly_generated_text
async def collect_results(results_generator):
results = []
is_first_time_to_work = True
history_generated_text = ''
async for newly_generated_text in stream_results(results_generator):
# if newly_generated_text.strip() == "":
# continue
# newly_generated_text = newly_generated_text.strip()
is_negative = judge_negative(newly_generated_text)
if not is_negative:
history_generated_text += newly_generated_text
if is_first_time_to_work:
print(f"Process {cuda_devices} is about to interrupt other process")
stop_event.clear()
other_stop_event.set()
clear_queue(outputs_queue)
clear_queue(tts_outputs_queue)
is_first_time_to_work = False
interrupt_signal.value = llm_id
if not stop_event.is_set():
results.append(newly_generated_text)
history_generated_text = history_generated_text.replace('<1> ', '').replace('<1>', '')
# print('newly_generated_text',newly_generated_text)
if newly_generated_text in [",", ",", ".", "。", "?", "\n", "?", "!", "!", "、"]:
# print('history_generated_text:',history_generated_text)
outputs_queue.put({"id": llm_id, "response": history_generated_text})
history_generated_text = ''
else:
print(f"Process {cuda_devices} is interrupted.")
break
else:
print(f"Process {cuda_devices} is generating negative text.")
break
current_inputs["response"] = "".join(results)
if not current_inputs["response"] == "":
global_history.append(current_inputs)
return results
results = loop.run_until_complete(collect_results(results_generator))
print(f"Process {cuda_devices} has generated results: {''.join(results)}")
def tts_tranform_text(text):
print(text)
params = {
"Text": text,
"SessionId": "session-1234",
"Volume": 1,
"Speed": 0,
"ProjectId": 0,
"ModelType": 1,
"VoiceType": 301009,
"PrimaryLanguage": 1,
"SampleRate": 16000,
"Codec": "wav",
"EnableSubtitle": True
}
req.from_json_string(json.dumps(params))
resp = client.TextToVoice(req)
aaa=json.loads(resp.to_json_string())
base64_audio_data = aaa['Audio']
audio_data = base64.b64decode(base64_audio_data)
wav_file = "tmp_audio/"
if not os.path.exists(wav_file):
os.makedirs(wav_file)
tmp_saved_wav_file = wav_file + str(301009) + "_" + str(shortuuid.uuid()) + ".wav"
with open(tmp_saved_wav_file, "wb") as audio_file:
audio_file.write(audio_data)
return tmp_saved_wav_file
def tts_worker(
inputs_queue,
outputs_queue,
worker_ready,
wait_workers_ready,
):
def audio_file_to_html(audio_file: str) -> str:
"""
Convert audio file to HTML audio player.
Args:
audio_file: Path to audio file
Returns:
audio_player: HTML audio player that auto-plays
"""
# Read in audio file to audio_bytes
audio_bytes = io.BytesIO()
with open(audio_file, "rb") as f:
audio_bytes.write(f.read())
# Generate audio player HTML object for autoplay
audio_bytes.seek(0)
audio = base64.b64encode(audio_bytes.read()).decode("utf-8")
audio_player = (
f'<audio src="data:audio/mpeg;base64,{audio}" controls autoplay></audio>'
)
return audio_player
def remove_uncommon_punctuation(text):
common_punctuation = ".,!?;:()[],。!?、:;() "
uncommon_punctuation_pattern = rf"[^\w\s{re.escape(common_punctuation)}]"
cleaned_text = re.sub(uncommon_punctuation_pattern, "", text)
return cleaned_text
def remove_special_tokens(input_str):
# Remove special tokens
special_tokens = ['<1>', '<2>', '<3>', '<unk>', '</s>']
for token in special_tokens:
input_str = input_str.replace(token, '')
return input_str
def replace_equation(sentence):
special_notations = {
"sin": " sine ",
"cos": " cosine ",
"tan": " tangent ",
"cot": " cotangent ",
"sec": " secant ",
"csc": " cosecant ",
"log": " logarithm ",
"exp": "e^",
"sqrt": "根号 ",
"abs": "绝对值 ",
}
special_operators = {
"+": "加",
"-": "减",
"*": "乘",
"/": "除",
"=": "等于",
'!=': '不等于',
'>': '大于',
'<': '小于',
'>=': '大于等于',
'<=': '小于等于',
}
greek_letters = {
"α": "alpha ",
"β": "beta ",
"γ": "gamma ",
"δ": "delta ",
"ε": "epsilon ",
"ζ": "zeta ",
"η": "eta ",
"θ": "theta ",
"ι": "iota ",
"κ": "kappa ",
"λ": "lambda ",
"μ": "mu ",
"ν": "nu ",
"ξ": "xi ",
"ο": "omicron ",
"π": "派 ",
"ρ": "rho ",
"σ": "sigma ",
"τ": "tau ",
"υ": "upsilon ",
"φ": "phi ",
"χ": "chi ",
"ψ": "psi ",
"ω": "omega "
}
sentence = sentence.replace('**', ' ')
sentence = re.sub(r'(?<![\d)])-(\d+)', r'负\1', sentence)
for key in special_notations:
sentence = sentence.replace(key, special_notations[key])
for key in special_operators:
sentence = sentence.replace(key, special_operators[key])
for key in greek_letters:
sentence = sentence.replace(key, greek_letters[key])
sentence = re.sub(r'\(?(\d+)\)?\((\d+)\)', r'\1乘\2', sentence)
sentence = re.sub(r'\(?(\w+)\)?\^\(?(\w+)\)?', r'\1的\2次方', sentence)
return sentence
worker_ready.set()
if not isinstance(wait_workers_ready, list):
wait_workers_ready = [wait_workers_ready]
past_llm_id = 0
while True:
# Wait for all workers to be ready
if not all([worker.is_set() for worker in wait_workers_ready]):
time.sleep(0.1)
continue
tts_input_text = ""
while not inputs_queue.empty():
time.sleep(0.03)
stop_at_punc_or_len = False
response = inputs_queue.get()
llm_id, newly_generated_text = response["id"], response["response"]
for character in newly_generated_text:
if past_llm_id != 0 and past_llm_id != llm_id:
# print(f"Past llm id {past_llm_id} is not equal to current llm id {llm_id}, resetting tts input text and putting pause signal")
tts_input_text = ""
tts_output_queue.put(
{
"id": llm_id,
"response": ("|PAUSE|", None, 0.2)
}
)
tts_input_text += character
past_llm_id = llm_id
# print('tts_input_text',tts_input_text)
if character in [",", ",", ".", "。", "?", "\n", "?", "!", "!", "、"] and len(tts_input_text) >= 5:
stop_at_punc_or_len = True
break
if stop_at_punc_or_len:
break
if tts_input_text.strip() == "":
continue
tts_input_text = remove_special_tokens(tts_input_text)
tts_input_text = replace_equation(tts_input_text)
tts_input_text = tts_input_text.lower()
# print(f"Start to generate audio for: {tts_input_text}, llm id {llm_id}")
if tts_input_text.strip() == "":
continue
audio_file = tts_tranform_text(tts_input_text)
html = audio_file_to_html(audio_file)
audio_duration = torchaudio.info(audio_file).num_frames / 24000
if past_llm_id == 0 or past_llm_id == llm_id:
outputs_queue.put(
{
"id": llm_id,
"response": (tts_input_text, html, audio_duration)
}
)
def merge_current_and_history(
global_history,
current_request,
skip_history_vision=False,
move_image_token_to_start=False
):
system_prompts = {
"video": "system:You are an AI robot and your name is Vita. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the video given by the user, and it is strictly forbidden to answer the question without the content of the video. Please note that you are seeing the video, not the image.</s>\n",
"image": "system:You are an AI robot and your name is Vita. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user. \n- You must answer the question strictly according to the content of the image given by the user, and it is strictly forbidden to answer the question without the content of the image. Please note that you are seeing the image, not the video.</s>\n",
"audio": "system:You are an AI robot and your name is Vita. \n- You are a multimodal large language model developed by the open source community. Your aim is to be helpful, honest and harmless. \n- You support the ability to communicate fluently and answer user questions in multiple languages of the user's choice. \n- If the user corrects the wrong answer you generated, you will apologize and discuss the correct answer with the user.</s>\n"
}
def select_system_prompt(current_request):
if "multi_modal_data" in current_request:
if "video" in current_request["multi_modal_data"]:
return system_prompts["video"]
elif "image" in current_request["multi_modal_data"]:
return system_prompts["video"]
elif "audio" in current_request["multi_modal_data"]:
return system_prompts["audio"]
return system_prompts["audio"]
system_prompt = select_system_prompt(current_request)
print('current request:',current_request)
user_prefix = "user:"
bot_prefix = "bot:"
eos = "</s>\n"
if len(global_history) == 0:
current_request["prompt"] = (system_prompt + user_prefix + current_request["prompt"] + eos + bot_prefix).replace('<1> ','<1>').replace('<2> ','<2>')
return current_request
# Initialize the current prompt and multimodal data
current_prompt = system_prompt
current_multi_modal_data = {"image": [], "audio": [], "video": []}
# Add the history to the current prompt
for history in global_history:
assert "prompt" in history, "Prompt must be provided in history."
assert "response" in history, "Response must be provided in history."
if skip_history_vision:
history_prompt = history["prompt"].replace(IMAGE_TOKEN, "").replace(VIDEO_TOKEN, "")
else:
history_prompt = history["prompt"]
history_prompt = user_prefix + history_prompt + eos + bot_prefix + history["response"] + eos
for modality in ["image", "audio", "video"]:
if skip_history_vision and modality in ["image", "video"]:
continue
if "multi_modal_data" in history and modality in history["multi_modal_data"]:
current_multi_modal_data[modality].extend(history["multi_modal_data"][modality])
current_prompt += history_prompt
# Add the current request to the current prompt
current_prompt += user_prefix + current_request["prompt"] + eos + bot_prefix
for modality in ["image", "audio", "video"]:
if "multi_modal_data" in current_request and modality in current_request["multi_modal_data"]:
current_multi_modal_data[modality].extend(current_request["multi_modal_data"][modality])
for modality in ["image", "audio", "video"]:
if current_multi_modal_data[modality] == []:
current_multi_modal_data.pop(modality, None)
if move_image_token_to_start:
num_image_tokens = current_prompt.count(IMAGE_TOKEN)
current_prompt = current_prompt.replace(IMAGE_TOKEN, "")
current_prompt = current_prompt.replace(system_prompt, "")
current_prompt = system_prompt + user_prefix + IMAGE_TOKEN * num_image_tokens + current_prompt.lstrip(user_prefix)
current_request["prompt"] = current_prompt.replace('<1> ','<1>').replace('<2> ','<2>')
current_request["multi_modal_data"] = current_multi_modal_data
return current_request
def launch_demo(
request_inputs_queue,
tts_output_queue,
worker_ready,
wait_workers_ready,
global_history,
interrupt_signal,
):
vad_path = "web_demo/wakeup_and_vad/resource"
vad_model = WakeupAndVAD(vad_path, cache_history=10)
collected_images = deque(maxlen=8)
collecting_images = False
collected_audio = torch.tensor([])
collecting_audio = False
start_time = time.time()
last_time_to_collect_image = start_time
last_time_to_collect_audio = start_time
last_output_id = 0
def save_video(images, video_filename):
copy_images = list(images)
if len(copy_images) == 0:
return
height, width, layers = copy_images[0].shape
size = (width, height)
out = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*'mp4v'), 20, size)
for image in copy_images:
out.write(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
out.release()
def process_image(image):
nonlocal last_time_to_collect_image
current_time_to_collect_image = time.time()
if current_time_to_collect_image - last_time_to_collect_image > 1:
collected_images.clear()
print("Clearing the collected images")
collected_images.append(image)
last_time_to_collect_image = current_time_to_collect_image
def reset_state():
nonlocal collected_images, collected_audio
print("Resetting the state")
while len(global_history) > 0:
global_history.pop()
collected_audio = torch.tensor([])
collected_images.clear()
def text_streamer():
nonlocal last_output_id
if tts_output_queue.empty():
yield None, None
while not tts_output_queue.empty():
try:
output = tts_output_queue.get_nowait()
llm_id = output["id"]
temp_output, audio, length = output["response"]
if llm_id != interrupt_signal.value:
print(f"Received output from other process {llm_id}, skipping...")
continue
# print(f"Received audio output {temp_output}")
if last_output_id != 0 and last_output_id != llm_id:
print(f"Received pause signal, pausing for 0.2s")
time.sleep(0.2)
last_output_id = llm_id
yield None, audio
time.sleep(length * 1.5 + 0.02)
except Empty:
print(f"The queue is empty, text output {temp_output}")
yield None, None
yield None, None
def add_audio(
audio,
answer_ready,
):
nonlocal collected_audio, collecting_audio
nonlocal last_time_to_collect_audio
current_time_to_collect_audio = time.time()
if current_time_to_collect_audio - last_time_to_collect_audio > 1:
collected_audio = torch.tensor([])
print("Clearing the collected audio")
last_time_to_collect_audio= current_time_to_collect_audio
target_sample_rate = 16000
# Load the audio file
waveform, sr = torchaudio.load(audio)
# Resample the audio if necessary
if sr != target_sample_rate:
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)(waveform)
chunk_size = vad_model.get_chunk_size()
# Add the audio to the FIFO tensor
if collected_audio.numel() == 0:
collected_audio = waveform
else:
collected_audio = torch.cat((collected_audio, waveform), dim=1)
while collected_audio.shape[1] >= chunk_size:
# Get the chunk of data
data = collected_audio[:, :chunk_size]
# Process the chunk
res = vad_model.predict(data)
# Remove the processed chunk from the FIFO tensor
collected_audio = collected_audio[:, chunk_size:]
if res is not None:
if "start" in res:
print("Start of dialog: %f" % res["start"])
# collecting_images = True
if "cache_dialog" in res:
print('res', res)
directory = './chat_history'
if not os.path.exists(directory):
os.makedirs(directory)
audio_duration = len(res["cache_dialog"]) / target_sample_rate
if audio_duration < 1.5:
print("The duration of the audio is less than 1.5s, skipping...")
continue
current_time = datetime.datetime.now()
# Format the time to create a unique filename
timestamp = current_time.strftime("%Y%m%d_%H%M%S")
audio_filename = f"{directory}/test_dialog_{timestamp}.wav"
torchaudio.save(audio_filename, res["cache_dialog"].unsqueeze(0), target_sample_rate)
if len(collected_images) > 0:
video_filename = f"{directory}/test_video_{timestamp}.mp4"
save_video(collected_images, video_filename)
else:
video_filename = ""
print("Start to generate response")
if video_filename:
current_request = {
"prompt": "<video><audio>",
"multi_modal_data": {
"video": [video_filename],
"audio": [audio_filename],
},
}
else:
current_request = {
"prompt": "<audio>",
"multi_modal_data": {
"audio": [audio_filename],
},
}
print(f"Start to put request into queue {current_request}")
request_inputs_queue.put(current_request)
if not tts_output_queue.empty():
answer_ready = 1 - answer_ready
return answer_ready
with gr.Blocks(title="VITA") as demo:
gr.Markdown("""<center><font size=8> VITA </center>""")
with gr.Row():
with gr.Column():
webcam = gr.Image(sources="webcam", type="numpy", streaming=True, label="📹 Video Recording (视频录制)",scale=2)
with gr.Column():
audio_stream = gr.Audio(sources=["microphone"], type='filepath', streaming=True, label="🎤 Record Audio (录音)",scale=0.5)
answer_ready = gr.State(value=0)
reset_context = gr.Button("🧹 Clear History (清除历史)")
html = gr.HTML(visible=True)
audio_stream.change(add_audio, [audio_stream, answer_ready], [answer_ready], show_progress=True)
answer_ready.change(fn=text_streamer, inputs=[], outputs=[html])
reset_context.click(fn=reset_state, inputs=[], outputs=[])
webcam.stream(fn=process_image, inputs=webcam, outputs=[])
while not all([worker.is_set() for worker in wait_workers_ready]):
time.sleep(0.1)
gradio_worker_ready.set()
demo.launch(
share=False,
debug=True,
server_name="0.0.0.0",
server_port=18806,
show_api=False,
show_error=False,
auth=("123", "123")
)
if __name__ == "__main__":
manager = multiprocessing.Manager()
request_inputs_queue = manager.Queue()
tts_inputs_queue = manager.Queue()
tts_output_queue = manager.Queue()
worker_1_stop_event = manager.Event()
worker_2_stop_event = manager.Event()
worker_1_start_event = manager.Event()
worker_2_start_event = manager.Event()
worker_1_start_event.set()
worker_1_2_start_event_lock = manager.Lock()
llm_worker_1_ready = manager.Event()
llm_worker_2_ready = manager.Event()
tts_worker_ready = manager.Event()
gradio_worker_ready = manager.Event()
interrupt_signal = manager.Value("i", 0)
model_path = "demo_VITA_ckpt/"
global_history = manager.list()
global_history_limit = 1
# Engine arguments for vLLM
engine_args = AsyncEngineArgs(
model=model_path,
dtype="float16",
tensor_parallel_size=2,
trust_remote_code=True,
gpu_memory_utilization=0.8,
disable_custom_all_reduce=True,
limit_mm_per_prompt={"image": 256, "audio":50},
)
model_1_process = multiprocessing.Process(
target=load_model,
kwargs={
"llm_id": 1,
"engine_args": engine_args,
"cuda_devices": "0,1",
"inputs_queue": request_inputs_queue,
"outputs_queue": tts_inputs_queue,
"tts_outputs_queue": tts_output_queue,
"start_event": worker_1_start_event,
"other_start_event": worker_2_start_event,
"start_event_lock": worker_1_2_start_event_lock,
"stop_event": worker_1_stop_event,
"other_stop_event": worker_2_stop_event,
"worker_ready": llm_worker_1_ready,
"wait_workers_ready": [llm_worker_2_ready, tts_worker_ready],
"global_history": global_history,
"global_history_limit": global_history_limit,
"interrupt_signal": interrupt_signal,
}
)
model_2_process = multiprocessing.Process(
target=load_model,
kwargs={
"llm_id": 2,
"engine_args": engine_args,
"cuda_devices": "2,3",
"inputs_queue": request_inputs_queue,
"outputs_queue": tts_inputs_queue,
"tts_outputs_queue": tts_output_queue,
"start_event": worker_2_start_event,
"other_start_event": worker_1_start_event,
"start_event_lock": worker_1_2_start_event_lock,
"stop_event": worker_2_stop_event,
"other_stop_event": worker_1_stop_event,
"worker_ready": llm_worker_2_ready,
"wait_workers_ready": [llm_worker_1_ready, tts_worker_ready],
"global_history": global_history,
"global_history_limit": global_history_limit,
"interrupt_signal": interrupt_signal,
}
)
tts_worker_process = multiprocessing.Process(
target=tts_worker,
kwargs={
"inputs_queue": tts_inputs_queue,
"outputs_queue": tts_output_queue,
"worker_ready": tts_worker_ready,
"wait_workers_ready": [llm_worker_1_ready, llm_worker_2_ready],
}
)
gradio_demo_process = multiprocessing.Process(
target=launch_demo,
kwargs={
"request_inputs_queue": request_inputs_queue,
"tts_output_queue": tts_output_queue,
"worker_ready": gradio_worker_ready,
"wait_workers_ready": [llm_worker_1_ready, llm_worker_2_ready, tts_worker_ready],
"global_history": global_history,
"interrupt_signal": interrupt_signal,
}
)
model_1_process.start()
model_2_process.start()
tts_worker_process.start()
gradio_demo_process.start()
model_1_process.join()
model_2_process.join()
tts_worker_process.join()
gradio_demo_process.join()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment