Unverified Commit fd6de37f authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

[BugFix] Fix 3D rope in transformers backend (#35097)


Signed-off-by: default avatarraushan <raushan@huggingface.co>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent c8aca0c9
...@@ -218,7 +218,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -218,7 +218,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
if "mm_token_type_ids" in processed_data if "mm_token_type_ids" in processed_data
else "token_type_ids" else "token_type_ids"
) )
mm_token_type_ids = processed_data.pop(token_type_key) mm_token_type_ids = processed_data.get(token_type_key)
# We can infer vLLM style placeholder from token type ids, if we split # We can infer vLLM style placeholder from token type ids, if we split
# it for each input `mm_data`. # it for each input `mm_data`.
...@@ -353,6 +353,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ...@@ -353,6 +353,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
num_image_patches = kwargs.pop("num_image_patches") num_image_patches = kwargs.pop("num_image_patches")
kwargs.pop("token_type_ids", None) # used only in `forward` kwargs.pop("token_type_ids", None) # used only in `forward`
kwargs.pop("mm_token_type_ids", None) # used only in `model.get_rope_index`
if pixel_values is not None: if pixel_values is not None:
# ROCm: Force math SDP backend for vision encoder to avoid accuracy issues # ROCm: Force math SDP backend for vision encoder to avoid accuracy issues
...@@ -443,6 +444,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ...@@ -443,6 +444,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
{ {
"image_grid_thw", "image_grid_thw",
"video_grid_thw", "video_grid_thw",
"mm_token_type_ids",
"second_per_grid_ts", "second_per_grid_ts",
"audio_feature_lengths", "audio_feature_lengths",
"use_audio_in_video", "use_audio_in_video",
...@@ -451,7 +453,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ...@@ -451,7 +453,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
if any( if any(
v v
for k, v in kwargs.items() for k, v in kwargs.items()
if k not in {"image_grid_thw", "video_grid_thw"} if k not in {"image_grid_thw", "mm_token_type_ids"}
): ):
raise NotImplementedError( raise NotImplementedError(
"Transformers modeling backend only supports images." "Transformers modeling backend only supports images."
...@@ -459,6 +461,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ...@@ -459,6 +461,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
image_grid_thw = kwargs.get("image_grid_thw", []) image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", []) video_grid_thw = kwargs.get("video_grid_thw", [])
mm_token_type_ids = kwargs.get("mm_token_type_ids")
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
image_grid_thw image_grid_thw
...@@ -467,10 +470,17 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ...@@ -467,10 +470,17 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
video_grid_thw video_grid_thw
) )
# In v4 `get_rope_index` doesn't have wildcard `kwargs`, and
# can't accept arbitrary args, even if its value is `None`
kwargs = {}
if mm_token_type_ids:
kwargs["mm_token_type_ids"] = torch.cat(mm_token_type_ids)
mrope_positions, mrope_position_delta = self.model.get_rope_index( mrope_positions, mrope_position_delta = self.model.get_rope_index(
input_ids=torch.tensor(input_tokens).unsqueeze(0), input_ids=torch.tensor(input_tokens).unsqueeze(0),
image_grid_thw=image_grid_thw, image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw, video_grid_thw=video_grid_thw,
**kwargs,
) )
mrope_positions = mrope_positions[:, 0] mrope_positions = mrope_positions[:, 0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment