Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3ff57ebf
Unverified
Commit
3ff57ebf
authored
Oct 23, 2024
by
Isotr0py
Committed by
GitHub
Oct 23, 2024
Browse files
[Model] Initialize Florence-2 language backbone support (#9555)
parent
2394962d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
428 additions
and
8 deletions
+428
-8
examples/florence2_inference.py
examples/florence2_inference.py
+44
-0
tests/conftest.py
tests/conftest.py
+20
-8
tests/models/encoder_decoder/vision_language/test_florence2.py
.../models/encoder_decoder/vision_language/test_florence2.py
+102
-0
vllm/model_executor/models/florence2.py
vllm/model_executor/models/florence2.py
+261
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
examples/florence2_inference.py
0 → 100644
View file @
3ff57ebf
'''
Demonstrate prompting of text-to-text
encoder/decoder models, specifically Florence-2
'''
# TODO(Isotr0py):
# Move to offline_inference_vision_language.py after porting vision backbone
from
vllm
import
LLM
,
SamplingParams
dtype
=
"float"
# Create a Florence-2 encoder/decoder model instance
llm
=
LLM
(
model
=
"microsoft/Florence-2-base"
,
tokenizer
=
"facebook/bart-base"
,
dtype
=
dtype
,
trust_remote_code
=
True
,
)
prompts
=
[
"<CAPTION>"
,
"<DETAILED_CAPTION>"
,
"<MORE_DETAILED_CAPTION>"
,
"<CAPTION_TO_PHRASE_GROUNDING>"
,
"<OD>"
,
"<DENSE_REGION_CAPTION>"
,
"<REGION_PROPOSAL>"
,
"<OCR>"
,
"<OCR_WITH_REGION>"
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0
,
top_p
=
1.0
,
min_tokens
=
0
,
max_tokens
=
20
,
)
# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated
# text, and other information.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
encoder_prompt
=
output
.
encoder_prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Encoder prompt:
{
encoder_prompt
!
r
}
, "
f
"Decoder prompt:
{
prompt
!
r
}
, "
f
"Generated text:
{
generated_text
!
r
}
"
)
tests/conftest.py
View file @
3ff57ebf
...
@@ -253,7 +253,9 @@ class HfRunner:
...
@@ -253,7 +253,9 @@ class HfRunner:
dtype
:
str
=
"half"
,
dtype
:
str
=
"half"
,
*
,
*
,
model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_embedding_model
:
bool
=
False
,
is_sentence_transformer
:
bool
=
False
,
is_sentence_transformer
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
auto_cls
:
Type
[
_BaseAutoModelClass
]
=
AutoModelForCausalLM
,
auto_cls
:
Type
[
_BaseAutoModelClass
]
=
AutoModelForCausalLM
,
postprocess_inputs
:
Callable
[[
BatchEncoding
],
postprocess_inputs
:
Callable
[[
BatchEncoding
],
BatchEncoding
]
=
identity
,
BatchEncoding
]
=
identity
,
...
@@ -281,11 +283,12 @@ class HfRunner:
...
@@ -281,11 +283,12 @@ class HfRunner:
**
model_kwargs
,
**
model_kwargs
,
))
))
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
if
not
skip_tokenizer_init
:
model_name
,
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
torch_dtype
=
torch_dtype
,
model_name
,
trust_remote_code
=
True
,
torch_dtype
=
torch_dtype
,
)
trust_remote_code
=
True
,
)
# don't put this import at the top level
# don't put this import at the top level
# it will call torch.cuda.device_count()
# it will call torch.cuda.device_count()
...
@@ -295,6 +298,8 @@ class HfRunner:
...
@@ -295,6 +298,8 @@ class HfRunner:
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
if
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
processor
.
tokenizer
self
.
postprocess_inputs
=
postprocess_inputs
self
.
postprocess_inputs
=
postprocess_inputs
...
@@ -535,6 +540,7 @@ class HfRunner:
...
@@ -535,6 +540,7 @@ class HfRunner:
encoder_decoder_prompts
:
List
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
encoder_decoder_prompts
:
List
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
images
:
Optional
[
PromptImageInput
]
=
None
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
List
[
TokensTextLogprobs
]:
)
->
List
[
TokensTextLogprobs
]:
'''
'''
...
@@ -545,11 +551,17 @@ class HfRunner:
...
@@ -545,11 +551,17 @@ class HfRunner:
all_output_ids
:
List
[
List
[
int
]]
=
[]
all_output_ids
:
List
[
List
[
int
]]
=
[]
all_output_strs
:
List
[
str
]
=
[]
all_output_strs
:
List
[
str
]
=
[]
for
(
encoder_prompt
,
for
i
,
(
encoder_prompt
,
decoder_prompt
)
in
enumerate
(
decoder_prompt
)
in
to_enc_dec_tuple_list
(
encoder_decoder_prompts
):
to_enc_dec_tuple_list
(
encoder_decoder_prompts
)):
processor_kwargs
:
Dict
[
str
,
Any
]
=
{
"text"
:
encoder_prompt
,
"return_tensors"
:
"pt"
,
}
if
images
is
not
None
and
images
[
i
]
is
not
None
:
processor_kwargs
[
"images"
]
=
images
[
i
]
encoder_input_ids
=
self
.
wrap_device
(
encoder_input_ids
=
self
.
wrap_device
(
self
.
tokenizer
(
encoder_prompt
,
return_tensors
=
"pt"
).
input_ids
,
self
.
processor
(
**
processor_kwargs
).
input_ids
,
device
=
self
.
model
.
device
.
type
,
device
=
self
.
model
.
device
.
type
,
)
)
...
...
tests/models/encoder_decoder/vision_language/test_florence2.py
0 → 100644
View file @
3ff57ebf
from
functools
import
partial
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
pytest
from
PIL
import
Image
from
vllm.inputs.data
import
ExplicitEncoderDecoderPrompt
from
vllm.sequence
import
SampleLogprobs
from
....conftest
import
HfRunner
,
VllmRunner
from
...utils
import
check_logprobs_close
Florence2Prompt
=
partial
(
ExplicitEncoderDecoderPrompt
,
decoder_prompt
=
None
,
mm_processor_kwargs
=
None
)
MODELS
=
[
"microsoft/Florence-2-base"
]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
TOKENIZER
=
"facebook/bart-base"
PROMPTS
=
[
Florence2Prompt
(
encoder_prompt
=
"<CAPTION>"
),
Florence2Prompt
(
encoder_prompt
=
"<DETAILED_CAPTION>"
),
Florence2Prompt
(
encoder_prompt
=
"<MORE_DETAILED_CAPTION>"
),
Florence2Prompt
(
encoder_prompt
=
"<CAPTION_TO_PHRASE_GROUNDING>"
),
Florence2Prompt
(
encoder_prompt
=
"<DENSE_REGION_CAPTION>"
),
Florence2Prompt
(
encoder_prompt
=
"<REGION_PROPOSAL>"
),
Florence2Prompt
(
encoder_prompt
=
"<OCR_WITH_REGION>"
),
Florence2Prompt
(
encoder_prompt
=
"<OCR>"
),
Florence2Prompt
(
encoder_prompt
=
"<OD>"
),
]
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids
,
output_str
,
out_logprobs
=
vllm_output
hf_output_str
=
"</s><s>"
+
output_str
+
"</s>"
return
output_ids
,
hf_output_str
,
out_logprobs
def
run_test
(
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
prompts
:
List
[
ExplicitEncoderDecoderPrompt
],
model
:
str
,
*
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
)
->
None
:
with
vllm_runner
(
model
,
tokenizer_name
=
TOKENIZER
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_encoder_decoder_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
)
# Florence-2 processors require image inputs
dummy_image
=
Image
.
new
(
mode
=
"RGB"
,
size
=
(
2
,
2
))
with
hf_runner
(
model
,
dtype
=
dtype
,
skip_tokenizer_init
=
True
)
as
hf_model
:
hf_model
.
model
.
get_output_embeddings
=
lambda
:
\
hf_model
.
model
.
language_model
.
lm_head
hf_outputs
=
(
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
prompts
,
max_tokens
,
num_logprobs
,
images
=
[
dummy_image
]
*
len
(
prompts
),
))
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
[
vllm_to_hf_output
(
vllm_output
)
for
vllm_output
in
vllm_outputs
],
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
model
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
run_test
(
hf_runner
,
vllm_runner
,
PROMPTS
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
vllm/model_executor/models/florence2.py
0 → 100644
View file @
3ff57ebf
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.bart
import
(
BartDecoder
,
BartEncoder
,
BartParallelLMHead
,
BartScaledWordEmbedding
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
AutoWeightsLoader
class
Florence2LanguageModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
shared
=
BartScaledWordEmbedding
(
self
.
vocab_size
,
config
.
d_model
)
self
.
encoder
=
BartEncoder
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
decoder
=
BartDecoder
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
encoder
.
embed_tokens
.
weight
=
self
.
shared
.
weight
self
.
decoder
.
embed_tokens
.
weight
=
self
.
shared
.
weight
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Model output torch.Tensor
"""
encoder_hidden_states
=
None
if
encoder_input_ids
.
numel
()
>
0
:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states
=
self
.
encoder
(
input_ids
=
encoder_input_ids
,
positions
=
encoder_positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
=
input_ids
,
decoder_positions
=
positions
,
encoder_hidden_states
=
encoder_hidden_states
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
)
return
decoder_outputs
class
Florence2LanguageForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
model
=
Florence2LanguageModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
embed_scale
=
math
.
sqrt
(
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
self
.
vocab_size
=
config
.
vocab_size
self
.
lm_head
=
BartParallelLMHead
(
self
.
vocab_size
,
config
.
d_model
,
embed_scale
=
embed_scale
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Output torch.Tensor
"""
return
self
.
model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
,
kv_caches
,
attn_metadata
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
"final_logits_bias"
in
name
:
continue
if
self
.
config
.
tie_word_embeddings
and
"embed_tokens"
in
name
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
class
Florence2ForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
# TODO(Isotr0py): Add vision backbone
self
.
language_model
=
Florence2LanguageForConditionalGeneration
(
config
=
config
.
text_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
@
property
def
sampler
(
self
):
return
self
.
language_model
.
sampler
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
*
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Output torch.Tensor
"""
return
self
.
language_model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
,
kv_caches
,
attn_metadata
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
skip_prefixes
=
[
'image_projection'
,
"vision_tower"
,
"image_proj_norm"
,
"image_pos_embed"
,
"visual_temporal_embed"
]
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
skip_prefixes
)
loader
.
load_weights
(
weights
)
vllm/model_executor/models/registry.py
View file @
3ff57ebf
...
@@ -85,6 +85,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -85,6 +85,7 @@ _TEXT_GENERATION_MODELS = {
# [Encoder-decoder]
# [Encoder-decoder]
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"Florence2ForConditionalGeneration"
:
(
"florence2"
,
"Florence2ForConditionalGeneration"
),
# noqa: E501
}
}
_EMBEDDING_MODELS
=
{
_EMBEDDING_MODELS
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment