Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
571841b7
Unverified
Commit
571841b7
authored
Nov 24, 2024
by
youkaichao
Committed by
GitHub
Nov 25, 2024
Browse files
[torch.compile] support encoder based models (#10613)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
7ea3cd7c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
10 deletions
+17
-10
tests/compile/test_basic_correctness.py
tests/compile/test_basic_correctness.py
+10
-0
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+7
-10
No files found.
tests/compile/test_basic_correctness.py
View file @
571841b7
...
@@ -62,6 +62,16 @@ test_settings = [
...
@@ -62,6 +62,16 @@ test_settings = [
method
=
"encode"
,
method
=
"encode"
,
fullgraph
=
True
,
fullgraph
=
True
,
),
),
# encoder-based embedding model (BERT)
TestSetting
(
model
=
"BAAI/bge-base-en-v1.5"
,
model_args
=
[
"--task"
,
"embedding"
],
pp_size
=
1
,
tp_size
=
1
,
attn_backend
=
"XFORMERS"
,
method
=
"encode"
,
fullgraph
=
True
,
),
# vision language model
# vision language model
TestSetting
(
TestSetting
(
model
=
"microsoft/Phi-3.5-vision-instruct"
,
model
=
"microsoft/Phi-3.5-vision-instruct"
,
...
...
vllm/model_executor/models/bert.py
View file @
571841b7
...
@@ -5,6 +5,7 @@ from torch import nn
...
@@ -5,6 +5,7 @@ from torch import nn
from
transformers
import
BertConfig
from
transformers
import
BertConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
PoolerConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
PoolerConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
...
@@ -92,14 +93,14 @@ class BertPooler(nn.Module):
...
@@ -92,14 +93,14 @@ class BertPooler(nn.Module):
return
pooled_output
return
pooled_output
@
support_torch_compile
class
BertEncoder
(
nn
.
Module
):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
:
BertConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
layer
=
nn
.
ModuleList
([
self
.
layer
=
nn
.
ModuleList
([
BertLayer
(
config
=
config
,
BertLayer
(
config
=
config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
...
@@ -336,12 +337,8 @@ class BertModel(nn.Module):
...
@@ -336,12 +337,8 @@ class BertModel(nn.Module):
add_pooling_layer
:
bool
=
False
):
add_pooling_layer
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
embeddings
=
embedding_class
(
config
)
self
.
embeddings
=
embedding_class
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
self
.
encoder
=
BertEncoder
(
vllm_config
=
vllm_config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.encoder"
)
prefix
=
f
"
{
prefix
}
.encoder"
)
self
.
pooler
=
BertPooler
(
config
)
if
add_pooling_layer
else
None
self
.
pooler
=
BertPooler
(
config
)
if
add_pooling_layer
else
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment