Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
......@@ -314,7 +314,6 @@ class ArcticAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias
import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers.models.audioflamingo3 import (
AudioFlamingo3Config,
AudioFlamingo3Processor,
)
from transformers.models.qwen2_audio import Qwen2AudioEncoder
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
DictEmbeddingItems,
ModalityData,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
init_vllm_registered_model,
maybe_prefix,
)
MAX_AUDIO_LEN = 10 * 60
# === Audio Inputs === #
class AudioFlamingo3FeatureInputs(TensorSchema):
"""
Dimensions:
- num_chunks: Number of audio chunks (flattened)
- nmb: Number of mel bins
- num_audios: Number of original audio files
"""
type: Literal["audio_features"]
input_features: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("num_chunks", "nmb", 3000),
]
feature_attention_mask: Annotated[
torch.Tensor,
TensorShape("num_chunks", 3000),
]
chunk_counts: Annotated[
torch.Tensor,
TensorShape("num_audios"),
]
class AudioFlamingo3EmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size
- naf: Number of audio features
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["audio_embeds"] = "audio_embeds"
audio_embeds: Annotated[
list[torch.Tensor],
TensorShape("bn", "naf", "hs"),
]
AudioFlamingo3Inputs: TypeAlias = (
AudioFlamingo3FeatureInputs | AudioFlamingo3EmbeddingInputs
)
class AudioFlamingo3Encoder(Qwen2AudioEncoder):
def __init__(
self,
config: PretrainedConfig,
):
super().__init__(config)
self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
# self.layer_norm is already initialized in super().__init__
def forward(
self,
input_features: torch.Tensor | list[torch.Tensor],
attention_mask: torch.Tensor = None,
):
# input_features: (batch, num_mel_bins, seq_len)
if isinstance(input_features, list):
input_features = torch.stack(input_features)
hidden_states = nn.functional.gelu(self.conv1(input_features))
hidden_states = nn.functional.gelu(self.conv2(hidden_states))
hidden_states = hidden_states.transpose(-1, -2)
hidden_states = (
hidden_states + self.embed_positions.weight[: hidden_states.size(-2), :]
).to(hidden_states.dtype)
for layer in self.layers:
layer_outputs = layer(hidden_states, attention_mask)
hidden_states = layer_outputs[0]
# AvgPool (time/2) + LayerNorm
# hidden_states: (batch, seq_len, hidden_size)
hidden_states = hidden_states.permute(0, 2, 1) # (batch, hidden_size, seq_len)
hidden_states = self.avg_pooler(hidden_states)
hidden_states = hidden_states.permute(
0, 2, 1
) # (batch, seq_len/2, hidden_size)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
def _get_feat_extract_output_lengths(self, input_lengths: torch.Tensor):
"""
Computes the output length of the convolutional layers and the output length
of the audio encoder
"""
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
return input_lengths, output_lengths
class AudioFlamingo3MultiModalProjector(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.linear_1 = nn.Linear(
config.audio_config.hidden_size,
config.text_config.hidden_size,
bias=config.projector_bias,
)
self.act = get_act_fn(config.projector_hidden_act)
self.linear_2 = nn.Linear(
config.text_config.hidden_size,
config.text_config.hidden_size,
bias=config.projector_bias,
)
def forward(self, audio_features):
hidden_states = self.linear_1(audio_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class AudioFlamingo3ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(AudioFlamingo3Config)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs)
def get_feature_extractor(self, **kwargs: object):
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor
return feature_extractor
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
class AudioFlamingo3DummyInputsBuilder(
BaseDummyInputsBuilder[AudioFlamingo3ProcessingInfo]
):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
hf_processor = self.info.get_hf_processor()
audio_token = hf_processor.audio_token
return audio_token * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = MAX_AUDIO_LEN * sampling_rate
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio": self._get_dummy_audios(
length=audio_len,
num_audios=num_audios,
overrides=audio_overrides,
)
}
def _audioflamingo3_field_config(hf_inputs: Mapping[str, torch.Tensor]):
chunk_counts = hf_inputs.get("chunk_counts")
if chunk_counts is not None:
return dict(
audio_embeds=MultiModalFieldConfig.batched("audio"),
input_features=MultiModalFieldConfig.flat_from_sizes(
"audio", chunk_counts, dim=0
),
feature_attention_mask=MultiModalFieldConfig.flat_from_sizes(
"audio", chunk_counts, dim=0
),
chunk_counts=MultiModalFieldConfig.batched("audio"),
)
return dict(
audio_embeds=MultiModalFieldConfig.batched("audio"),
input_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
chunk_counts=MultiModalFieldConfig.batched("audio"),
)
class AudioFlamingo3MultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[Any],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_audioflamingo3_field_config,
)
return super()._parse_audio_data(data)
class AudioFlamingo3MultiModalProcessor(
BaseMultiModalProcessor[AudioFlamingo3ProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return AudioFlamingo3MultiModalDataParser(
target_sr=feature_extractor.sampling_rate
)
def _call_hf_processor(
self,
prompt: str,
mm_data: dict[str, object],
mm_kwargs: Mapping[str, Any],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
audios = mm_data.pop("audios", [])
if audios:
mm_data["audio"] = audios
if not mm_data.get("audio", []):
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
# Calculate chunk counts
audio_list = mm_data.get("audio")
if not isinstance(audio_list, list):
audio_list = [audio_list]
chunk_counts = []
sampling_rate = feature_extractor.sampling_rate
chunk_length = feature_extractor.chunk_length
window_size = int(sampling_rate * chunk_length)
# MAX_AUDIO_LEN is 10 * 60 in HF processor.
max_windows = int(MAX_AUDIO_LEN // chunk_length)
for audio in audio_list:
# audio is numpy array or list
n_samples = len(audio) if isinstance(audio, list) else audio.shape[0]
n_win = max(1, (n_samples + window_size - 1) // window_size)
if n_win > max_windows:
n_win = max_windows
chunk_counts.append(n_win)
outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
if "input_features_mask" in outputs:
outputs["feature_attention_mask"] = outputs.pop("input_features_mask")
outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long)
return outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _audioflamingo3_field_config(hf_inputs)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
audio_token = getattr(processor, "audio_token", "<sound>")
audio_token_id = vocab.get(audio_token)
if audio_token_id is None:
# Fallback if not found, though it should be there
audio_token_id = processor.audio_token_id
out_mm_data = out_mm_kwargs.get_data()
feature_attention_mask = out_mm_data.get("feature_attention_mask")
chunk_counts = out_mm_data.get("chunk_counts")
def get_replacement_audioflamingo3(item_idx: int):
if feature_attention_mask is not None:
if chunk_counts is not None:
counts = (
chunk_counts.tolist()
if isinstance(chunk_counts, torch.Tensor)
else chunk_counts
)
start_idx = sum(counts[:item_idx])
count = counts[item_idx]
end_idx = start_idx + count
if isinstance(feature_attention_mask, list):
mask_list = feature_attention_mask[start_idx:end_idx]
if len(mask_list) > 0 and isinstance(
mask_list[0], torch.Tensor
):
mask = torch.stack(mask_list)
else:
mask = torch.tensor(mask_list)
else:
mask = feature_attention_mask[start_idx:end_idx]
else:
# feature_attention_mask is list[Tensor] or Tensor
if isinstance(feature_attention_mask, list):
mask = feature_attention_mask[item_idx]
else:
mask = feature_attention_mask[item_idx].unsqueeze(0)
# mask shape: (num_chunks, 3000)
input_lengths = mask.sum(-1)
conv_lengths = (input_lengths - 1) // 2 + 1
audio_output_lengths = (conv_lengths - 2) // 2 + 1
num_features = audio_output_lengths.sum().item()
else:
audio_embeds = out_mm_data["audio_embeds"][item_idx]
num_features = audio_embeds.shape[0]
if num_features == 0:
raise ValueError("Audio is too short")
audio_tokens = [audio_token_id] * int(num_features)
return PromptUpdateDetails.select_token_id(
audio_tokens,
embed_token_id=audio_token_id,
)
return [
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_replacement_audioflamingo3,
)
]
@MULTIMODAL_REGISTRY.register_processor(
AudioFlamingo3MultiModalProcessor,
info=AudioFlamingo3ProcessingInfo,
dummy_inputs=AudioFlamingo3DummyInputsBuilder,
)
class AudioFlamingo3ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
):
"""
AudioFlamingo3 model for conditional generation.
This model integrates a Whisper-based audio encoder with a Qwen2 language model.
It supports multi-chunk audio processing.
"""
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model.",
connector="multi_modal_projector.",
tower_model="audio_tower.",
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.audio_tower = AudioFlamingo3Encoder(
config.audio_config,
)
self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)
self.quant_config = quant_config
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> AudioFlamingo3Inputs | None:
input_features = kwargs.pop("input_features", None)
audio_embeds = kwargs.pop("audio_embeds", None)
feature_attention_mask = kwargs.pop("feature_attention_mask", None)
chunk_counts = kwargs.pop("chunk_counts", None)
if input_features is None and audio_embeds is None:
return None
if audio_embeds is not None:
return AudioFlamingo3EmbeddingInputs(
type="audio_embeds", audio_embeds=audio_embeds
)
if input_features is not None:
return AudioFlamingo3FeatureInputs(
type="audio_features",
input_features=input_features,
feature_attention_mask=feature_attention_mask,
chunk_counts=chunk_counts,
)
raise AssertionError("This line should be unreachable.")
def _process_audio_input(
self, audio_input: AudioFlamingo3Inputs
) -> torch.Tensor | tuple[torch.Tensor, ...]:
if audio_input["type"] == "audio_embeds":
audio_embeds = audio_input["audio_embeds"]
return tuple(audio_embeds)
input_features = audio_input["input_features"]
feature_attention_mask = audio_input["feature_attention_mask"]
chunk_counts = audio_input.get("chunk_counts")
if isinstance(input_features, list):
input_features = torch.cat(input_features, dim=0)
feature_attention_mask = torch.cat(feature_attention_mask, dim=0)
if chunk_counts is None:
chunk_counts = [1] * input_features.shape[0]
elif isinstance(chunk_counts, torch.Tensor):
chunk_counts = chunk_counts.tolist()
elif (
isinstance(chunk_counts, list)
and chunk_counts
and isinstance(chunk_counts[0], torch.Tensor)
):
chunk_counts = [c.item() for c in chunk_counts]
# Calculate output lengths
input_lengths = feature_attention_mask.sum(-1)
# Conv downsampling
conv_lengths = (input_lengths - 1) // 2 + 1
# AvgPool downsampling
audio_output_lengths = (conv_lengths - 2) // 2 + 1
batch_size, _, max_mel_seq_len = input_features.shape
# Calculate max_seq_len after convs (before pooling) for attention mask
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (
torch.arange(
0,
max_seq_len,
dtype=conv_lengths.dtype,
device=conv_lengths.device,
)
.unsqueeze(0)
.expand(batch_size, max_seq_len)
)
lengths_expand = conv_lengths.unsqueeze(-1).expand(batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
batch_size, 1, max_seq_len, max_seq_len
)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.audio_tower.conv1.weight.dtype,
device=self.audio_tower.conv1.weight.device,
)
audio_attention_mask[audio_attention_mask_] = float("-inf")
# Forward pass
audio_features = self.audio_tower(
input_features, attention_mask=audio_attention_mask
)
# Project
audio_features = self.multi_modal_projector(audio_features)
# Masking after pooling
num_audios, max_audio_tokens, embed_dim = audio_features.shape
audio_output_lengths = audio_output_lengths.unsqueeze(1)
audio_features_mask = (
torch.arange(max_audio_tokens)
.expand(num_audios, max_audio_tokens)
.to(audio_output_lengths.device)
< audio_output_lengths
)
masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)
# Split to tuple of embeddings for individual audio input.
chunk_embeddings = torch.split(
masked_audio_features, audio_output_lengths.flatten().tolist()
)
grouped_embeddings = []
current_idx = 0
for count in chunk_counts:
audio_chunks = chunk_embeddings[current_idx : current_idx + count]
grouped_embeddings.append(torch.cat(audio_chunks, dim=0))
current_idx += count
return tuple(grouped_embeddings)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return []
masked_audio_features = self._process_audio_input(audio_input)
return masked_audio_features
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
"""Inference-only BAGEL model compatible with HuggingFace weights.
BAGEL is a unified multimodal model for image understanding and generation.
For vLLM, we focus on the image understanding (vision-to-text) capabilities.
"""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, TypeAlias
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.bagel import BagelProcessor
from vllm.utils.tensor_schema import TensorSchema
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
logger = init_logger(__name__)
class BagelImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"]
pixel_values: torch.Tensor # Shape: (bn, 3, h, w)
BagelImageInputs: TypeAlias = BagelImagePixelInputs
class BagelVisionMLP(nn.Module):
"""MLP connector for vision features."""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int,
act_layer: str = "gelu_pytorch_tanh",
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.fc1 = ColumnParallelLinear(
in_features,
hidden_features,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.act = get_act_fn(act_layer)
self.fc2 = RowParallelLinear(
hidden_features,
out_features,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc1(x)
x = self.act(x)
x, _ = self.fc2(x)
return x
class PositionEmbedding(nn.Module):
"""2D position embedding for vision tokens using sin-cos embeddings."""
def __init__(self, max_num_patch_per_side: int, hidden_size: int):
super().__init__()
self.max_num_patch_per_side = max_num_patch_per_side
self.hidden_size = hidden_size
# Create learnable 2D position embeddings (frozen sin-cos)
pos_embed = self._get_2d_sincos_pos_embed(hidden_size, max_num_patch_per_side)
self.register_buffer(
"pos_embed",
torch.from_numpy(pos_embed).float(),
persistent=False,
)
@staticmethod
def _get_2d_sincos_pos_embed(embed_dim: int, grid_size: int):
"""Generate 2D sin-cos position embeddings."""
import numpy as np
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = PositionEmbedding._get_2d_sincos_pos_embed_from_grid(
embed_dim, grid
)
return pos_embed
@staticmethod
def _get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid):
"""Generate 2D sin-cos position embeddings from grid."""
import numpy as np
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = PositionEmbedding._get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0]
)
emb_w = PositionEmbedding._get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1]
)
emb = np.concatenate([emb_h, emb_w], axis=1)
return emb
@staticmethod
def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos):
"""Generate 1D sin-cos position embeddings."""
import numpy as np
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega
pos = pos.reshape(-1)
out = np.einsum("m,d->md", pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
def forward(self, position_ids: torch.Tensor) -> torch.Tensor:
"""
Args:
position_ids: Flattened position IDs, shape (N,) where each ID
corresponds to a position in the flattened grid
Returns:
Position embeddings of shape (N, hidden_size)
"""
# Ensure position_ids are on the same device as pos_embed
position_ids = position_ids.to(self.pos_embed.device)
return self.pos_embed[position_ids]
class BagelProcessingInfo(BaseProcessingInfo):
"""Processing information for BAGEL model."""
def get_hf_processor(self, **kwargs: object) -> BagelProcessor:
from vllm.transformers_utils.processor import cached_get_image_processor
image_processor = cached_get_image_processor(
self.ctx.model_config.model,
revision=self.ctx.model_config.revision,
trust_remote_code=self.ctx.model_config.trust_remote_code,
)
tokenizer = self.get_tokenizer()
return BagelProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
**kwargs,
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_config = self.get_hf_config()
# Calculate max tokens per image
# For BAGEL: (vit_max_num_patch_per_side) ** 2
max_num_patches = hf_config.vit_max_num_patch_per_side**2
return {"image": max_num_patches}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
vit_config = hf_config.vit_config
patch_size = vit_config.patch_size
# Calculate number of patches
num_patches_h = image_height // patch_size
num_patches_w = image_width // patch_size
return num_patches_h * num_patches_w
class BagelDummyInputsBuilder(BaseDummyInputsBuilder[BagelProcessingInfo]):
"""Build dummy inputs for BAGEL model profiling."""
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
# Use a simple placeholder for each image
return "<|image_pad|>" * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
hf_config = self.info.get_hf_config()
vit_config = hf_config.vit_config
# Use the configured image size
image_size = vit_config.image_size
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=image_size,
height=image_size,
num_images=num_images,
overrides=image_overrides,
),
}
class BagelMultiModalProcessor(BaseMultiModalProcessor[BagelProcessingInfo]):
"""Multimodal processor for BAGEL model."""
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptReplacement]:
"""Replace image placeholders with the correct number of tokens."""
hf_config = self.info.get_hf_config()
# Get the tokenizer to look up the image token ID
tokenizer = self.info.get_tokenizer()
image_token_id = tokenizer.get_vocab().get("<|image_pad|>")
if image_token_id is None:
raise ValueError(
"Image token '<|image_pad|>' not found in tokenizer vocabulary"
)
def get_replacement_bagel(item_idx: int):
# For BAGEL, calculate number of tokens based on max patch size
num_tokens = hf_config.vit_max_num_patch_per_side**2
# Use the image token ID from tokenizer
return [image_token_id] * num_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_bagel,
)
]
def _get_mm_fields_config(
self,
hf_inputs: Any,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return {
"pixel_values": MultiModalFieldConfig.batched("image"),
}
@MULTIMODAL_REGISTRY.register_processor(
BagelMultiModalProcessor,
info=BagelProcessingInfo,
dummy_inputs=BagelDummyInputsBuilder,
)
class BagelForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
):
"""
BAGEL: A unified multimodal model for image understanding and generation.
For vLLM, we focus on the image understanding (vision-to-text) capabilities.
The image generation part is not supported in vLLM.
"""
# Weight mapping from HF to vLLM
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.": "language_model.",
"vit_model.": "vit_model.",
"connector.": "connector.",
"vit_pos_embed.": "vit_pos_embed.",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
# Ensure we have a BagelConfig (check by name to handle trust_remote_code)
# When trust_remote_code=True, the config comes from transformers_modules
if type(config).__name__ != "BagelConfig":
raise ValueError(
f"Expected BagelConfig, got {type(config).__name__}. "
"Make sure the model config is properly loaded."
)
self.config = config
self.multimodal_config = multimodal_config
# Initialize language model (Qwen2)
# Pass the llm_config from BagelConfig to initialize Qwen2 properly
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.llm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
# Initialize vision model (SigLIP) if visual understanding is enabled
if config.visual_und:
# Fix vit_config: checkpoint has 26 layers (0-25) but config says 27
# Also disable head as it's not in checkpoint
vit_config = config.vit_config
if vit_config.num_hidden_layers == 27:
logger.warning(
"Overriding vit_config.num_hidden_layers from 27 to 26 "
"to match the Bagel model checkpoint."
)
vit_config.num_hidden_layers = 26
if not hasattr(vit_config, "vision_use_head"):
logger.warning(
"Setting vit_config.vision_use_head to False as it is not "
"present in the Bagel model checkpoint."
)
vit_config.vision_use_head = False
self.vit_model = SiglipVisionModel(
config=vit_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vit_model"),
)
# Initialize connector (MLP)
vit_hidden_size = config.vit_config.hidden_size
llm_hidden_size = config.llm_config.hidden_size
self.connector = BagelVisionMLP(
in_features=vit_hidden_size,
hidden_features=llm_hidden_size,
out_features=llm_hidden_size,
act_layer=config.connector_act,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "connector"),
)
# Position embedding for vision tokens
self.vit_pos_embed = PositionEmbedding(
max_num_patch_per_side=config.vit_max_num_patch_per_side,
hidden_size=llm_hidden_size,
)
else:
self.vit_model = None
self.connector = None
self.vit_pos_embed = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> BagelImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return None
return BagelImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
)
def _process_image_input(
self, image_input: BagelImageInputs
) -> tuple[torch.Tensor, ...]:
"""Process image inputs through vision encoder and connector."""
pixel_values = image_input["pixel_values"]
# Handle potential extra batch dimension
# Expected shape: (batch_size * num_images, 3, H, W)
# But might receive: (batch_size, num_images, 3, H, W)
if pixel_values.ndim == 5:
# Flatten batch and num_images dimensions
batch_size, num_images, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_images, channels, height, width
)
# Get vision features from SigLIP
# pixel_values shape: (batch_size * num_images, 3, H, W)
vision_features = self.vit_model(pixel_values)
# Pass through connector
vision_embeds = self.connector(vision_features)
# Add position embeddings
batch_size, num_patches, hidden_size = vision_embeds.shape
patch_size = self.config.vit_config.patch_size
image_size = self.config.vit_config.image_size
# Calculate grid dimensions
num_patches_per_side = image_size // patch_size
# Create flattened position IDs (0 to num_patches-1)
# For BAGEL, we use extrapolate mode by default
h_coords = torch.arange(num_patches_per_side, device=vision_embeds.device)
w_coords = torch.arange(num_patches_per_side, device=vision_embeds.device)
position_ids = (
h_coords[:, None] * self.config.vit_max_num_patch_per_side + w_coords
).flatten()
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1).flatten()
# Add position embeddings
pos_embeds = self.vit_pos_embed(position_ids)
pos_embeds = pos_embeds.reshape(batch_size, num_patches, hidden_size)
# Ensure pos_embeds are on the same device as vision_embeds
pos_embeds = pos_embeds.to(vision_embeds.device)
vision_embeds = vision_embeds + pos_embeds
# Split by image
return tuple(vision_embeds)
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
"""Get multimodal embeddings from input."""
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
return self._process_image_input(image_input)
def get_language_model(self) -> nn.Module:
return self.language_model
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
"""Run forward pass for BAGEL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a batch.
positions: Flattened (concatenated) position ids corresponding to a batch.
intermediate_tensors: Intermediate tensors from prior forward pass.
inputs_embeds: Optional tensor of input embeddings.
"""
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights from checkpoint."""
skip_prefixes = []
# Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module
skip_prefixes.append("vit_pos_embed.pos_embed")
# If visual understanding is disabled, skip vision-related weights
if self.vit_model is None:
skip_prefixes.extend(["vit_model.", "connector.", "vit_pos_embed"])
# Skip generation-related weights since we only support text2text and image2text
# Filter out all image generation components:
# - 'moe_gen': MoE generation weights
# - 'latent_pos_embed': Latent position embeddings for VAE
# - 'llm2vae', 'vae2llm': LLM-VAE projections
# - 'time_embedder': Timestep embeddings for diffusion
# - VAE encoder/decoder: Use specific prefixes to avoid matching vision encoder
generation_keywords = [
"moe_gen",
"latent_pos_embed",
"llm2vae",
"vae2llm",
"time_embedder",
]
vae_prefixes = [
"decoder.",
"encoder.",
] # VAE encoder/decoder, not vision encoder
filtered_weights = []
for name, tensor in weights:
# Skip generation-related keywords
if any(skip in name for skip in generation_keywords):
continue
if any(name.startswith(prefix) for prefix in vae_prefixes):
continue
if "patch_embedding.weight" in name and tensor.ndim == 2:
out_channels = tensor.shape[0]
in_features = tensor.shape[1]
patch_size = self.config.vit_config.patch_size
in_channels = self.config.vit_config.num_channels
if in_features == in_channels * patch_size * patch_size:
tensor = tensor.reshape(
out_channels, patch_size, patch_size, in_channels
)
tensor = tensor.permute(0, 3, 1, 2).contiguous()
filtered_weights.append((name, tensor))
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(filtered_weights, mapper=self.hf_to_vllm_mapper)
......@@ -189,7 +189,6 @@ class BaiChuanAttention(nn.Module):
else:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=rope_parameters,
)
......
......@@ -127,11 +127,11 @@ class BailingAttention(nn.Module):
prefix=f"{prefix}.dense",
)
self.rotary_dim = getattr(config, "rotary_dim", self.head_dim)
rotary_dim = getattr(config, "rotary_dim", self.head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=config.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -178,14 +178,11 @@ class BambaAttentionDecoderLayer(nn.Module):
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
if hasattr(config, "attn_rotary_emb"):
rotary_dim = config.attn_rotary_emb # for backward compatibility
else:
rotary_dim = self.head_dim # default
rotary_dim = getattr(config, "attn_rotary_emb", self.head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -314,7 +314,6 @@ class ChameleonAttention(nn.Module):
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
)
......
......@@ -99,13 +99,16 @@ class GLMAttention(nn.Module):
# https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192)
rope_parameters = {"rope_type": "default", "rope_theta": 10000 * rope_ratio}
rope_parameters = {
"rope_type": "default",
"rope_theta": 10000 * rope_ratio,
"partial_rotary_factor": 0.5,
}
# NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False,
# which is equivalent to is_neox_style=True
is_neox_style = not config.original_rope
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim // 2,
max_position=max_positions,
rope_parameters=rope_parameters,
is_neox_style=is_neox_style,
......
......@@ -175,7 +175,6 @@ class CohereAttention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=False,
......
......@@ -42,9 +42,10 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"rope_parameters": config.rope_parameters,
}
......@@ -77,9 +78,11 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
if not model_config.enforce_eager:
max_position = round_up(max_position, 8)
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": max_position,
"rope_parameters": config.rope_parameters,
}
......@@ -113,12 +116,10 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = int(head_dim * config.rotary_emb_fraction)
max_trained_positions = getattr(config, "max_trained_positions", 2048)
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": rotary_emb_dim,
"max_position": max_trained_positions,
"rope_parameters": config.rope_parameters,
}
......@@ -214,7 +215,7 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
tokens = getattr(config, "classifier_from_token", None)
assert tokens is not None and len(tokens) == 2, (
"Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py"
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/offline_reranker.py"
)
vllm_config.model_config.hf_config.method = "from_2_way_softmax"
......@@ -240,9 +241,10 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"rope_parameters": config.rope_parameters,
}
......@@ -361,7 +363,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
else:
kernel_block_alignment_size = 16
if (
current_platform.is_device_capability(100)
current_platform.is_device_capability_family(100)
and model_config.get_head_size() == 256
and (
attention_config.backend is None
......
......@@ -222,7 +222,6 @@ class DbrxAttention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position,
rope_parameters=rope_parameters,
is_neox_style=True,
......
......@@ -85,6 +85,7 @@ from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata,
)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
from .utils import (
......@@ -158,7 +159,6 @@ class DeepseekAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......@@ -501,7 +501,6 @@ class DeepseekV2Attention(nn.Module):
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=False,
......@@ -620,8 +619,15 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype()
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run
current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
)
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
......@@ -655,17 +661,17 @@ def sparse_attn_indexer(
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks:
k_fp8 = torch.empty(
[chunk.total_seq_lens, head_dim],
device=k.device,
dtype=fp8_dtype,
)
k_scale = torch.empty(
[chunk.total_seq_lens, 4],
device=k.device,
dtype=torch.uint8,
)
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
......@@ -781,15 +787,6 @@ def sparse_attn_indexer_fake(
total_seq_lens: int,
topk_indices_buffer: torch.Tensor | None,
) -> torch.Tensor:
# profile run
# NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage.
_flattened_kv = torch.empty(
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
)
fp8_dtype = current_platform.fp8_dtype()
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
return topk_indices_buffer
......@@ -1020,7 +1017,6 @@ class DeepseekV2MLAAttention(nn.Module):
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=False,
......@@ -1040,7 +1036,6 @@ class DeepseekV2MLAAttention(nn.Module):
if self.is_v32:
self.indexer_rope_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -250,7 +250,6 @@ class Dots1Attention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -5,15 +5,14 @@ from typing import Annotated, Literal, TypeAlias
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
......@@ -30,6 +29,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
......@@ -159,32 +161,6 @@ class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
return processor
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
......@@ -254,11 +230,15 @@ class DotsVisionAttention(nn.Module):
bias: bool = True,
*,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.embed_dim = dim
self.tp_size = (
......@@ -287,31 +267,18 @@ class DotsVisionAttention(nn.Module):
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel,
)
# Select attention backend
self.attn_backend = get_vit_attn_backend(
self.hidden_size_per_attention_head,
torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Unsupported vision attention backend: {self.attn_backend}"
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward(
self,
......@@ -319,7 +286,7 @@ class DotsVisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None = None,
*,
max_seqlen: int | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
# [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1)
......@@ -333,44 +300,20 @@ class DotsVisionAttention(nn.Module):
if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb.cos(),
rotary_pos_emb.sin(),
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
output = self.flash_attn_varlen_func(
q_,
k_,
v_,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
context_layer = output.view(
bs,
-1,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
s = int(cu_seqlens[i - 1])
e = int(cu_seqlens[i])
q_i = q[:, s:e].permute(0, 2, 1, 3)
k_i = k[:, s:e].permute(0, 2, 1, 3)
v_i = v[:, s:e].permute(0, 2, 1, 3)
out_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
out_i = out_i.permute(0, 2, 1, 3)
outputs.append(out_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
else:
raise RuntimeError("Unsupported attention backend")
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
# [B,S,H,D] -> [S,B,H*D] -> [S, C]
context_layer = context_layer.permute(1, 0, 2, 3).contiguous()
......@@ -385,14 +328,19 @@ class DotsSwiGLUFFN(nn.Module):
config,
*,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
hidden_features = config.intermediate_size
in_features = config.embed_dim
bias = config.use_bias
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
# Referenced aimv2.py AIMv2SwiGLUFFN
self.fc13 = MergedColumnParallelLinear(
in_features,
......@@ -498,9 +446,8 @@ class DotsVisionBlock(nn.Module):
config,
*,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
......@@ -510,16 +457,15 @@ class DotsVisionBlock(nn.Module):
num_heads=config.num_attention_heads,
bias=config.use_bias,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
self.mlp = DotsSwiGLUFFN(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
......@@ -546,12 +492,11 @@ class DotsVisionTransformer(nn.Module):
self,
config: DotsVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.config = config
......@@ -561,6 +506,11 @@ class DotsVisionTransformer(nn.Module):
head_dim = config.embed_dim // config.num_attention_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
......@@ -578,9 +528,8 @@ class DotsVisionTransformer(nn.Module):
DotsVisionBlock(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{i}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
for i in range(num_layers)
]
......@@ -592,6 +541,11 @@ class DotsVisionTransformer(nn.Module):
else:
self.post_trunk_norm = None
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.merger = PatchMerger(
dim=config.hidden_size,
context_dim=config.embed_dim,
......@@ -647,7 +601,7 @@ class DotsVisionTransformer(nn.Module):
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen
def forward(
......@@ -733,17 +687,12 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
self.config.vision_config = vision_config
else:
vision_config = self.config.vision_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.vision_tower = DotsVisionTransformer(
vision_config,
quant_config=self.quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config,
......
......@@ -288,7 +288,6 @@ class Ernie4_5_MoeAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=False,
......
......@@ -33,14 +33,14 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops import rearrange
from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
from vllm.attention.layers.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
......@@ -53,6 +53,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
......@@ -69,7 +72,6 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......@@ -89,52 +91,6 @@ logger = init_logger(__name__)
# === Vision Transformer === #
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
apply_rotary_emb = apply_rotary_emb_torch
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
"""All-gather the input tensor interleavely across model parallel group."""
import torch.distributed as dist
......@@ -163,8 +119,8 @@ class Ernie4_5_VisionAttention(nn.Module):
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
......@@ -193,33 +149,18 @@ class Ernie4_5_VisionAttention(nn.Module):
prefix=f"{prefix}.proj",
)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Ernie45-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
......@@ -253,58 +194,32 @@ class Ernie4_5_VisionAttention(nn.Module):
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = self.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb.cos(),
rotary_pos_emb.sin(),
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
output = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(output, "b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
......@@ -350,8 +265,8 @@ class Ernie4_5_VisionBlock(nn.Module):
act_layer: type[nn.Module] = QuickGELU,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
......@@ -366,8 +281,8 @@ class Ernie4_5_VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
attn_backend_override=attn_backend_override,
)
self.mlp = Ernie4_5_VisionMLP(
......@@ -383,7 +298,7 @@ class Ernie4_5_VisionBlock(nn.Module):
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
......@@ -441,8 +356,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
vision_config,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
patch_size = vision_config.patch_size
......@@ -477,8 +392,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
......@@ -489,6 +404,9 @@ class Ernie4_5_VisionTransformer(nn.Module):
)
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
......@@ -535,13 +453,13 @@ class Ernie4_5_VisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen
def forward(
......@@ -1304,17 +1222,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.vision_model = Ernie4_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
attn_backend_override=attn_backend_override,
)
self.language_model = Ernie4_5_VLMoeForCausalLM(
......
......@@ -167,7 +167,6 @@ class ExaoneAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
......
......@@ -176,7 +176,6 @@ class Exaone4Attention(nn.Module):
set_default_rope_theta(config, default_theta=1000000)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
......
......@@ -167,7 +167,6 @@ class FalconAttention(nn.Module):
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -242,14 +242,11 @@ class FalconH1AttentionDecoderLayer(nn.Module):
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
if hasattr(config, "attn_rotary_emb"):
rotary_dim = config.attn_rotary_emb # for backward compatibility
else:
rotary_dim = self.head_dim # default
rotary_dim = getattr(config, "attn_rotary_emb", self.head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment