Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d62a076e
Unverified
Commit
d62a076e
authored
May 14, 2025
by
Cyrus Leung
Committed by
GitHub
May 14, 2025
Browse files
[Model] GritLM supports other attention backends (#18109)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
259127f8
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
84 additions
and
107 deletions
+84
-107
tests/models/language/pooling/test_gritlm.py
tests/models/language/pooling/test_gritlm.py
+31
-46
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+12
-34
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+28
-14
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+13
-13
No files found.
tests/models/language/pooling/test_gritlm.py
View file @
d62a076e
...
@@ -11,7 +11,6 @@ from scipy.spatial.distance import cosine
...
@@ -11,7 +11,6 @@ from scipy.spatial.distance import cosine
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.utils
import
STR_BACKEND_ENV_VAR
from
....utils
import
RemoteOpenAIServer
from
....utils
import
RemoteOpenAIServer
...
@@ -117,12 +116,7 @@ def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]):
...
@@ -117,12 +116,7 @@ def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]):
assert
math
.
isclose
(
cosine_sim_q1_d1
,
0.534
,
abs_tol
=
0.001
)
assert
math
.
isclose
(
cosine_sim_q1_d1
,
0.534
,
abs_tol
=
0.001
)
def
test_gritlm_offline_embedding
(
monkeypatch
:
pytest
.
MonkeyPatch
,
def
test_gritlm_offline_embedding
(
vllm_runner
):
vllm_runner
):
# GritLM embedding implementation is only supported by XFormers backend.
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"XFORMERS"
)
queries
,
q_instruction
,
documents
,
d_instruction
=
get_test_data
()
queries
,
q_instruction
,
documents
,
d_instruction
=
get_test_data
()
with
vllm_runner
(
with
vllm_runner
(
...
@@ -150,11 +144,9 @@ def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch,
...
@@ -150,11 +144,9 @@ def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch,
async
def
test_gritlm_api_server_embedding
():
async
def
test_gritlm_api_server_embedding
():
queries
,
q_instruction
,
documents
,
d_instruction
=
get_test_data
()
queries
,
q_instruction
,
documents
,
d_instruction
=
get_test_data
()
# GritLM embedding implementation is only supported by XFormers backend.
args
=
[
"--task"
,
"embed"
,
"--max_model_len"
,
str
(
MAX_MODEL_LEN
)]
args
=
[
"--task"
,
"embed"
,
"--max_model_len"
,
str
(
MAX_MODEL_LEN
)]
env_dict
=
{
STR_BACKEND_ENV_VAR
:
"XFORMERS"
}
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
,
env_dict
=
env_dict
)
as
server
:
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
server
:
client_embedding
=
server
.
get_async_client
()
client_embedding
=
server
.
get_async_client
()
d_rep
=
await
run_client_embeddings
(
d_rep
=
await
run_client_embeddings
(
...
@@ -172,11 +164,6 @@ async def test_gritlm_api_server_embedding():
...
@@ -172,11 +164,6 @@ async def test_gritlm_api_server_embedding():
def
test_gritlm_offline_generate
(
monkeypatch
:
pytest
.
MonkeyPatch
,
vllm_runner
):
def
test_gritlm_offline_generate
(
monkeypatch
:
pytest
.
MonkeyPatch
,
vllm_runner
):
# GritLM embedding implementation is only supported by XFormers backend.
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"XFORMERS"
)
input
=
"<|user|>
\n
What is the capital of France?
\n
<|assistant|>
\n
"
input
=
"<|user|>
\n
What is the capital of France?
\n
<|assistant|>
\n
"
with
vllm_runner
(
with
vllm_runner
(
...
@@ -196,11 +183,9 @@ def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner):
...
@@ -196,11 +183,9 @@ def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner):
async
def
test_gritlm_api_server_generate
():
async
def
test_gritlm_api_server_generate
():
input
=
"<|user|>
\n
What is the capital of France?
\n
<|assistant|>
\n
"
input
=
"<|user|>
\n
What is the capital of France?
\n
<|assistant|>
\n
"
# GritLM embedding implementation is only supported by XFormers backend.
args
=
[
"--task"
,
"generate"
,
"--max_model_len"
,
str
(
MAX_MODEL_LEN
)]
args
=
[
"--task"
,
"generate"
,
"--max_model_len"
,
str
(
MAX_MODEL_LEN
)]
env_dict
=
{
"VLLM_USE_V1"
:
"0"
,
STR_BACKEND_ENV_VAR
:
"XFORMERS"
}
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
,
env_dict
=
env_dict
)
as
server
:
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
server
:
client_generate
=
server
.
get_async_client
()
client_generate
=
server
.
get_async_client
()
outputs
=
await
client_generate
.
completions
.
create
(
outputs
=
await
client_generate
.
completions
.
create
(
...
...
vllm/model_executor/models/gritlm.py
View file @
d62a076e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
array
import
array
from
array
import
array
from
typing
import
Optional
,
Union
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
from
vllm.attention.backends.xformers
import
XFormersImpl
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
PoolerHead
from
vllm.model_executor.layers.pooler
import
PoolerHead
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.pooling_metadata
import
(
PoolingMetadata
,
from
vllm.model_executor.pooling_metadata
import
(
PoolingMetadata
,
PoolingTensors
)
PoolingTensors
)
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
PoolingSequenceGroupOutput
)
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
.interfaces
import
SupportsV0Only
from
.interfaces
import
SupportsV0Only
...
@@ -204,38 +200,20 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
...
@@ -204,38 +200,20 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
# Use full attention for pooling
if
vllm_config
.
model_config
.
runner_type
==
"pooling"
:
self
.
runner_type
=
vllm_config
.
model_config
.
runner_type
hf_config
=
vllm_config
.
model_config
.
hf_config
hf_config
.
is_causal
=
False
self
.
_pooler
=
GritLMPooler
(
vllm_config
.
model
_config
)
vllm_config
.
cache
_config
.
sliding_window
=
None
for
layer
in
self
.
model
.
layers
:
for
attr
in
(
"sliding_window"
,
"interleaved_sliding_window"
):
if
self
.
runner_type
==
"pooling"
and
hasattr
(
layer
,
"self_attn"
):
if
hasattr
(
hf_config
,
attr
):
assert
isinstance
(
layer
.
self_attn
.
attn
.
impl
,
XFormersImpl
),
(
delattr
(
hf_config
,
attr
)
"GritLM embedding is only supported by XFormers backend, "
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS"
)
def
forward
(
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
# Change attention to non-causal for pooling tasks.
if
self
.
runner_type
==
"pooling"
:
attn_metadata
=
get_forward_context
().
attn_metadata
assert
attn_metadata
.
prefill_metadata
.
attn_bias
is
None
attn_metadata
.
prefill_metadata
.
attn_bias
=
[
BlockDiagonalMask
.
from_seqlens
(
attn_metadata
.
seq_lens
)
]
return
super
().
forward
(
self
.
_pooler
=
GritLMPooler
(
vllm_config
.
model_config
)
input_ids
=
input_ids
,
positions
=
positions
,
**
kwargs
,
)
def
pooler
(
def
pooler
(
self
,
self
,
...
...
vllm/model_executor/models/llama.py
View file @
d62a076e
...
@@ -28,7 +28,7 @@ import torch
...
@@ -28,7 +28,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.attention
import
Attention
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -96,7 +96,8 @@ class LlamaMLP(nn.Module):
...
@@ -96,7 +96,8 @@ class LlamaMLP(nn.Module):
class
LlamaAttention
(
nn
.
Module
):
class
LlamaAttention
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
LlamaConfig
,
config
:
LlamaConfig
,
hidden_size
:
int
,
hidden_size
:
int
,
num_heads
:
int
,
num_heads
:
int
,
...
@@ -108,7 +109,9 @@ class LlamaAttention(nn.Module):
...
@@ -108,7 +109,9 @@ class LlamaAttention(nn.Module):
bias
:
bool
=
False
,
bias
:
bool
=
False
,
bias_o_proj
:
bool
=
False
,
bias_o_proj
:
bool
=
False
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
)
->
None
:
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
layer_idx
=
extract_layer_index
(
prefix
)
layer_idx
=
extract_layer_index
(
prefix
)
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -194,6 +197,7 @@ class LlamaAttention(nn.Module):
...
@@ -194,6 +197,7 @@ class LlamaAttention(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
per_layer_sliding_window
=
sliding_window
,
per_layer_sliding_window
=
sliding_window
,
attn_type
=
attn_type
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
)
...
@@ -238,6 +242,15 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -238,6 +242,15 @@ class LlamaDecoderLayer(nn.Module):
if
hasattr
(
config
,
'qkv_bias'
):
if
hasattr
(
config
,
'qkv_bias'
):
attention_bias
=
config
.
qkv_bias
attention_bias
=
config
.
qkv_bias
# By default, Llama uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. parasail-ai/GritLM-7B-vllm)
if
getattr
(
config
,
"is_causal"
,
True
):
attn_type
=
AttentionType
.
DECODER
else
:
attn_type
=
AttentionType
.
ENCODER_ONLY
self
.
self_attn
=
LlamaAttention
(
self
.
self_attn
=
LlamaAttention
(
config
=
config
,
config
=
config
,
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
@@ -252,6 +265,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -252,6 +265,7 @@ class LlamaDecoderLayer(nn.Module):
bias_o_proj
=
bias_o_proj
,
bias_o_proj
=
bias_o_proj
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
attn_type
=
attn_type
,
)
)
self
.
mlp
=
LlamaMLP
(
self
.
mlp
=
LlamaMLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
...
vllm/model_executor/models/qwen2.py
View file @
d62a076e
...
@@ -111,8 +111,8 @@ class Qwen2Attention(nn.Module):
...
@@ -111,8 +111,8 @@ class Qwen2Attention(nn.Module):
rope_scaling
:
Optional
[
Tuple
]
=
None
,
rope_scaling
:
Optional
[
Tuple
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
Any
]]
=
None
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment