Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9ddee6b1
Commit
9ddee6b1
authored
Dec 04, 2024
by
zhuwenwen
Browse files
support falcon and optimize layout
parent
2d5a25cd
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
4 deletions
+62
-4
README.md
README.md
+1
-0
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+3
-2
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+2
-2
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+56
-0
No files found.
README.md
View file @
9ddee6b1
...
@@ -19,6 +19,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
...
@@ -19,6 +19,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| BloomForCausalLM | BLOOM | Yes | No | - |
| BloomForCausalLM | BLOOM | Yes | No | - |
| InternLMForCausalLM | InternLM | Yes | No | - |
| InternLMForCausalLM | InternLM | Yes | No | - |
| InternLM2ForCausalLM | InternLM2 | Yes | No | - |
| InternLM2ForCausalLM | InternLM2 | Yes | No | - |
| FalconForCausalLM | falcon | Yes | No | - |
| TeleChat12BForCausalLM (#TelechatForCausalLM) | TeleChat-12B | Yes | No | - |
| TeleChat12BForCausalLM (#TelechatForCausalLM) | TeleChat-12B | Yes | No | - |
| MiniCPMForCausalLM | MiniCPM | Yes | No | - |
| MiniCPMForCausalLM | MiniCPM | Yes | No | - |
| MiniCPM3ForCausalLM | MiniCPM3 | Yes | No | - |
| MiniCPM3ForCausalLM | MiniCPM3 | Yes | No | - |
...
...
tests/multimodal/test_utils.py
View file @
9ddee6b1
...
@@ -5,12 +5,13 @@ from typing import Dict, Tuple
...
@@ -5,12 +5,13 @@ from typing import Dict, Tuple
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
os
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
AutoConfig
,
AutoTokenizer
from
transformers
import
AutoConfig
,
AutoTokenizer
from
vllm.multimodal.utils
import
(
async_fetch_image
,
fetch_image
,
from
vllm.multimodal.utils
import
(
async_fetch_image
,
fetch_image
,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
..utils
import
urls_port
from
..utils
import
models_path_prefix
,
urls_port
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS
=
[
TEST_IMAGE_URLS
=
[
...
@@ -85,7 +86,7 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
...
@@ -85,7 +86,7 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
assert
_image_equals
(
data_image_sync
,
data_image_async
)
assert
_image_equals
(
data_image_sync
,
data_image_async
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
os
.
path
.
join
(
models_path_prefix
,
"llava-hf/llava-v1.6-mistral-7b-hf"
)
])
def
test_repeat_and_pad_placeholder_tokens
(
model
):
def
test_repeat_and_pad_placeholder_tokens
(
model
):
config
=
AutoConfig
.
from_pretrained
(
model
)
config
=
AutoConfig
.
from_pretrained
(
model
)
image_token_id
=
config
.
image_token_index
image_token_id
=
config
.
image_token_index
...
...
vllm/model_executor/model_loader/utils.py
View file @
9ddee6b1
...
@@ -23,14 +23,14 @@ def get_model_architecture(
...
@@ -23,14 +23,14 @@ def get_model_architecture(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'Qwen2MoeForCausalLM'
,
'Qwen2VLForConditionalGeneration'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
]
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'Qwen2MoeForCausalLM'
,
'Qwen2VLForConditionalGeneration'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
,
'FalconForCausalLM'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
visions
!=
[]:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
visions
!=
[]:
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LLAMA_NN'
]
=
'0'
else
:
else
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
os
.
environ
[
'LLAMA_NN'
]
=
'1'
if
architectures
==
[
'BloomForCausalLM'
]
or
os
.
getenv
(
'LM_NN'
)
==
'0'
:
if
(
architectures
==
[
'BloomForCausalLM'
]
or
architectures
==
[
'FalconForCausalLM'
])
or
os
.
getenv
(
'LM_NN'
)
==
'0'
:
os
.
environ
[
'LM_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
else
:
else
:
os
.
environ
[
'LM_NN'
]
=
'1'
os
.
environ
[
'LM_NN'
]
=
'1'
...
...
vllm/model_executor/models/falcon.py
View file @
9ddee6b1
...
@@ -21,6 +21,8 @@
...
@@ -21,6 +21,8 @@
import
math
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
os
import
re
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
torch.nn
import
LayerNorm
...
@@ -47,6 +49,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -47,6 +49,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
RWConfig
from
vllm.transformers_utils.configs
import
RWConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
...
@@ -176,6 +181,11 @@ class FalconAttention(nn.Module):
...
@@ -176,6 +181,11 @@ class FalconAttention(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -184,6 +194,8 @@ class FalconAttention(nn.Module):
...
@@ -184,6 +194,8 @@ class FalconAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
bias
=
self
.
query_key_value
(
hidden_states
)
qkv
,
bias
=
self
.
query_key_value
(
hidden_states
)
if
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
and
self
.
quant_method
is
None
:
qkv
=
qkv
[...,:
-
32
]
if
bias
is
not
None
:
if
bias
is
not
None
:
qkv
+=
bias
qkv
+=
bias
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
@@ -246,6 +258,9 @@ class FalconDecoderLayer(nn.Module):
...
@@ -246,6 +258,9 @@ class FalconDecoderLayer(nn.Module):
self
.
mlp
=
FalconMLP
(
config
,
quant_config
)
self
.
mlp
=
FalconMLP
(
config
,
quant_config
)
self
.
config
=
config
self
.
config
=
config
if
(
not
hasattr
(
config
,
"num_ln_in_parallel_attn"
)):
config
.
num_ln_in_parallel_attn
=
None
if
(
config
.
num_ln_in_parallel_attn
is
None
if
(
config
.
num_ln_in_parallel_attn
is
None
and
config
.
new_decoder_architecture
):
and
config
.
new_decoder_architecture
):
config
.
num_ln_in_parallel_attn
=
2
config
.
num_ln_in_parallel_attn
=
2
...
@@ -404,6 +419,17 @@ class FalconForCausalLM(nn.Module):
...
@@ -404,6 +419,17 @@ class FalconForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'0'
))
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
...
@@ -481,3 +507,33 @@ class FalconForCausalLM(nn.Module):
...
@@ -481,3 +507,33 @@ class FalconForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"self_attention.query_key_value.weight"
,
"self_attention.dense.weight"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_4h_to_h.weight"
,
]
combined_words
=
"|"
.
join
(
lay_key_words
)
lay_qkv_words
=
[
"self_attention.query_key_value.weight"
]
qkv_words
=
"|"
.
join
(
lay_qkv_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
self
.
use_gemm_pad
and
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
if
self
.
use_fa_pad
and
(
re
.
findall
(
qkv_words
,
layername
)):
if
not
gemm_bank_conf
(
weight
.
data
.
shape
[
0
]):
weight
.
data
=
pad_weight
(
weight
.
data
,
32
)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment