Unverified Commit 20b14095 authored by Nick Cao's avatar Nick Cao Committed by GitHub
Browse files

[Bugfix] Fix loading Music Flamingo (#35535)


Signed-off-by: default avatarNick Cao <ncao@redhat.com>
parent 17c1bdf3
......@@ -128,12 +128,6 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
super().__init__(config)
self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
# self.layer_norm is already initialized in super().__init__
# Keep a dummy freqs parameter for MusicFlamingo checkpoints.
self.pos_emb = nn.Module()
freqs = torch.empty(getattr(config, "num_mel_bins", 128))
self.pos_emb.register_parameter(
"freqs", nn.Parameter(freqs, requires_grad=False)
)
def forward(
self,
......
......@@ -21,6 +21,7 @@ from vllm.multimodal.processing import BaseProcessingInfo
from .audioflamingo3 import (
AudioFlamingo3DummyInputsBuilder,
AudioFlamingo3ForConditionalGeneration,
AudioFlamingo3MultiModalDataParser,
AudioFlamingo3MultiModalProcessor,
)
......@@ -53,8 +54,16 @@ class MusicFlamingoProcessingInfo(BaseProcessingInfo):
hf_processor = self.get_hf_processor(**kwargs)
return hf_processor.feature_extractor
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return AudioFlamingo3MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
return {"audio": 1}
class MusicFlamingoDummyInputsBuilder(AudioFlamingo3DummyInputsBuilder):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment