Unverified Commit a8c0f599 authored by bingchen-mi's avatar bingchen-mi Committed by GitHub
Browse files

[Bugfix] MiDashengLM model contact error under concurrent testing (#24738)


Signed-off-by: default avatarchenbing8 <chenbing8@xiaomi.com>
Signed-off-by: default avatarbingchen-mi <chenbing8@xiaomi.com>
parent f4a948f3
...@@ -497,8 +497,11 @@ class MiDashengLMDummyInputsBuilder( ...@@ -497,8 +497,11 @@ class MiDashengLMDummyInputsBuilder(
hf_processor = self.info.get_hf_processor() hf_processor = self.info.get_hf_processor()
audio_token = hf_processor.audio_token audio_token = hf_processor.audio_token
audio_bos_token = hf_processor.audio_bos_token
audio_eos_token = hf_processor.audio_eos_token
return audio_token * num_audios single_audio_text = f"{audio_bos_token}{audio_token}{audio_eos_token}"
return single_audio_text * num_audios
def get_dummy_mm_data( def get_dummy_mm_data(
self, self,
...@@ -577,14 +580,7 @@ class MiDashengLMMultiModalProcessor( ...@@ -577,14 +580,7 @@ class MiDashengLMMultiModalProcessor(
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
audio_token = getattr(processor, "audio_token", "<|AUDIO|>") audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
audio_bos_token = getattr(processor, "audio_bos_token",
"<|audio_bos|>")
audio_eos_token = getattr(processor, "audio_eos_token",
"<|audio_eos|>")
audio_token_id = vocab[audio_token] audio_token_id = vocab[audio_token]
audio_bos_id = vocab[audio_bos_token]
audio_eos_id = vocab[audio_eos_token]
out_mm_data = out_mm_kwargs.get_data() out_mm_data = out_mm_kwargs.get_data()
audio_length = out_mm_data.get("audio_length") audio_length = out_mm_data.get("audio_length")
...@@ -604,7 +600,7 @@ class MiDashengLMMultiModalProcessor( ...@@ -604,7 +600,7 @@ class MiDashengLMMultiModalProcessor(
audio_tokens = [audio_token_id] * num_features audio_tokens = [audio_token_id] * num_features
return PromptUpdateDetails.select_token_id( return PromptUpdateDetails.select_token_id(
[audio_bos_id] + audio_tokens + [audio_eos_id], audio_tokens,
embed_token_id=audio_token_id, embed_token_id=audio_token_id,
) )
...@@ -670,7 +666,17 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -670,7 +666,17 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(mm_input)}") f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor): if isinstance(mm_input, torch.Tensor):
return mm_input.reshape(-1, *mm_input.shape[2:]) return mm_input.reshape(-1, *mm_input.shape[2:])
else:
if name == "input_values":
max_length = max(tensor.shape[1] for tensor in mm_input)
padded_mm_input = [
torch.nn.functional.pad(tensor,
(0, max_length - tensor.shape[1]))
if tensor.shape[1] < max_length else tensor
for tensor in mm_input
]
return torch.concat(padded_mm_input)
return torch.concat(mm_input) return torch.concat(mm_input)
def _parse_and_validate_audio_input( def _parse_and_validate_audio_input(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment