Unverified Commit 8616300a authored by zhoukz's avatar zhoukz Committed by GitHub
Browse files

[Model][Bugfix] Fix issues in MiDashengLM implementation for quantized models (#25854)


Signed-off-by: zhoukz's avatarzhoukz <me@zhoukz.com>
parent edbaadd9
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only MiDashengLM model compatible with HuggingFace weights.""" """Inference-only MiDashengLM model compatible with HuggingFace weights."""
import collections import collections
import collections.abc import collections.abc
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
...@@ -30,10 +31,10 @@ from typing import Any, Callable, Optional, TypedDict, Union, cast ...@@ -30,10 +31,10 @@ from typing import Any, Callable, Optional, TypedDict, Union, cast
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio.transforms as audio_transforms import torchaudio.functional as F
from torch.nn.functional import scaled_dot_product_attention
from transformers import BatchFeature from transformers import BatchFeature
from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -41,7 +42,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -41,7 +42,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
...@@ -147,15 +147,19 @@ class DashengMlp(nn.Module): ...@@ -147,15 +147,19 @@ class DashengMlp(nn.Module):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
self.fc1 = ColumnParallelLinear(input_size=in_features, self.fc1 = ColumnParallelLinear(
input_size=in_features,
output_size=hidden_features, output_size=hidden_features,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1") prefix=f"{prefix}.fc1",
)
self.act = get_act_fn("gelu") self.act = get_act_fn("gelu")
self.fc2 = RowParallelLinear(input_size=hidden_features, self.fc2 = RowParallelLinear(
input_size=hidden_features,
output_size=out_features, output_size=out_features,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc2") prefix=f"{prefix}.fc2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc1(x) x, _ = self.fc1(x)
...@@ -171,7 +175,6 @@ class DashengAttention(nn.Module): ...@@ -171,7 +175,6 @@ class DashengAttention(nn.Module):
dim: int, dim: int,
num_heads: int = 8, num_heads: int = 8,
qkv_bias: bool = False, qkv_bias: bool = False,
causal: bool = False,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -205,33 +208,30 @@ class DashengAttention(nn.Module): ...@@ -205,33 +208,30 @@ class DashengAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv", prefix=f"{prefix}.qkv",
) )
self.attn = MultiHeadAttention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
)
self.proj = RowParallelLinear( self.proj = RowParallelLinear(
input_size=dim, input_size=dim,
output_size=dim, output_size=dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
) )
self.causal = causal
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
B, N, C = x.shape B, N, C = x.shape
qkv_out, _ = self.qkv(x) qkv, _ = self.qkv(x)
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
dim=-1) qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn_out = self.attn(q, k, v)
C_local = attn_out.numel() // (B * N) # C_local for parallel
attn_out = attn_out.view(B, N, C_local)
x, _ = self.proj(attn_out) x = scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask[:, None, None, :] if mask is not None else None,
)
x = x.transpose(1, 2).reshape(B, N, C)
x, _ = self.proj(x)
return x return x
...@@ -280,6 +280,63 @@ class DashengBlock(nn.Module): ...@@ -280,6 +280,63 @@ class DashengBlock(nn.Module):
return x return x
class DashengFrontend(nn.Module):
def __init__(self, config: DashengConfig):
super().__init__()
self.config = config
spectrogram_window = torch.hann_window(self.config.win_length)
self.register_buffer(
"spectrogram_window",
spectrogram_window,
persistent=False,
)
self.spectrogram_window: torch.Tensor
melscale_fbanks = F.melscale_fbanks(
n_freqs=self.config.n_fft // 2 + 1,
f_min=self.config.f_min,
f_max=self.config.f_max,
n_mels=self.config.n_mels,
sample_rate=self.config.sample_rate,
)
self.register_buffer("melscale_fbanks",
melscale_fbanks,
persistent=False)
self.melscale_fbanks: torch.Tensor
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
spectrogram = F.spectrogram(
waveform=waveform.to(torch.float32),
pad=0,
window=self.spectrogram_window,
n_fft=self.config.n_fft,
hop_length=self.config.hop_length,
win_length=self.config.win_length,
power=2,
normalized=False,
center=self.config.center,
)
mel_spectrogram = (
spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
# x has shape [batch, freq, time].
# F.amplitude_to_DB accepts inputs shaped as:
# - [freq, time]
# - [channel, freq, time]
# - [..., channel, freq, time]
# Here we insert a channel dimension of size 1 before calling it,
# then remove that extra dimension afterward.
log_mel_spectrogram = F.amplitude_to_DB(
mel_spectrogram.unsqueeze(1),
multiplier=10,
amin=1e-10,
db_multiplier=0,
top_db=120,
).squeeze(1)
return log_mel_spectrogram.to(waveform.dtype)
class DashengAudioTransformer(nn.Module): class DashengAudioTransformer(nn.Module):
def __init__( def __init__(
...@@ -293,7 +350,7 @@ class DashengAudioTransformer(nn.Module): ...@@ -293,7 +350,7 @@ class DashengAudioTransformer(nn.Module):
self.target_length = config.target_length self.target_length = config.target_length
self.hop_length = config.hop_length self.hop_length = config.hop_length
self._init_front_end(config) self.front_end = DashengFrontend(config)
self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01) self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
...@@ -318,34 +375,10 @@ class DashengAudioTransformer(nn.Module): ...@@ -318,34 +375,10 @@ class DashengAudioTransformer(nn.Module):
qkv_bias=config.qkv_bias, qkv_bias=config.qkv_bias,
init_values=config.init_values, init_values=config.init_values,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.block{i}", prefix=f"{prefix}.blocks.{i}",
) for i in range(config.depth)) ) for i in range(config.depth))
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6) self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
def _init_front_end(self, config):
with set_default_torch_dtype(torch.float32):
self.front_end = nn.Sequential(
audio_transforms.MelSpectrogram(
f_min=config.f_min,
f_max=config.f_max,
center=config.center,
win_length=config.win_length,
hop_length=config.hop_length,
sample_rate=config.sample_rate,
n_fft=config.n_fft,
n_mels=config.n_mels,
),
audio_transforms.AmplitudeToDB(top_db=120),
)
mel_spectrogram = self.front_end[0]
fb = mel_spectrogram.mel_scale.fb
win = mel_spectrogram.spectrogram.window
mel_spectrogram.mel_scale.fb = fb.to(torch.bfloat16).to(
torch.float32)
mel_spectrogram.spectrogram.window = win.to(torch.bfloat16).to(
torch.float32)
def forward_features( def forward_features(
self, self,
x: torch.Tensor, x: torch.Tensor,
...@@ -430,14 +463,16 @@ class AudioProjectorSubsample(nn.Module): ...@@ -430,14 +463,16 @@ class AudioProjectorSubsample(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.net.0", prefix=f"{prefix}.net.0",
return_bias=False, return_bias=False,
), get_act_fn("gelu"), ),
get_act_fn("gelu"),
RowParallelLinear( RowParallelLinear(
input_size=out_dim, input_size=out_dim,
output_size=out_dim, output_size=out_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.net.2", prefix=f"{prefix}.net.2",
return_bias=False, return_bias=False,
)) ),
)
def forward(self, x, mask=None): def forward(self, x, mask=None):
batch_size, seq_len, dim = x.shape batch_size, seq_len, dim = x.shape
...@@ -534,9 +569,12 @@ class MiDashengLMMultiModalProcessor( ...@@ -534,9 +569,12 @@ class MiDashengLMMultiModalProcessor(
# + Padding # + Padding
min_audio_len = self.info.get_min_audio_len() min_audio_len = self.info.get_min_audio_len()
processed_audios = [ processed_audios = [
np.pad(audio, (0, min_audio_len - audio.shape[-1]), np.pad(
mode='constant', audio,
constant_values=0) if isinstance(audio, np.ndarray) (0, min_audio_len - audio.shape[-1]),
mode="constant",
constant_values=0,
) if isinstance(audio, np.ndarray)
and audio.shape[-1] < min_audio_len else audio for audio in audios and audio.shape[-1] < min_audio_len else audio for audio in audios
] ]
...@@ -585,8 +623,8 @@ class MiDashengLMMultiModalProcessor( ...@@ -585,8 +623,8 @@ class MiDashengLMMultiModalProcessor(
if audio_length is None: if audio_length is None:
audio_output_lengths = [] audio_output_lengths = []
else: else:
audio_length_np = audio_length.cpu().numpy() if isinstance( audio_length_np = (audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length audio_length, torch.Tensor) else audio_length)
audio_output_lengths = [ audio_output_lengths = [
max(1, calculate_mel_frames_dasheng( max(1, calculate_mel_frames_dasheng(
int(length))) # at least one frame int(length))) # at least one frame
...@@ -617,6 +655,17 @@ class MiDashengLMMultiModalProcessor( ...@@ -617,6 +655,17 @@ class MiDashengLMMultiModalProcessor(
dummy_inputs=MiDashengLMDummyInputsBuilder, dummy_inputs=MiDashengLMDummyInputsBuilder,
) )
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
...@@ -660,8 +709,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -660,8 +709,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
def _validate_and_reshape_mm_tensor(self, mm_input: object, def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor: name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)): if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. " raise ValueError(
f"Got type: {type(mm_input)}") f"Incorrect type of {name}. Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor): if isinstance(mm_input, torch.Tensor):
return mm_input.reshape(-1, *mm_input.shape[2:]) return mm_input.reshape(-1, *mm_input.shape[2:])
...@@ -710,8 +759,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -710,8 +759,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_input["input_values"].dtype) audio_input["input_values"].dtype)
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
audio_length_np = audio_length.cpu().numpy() if isinstance( audio_length_np = (audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length audio_length, torch.Tensor) else audio_length)
audio_output_lengths = [ audio_output_lengths = [
max(1, calculate_mel_frames_dasheng( max(1, calculate_mel_frames_dasheng(
int(length))) # at least one frame int(length))) # at least one frame
...@@ -720,11 +769,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -720,11 +769,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_output_lengths = torch.tensor(audio_output_lengths).to( audio_output_lengths = torch.tensor(audio_output_lengths).to(
audio_embeddings.device) audio_embeddings.device)
audio_feature_mask = (torch.arange( audio_feature_mask = torch.arange(
max_audio_tokens, max_audio_tokens,
device=audio_embeddings.device).unsqueeze(0).expand( device=audio_embeddings.device).unsqueeze(0).expand(
batch_size, max_audio_tokens) batch_size,
< audio_output_lengths.unsqueeze(1)) max_audio_tokens) < audio_output_lengths.unsqueeze(1)
masked_audio_features = audio_embeddings[audio_feature_mask].view( masked_audio_features = audio_embeddings[audio_feature_mask].view(
-1, embed_dim) -1, embed_dim)
...@@ -762,10 +811,12 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -762,10 +811,12 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
) )
input_ids = None input_ids = None
return self.decoder.model(input_ids, return self.decoder.model(
input_ids,
positions, positions,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds,
)
def compute_logits( def compute_logits(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment