Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8616300a
Unverified
Commit
8616300a
authored
Sep 29, 2025
by
zhoukz
Committed by
GitHub
Sep 29, 2025
Browse files
[Model][Bugfix] Fix issues in MiDashengLM implementation for quantized models (#25854)
Signed-off-by:
zhoukz
<
me@zhoukz.com
>
parent
edbaadd9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
122 additions
and
71 deletions
+122
-71
vllm/model_executor/models/midashenglm.py
vllm/model_executor/models/midashenglm.py
+122
-71
No files found.
vllm/model_executor/models/midashenglm.py
View file @
8616300a
...
...
@@ -22,6 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
import
collections
import
collections.abc
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
...
...
@@ -30,10 +31,10 @@ from typing import Any, Callable, Optional, TypedDict, Union, cast
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torchaudio.transforms
as
audio_transforms
import
torchaudio.functional
as
F
from
torch.nn.functional
import
scaled_dot_product_attention
from
transformers
import
BatchFeature
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
...
...
@@ -41,7 +42,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
)
...
...
@@ -147,15 +147,19 @@ class DashengMlp(nn.Module):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
ColumnParallelLinear
(
input_size
=
in_features
,
self
.
fc1
=
ColumnParallelLinear
(
input_size
=
in_features
,
output_size
=
hidden_features
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
)
prefix
=
f
"
{
prefix
}
.fc1"
,
)
self
.
act
=
get_act_fn
(
"gelu"
)
self
.
fc2
=
RowParallelLinear
(
input_size
=
hidden_features
,
self
.
fc2
=
RowParallelLinear
(
input_size
=
hidden_features
,
output_size
=
out_features
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
)
prefix
=
f
"
{
prefix
}
.fc2"
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
fc1
(
x
)
...
...
@@ -171,7 +175,6 @@ class DashengAttention(nn.Module):
dim
:
int
,
num_heads
:
int
=
8
,
qkv_bias
:
bool
=
False
,
causal
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
...
...
@@ -205,33 +208,30 @@ class DashengAttention(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv"
,
)
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
,
)
self
.
proj
=
RowParallelLinear
(
input_size
=
dim
,
output_size
=
dim
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.proj"
,
)
self
.
causal
=
causal
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
):
B
,
N
,
C
=
x
.
shape
qkv_out
,
_
=
self
.
qkv
(
x
)
q
,
k
,
v
=
qkv_out
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_out
=
self
.
attn
(
q
,
k
,
v
)
C_local
=
attn_out
.
numel
()
//
(
B
*
N
)
# C_local for parallel
attn_out
=
attn_out
.
view
(
B
,
N
,
C_local
)
qkv
,
_
=
self
.
qkv
(
x
)
qkv
=
qkv
.
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
)
qkv
=
qkv
.
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
.
unbind
(
0
)
x
,
_
=
self
.
proj
(
attn_out
)
x
=
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
mask
[:,
None
,
None
,
:]
if
mask
is
not
None
else
None
,
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
,
_
=
self
.
proj
(
x
)
return
x
...
...
@@ -280,6 +280,63 @@ class DashengBlock(nn.Module):
return
x
class
DashengFrontend
(
nn
.
Module
):
def
__init__
(
self
,
config
:
DashengConfig
):
super
().
__init__
()
self
.
config
=
config
spectrogram_window
=
torch
.
hann_window
(
self
.
config
.
win_length
)
self
.
register_buffer
(
"spectrogram_window"
,
spectrogram_window
,
persistent
=
False
,
)
self
.
spectrogram_window
:
torch
.
Tensor
melscale_fbanks
=
F
.
melscale_fbanks
(
n_freqs
=
self
.
config
.
n_fft
//
2
+
1
,
f_min
=
self
.
config
.
f_min
,
f_max
=
self
.
config
.
f_max
,
n_mels
=
self
.
config
.
n_mels
,
sample_rate
=
self
.
config
.
sample_rate
,
)
self
.
register_buffer
(
"melscale_fbanks"
,
melscale_fbanks
,
persistent
=
False
)
self
.
melscale_fbanks
:
torch
.
Tensor
def
forward
(
self
,
waveform
:
torch
.
Tensor
)
->
torch
.
Tensor
:
spectrogram
=
F
.
spectrogram
(
waveform
=
waveform
.
to
(
torch
.
float32
),
pad
=
0
,
window
=
self
.
spectrogram_window
,
n_fft
=
self
.
config
.
n_fft
,
hop_length
=
self
.
config
.
hop_length
,
win_length
=
self
.
config
.
win_length
,
power
=
2
,
normalized
=
False
,
center
=
self
.
config
.
center
,
)
mel_spectrogram
=
(
spectrogram
.
mT
@
self
.
melscale_fbanks
.
to
(
torch
.
float32
)).
mT
# x has shape [batch, freq, time].
# F.amplitude_to_DB accepts inputs shaped as:
# - [freq, time]
# - [channel, freq, time]
# - [..., channel, freq, time]
# Here we insert a channel dimension of size 1 before calling it,
# then remove that extra dimension afterward.
log_mel_spectrogram
=
F
.
amplitude_to_DB
(
mel_spectrogram
.
unsqueeze
(
1
),
multiplier
=
10
,
amin
=
1e-10
,
db_multiplier
=
0
,
top_db
=
120
,
).
squeeze
(
1
)
return
log_mel_spectrogram
.
to
(
waveform
.
dtype
)
class
DashengAudioTransformer
(
nn
.
Module
):
def
__init__
(
...
...
@@ -293,7 +350,7 @@ class DashengAudioTransformer(nn.Module):
self
.
target_length
=
config
.
target_length
self
.
hop_length
=
config
.
hop_length
self
.
_init_f
ront
_
end
(
config
)
self
.
front_end
=
DashengF
rontend
(
config
)
self
.
init_bn
=
nn
.
BatchNorm2d
(
config
.
n_mels
,
momentum
=
0.01
)
...
...
@@ -318,34 +375,10 @@ class DashengAudioTransformer(nn.Module):
qkv_bias
=
config
.
qkv_bias
,
init_values
=
config
.
init_values
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.block
{
i
}
"
,
prefix
=
f
"
{
prefix
}
.block
s.
{
i
}
"
,
)
for
i
in
range
(
config
.
depth
))
self
.
norm
=
nn
.
LayerNorm
(
config
.
embed_dim
,
eps
=
1e-6
)
def
_init_front_end
(
self
,
config
):
with
set_default_torch_dtype
(
torch
.
float32
):
self
.
front_end
=
nn
.
Sequential
(
audio_transforms
.
MelSpectrogram
(
f_min
=
config
.
f_min
,
f_max
=
config
.
f_max
,
center
=
config
.
center
,
win_length
=
config
.
win_length
,
hop_length
=
config
.
hop_length
,
sample_rate
=
config
.
sample_rate
,
n_fft
=
config
.
n_fft
,
n_mels
=
config
.
n_mels
,
),
audio_transforms
.
AmplitudeToDB
(
top_db
=
120
),
)
mel_spectrogram
=
self
.
front_end
[
0
]
fb
=
mel_spectrogram
.
mel_scale
.
fb
win
=
mel_spectrogram
.
spectrogram
.
window
mel_spectrogram
.
mel_scale
.
fb
=
fb
.
to
(
torch
.
bfloat16
).
to
(
torch
.
float32
)
mel_spectrogram
.
spectrogram
.
window
=
win
.
to
(
torch
.
bfloat16
).
to
(
torch
.
float32
)
def
forward_features
(
self
,
x
:
torch
.
Tensor
,
...
...
@@ -430,14 +463,16 @@ class AudioProjectorSubsample(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.net.0"
,
return_bias
=
False
,
),
get_act_fn
(
"gelu"
),
),
get_act_fn
(
"gelu"
),
RowParallelLinear
(
input_size
=
out_dim
,
output_size
=
out_dim
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.net.2"
,
return_bias
=
False
,
))
),
)
def
forward
(
self
,
x
,
mask
=
None
):
batch_size
,
seq_len
,
dim
=
x
.
shape
...
...
@@ -534,9 +569,12 @@ class MiDashengLMMultiModalProcessor(
# + Padding
min_audio_len
=
self
.
info
.
get_min_audio_len
()
processed_audios
=
[
np
.
pad
(
audio
,
(
0
,
min_audio_len
-
audio
.
shape
[
-
1
]),
mode
=
'constant'
,
constant_values
=
0
)
if
isinstance
(
audio
,
np
.
ndarray
)
np
.
pad
(
audio
,
(
0
,
min_audio_len
-
audio
.
shape
[
-
1
]),
mode
=
"constant"
,
constant_values
=
0
,
)
if
isinstance
(
audio
,
np
.
ndarray
)
and
audio
.
shape
[
-
1
]
<
min_audio_len
else
audio
for
audio
in
audios
]
...
...
@@ -585,8 +623,8 @@ class MiDashengLMMultiModalProcessor(
if
audio_length
is
None
:
audio_output_lengths
=
[]
else
:
audio_length_np
=
audio_length
.
cpu
().
numpy
()
if
isinstance
(
audio_length
,
torch
.
Tensor
)
else
audio_length
audio_length_np
=
(
audio_length
.
cpu
().
numpy
()
if
isinstance
(
audio_length
,
torch
.
Tensor
)
else
audio_length
)
audio_output_lengths
=
[
max
(
1
,
calculate_mel_frames_dasheng
(
int
(
length
)))
# at least one frame
...
...
@@ -617,6 +655,17 @@ class MiDashengLMMultiModalProcessor(
dummy_inputs
=
MiDashengLMDummyInputsBuilder
,
)
class
MiDashengLMModel
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
...
...
@@ -660,8 +709,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
f
"
Got type:
{
type
(
mm_input
)
}
"
)
raise
ValueError
(
f
"Incorrect type of
{
name
}
.
Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
return
mm_input
.
reshape
(
-
1
,
*
mm_input
.
shape
[
2
:])
...
...
@@ -710,8 +759,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_input
[
"input_values"
].
dtype
)
batch_size
,
max_audio_tokens
,
embed_dim
=
audio_embeddings
.
shape
audio_length_np
=
audio_length
.
cpu
().
numpy
()
if
isinstance
(
audio_length
,
torch
.
Tensor
)
else
audio_length
audio_length_np
=
(
audio_length
.
cpu
().
numpy
()
if
isinstance
(
audio_length
,
torch
.
Tensor
)
else
audio_length
)
audio_output_lengths
=
[
max
(
1
,
calculate_mel_frames_dasheng
(
int
(
length
)))
# at least one frame
...
...
@@ -720,11 +769,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_output_lengths
=
torch
.
tensor
(
audio_output_lengths
).
to
(
audio_embeddings
.
device
)
audio_feature_mask
=
(
torch
.
arange
(
audio_feature_mask
=
torch
.
arange
(
max_audio_tokens
,
device
=
audio_embeddings
.
device
).
unsqueeze
(
0
).
expand
(
batch_size
,
max_audio_tokens
)
<
audio_output_lengths
.
unsqueeze
(
1
)
)
batch_size
,
max_audio_tokens
)
<
audio_output_lengths
.
unsqueeze
(
1
)
masked_audio_features
=
audio_embeddings
[
audio_feature_mask
].
view
(
-
1
,
embed_dim
)
...
...
@@ -762,10 +811,12 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
)
input_ids
=
None
return
self
.
decoder
.
model
(
input_ids
,
return
self
.
decoder
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
,
)
def
compute_logits
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment