Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1a6fcad4
Unverified
Commit
1a6fcad4
authored
Feb 06, 2025
by
Harry Mellor
Committed by
GitHub
Feb 05, 2025
Browse files
Improve `TransformersModel` UX (#12785)
parent
56534cd5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
21 deletions
+32
-21
vllm/model_executor/models/transformers.py
vllm/model_executor/models/transformers.py
+32
-21
No files found.
vllm/model_executor/models/transformers.py
View file @
1a6fcad4
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# limitations under the License.
# limitations under the License.
"""Wrapper around `transformers` models"""
"""Wrapper around `transformers` models"""
import
re
import
re
from
typing
import
Iterable
,
Optional
,
Union
from
typing
import
Iterable
,
Literal
,
Optional
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -72,15 +72,24 @@ def vllm_flash_attention_forward(
...
@@ -72,15 +72,24 @@ def vllm_flash_attention_forward(
ALL_ATTENTION_FUNCTIONS
[
"vllm"
]
=
vllm_flash_attention_forward
ALL_ATTENTION_FUNCTIONS
[
"vllm"
]
=
vllm_flash_attention_forward
def
log_replacement
(
name
:
str
,
old_module
:
nn
.
Module
,
new_module
:
nn
.
Module
):
logger
.
debug
(
"%s: %s -> %s"
,
name
,
old_module
,
new_module
)
def
replace_linear_class
(
def
replace_linear_class
(
linear
:
nn
.
Linear
,
linear
:
nn
.
Linear
,
style
:
str
,
style
:
Literal
[
"colwise"
,
"rowwise"
]
,
quant_config
=
None
)
->
Union
[
ColumnParallelLinear
,
RowParallelLinear
]:
quant_config
=
None
)
->
Union
[
ColumnParallelLinear
,
RowParallelLinear
]:
"""
"""
In model configurations, we use a neutral type (string) to specify parallel
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
styles, here we use it to translate nn.Linear into vllm-style tp Linear.
`quant_config` is not yet supported.
Quant config is not supported yet
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
quant_config (QuantConfig): Quantization config for the new linear.
Returns:
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
"""
"""
if
not
isinstance
(
style
,
str
):
if
not
isinstance
(
style
,
str
):
...
@@ -93,7 +102,10 @@ def replace_linear_class(
...
@@ -93,7 +102,10 @@ def replace_linear_class(
}.
get
(
style
)
}.
get
(
style
)
if
vllm_linear_cls
is
None
:
if
vllm_linear_cls
is
None
:
raise
ValueError
(
f
"Unsupported parallel style value:
{
style
}
"
)
logger
.
warning
(
"Unsupported parallel style value: %s. "
"This layer will not be tensor parallelized."
,
style
)
return
linear
class
HFCompatibleLinear
(
vllm_linear_cls
):
class
HFCompatibleLinear
(
vllm_linear_cls
):
"""
"""
...
@@ -119,25 +131,24 @@ class TransformersModel(nn.Module):
...
@@ -119,25 +131,24 @@ class TransformersModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
"Using Transformers backend."
)
logger
.
info
(
"Using Transformers backend."
)
self
.
vllm_config
=
vllm_config
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
quant_config
=
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
model
:
PreTrainedModel
=
AutoModel
.
from_config
(
self
.
model
:
PreTrainedModel
=
AutoModel
.
from_config
(
self
.
config
,
self
.
config
,
attn_implementation
=
"vllm"
,
attn_implementation
=
"vllm"
,
torch_dtype
=
vllm_config
.
model_config
.
dtype
,
trust_remote_code
=
vllm_config
.
model_config
.
trust_remote_code
,
trust_remote_code
=
vllm_config
.
model_config
.
trust_remote_code
,
)
)
prefix
=
self
.
model
.
base_model_prefix
prefix
=
self
.
model
.
base_model_prefix
# MLP modifications
# MLP modifications
self
.
tensor_parallelize
(
self
.
model
)
self
.
apply_base_model_tp_plan
(
self
.
model
)
# Attention modifications (assumes 1 attention op per hidden layer)
# Attention modifications (assumes 1 attention op per hidden layer)
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -170,13 +181,13 @@ class TransformersModel(nn.Module):
...
@@ -170,13 +181,13 @@ class TransformersModel(nn.Module):
config
.
vocab_size
,
logit_scale
)
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
get_sampler
()
self
.
sampler
=
get_sampler
()
def
log_replacement
(
self
,
name
:
str
,
old_module
:
nn
.
Module
,
def
apply_base_model_tp_plan
(
self
,
module
:
nn
.
Module
,
prefix
:
str
=
""
):
new_module
:
nn
.
Module
):
"""
logger
.
debug
(
"%s: %s -> %s"
,
name
,
old_module
,
new_
module
)
Apply the base model tensor parallelization plan to a
module
.
Currently only supports linear layers.
def
tensor_parallelize
(
self
,
module
:
nn
.
Module
,
prefix
:
str
=
""
):
"""
if
(
self
.
config
.
base_model_tp_plan
is
None
if
(
self
.
config
.
base_model_tp_plan
is
None
and
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
>
1
):
and
get_
tensor_
model_
parallel_
world_
size
()
>
1
):
raise
ValueError
(
raise
ValueError
(
"Trying to run tensor parallelization but the model does not "
"Trying to run tensor parallelization but the model does not "
"support it yet!"
)
"support it yet!"
)
...
@@ -189,9 +200,9 @@ class TransformersModel(nn.Module):
...
@@ -189,9 +200,9 @@ class TransformersModel(nn.Module):
new_module
=
replace_linear_class
(
child_module
,
style
,
new_module
=
replace_linear_class
(
child_module
,
style
,
self
.
quant_config
)
self
.
quant_config
)
setattr
(
module
,
child_name
,
new_module
)
setattr
(
module
,
child_name
,
new_module
)
self
.
log_replacement
(
qual_name
,
child_module
,
new_module
)
log_replacement
(
qual_name
,
child_module
,
new_module
)
else
:
else
:
self
.
tensor_parallelize
(
child_module
,
prefix
=
qual_name
)
self
.
apply_base_model_tp_plan
(
child_module
,
prefix
=
qual_name
)
def
replace_vocab_embed_class
(
self
,
module
:
nn
.
Module
):
def
replace_vocab_embed_class
(
self
,
module
:
nn
.
Module
):
# Use native set input embeddings
# Use native set input embeddings
...
@@ -201,8 +212,8 @@ class TransformersModel(nn.Module):
...
@@ -201,8 +212,8 @@ class TransformersModel(nn.Module):
org_num_embeddings
=
self
.
config
.
vocab_size
,
org_num_embeddings
=
self
.
config
.
vocab_size
,
quant_config
=
None
,
quant_config
=
None
,
)
)
self
.
log_replacement
(
"input embedding"
,
log_replacement
(
"input embedding"
,
self
.
model
.
get_input_embeddings
(),
self
.
model
.
get_input_embeddings
(),
new_module
)
new_module
)
self
.
model
.
set_input_embeddings
(
new_module
)
self
.
model
.
set_input_embeddings
(
new_module
)
def
forward
(
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment