Unverified Commit ed46f143 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Support `is_causal` HF config field for Qwen2 model (#10621)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 05d1f8c9
...@@ -342,7 +342,7 @@ Text Embedding ...@@ -342,7 +342,7 @@ Text Embedding
- ✅︎ - ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM` * - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based - Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc. - :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- ✅︎ - ✅︎
- ✅︎ - ✅︎
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM` * - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
...@@ -363,6 +363,13 @@ Text Embedding ...@@ -363,6 +363,13 @@ Text Embedding
.. tip:: .. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`. You can override the model's pooling method by passing :code:`--override-pooler-config`.
.. note::
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
You can set `--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention
despite being described otherwise on its model card.
Reward Modeling Reward Modeling
--------------- ---------------
...@@ -609,7 +616,7 @@ Text Generation ...@@ -609,7 +616,7 @@ Text Generation
vLLM currently only supports adding LoRA to the language backbone of multimodal models. vLLM currently only supports adding LoRA to the language backbone of multimodal models.
.. note:: .. note::
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
Multimodal Embedding Multimodal Embedding
......
...@@ -21,6 +21,7 @@ from ..utils import check_embeddings_close ...@@ -21,6 +21,7 @@ from ..utils import check_embeddings_close
marks=[pytest.mark.core_model]), marks=[pytest.mark.core_model]),
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
], ],
) )
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
...@@ -31,6 +32,10 @@ def test_models( ...@@ -31,6 +32,10 @@ def test_models(
model, model,
dtype: str, dtype: str,
) -> None: ) -> None:
vllm_extra_kwargs = {}
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}
# The example_prompts has ending "\n", for example: # The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n" # "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see: # sentence_transformers will strip the input texts, see:
...@@ -43,8 +48,11 @@ def test_models( ...@@ -43,8 +48,11 @@ def test_models(
is_sentence_transformer=True) as hf_model: is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts) hf_outputs = hf_model.encode(example_prompts)
with vllm_runner(model, task="embedding", dtype=dtype, with vllm_runner(model,
max_model_len=None) as vllm_model: task="embedding",
dtype=dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts) vllm_outputs = vllm_model.encode(example_prompts)
# This test is for verifying whether the model's extra_repr # This test is for verifying whether the model's extra_repr
# can be printed correctly. # can be printed correctly.
......
...@@ -24,7 +24,7 @@ def check_embeddings_close( ...@@ -24,7 +24,7 @@ def check_embeddings_close(
dim=0) dim=0)
fail_msg = (f"Test{prompt_idx}:" fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{embeddings_0!r}" f"\n{name_0}:\t{embeddings_0[:16]!r}"
f"\n{name_1}:\t{embeddings_1!r}") f"\n{name_1}:\t{embeddings_1[:16]!r}")
assert sim >= 1 - tol, fail_msg assert sim >= 1 - tol, fail_msg
...@@ -27,7 +27,7 @@ from vllm.transformers_utils.config import ( ...@@ -27,7 +27,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config, get_pooling_config, get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
identity, print_warning_once, resolve_obj_by_qualname) print_warning_once, resolve_obj_by_qualname)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
...@@ -183,7 +183,7 @@ class ModelConfig: ...@@ -183,7 +183,7 @@ class ModelConfig:
hf_overrides_fn = hf_overrides hf_overrides_fn = hf_overrides
else: else:
hf_overrides_kw = hf_overrides hf_overrides_kw = hf_overrides
hf_overrides_fn = identity hf_overrides_fn = None
if rope_scaling is not None: if rope_scaling is not None:
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling} hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
...@@ -212,8 +212,15 @@ class ModelConfig: ...@@ -212,8 +212,15 @@ class ModelConfig:
self.skip_tokenizer_init = skip_tokenizer_init self.skip_tokenizer_init = skip_tokenizer_init
hf_config = get_config(self.model, trust_remote_code, revision, hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, config_format, **hf_overrides_kw) code_revision, config_format)
if hf_overrides_kw:
logger.info("Overriding HF config with %s", hf_overrides_kw)
hf_config.update(hf_overrides_kw)
if hf_overrides_fn:
logger.info("Overriding HF config with %s", hf_overrides_fn)
hf_config = hf_overrides_fn(hf_config) hf_config = hf_overrides_fn(hf_config)
self.hf_config = hf_config self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
......
...@@ -27,7 +27,7 @@ import torch ...@@ -27,7 +27,7 @@ import torch
from torch import nn from torch import nn
from transformers import Qwen2Config from transformers import Qwen2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -164,11 +164,17 @@ class Qwen2Attention(nn.Module): ...@@ -164,11 +164,17 @@ class Qwen2Attention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=attn_type)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -210,6 +216,15 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -210,6 +216,15 @@ class Qwen2DecoderLayer(nn.Module):
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
self._attn_type = AttentionType.DECODER
else:
self._attn_type = AttentionType.ENCODER_ONLY
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -230,6 +245,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -230,6 +245,7 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
attn_type=self._attn_type,
) )
# Fully Connected # Fully Connected
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment