Unverified Commit 4cfe5d2b authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] `multi_modal_kwargs` broadcast for CPU tensor parallel (#10541)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent c8acd805
......@@ -35,6 +35,7 @@ class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
"input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
......
......@@ -83,6 +83,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment