Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ed46f143
Unverified
Commit
ed46f143
authored
Nov 25, 2024
by
Cyrus Leung
Committed by
GitHub
Nov 25, 2024
Browse files
[Model] Support `is_causal` HF config field for Qwen2 model (#10621)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
05d1f8c9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
51 additions
and
13 deletions
+51
-13
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+10
-3
tests/models/embedding/language/test_embedding.py
tests/models/embedding/language/test_embedding.py
+10
-2
tests/models/embedding/utils.py
tests/models/embedding/utils.py
+2
-2
vllm/config.py
vllm/config.py
+11
-4
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+18
-2
No files found.
docs/source/models/supported_models.rst
View file @
ed46f143
...
...
@@ -342,7 +342,7 @@ Text Embedding
- ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-
1.5
B-instruct`, etc.
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-
7
B-instruct`
(see note)
, etc.
- ✅︎
- ✅︎
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
...
...
@@ -363,6 +363,13 @@ Text Embedding
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.
.. note::
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
You can set `--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention
despite being described otherwise on its model card.
Reward Modeling
---------------
...
...
@@ -606,10 +613,10 @@ Text Generation
| :sup:`+` Multiple items can be inputted per text prompt for this modality.
.. note::
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
.. note::
For
:code:`openbmb/MiniCPM-V-2`
, the official repo
doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
The official
:code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
Multimodal Embedding
...
...
tests/models/embedding/language/test_embedding.py
View file @
ed46f143
...
...
@@ -21,6 +21,7 @@ from ..utils import check_embeddings_close
marks
=
[
pytest
.
mark
.
core_model
]),
pytest
.
param
(
"ssmits/Qwen2-7B-Instruct-embed-base"
),
pytest
.
param
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
),
pytest
.
param
(
"Alibaba-NLP/gte-Qwen2-7B-instruct"
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
...
...
@@ -31,6 +32,10 @@ def test_models(
model
,
dtype
:
str
,
)
->
None
:
vllm_extra_kwargs
=
{}
if
model
==
"Alibaba-NLP/gte-Qwen2-7B-instruct"
:
vllm_extra_kwargs
[
"hf_overrides"
]
=
{
"is_causal"
:
False
}
# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
...
...
@@ -43,8 +48,11 @@ def test_models(
is_sentence_transformer
=
True
)
as
hf_model
:
hf_outputs
=
hf_model
.
encode
(
example_prompts
)
with
vllm_runner
(
model
,
task
=
"embedding"
,
dtype
=
dtype
,
max_model_len
=
None
)
as
vllm_model
:
with
vllm_runner
(
model
,
task
=
"embedding"
,
dtype
=
dtype
,
max_model_len
=
None
,
**
vllm_extra_kwargs
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
...
...
tests/models/embedding/utils.py
View file @
ed46f143
...
...
@@ -24,7 +24,7 @@ def check_embeddings_close(
dim
=
0
)
fail_msg
=
(
f
"Test
{
prompt_idx
}
:"
f
"
\n
{
name_0
}
:
\t
{
embeddings_0
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
embeddings_1
!
r
}
"
)
f
"
\n
{
name_0
}
:
\t
{
embeddings_0
[:
16
]
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
embeddings_1
[:
16
]
!
r
}
"
)
assert
sim
>=
1
-
tol
,
fail_msg
vllm/config.py
View file @
ed46f143
...
...
@@ -27,7 +27,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config
,
get_pooling_config
,
get_sentence_transformer_tokenizer_config
,
is_encoder_decoder
,
uses_mrope
)
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
identity
,
print_warning_once
,
resolve_obj_by_qualname
)
print_warning_once
,
resolve_obj_by_qualname
)
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
...
...
@@ -183,7 +183,7 @@ class ModelConfig:
hf_overrides_fn
=
hf_overrides
else
:
hf_overrides_kw
=
hf_overrides
hf_overrides_fn
=
identity
hf_overrides_fn
=
None
if
rope_scaling
is
not
None
:
hf_override
:
Dict
[
str
,
Any
]
=
{
"rope_scaling"
:
rope_scaling
}
...
...
@@ -212,8 +212,15 @@ class ModelConfig:
self
.
skip_tokenizer_init
=
skip_tokenizer_init
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
code_revision
,
config_format
,
**
hf_overrides_kw
)
hf_config
=
hf_overrides_fn
(
hf_config
)
code_revision
,
config_format
)
if
hf_overrides_kw
:
logger
.
info
(
"Overriding HF config with %s"
,
hf_overrides_kw
)
hf_config
.
update
(
hf_overrides_kw
)
if
hf_overrides_fn
:
logger
.
info
(
"Overriding HF config with %s"
,
hf_overrides_fn
)
hf_config
=
hf_overrides_fn
(
hf_config
)
self
.
hf_config
=
hf_config
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
...
...
vllm/model_executor/models/qwen2.py
View file @
ed46f143
...
...
@@ -27,7 +27,7 @@ import torch
from
torch
import
nn
from
transformers
import
Qwen2Config
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -164,11 +164,17 @@ class Qwen2Attention(nn.Module):
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
attn_type
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -210,6 +216,15 @@ class Qwen2DecoderLayer(nn.Module):
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if
getattr
(
config
,
"is_causal"
,
True
):
self
.
_attn_type
=
AttentionType
.
DECODER
else
:
self
.
_attn_type
=
AttentionType
.
ENCODER_ONLY
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -230,6 +245,7 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
attn_type
=
self
.
_attn_type
,
)
# Fully Connected
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment