Unverified Commit 53475674 authored by Mick's avatar Mick Committed by GitHub
Browse files

chore: improvements on mm_utils (#7737)

parent ce32bc2b
...@@ -85,8 +85,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) ...@@ -85,8 +85,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
"No data_token_pairs provided, RadixAttention might be influenced." "No data_token_pairs provided, RadixAttention might be influenced."
) )
return input_ids return input_ids
start_token_ids = [s for s, _e in data_token_pairs] start_token_ids = {s for s, _e in data_token_pairs}
end_tokens_ids = [e for _s, e in data_token_pairs] end_tokens_ids = {e for _s, e in data_token_pairs}
padded_ids = [] padded_ids = []
last_idx = 0 last_idx = 0
...@@ -135,7 +135,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa ...@@ -135,7 +135,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
if not input_ids or not mm_inputs.mm_items: if not input_ids or not mm_inputs.mm_items:
return input_ids return input_ids
input_ids_tensor = torch.tensor(input_ids) input_ids_tensor = torch.as_tensor(input_ids)
# Create mapping of token_ids to pad_values for each modality # Create mapping of token_ids to pad_values for each modality
token_to_pad_mapping = {} token_to_pad_mapping = {}
...@@ -211,7 +211,7 @@ def get_embedding_chunk( ...@@ -211,7 +211,7 @@ def get_embedding_chunk(
end_index += extend_end_index - start + 1 end_index += extend_end_index - start + 1
elif extend_end_index > end: elif extend_end_index > end:
end_index += end - start + 1 end_index += end - start + 1
# some models embedding is 3-dim, reshape it to 2-dim # some models' embedding is 3-dim, reshape it to 2-dim
embedding = embedding.reshape(-1, embedding.shape[-1]) embedding = embedding.reshape(-1, embedding.shape[-1])
embedding_chunk = embedding[start_index:end_index] embedding_chunk = embedding[start_index:end_index]
return embedding_chunk, start_index, end_index return embedding_chunk, start_index, end_index
...@@ -428,7 +428,7 @@ def embed_mm_inputs( ...@@ -428,7 +428,7 @@ def embed_mm_inputs(
modality_id = modality.name.lower() modality_id = modality.name.lower()
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None) embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
if len(items) != 0 and embedder is not None: if len(items) != 0 and embedder is not None:
placeholder_tensor = torch.tensor( placeholder_tensor = torch.as_tensor(
[item.pad_value for item in items], [item.pad_value for item in items],
device=input_ids.device, device=input_ids.device,
) )
...@@ -473,11 +473,9 @@ def embed_mm_inputs( ...@@ -473,11 +473,9 @@ def embed_mm_inputs(
for embedding, mask in zip(embeddings, masks): for embedding, mask in zip(embeddings, masks):
if embedding is None or mask is None: if embedding is None or mask is None:
continue continue
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device) # in-place update
inputs_embeds = inputs_embeds.masked_scatter( indices = torch.where(mask.squeeze(dim=-1))[0]
mask, inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
)
return inputs_embeds return inputs_embeds
...@@ -561,34 +559,36 @@ def get_multimodal_data_bounds( ...@@ -561,34 +559,36 @@ def get_multimodal_data_bounds(
[bounds_count, 2] [bounds_count, 2]
""" """
# All the multimodal data in the batch should share the same special bound token ids. # All the multimodal data in the batch should share the same special bound token ids.
start_tokens = [s for s, _e in token_pairs] start_tokens = {s for s, _e in token_pairs}
end_tokens = [e for _s, e in token_pairs] end_tokens = {e for _s, e in token_pairs}
assert all(isinstance(t, int) for t in start_tokens) assert all(isinstance(t, int) for t in start_tokens)
assert all(isinstance(t, int) for t in end_tokens) assert all(isinstance(t, int) for t in end_tokens)
start_cond = torch.isin( start_cond = torch.isin(
input_ids, torch.tensor(start_tokens, device=input_ids.device) input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
)
end_cond = torch.isin(
input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
) )
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
(data_start_tokens,) = torch.where(start_cond) (data_start_tokens,) = torch.where(start_cond)
(data_end_tokens,) = torch.where(end_cond) (data_end_tokens,) = torch.where(end_cond)
data_start_tokens_cpu = data_start_tokens.cpu().tolist()
data_end_tokens_cpu = data_end_tokens.cpu().tolist()
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
if len(data_start_tokens) != len(data_end_tokens): if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
if ( if (
len(data_start_tokens) + 1 == len(data_end_tokens) len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
and input_ids[0] in pad_values and input_ids[0].item() in pad_values
and data_end_tokens[0] < data_start_tokens[0] and data_end_tokens_cpu
and data_start_tokens_cpu
and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
): ):
data_start_tokens = torch.cat( data_start_tokens_cpu.insert(0, 0)
[ valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))
torch.tensor([0], device=data_start_tokens.device),
data_start_tokens,
]
)
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
if valid_mm_data_nums == 0: if valid_mm_data_nums == 0:
return torch.zeros((0, 2), device=input_ids.device) return torch.zeros((0, 2), device=input_ids.device)
...@@ -596,8 +596,8 @@ def get_multimodal_data_bounds( ...@@ -596,8 +596,8 @@ def get_multimodal_data_bounds(
# Filter out pairs where start_token >= end_token # Filter out pairs where start_token >= end_token
valid_pairs = [] valid_pairs = []
for i in range(valid_mm_data_nums): for i in range(valid_mm_data_nums):
start_token = data_start_tokens[i] start_token = data_start_tokens_cpu[i]
end_token = data_end_tokens[i] end_token = data_end_tokens_cpu[i]
if start_token < end_token: if start_token < end_token:
valid_pairs.append((start_token + 1, end_token - 1)) valid_pairs.append((start_token + 1, end_token - 1))
...@@ -605,7 +605,7 @@ def get_multimodal_data_bounds( ...@@ -605,7 +605,7 @@ def get_multimodal_data_bounds(
return torch.zeros((0, 2), device=input_ids.device) return torch.zeros((0, 2), device=input_ids.device)
# Convert valid pairs to tensor # Convert valid pairs to tensor
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device) valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor return valid_pairs_tensor
...@@ -634,11 +634,7 @@ def tensor_hash(tensor_list) -> int: ...@@ -634,11 +634,7 @@ def tensor_hash(tensor_list) -> int:
tensor = tensor.float() tensor = tensor.float()
assert isinstance(tensor, torch.Tensor) assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda: tensor_cpu = tensor.cpu()
# TODO: improve this
tensor_cpu = tensor.cpu()
else:
tensor_cpu = tensor
mv = memoryview(tensor_cpu.numpy()) mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes()) return data_hash(mv.tobytes())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment