Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8616300a
Unverified
Commit
8616300a
authored
Sep 29, 2025
by
zhoukz
Committed by
GitHub
Sep 29, 2025
Browse files
[Model][Bugfix] Fix issues in MiDashengLM implementation for quantized models (#25854)
Signed-off-by:
zhoukz
<
me@zhoukz.com
>
parent
edbaadd9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
122 additions
and
71 deletions
+122
-71
vllm/model_executor/models/midashenglm.py
vllm/model_executor/models/midashenglm.py
+122
-71
No files found.
vllm/model_executor/models/midashenglm.py
View file @
8616300a
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
import
collections
import
collections
import
collections.abc
import
collections.abc
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
...
@@ -30,10 +31,10 @@ from typing import Any, Callable, Optional, TypedDict, Union, cast
...
@@ -30,10 +31,10 @@ from typing import Any, Callable, Optional, TypedDict, Union, cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torchaudio.transforms
as
audio_transforms
import
torchaudio.functional
as
F
from
torch.nn.functional
import
scaled_dot_product_attention
from
transformers
import
BatchFeature
from
transformers
import
BatchFeature
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
...
@@ -41,7 +42,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -41,7 +42,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
)
MultiModalKwargsItems
)
...
@@ -147,15 +147,19 @@ class DashengMlp(nn.Module):
...
@@ -147,15 +147,19 @@ class DashengMlp(nn.Module):
super
().
__init__
()
super
().
__init__
()
out_features
=
out_features
or
in_features
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
ColumnParallelLinear
(
input_size
=
in_features
,
self
.
fc1
=
ColumnParallelLinear
(
input_size
=
in_features
,
output_size
=
hidden_features
,
output_size
=
hidden_features
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
)
prefix
=
f
"
{
prefix
}
.fc1"
,
)
self
.
act
=
get_act_fn
(
"gelu"
)
self
.
act
=
get_act_fn
(
"gelu"
)
self
.
fc2
=
RowParallelLinear
(
input_size
=
hidden_features
,
self
.
fc2
=
RowParallelLinear
(
input_size
=
hidden_features
,
output_size
=
out_features
,
output_size
=
out_features
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
)
prefix
=
f
"
{
prefix
}
.fc2"
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
fc1
(
x
)
x
,
_
=
self
.
fc1
(
x
)
...
@@ -171,7 +175,6 @@ class DashengAttention(nn.Module):
...
@@ -171,7 +175,6 @@ class DashengAttention(nn.Module):
dim
:
int
,
dim
:
int
,
num_heads
:
int
=
8
,
num_heads
:
int
=
8
,
qkv_bias
:
bool
=
False
,
qkv_bias
:
bool
=
False
,
causal
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
...
@@ -205,33 +208,30 @@ class DashengAttention(nn.Module):
...
@@ -205,33 +208,30 @@ class DashengAttention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv"
,
prefix
=
f
"
{
prefix
}
.qkv"
,
)
)
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
,
)
self
.
proj
=
RowParallelLinear
(
self
.
proj
=
RowParallelLinear
(
input_size
=
dim
,
input_size
=
dim
,
output_size
=
dim
,
output_size
=
dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.proj"
,
prefix
=
f
"
{
prefix
}
.proj"
,
)
)
self
.
causal
=
causal
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
):
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
):
B
,
N
,
C
=
x
.
shape
B
,
N
,
C
=
x
.
shape
qkv_out
,
_
=
self
.
qkv
(
x
)
qkv
,
_
=
self
.
qkv
(
x
)
q
,
k
,
v
=
qkv_out
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
qkv
=
qkv
.
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
)
dim
=-
1
)
qkv
=
qkv
.
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
.
unbind
(
0
)
attn_out
=
self
.
attn
(
q
,
k
,
v
)
C_local
=
attn_out
.
numel
()
//
(
B
*
N
)
# C_local for parallel
attn_out
=
attn_out
.
view
(
B
,
N
,
C_local
)
x
,
_
=
self
.
proj
(
attn_out
)
x
=
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
mask
[:,
None
,
None
,
:]
if
mask
is
not
None
else
None
,
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
,
_
=
self
.
proj
(
x
)
return
x
return
x
...
@@ -280,6 +280,63 @@ class DashengBlock(nn.Module):
...
@@ -280,6 +280,63 @@ class DashengBlock(nn.Module):
return
x
return
x
class
DashengFrontend
(
nn
.
Module
):
def
__init__
(
self
,
config
:
DashengConfig
):
super
().
__init__
()
self
.
config
=
config
spectrogram_window
=
torch
.
hann_window
(
self
.
config
.
win_length
)
self
.
register_buffer
(
"spectrogram_window"
,
spectrogram_window
,
persistent
=
False
,
)
self
.
spectrogram_window
:
torch
.
Tensor
melscale_fbanks
=
F
.
melscale_fbanks
(
n_freqs
=
self
.
config
.
n_fft
//
2
+
1
,
f_min
=
self
.
config
.
f_min
,
f_max
=
self
.
config
.
f_max
,
n_mels
=
self
.
config
.
n_mels
,
sample_rate
=
self
.
config
.
sample_rate
,
)
self
.
register_buffer
(
"melscale_fbanks"
,
melscale_fbanks
,
persistent
=
False
)
self
.
melscale_fbanks
:
torch
.
Tensor
def
forward
(
self
,
waveform
:
torch
.
Tensor
)
->
torch
.
Tensor
:
spectrogram
=
F
.
spectrogram
(
waveform
=
waveform
.
to
(
torch
.
float32
),
pad
=
0
,
window
=
self
.
spectrogram_window
,
n_fft
=
self
.
config
.
n_fft
,
hop_length
=
self
.
config
.
hop_length
,
win_length
=
self
.
config
.
win_length
,
power
=
2
,
normalized
=
False
,
center
=
self
.
config
.
center
,
)
mel_spectrogram
=
(
spectrogram
.
mT
@
self
.
melscale_fbanks
.
to
(
torch
.
float32
)).
mT
# x has shape [batch, freq, time].
# F.amplitude_to_DB accepts inputs shaped as:
# - [freq, time]
# - [channel, freq, time]
# - [..., channel, freq, time]
# Here we insert a channel dimension of size 1 before calling it,
# then remove that extra dimension afterward.
log_mel_spectrogram
=
F
.
amplitude_to_DB
(
mel_spectrogram
.
unsqueeze
(
1
),
multiplier
=
10
,
amin
=
1e-10
,
db_multiplier
=
0
,
top_db
=
120
,
).
squeeze
(
1
)
return
log_mel_spectrogram
.
to
(
waveform
.
dtype
)
class
DashengAudioTransformer
(
nn
.
Module
):
class
DashengAudioTransformer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -293,7 +350,7 @@ class DashengAudioTransformer(nn.Module):
...
@@ -293,7 +350,7 @@ class DashengAudioTransformer(nn.Module):
self
.
target_length
=
config
.
target_length
self
.
target_length
=
config
.
target_length
self
.
hop_length
=
config
.
hop_length
self
.
hop_length
=
config
.
hop_length
self
.
_init_f
ront
_
end
(
config
)
self
.
front_end
=
DashengF
rontend
(
config
)
self
.
init_bn
=
nn
.
BatchNorm2d
(
config
.
n_mels
,
momentum
=
0.01
)
self
.
init_bn
=
nn
.
BatchNorm2d
(
config
.
n_mels
,
momentum
=
0.01
)
...
@@ -318,34 +375,10 @@ class DashengAudioTransformer(nn.Module):
...
@@ -318,34 +375,10 @@ class DashengAudioTransformer(nn.Module):
qkv_bias
=
config
.
qkv_bias
,
qkv_bias
=
config
.
qkv_bias
,
init_values
=
config
.
init_values
,
init_values
=
config
.
init_values
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.block
{
i
}
"
,
prefix
=
f
"
{
prefix
}
.block
s.
{
i
}
"
,
)
for
i
in
range
(
config
.
depth
))
)
for
i
in
range
(
config
.
depth
))
self
.
norm
=
nn
.
LayerNorm
(
config
.
embed_dim
,
eps
=
1e-6
)
self
.
norm
=
nn
.
LayerNorm
(
config
.
embed_dim
,
eps
=
1e-6
)
def
_init_front_end
(
self
,
config
):
with
set_default_torch_dtype
(
torch
.
float32
):
self
.
front_end
=
nn
.
Sequential
(
audio_transforms
.
MelSpectrogram
(
f_min
=
config
.
f_min
,
f_max
=
config
.
f_max
,
center
=
config
.
center
,
win_length
=
config
.
win_length
,
hop_length
=
config
.
hop_length
,
sample_rate
=
config
.
sample_rate
,
n_fft
=
config
.
n_fft
,
n_mels
=
config
.
n_mels
,
),
audio_transforms
.
AmplitudeToDB
(
top_db
=
120
),
)
mel_spectrogram
=
self
.
front_end
[
0
]
fb
=
mel_spectrogram
.
mel_scale
.
fb
win
=
mel_spectrogram
.
spectrogram
.
window
mel_spectrogram
.
mel_scale
.
fb
=
fb
.
to
(
torch
.
bfloat16
).
to
(
torch
.
float32
)
mel_spectrogram
.
spectrogram
.
window
=
win
.
to
(
torch
.
bfloat16
).
to
(
torch
.
float32
)
def
forward_features
(
def
forward_features
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -430,14 +463,16 @@ class AudioProjectorSubsample(nn.Module):
...
@@ -430,14 +463,16 @@ class AudioProjectorSubsample(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.net.0"
,
prefix
=
f
"
{
prefix
}
.net.0"
,
return_bias
=
False
,
return_bias
=
False
,
),
get_act_fn
(
"gelu"
),
),
get_act_fn
(
"gelu"
),
RowParallelLinear
(
RowParallelLinear
(
input_size
=
out_dim
,
input_size
=
out_dim
,
output_size
=
out_dim
,
output_size
=
out_dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.net.2"
,
prefix
=
f
"
{
prefix
}
.net.2"
,
return_bias
=
False
,
return_bias
=
False
,
))
),
)
def
forward
(
self
,
x
,
mask
=
None
):
def
forward
(
self
,
x
,
mask
=
None
):
batch_size
,
seq_len
,
dim
=
x
.
shape
batch_size
,
seq_len
,
dim
=
x
.
shape
...
@@ -534,9 +569,12 @@ class MiDashengLMMultiModalProcessor(
...
@@ -534,9 +569,12 @@ class MiDashengLMMultiModalProcessor(
# + Padding
# + Padding
min_audio_len
=
self
.
info
.
get_min_audio_len
()
min_audio_len
=
self
.
info
.
get_min_audio_len
()
processed_audios
=
[
processed_audios
=
[
np
.
pad
(
audio
,
(
0
,
min_audio_len
-
audio
.
shape
[
-
1
]),
np
.
pad
(
mode
=
'constant'
,
audio
,
constant_values
=
0
)
if
isinstance
(
audio
,
np
.
ndarray
)
(
0
,
min_audio_len
-
audio
.
shape
[
-
1
]),
mode
=
"constant"
,
constant_values
=
0
,
)
if
isinstance
(
audio
,
np
.
ndarray
)
and
audio
.
shape
[
-
1
]
<
min_audio_len
else
audio
for
audio
in
audios
and
audio
.
shape
[
-
1
]
<
min_audio_len
else
audio
for
audio
in
audios
]
]
...
@@ -585,8 +623,8 @@ class MiDashengLMMultiModalProcessor(
...
@@ -585,8 +623,8 @@ class MiDashengLMMultiModalProcessor(
if
audio_length
is
None
:
if
audio_length
is
None
:
audio_output_lengths
=
[]
audio_output_lengths
=
[]
else
:
else
:
audio_length_np
=
audio_length
.
cpu
().
numpy
()
if
isinstance
(
audio_length_np
=
(
audio_length
.
cpu
().
numpy
()
if
isinstance
(
audio_length
,
torch
.
Tensor
)
else
audio_length
audio_length
,
torch
.
Tensor
)
else
audio_length
)
audio_output_lengths
=
[
audio_output_lengths
=
[
max
(
1
,
calculate_mel_frames_dasheng
(
max
(
1
,
calculate_mel_frames_dasheng
(
int
(
length
)))
# at least one frame
int
(
length
)))
# at least one frame
...
@@ -617,6 +655,17 @@ class MiDashengLMMultiModalProcessor(
...
@@ -617,6 +655,17 @@ class MiDashengLMMultiModalProcessor(
dummy_inputs
=
MiDashengLMDummyInputsBuilder
,
dummy_inputs
=
MiDashengLMDummyInputsBuilder
,
)
)
class
MiDashengLMModel
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
class
MiDashengLMModel
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
@
classmethod
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
...
@@ -660,8 +709,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -660,8 +709,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
raise
ValueError
(
f
"
Got type:
{
type
(
mm_input
)
}
"
)
f
"Incorrect type of
{
name
}
.
Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
if
isinstance
(
mm_input
,
torch
.
Tensor
):
return
mm_input
.
reshape
(
-
1
,
*
mm_input
.
shape
[
2
:])
return
mm_input
.
reshape
(
-
1
,
*
mm_input
.
shape
[
2
:])
...
@@ -710,8 +759,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -710,8 +759,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_input
[
"input_values"
].
dtype
)
audio_input
[
"input_values"
].
dtype
)
batch_size
,
max_audio_tokens
,
embed_dim
=
audio_embeddings
.
shape
batch_size
,
max_audio_tokens
,
embed_dim
=
audio_embeddings
.
shape
audio_length_np
=
audio_length
.
cpu
().
numpy
()
if
isinstance
(
audio_length_np
=
(
audio_length
.
cpu
().
numpy
()
if
isinstance
(
audio_length
,
torch
.
Tensor
)
else
audio_length
audio_length
,
torch
.
Tensor
)
else
audio_length
)
audio_output_lengths
=
[
audio_output_lengths
=
[
max
(
1
,
calculate_mel_frames_dasheng
(
max
(
1
,
calculate_mel_frames_dasheng
(
int
(
length
)))
# at least one frame
int
(
length
)))
# at least one frame
...
@@ -720,11 +769,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -720,11 +769,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_output_lengths
=
torch
.
tensor
(
audio_output_lengths
).
to
(
audio_output_lengths
=
torch
.
tensor
(
audio_output_lengths
).
to
(
audio_embeddings
.
device
)
audio_embeddings
.
device
)
audio_feature_mask
=
(
torch
.
arange
(
audio_feature_mask
=
torch
.
arange
(
max_audio_tokens
,
max_audio_tokens
,
device
=
audio_embeddings
.
device
).
unsqueeze
(
0
).
expand
(
device
=
audio_embeddings
.
device
).
unsqueeze
(
0
).
expand
(
batch_size
,
max_audio_tokens
)
batch_size
,
<
audio_output_lengths
.
unsqueeze
(
1
)
)
max_audio_tokens
)
<
audio_output_lengths
.
unsqueeze
(
1
)
masked_audio_features
=
audio_embeddings
[
audio_feature_mask
].
view
(
masked_audio_features
=
audio_embeddings
[
audio_feature_mask
].
view
(
-
1
,
embed_dim
)
-
1
,
embed_dim
)
...
@@ -762,10 +811,12 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -762,10 +811,12 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
)
)
input_ids
=
None
input_ids
=
None
return
self
.
decoder
.
model
(
input_ids
,
return
self
.
decoder
.
model
(
input_ids
,
positions
,
positions
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
,
)
def
compute_logits
(
def
compute_logits
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment