Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e7d022fd
Commit
e7d022fd
authored
Nov 07, 2024
by
王敏
Browse files
1.Mixtral模型的gemm由TN改为NN
parent
82aee745
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
1 deletion
+32
-1
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+1
-1
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+31
-0
No files found.
vllm/model_executor/model_loader/utils.py
View file @
e7d022fd
...
...
@@ -23,7 +23,7 @@ def get_model_architecture(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'Qwen2MoeForCausalLM'
,
'Qwen2VLForConditionalGeneration'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
]
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'Qwen2MoeForCausalLM'
,
'Qwen2VLForConditionalGeneration'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
visions
!=
[]:
...
...
vllm/model_executor/models/mixtral.py
View file @
e7d022fd
...
...
@@ -21,6 +21,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
import
os
import
re
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -46,6 +48,8 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
.interfaces
import
SupportsLoRA
from
.utils
import
is_pp_missing_parameter
,
make_layers
...
...
@@ -366,6 +370,12 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -483,3 +493,24 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"block_sparse_moe.gate.weight"
,
"self_attn.qkv_proj.weight"
,
"self_attn.o_proj.weight"
,
"lm_head.weight"
,
]
combined_words
=
"|"
.
join
(
lay_key_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment