Unverified Commit 53475674 authored by Mick's avatar Mick Committed by GitHub
Browse files

chore: improvements on mm_utils (#7737)

parent ce32bc2b
......@@ -85,8 +85,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
"No data_token_pairs provided, RadixAttention might be influenced."
)
return input_ids
start_token_ids = [s for s, _e in data_token_pairs]
end_tokens_ids = [e for _s, e in data_token_pairs]
start_token_ids = {s for s, _e in data_token_pairs}
end_tokens_ids = {e for _s, e in data_token_pairs}
padded_ids = []
last_idx = 0
......@@ -135,7 +135,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
if not input_ids or not mm_inputs.mm_items:
return input_ids
input_ids_tensor = torch.tensor(input_ids)
input_ids_tensor = torch.as_tensor(input_ids)
# Create mapping of token_ids to pad_values for each modality
token_to_pad_mapping = {}
......@@ -211,7 +211,7 @@ def get_embedding_chunk(
end_index += extend_end_index - start + 1
elif extend_end_index > end:
end_index += end - start + 1
# some models embedding is 3-dim, reshape it to 2-dim
# some models' embedding is 3-dim, reshape it to 2-dim
embedding = embedding.reshape(-1, embedding.shape[-1])
embedding_chunk = embedding[start_index:end_index]
return embedding_chunk, start_index, end_index
......@@ -428,7 +428,7 @@ def embed_mm_inputs(
modality_id = modality.name.lower()
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
if len(items) != 0 and embedder is not None:
placeholder_tensor = torch.tensor(
placeholder_tensor = torch.as_tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
......@@ -473,11 +473,9 @@ def embed_mm_inputs(
for embedding, mask in zip(embeddings, masks):
if embedding is None or mask is None:
continue
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(
mask,
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
)
# in-place update
indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
return inputs_embeds
......@@ -561,34 +559,36 @@ def get_multimodal_data_bounds(
[bounds_count, 2]
"""
# All the multimodal data in the batch should share the same special bound token ids.
start_tokens = [s for s, _e in token_pairs]
end_tokens = [e for _s, e in token_pairs]
start_tokens = {s for s, _e in token_pairs}
end_tokens = {e for _s, e in token_pairs}
assert all(isinstance(t, int) for t in start_tokens)
assert all(isinstance(t, int) for t in end_tokens)
start_cond = torch.isin(
input_ids, torch.tensor(start_tokens, device=input_ids.device)
input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
)
end_cond = torch.isin(
input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
)
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
(data_start_tokens,) = torch.where(start_cond)
(data_end_tokens,) = torch.where(end_cond)
data_start_tokens_cpu = data_start_tokens.cpu().tolist()
data_end_tokens_cpu = data_end_tokens.cpu().tolist()
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
if len(data_start_tokens) != len(data_end_tokens):
if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
if (
len(data_start_tokens) + 1 == len(data_end_tokens)
and input_ids[0] in pad_values
and data_end_tokens[0] < data_start_tokens[0]
len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
and input_ids[0].item() in pad_values
and data_end_tokens_cpu
and data_start_tokens_cpu
and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
):
data_start_tokens = torch.cat(
[
torch.tensor([0], device=data_start_tokens.device),
data_start_tokens,
]
)
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
data_start_tokens_cpu.insert(0, 0)
valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))
if valid_mm_data_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)
......@@ -596,8 +596,8 @@ def get_multimodal_data_bounds(
# Filter out pairs where start_token >= end_token
valid_pairs = []
for i in range(valid_mm_data_nums):
start_token = data_start_tokens[i]
end_token = data_end_tokens[i]
start_token = data_start_tokens_cpu[i]
end_token = data_end_tokens_cpu[i]
if start_token < end_token:
valid_pairs.append((start_token + 1, end_token - 1))
......@@ -605,7 +605,7 @@ def get_multimodal_data_bounds(
return torch.zeros((0, 2), device=input_ids.device)
# Convert valid pairs to tensor
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor
......@@ -634,11 +634,7 @@ def tensor_hash(tensor_list) -> int:
tensor = tensor.float()
assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda:
# TODO: improve this
tensor_cpu = tensor.cpu()
else:
tensor_cpu = tensor
tensor_cpu = tensor.cpu()
mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment