Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1efce686
Unverified
Commit
1efce686
authored
Dec 12, 2024
by
Pooya Davoodi
Committed by
GitHub
Dec 13, 2024
Browse files
[Bugfix] Use runner_type instead of task in GritLM (#11144)
Signed-off-by:
Pooya Davoodi
<
pooya.davoodi@parasail.io
>
parent
30870b4f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
7 deletions
+7
-7
tests/models/embedding/language/test_gritlm.py
tests/models/embedding/language/test_gritlm.py
+3
-3
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+4
-4
No files found.
tests/models/embedding/language/test_gritlm.py
View file @
1efce686
...
@@ -35,7 +35,7 @@ def test_find_array(monkeypatch):
...
@@ -35,7 +35,7 @@ def test_find_array(monkeypatch):
from
vllm.model_executor.models.gritlm
import
GritLMPooler
from
vllm.model_executor.models.gritlm
import
GritLMPooler
# Create an LLM object to get the model config.
# Create an LLM object to get the model config.
llm
=
vllm
.
LLM
(
MODEL_NAME
,
task
=
"embed
ding
"
,
max_model_len
=
MAX_MODEL_LEN
)
llm
=
vllm
.
LLM
(
MODEL_NAME
,
task
=
"embed"
,
max_model_len
=
MAX_MODEL_LEN
)
pooler
=
GritLMPooler
(
model_config
=
llm
.
llm_engine
.
model_config
)
pooler
=
GritLMPooler
(
model_config
=
llm
.
llm_engine
.
model_config
)
arr
=
_arr
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
])
arr
=
_arr
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
])
...
@@ -55,7 +55,7 @@ def server_embedding():
...
@@ -55,7 +55,7 @@ def server_embedding():
with
pytest
.
MonkeyPatch
.
context
()
as
mp
:
with
pytest
.
MonkeyPatch
.
context
()
as
mp
:
mp
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"XFORMERS"
)
mp
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"XFORMERS"
)
args
=
[
"--task"
,
"embed
ding
"
,
"--max_model_len"
,
str
(
MAX_MODEL_LEN
)]
args
=
[
"--task"
,
"embed"
,
"--max_model_len"
,
str
(
MAX_MODEL_LEN
)]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
yield
remote_server
...
@@ -141,7 +141,7 @@ def test_gritlm_offline_embedding(monkeypatch):
...
@@ -141,7 +141,7 @@ def test_gritlm_offline_embedding(monkeypatch):
queries
,
q_instruction
,
documents
,
d_instruction
=
get_test_data
()
queries
,
q_instruction
,
documents
,
d_instruction
=
get_test_data
()
llm
=
vllm
.
LLM
(
MODEL_NAME
,
task
=
"embed
ding
"
,
max_model_len
=
MAX_MODEL_LEN
)
llm
=
vllm
.
LLM
(
MODEL_NAME
,
task
=
"embed"
,
max_model_len
=
MAX_MODEL_LEN
)
d_rep
=
run_llm_encode
(
d_rep
=
run_llm_encode
(
llm
,
llm
,
...
...
vllm/model_executor/models/gritlm.py
View file @
1efce686
...
@@ -203,12 +203,12 @@ class GritLM(LlamaForCausalLM):
...
@@ -203,12 +203,12 @@ class GritLM(LlamaForCausalLM):
)
->
None
:
)
->
None
:
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
task
=
vllm_config
.
model_config
.
task
self
.
runner_type
=
vllm_config
.
model_config
.
runner_type
self
.
_pooler
=
GritLMPooler
(
vllm_config
.
model_config
)
self
.
_pooler
=
GritLMPooler
(
vllm_config
.
model_config
)
for
layer
in
self
.
model
.
layers
:
for
layer
in
self
.
model
.
layers
:
if
self
.
task
==
"embedd
ing"
and
hasattr
(
layer
,
"self_attn"
):
if
self
.
runner_type
==
"pool
ing"
and
hasattr
(
layer
,
"self_attn"
):
assert
isinstance
(
layer
.
self_attn
.
attn
.
impl
,
XFormersImpl
),
(
assert
isinstance
(
layer
.
self_attn
.
attn
.
impl
,
XFormersImpl
),
(
"GritLM embedding is only supported by XFormers backend, "
"GritLM embedding is only supported by XFormers backend, "
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS"
)
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS"
)
...
@@ -222,8 +222,8 @@ class GritLM(LlamaForCausalLM):
...
@@ -222,8 +222,8 @@ class GritLM(LlamaForCausalLM):
**
kwargs
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
# Change attention to non-causal for
embedd
ing task.
# Change attention to non-causal for
pool
ing task
s
.
if
self
.
task
==
"embedd
ing"
:
if
self
.
runner_type
==
"pool
ing"
:
assert
attn_metadata
.
prefill_metadata
.
attn_bias
is
None
assert
attn_metadata
.
prefill_metadata
.
attn_bias
is
None
attn_metadata
.
prefill_metadata
.
attn_bias
=
[
attn_metadata
.
prefill_metadata
.
attn_bias
=
[
BlockDiagonalMask
.
from_seqlens
(
attn_metadata
.
seq_lens
)
BlockDiagonalMask
.
from_seqlens
(
attn_metadata
.
seq_lens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment