"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "636e1499fcf93b6aced186d863404da1f3340060"
Unverified Commit 7b81f956 authored by Binyao Jiang's avatar Binyao Jiang Committed by GitHub
Browse files

Fix qwen2 audio not working bug (#8600)

parent d3e67deb
...@@ -614,8 +614,7 @@ def general_mm_embed_routine( ...@@ -614,8 +614,7 @@ def general_mm_embed_routine(
input_ids: Input token IDs tensor input_ids: Input token IDs tensor
forward_batch: Batch information for model forward pass forward_batch: Batch information for model forward pass
language_model: Base language model to use language_model: Base language model to use
image_data_embedding_func: Function to embed image data data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
audio_data_embedding_func: Function to embed audio data
placeholder_tokens: Token IDs for multimodal placeholders placeholder_tokens: Token IDs for multimodal placeholders
**kwargs: Additional arguments passed to language model **kwargs: Additional arguments passed to language model
......
...@@ -52,7 +52,11 @@ from sglang.srt.managers.mm_utils import ( ...@@ -52,7 +52,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens, MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
...@@ -106,15 +110,10 @@ class Qwen2AudioForConditionalGeneration(nn.Module): ...@@ -106,15 +110,10 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
self.language_model = Qwen2ForCausalLM( self.language_model = Qwen2ForCausalLM(
config.text_config, quant_config, prefix=add_prefix("model", prefix) config.text_config, quant_config, prefix=add_prefix("model", prefix)
) )
self.pattern = MultiModalityDataPaddingPatternMultimodalTokens()
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs for audio return self.pattern.pad_input_tokens(input_ids, mm_inputs)
audio_token_id: int = getattr(
mm_inputs, "audio_token_id", mm_inputs.im_token_id
)
pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# Extract audio features from input items # Extract audio features from input items
...@@ -143,7 +142,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module): ...@@ -143,7 +142,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
language_model=self.language_model, language_model=self.language_model,
audio_data_embedding_func=self.get_audio_feature, data_embedding_funcs={
Modality.AUDIO: self.get_audio_feature,
},
positions=positions, positions=positions,
) )
......
...@@ -190,6 +190,53 @@ class TestGemma3nServer(TestOpenAIVisionServer): ...@@ -190,6 +190,53 @@ class TestGemma3nServer(TestOpenAIVisionServer):
# self._test_audio_ambient_completion() # self._test_audio_ambient_completion()
class TestQwen2AudioServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen2-Audio-7B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--mem-fraction-static",
"0.70",
],
)
cls.base_url += "/v1"
def test_audio_chat_completion(self):
self._test_audio_speech_completion()
self._test_audio_ambient_completion()
# Qwen2Audio does not support image
def test_single_image_chat_completion(self):
pass
# Qwen2Audio does not support image
def test_multi_turn_chat_completion(self):
pass
# Qwen2Audio does not support image
def test_multi_images_chat_completion(self):
pass
# Qwen2Audio does not support image
def test_video_images_chat_completion(self):
pass
# Qwen2Audio does not support image
def test_regex(self):
pass
# Qwen2Audio does not support image
def test_mixed_batch(self):
pass
class TestKimiVLServer(TestOpenAIVisionServer): class TestKimiVLServer(TestOpenAIVisionServer):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
...@@ -547,7 +547,7 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -547,7 +547,7 @@ class TestOpenAIVisionServer(CustomTestCase):
# bird song # bird song
audio_response = self.get_audio_response( audio_response = self.get_audio_response(
AUDIO_BIRD_SONG_URL, AUDIO_BIRD_SONG_URL,
"Please listen to the audio snippet carefully and transcribe the content.", "Please listen to the audio snippet carefully and transcribe the content in English.",
"ambient", "ambient",
) )
assert "bird" in audio_response assert "bird" in audio_response
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment