Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f0fe4fe8
Unverified
Commit
f0fe4fe8
authored
Oct 14, 2024
by
Xiang Xu
Committed by
GitHub
Oct 14, 2024
Browse files
[Model] Make llama3.2 support multiple and interleaved images (#9095)
parent
4d31cd42
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
384 additions
and
42 deletions
+384
-42
examples/offline_inference_vision_language_multi_image.py
examples/offline_inference_vision_language_multi_image.py
+23
-0
tests/models/encoder_decoder/vision_language/test_mllama.py
tests/models/encoder_decoder/vision_language/test_mllama.py
+82
-3
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+279
-39
No files found.
examples/offline_inference_vision_language_multi_image.py
View file @
f0fe4fe8
...
...
@@ -234,12 +234,35 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
)
def
load_mllama
(
question
,
image_urls
:
List
[
str
])
->
ModelRequestData
:
model_name
=
"meta-llama/Llama-3.2-11B-Vision-Instruct"
# The configuration below has been confirmed to launch on a single L40 GPU.
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
4096
,
max_num_seqs
=
16
,
enforce_eager
=
True
,
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
)
prompt
=
f
"<|image|><|image|><|begin_of_text|>
{
question
}
"
return
ModelRequestData
(
llm
=
llm
,
prompt
=
prompt
,
stop_token_ids
=
None
,
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
],
chat_template
=
None
,
)
model_example_map
=
{
"phi3_v"
:
load_phi3v
,
"internvl_chat"
:
load_internvl
,
"NVLM_D"
:
load_nvlm_d
,
"qwen2_vl"
:
load_qwen2_vl
,
"qwen_vl_chat"
:
load_qwenvl_chat
,
"mllama"
:
load_mllama
,
}
...
...
tests/models/encoder_decoder/vision_language/test_mllama.py
View file @
f0fe4fe8
...
...
@@ -12,7 +12,7 @@ from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
from
....utils
import
large_gpu_test
from
...utils
import
check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT
=
1
_LIMIT_IMAGE_PER_PROMPT
=
3
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
...
...
@@ -244,8 +244,9 @@ def _run_test(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
sizes
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
def
test_models_single_leading_image
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
sizes
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
run_test
(
hf_runner
,
vllm_runner
,
...
...
@@ -257,3 +258,81 @@ def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models_multi_leading_images
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
stop_sign
=
image_assets
[
0
].
pil_image
cherry_blossom
=
image_assets
[
1
].
pil_image
inputs
=
[(
[
"<|image|><|image|><|begin_of_text|>Describe 2 images."
,
# noqa: E501
"<|image|><|image|><|begin_of_text|>Describe 2 images."
,
# noqa: E501
"<|image|><|image|><|image|><|begin_of_text|>Describe 3 images."
,
# noqa: E501
],
[
[
stop_sign
,
cherry_blossom
],
# Images with different sizes.
[
stop_sign
.
resize
((
512
,
512
)),
stop_sign
,
],
[
stop_sign
,
stop_sign
.
resize
((
512
,
1536
)),
cherry_blossom
.
resize
((
512
,
1024
)),
],
])]
_run_test
(
hf_runner
,
vllm_runner
,
inputs
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models_interleaved_images
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
stop_sign
=
image_assets
[
0
].
pil_image
cherry_blossom
=
image_assets
[
1
].
pil_image
inputs
=
[(
[
"<|begin_of_text|>The content of the image <|image|> is"
,
# noqa: E501
"<|begin_of_text|>Between the first image <|image|> and the second image<|image|>, "
# noqa: E501
"which is a stop sign and which is a cherry blossom?"
,
# noqa: E501
],
[
[
stop_sign
],
[
stop_sign
,
cherry_blossom
],
])]
_run_test
(
hf_runner
,
vllm_runner
,
inputs
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
vllm/model_executor/models/mllama.py
View file @
f0fe4fe8
...
...
@@ -18,6 +18,7 @@ from array import array
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
...
...
@@ -28,9 +29,12 @@ from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast
)
from
transformers.models.mllama.image_processing_mllama
import
(
get_optimal_tiled_canvas
)
from
transformers.models.mllama.processing_mllama
import
(
get_cross_attention_token_mask
)
import
vllm.distributed.parallel_state
as
ps
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
...
...
@@ -72,6 +76,16 @@ class MllamaImagePixelInputs(TypedDict):
# TODO: support LlamaImageEmbeddingInputs
def
_get_num_image_in_last_group
(
prompt_token_ids
:
List
[
int
])
->
int
:
num_images
=
0
for
token_id
in
prompt_token_ids
[::
-
1
]:
if
token_id
==
MLLAMA_IMAGE_TOKEN_ID
:
num_images
+=
1
elif
num_images
>
0
:
break
return
num_images
def
input_processor_for_mllama
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
# move encoder_prompt to prompt
if
llm_inputs
.
get
(
"prompt"
)
is
None
:
...
...
@@ -91,12 +105,16 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
llm_inputs
[
"encoder_multi_modal_data"
]
=
{}
return
llm_inputs
# get num_tiles
if
isinstance
(
multi_modal_data
[
'image'
],
Image
.
Image
):
multi_modal_data
[
'image'
]
=
[
multi_modal_data
[
'image'
]]
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tiles for those images.
num_decode_images
=
_get_num_image_in_last_group
(
llm_inputs
[
"prompt_token_ids"
])
hf_config
=
ctx
.
model_config
.
hf_config
num_tiles
=
0
for
image
in
multi_modal_data
[
"image"
]:
for
image
in
multi_modal_data
[
"image"
]
[::
-
1
]
:
width
,
height
=
image
.
size
tile_size
=
hf_config
.
vision_config
.
image_size
canvas_height
,
canvas_width
=
get_optimal_tiled_canvas
(
...
...
@@ -108,8 +126,13 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
num_tiles_height
=
canvas_height
//
tile_size
num_tiles_width
=
canvas_width
//
tile_size
num_tiles
+=
num_tiles_height
*
num_tiles_width
num_decode_images
-=
1
if
num_decode_images
==
0
:
break
# set encoder prompt based on num_tiles
# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
assert
hf_config
.
vision_config
.
image_size
%
14
==
0
,
\
"chunk size should be multiple of 14"
token_per_chunk
=
(
hf_config
.
vision_config
.
image_size
//
14
)
**
2
+
1
...
...
@@ -675,6 +698,7 @@ class MllamaTextCrossAttention(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
],
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
cross_attention_states
:
Optional
[
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
...
...
@@ -697,15 +721,71 @@ class MllamaTextCrossAttention(nn.Module):
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
head_dim
)
q
=
self
.
q_norm
(
q
)
output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
if
attention_mask
is
not
None
:
output
=
self
.
attention_with_mask
(
q
,
k
,
v
,
kv_cache
,
attention_mask
,
kv_range_for_decode
,
attn_metadata
)
else
:
output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
out
,
_
=
self
.
o_proj
(
output
)
return
out
def
attention_with_mask
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
kv_range_for_decode
:
List
[
Tuple
[
int
,
int
]],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
# Skip writing kv-cache for the initial profiling run.
if
len
(
kv_cache
.
shape
)
==
3
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
cached_k
=
torch
.
cat
([
k
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
cached_v
=
torch
.
cat
([
v
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
PagedAttention
.
write_to_paged_cache
(
cached_k
,
cached_v
,
key_cache
,
value_cache
,
attn_metadata
.
cross_slot_mapping
,
"auto"
,
1.0
,
1.0
)
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
# can be optimized by xformers.BlockDiagonalMask.
# The mask is specially calculated for supporting multi
# images and interleaved images.
q_len
=
q
.
shape
[
0
]
kv_len
=
k
.
shape
[
0
]
q
=
q
.
transpose
(
0
,
1
).
view
(
self
.
num_local_key_value_heads
,
self
.
num_key_value_groups
,
q_len
,
self
.
head_dim
)
k
=
k
.
transpose
(
0
,
1
)[:,
None
,
:,
:].
expand
(
self
.
num_local_key_value_heads
,
self
.
num_key_value_groups
,
kv_len
,
self
.
head_dim
)
v
=
v
.
transpose
(
0
,
1
)[:,
None
,
:,
:].
expand
(
self
.
num_local_key_value_heads
,
self
.
num_key_value_groups
,
kv_len
,
self
.
head_dim
)
attention_mask
=
attention_mask
.
view
(
1
,
1
,
q_len
,
kv_len
)
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attention_mask
,
is_causal
=
False
)
output
=
output
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
q_len
,
self
.
num_local_heads
*
self
.
head_dim
)
return
output
class
MllamaCrossAttentionDecoderLayer
(
torch
.
nn
.
Module
):
"""Cross-attention transformer block with tanh-gated attention
...
...
@@ -741,6 +821,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_states
:
torch
.
Tensor
,
cross_attention_states
:
torch
.
Tensor
,
cross_attention_mask
:
torch
.
Tensor
,
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
torch
.
Tensor
,
kv_cache
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
...
...
@@ -751,6 +832,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_states
=
self
.
cross_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
cross_attention_states
=
cross_attention_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
...
...
@@ -804,6 +886,7 @@ class MllamaTextModel(nn.Module):
positions
:
Optional
[
torch
.
LongTensor
],
cross_attention_states
:
Optional
[
torch
.
LongTensor
],
cross_attention_mask
:
Optional
[
torch
.
LongTensor
],
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
...
...
@@ -820,6 +903,7 @@ class MllamaTextModel(nn.Module):
hidden_states
=
hidden_states
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_cache
=
kv_caches
[
idx
],
...
...
@@ -868,6 +952,7 @@ class MllamaForCausalLM(nn.Module):
positions
:
Optional
[
torch
.
LongTensor
],
cross_attention_states
:
Optional
[
torch
.
LongTensor
],
cross_attention_mask
:
Optional
[
torch
.
LongTensor
],
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
...
...
@@ -879,6 +964,7 @@ class MllamaForCausalLM(nn.Module):
positions
=
positions
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
...
...
@@ -1026,36 +1112,102 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
raise
AssertionError
(
"This line should be unreachable."
)
def
flat_encoder_result
(
self
,
cross_attention_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
):
attn_metadata
:
AttentionMetadata
,
actual_encoder_seq_lens
:
List
[
int
]):
cross_attention_states_flat
=
torch
.
zeros
(
sum
(
a
ttn_metadata
.
encoder_seq_lens
),
sum
(
a
ctual_
encoder_seq_lens
),
cross_attention_states
.
shape
[
-
1
],
device
=
cross_attention_states
.
device
,
dtype
=
cross_attention_states
.
dtype
)
start_pos
=
0
for
seq_len
,
vision_token_in_batch
in
zip
(
attn_metadata
.
encoder_seq_lens
,
cross_attention_states
):
for
seq_len
,
vision_token_in_batch
in
zip
(
actual_encoder_seq_lens
,
cross_attention_states
):
end_pos
=
start_pos
+
seq_len
cross_attention_states_flat
[
start_pos
:
end_pos
]
=
vision_token_in_batch
[:
seq_len
]
start_pos
=
end_pos
cross_attention_states
=
cross_attention_states_flat
return
cross_attention_states
def
get_cross_attention_states
(
self
,
image_inputs
:
MllamaImagePixelInputs
,
attn_metadata
:
AttentionMetadata
,
actual_encoder_seq_lens
:
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
]:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values
=
image_inputs
[
'data'
]
aspect_ratio_ids
=
image_inputs
[
'aspect_ratio_ids'
]
aspect_ratio_mask
=
image_inputs
[
'aspect_ratio_mask'
]
cross_attention_states
=
self
.
vision_model
(
pixel_values
,
aspect_ratio_ids
,
aspect_ratio_mask
)
cross_attention_states
=
self
.
multi_modal_projector
(
cross_attention_states
)
bsz
,
_
,
_
,
_
,
image_token_dim
=
tuple
(
cross_attention_states
.
shape
)
cross_attention_states
=
cross_attention_states
.
view
(
bsz
,
-
1
,
image_token_dim
)
cross_attention_states
=
self
.
flat_encoder_result
(
cross_attention_states
,
attn_metadata
,
actual_encoder_seq_lens
)
return
cross_attention_states
def
get_cross_attention_mask
(
self
,
input_ids
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
num_tiles
:
List
[
List
[
int
]],
num_tokens_per_tile
:
int
,
dtype
:
torch
.
dtype
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
token_ids
=
input_ids
.
tolist
()
start
=
0
batch_token_ids
=
[]
for
seq_len
in
attn_metadata
.
seq_lens
:
batch_token_ids
.
append
(
token_ids
[
start
:
start
+
seq_len
])
start
+=
seq_len
sparse_mask
=
[
get_cross_attention_token_mask
(
t
,
MLLAMA_IMAGE_TOKEN_ID
)
for
t
in
batch_token_ids
]
# Skip generating cross-attention mask if all samples
# are text-only or have only 1 leading image.
if
skip_attention_mask
(
sparse_mask
):
return
None
,
None
dense_mask
,
tile_range_for_decode
=
\
convert_sparse_cross_attention_mask_to_dense
(
sparse_mask
,
num_tiles
,
attn_metadata
.
seq_lens
)
cross_attention_mask
=
\
convert_dense_cross_attention_mask_to_tensor
(
dense_mask
,
num_tokens_per_tile
,
input_ids
.
device
,
dtype
)
kv_range_for_decode
=
[[
t
[
0
]
*
num_tokens_per_tile
,
t
[
1
]
*
num_tokens_per_tile
]
for
t
in
tile_range_for_decode
]
return
cross_attention_mask
,
kv_range_for_decode
def
get_full_text_row_masked_out_mask
(
self
,
attn_metadata
:
AttentionMetadata
,
device
:
torch
.
device
,
)
->
torch
.
Tensor
:
full_text_row_masked_out_mask
=
torch
.
ones
(
(
attn_metadata
.
num_prefill_tokens
,
1
),
dtype
=
torch
.
bool
)
start_pos
=
0
for
seq_len
,
encoder_seq_len
in
zip
(
attn_metadata
.
seq_lens_tensor
.
cpu
(),
attn_metadata
.
encoder_seq_lens
):
for
seq_len
,
encoder_seq_len
in
zip
(
attn_metadata
.
seq_lens
,
attn_metadata
.
encoder_seq_lens
):
if
encoder_seq_len
==
0
:
full_text_row_masked_out_mask
[
start_pos
:
start_pos
+
seq_len
]
=
False
start_pos
+=
seq_len
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
.
to
(
cross_attention_states
.
device
)
return
cross_attention_states
,
full_text_row_masked_out_mask
device
)
return
full_text_row_masked_out_mask
def
forward
(
self
,
...
...
@@ -1069,39 +1221,54 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata
.
num_decode_tokens
>
0
:
raise
ValueError
(
"Chunk prefill not supported"
)
image_inputs
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
cross_attention_states
=
None
cross_attention_mask
=
None
kv_range_for_decode
=
None
# For 1) text-only prefill and decode, 2) image-present decode.
if
image_inputs
is
None
:
cross_attention_mask
=
None
full_text_row_masked_out_mask
=
(
attn_metadata
.
encoder_seq_lens_tensor
!=
0
).
reshape
(
-
1
,
1
).
to
(
input_ids
.
device
)
cross_attention_states
=
None
skip_cross_attention
=
max
(
attn_metadata
.
encoder_seq_lens
)
==
0
# For image-present prefill.
else
:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values
=
image_inputs
[
'data'
]
aspect_ratio_ids
=
image_inputs
[
'aspect_ratio_ids'
]
aspect_ratio_mask
=
image_inputs
[
'aspect_ratio_mask'
]
cross_attention_states
=
self
.
vision_model
(
pixel_values
,
aspect_ratio_ids
,
aspect_ratio_mask
)
cross_attention_states
=
self
.
multi_modal_projector
(
cross_attention_states
)
bsz
,
_
,
_
,
_
,
image_token_dim
=
tuple
(
cross_attention_states
.
shape
)
cross_attention_states
=
cross_attention_states
.
view
(
bsz
,
-
1
,
image_token_dim
)
cross_attention_states
,
full_text_row_masked_out_mask
=
\
self
.
flat_encoder_result
(
cross_attention_states
,
attn_metadata
)
skip_cross_attention
=
False
# TODO: support multi-image by this mask
cross_attention_mask
=
None
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See input_processor_for_mllama() for more details.
num_tiles_tensor
=
kwargs
.
pop
(
"num_tiles"
)
num_tiles
=
[
t
[
0
].
tolist
()
for
t
in
num_tiles_tensor
]
num_tokens_per_tile
=
(
self
.
image_size
//
14
)
**
2
+
1
actual_encoder_seq_lens
=
[
sum
(
num_tile
)
*
num_tokens_per_tile
for
num_tile
in
num_tiles
]
for
actual_len
,
last_group_len
in
zip
(
actual_encoder_seq_lens
,
attn_metadata
.
encoder_seq_lens
):
assert
actual_len
>=
last_group_len
cross_attention_states
=
self
.
get_cross_attention_states
(
image_inputs
,
attn_metadata
,
actual_encoder_seq_lens
)
full_text_row_masked_out_mask
=
\
self
.
get_full_text_row_masked_out_mask
(
attn_metadata
,
input_ids
.
device
)
cross_attention_mask
,
kv_range_for_decode
=
\
self
.
get_cross_attention_mask
(
input_ids
,
attn_metadata
,
num_tiles
,
num_tokens_per_tile
,
cross_attention_states
.
dtype
)
outputs
=
self
.
language_model
(
input_ids
=
input_ids
,
positions
=
positions
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
...
...
@@ -1140,3 +1307,76 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
skip_attention_mask
(
sparse_mask
:
List
[
List
[
int
]])
->
bool
:
for
mask
in
sparse_mask
:
# Skip text-only samples.
if
len
(
mask
)
==
0
:
continue
# If the sample contains more than 1 images,
# we can't skip mask.
if
len
(
mask
)
!=
1
:
return
False
# If the sample contains only 1 image,
# but the image is not the leading one,
# we can't skip mask.
if
mask
[
0
][
0
]
!=
0
or
mask
[
0
][
1
]
!=
-
1
:
return
False
return
True
def
convert_sparse_cross_attention_mask_to_dense
(
sparse_mask
:
List
[
List
[
List
[
int
]]],
num_tiles
:
List
[
List
[
int
]],
lengths
:
List
[
int
],
)
->
Tuple
[
np
.
ndarray
,
List
[
Tuple
[
int
,
int
]]]:
total_length
=
sum
(
lengths
)
total_tiles
=
sum
([
sum
(
tiles
)
for
tiles
in
num_tiles
])
dense_mask
=
np
.
zeros
(
shape
=
(
total_length
,
total_tiles
),
dtype
=
np
.
int64
)
# A list of ranges, range[i] = [start, end] means
# if the i-th sample has N tiles in total, the tiles[start, end]
# will be used for cross-attention decoding.
tile_range_for_decode
=
[]
seq_start
=
0
tile_start
=
0
for
masks
,
tiles
,
length
in
zip
(
sparse_mask
,
num_tiles
,
lengths
):
ts
,
td
=
-
1
,
0
for
mask
,
tile
in
zip
(
masks
,
tiles
):
if
len
(
mask
)
!=
2
:
continue
start
,
end
=
mask
end
=
min
(
end
,
length
)
if
end
==
-
1
:
end
=
length
if
end
==
length
:
if
ts
==
-
1
:
ts
=
tile_start
td
+=
tile
dense_mask
[
seq_start
+
start
:
seq_start
+
end
,
tile_start
:
tile_start
+
tile
]
=
1
tile_start
+=
tile
tile_range_for_decode
.
append
((
ts
,
ts
+
td
))
seq_start
+=
length
return
dense_mask
,
tile_range_for_decode
def
convert_dense_cross_attention_mask_to_tensor
(
cross_attention_token_mask
:
np
.
ndarray
,
num_tokens_per_tile
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
mask
=
torch
.
tensor
(
cross_attention_token_mask
,
dtype
=
dtype
,
device
=
device
)
mask
=
mask
.
repeat_interleave
(
num_tokens_per_tile
,
dim
=
1
)
mask
=
1.0
-
mask
mask
=
mask
.
masked_fill
(
mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
ninf
=
torch
.
finfo
(
dtype
).
min
full_text_mask
=
((
mask
!=
ninf
).
any
(
dim
=-
1
).
type_as
(
mask
)[...,
None
])
mask
*=
full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
return
mask
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment