"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "b74b887bad92eb3e31891dffd805b61d9fcdec63"
Unverified Commit e5949e5a authored by Chenxi Yang's avatar Chenxi Yang Committed by GitHub
Browse files

Remove index_put from MM embeddings merging (#22105)


Co-authored-by: default avatarChenxi Yang <cxyang@meta.com>
parent 49bcd893
...@@ -393,7 +393,7 @@ def merge_multimodal_embeddings_from_map( ...@@ -393,7 +393,7 @@ def merge_multimodal_embeddings_from_map(
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor: placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
""" """
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
placeholder map . placeholder map .
Note: Note:
...@@ -418,17 +418,23 @@ def _merge_multimodal_embeddings( ...@@ -418,17 +418,23 @@ def _merge_multimodal_embeddings(
Note: Note:
This updates ``inputs_embeds`` in place. This updates ``inputs_embeds`` in place.
""" """
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings) flattened = _flatten_embeddings(multimodal_embeddings)
if flattened.shape[0] != num_expected_tokens: try:
expr = _embedding_count_expression(multimodal_embeddings) # This is equivalent to: inputs_embeds[is_multimodal] = flattened.
raise ValueError( inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), flattened)
f"Attempted to assign {expr} = {flattened.shape[0]} " except RuntimeError as e:
f"multimodal tokens to {num_expected_tokens} placeholders") num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if flattened.shape[0] != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
else:
raise ValueError("Error during masked scatter operation") from e
inputs_embeds[is_multimodal] = flattened
return inputs_embeds return inputs_embeds
...@@ -478,11 +484,11 @@ def merge_multimodal_embeddings( ...@@ -478,11 +484,11 @@ def merge_multimodal_embeddings(
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``. ``input_ids``.
``placeholder_token_id`` can be a list of token ids (e.g, token ids ``placeholder_token_id`` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to their embeddings in ``multimodal_embeddings`` since we need to
slice-merge instead of individually scattering. slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
...@@ -491,9 +497,9 @@ def merge_multimodal_embeddings( ...@@ -491,9 +497,9 @@ def merge_multimodal_embeddings(
- I is image embedding token - I is image embedding token
- B is image break token - B is image break token
- E is image end token. - E is image end token.
Then the image embeddings (that correspond to I's) from vision encoder Then the image embeddings (that correspond to I's) from vision encoder
must be padded with embeddings of S, B, and E in the same order of must be padded with embeddings of S, B, and E in the same order of
input_ids for a correct embedding merge. input_ids for a correct embedding merge.
Note: Note:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment