Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
45ac4ff2
Unverified
Commit
45ac4ff2
authored
Nov 25, 2024
by
youkaichao
Committed by
GitHub
Nov 25, 2024
Browse files
[bugfix] fix aria model and add torch.compile (#10645)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
6e9ff050
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
28 deletions
+14
-28
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+4
-22
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+10
-6
No files found.
vllm/model_executor/models/aria.py
View file @
45ac4ff2
...
...
@@ -29,7 +29,7 @@ from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP,
LlamaModel
)
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
is_pp_missing_parameter
,
make_layers
,
maybe_prefix
,
maybe_prefix
,
merge_multimodal_embeddings
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
...
...
@@ -363,27 +363,9 @@ class AriaMoELMModel(LlamaModel):
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
# FIXME: this is a hack to disable the compilation of the model
self
.
do_not_compile
=
True
self
.
layers
=
None
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
MoEDecoderLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
layer_type
=
MoEDecoderLayer
)
# Adapted from LlamaModel.load_weights with the modification of adding
# the expert weights mapping to `stacked_params_mapping`
...
...
vllm/model_executor/models/llama.py
View file @
45ac4ff2
...
...
@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
torch
from
torch
import
nn
...
...
@@ -273,7 +273,11 @@ class LlamaDecoderLayer(nn.Module):
@
support_torch_compile
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
layer_type
:
Type
[
LlamaDecoderLayer
]
=
LlamaDecoderLayer
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
...
...
@@ -299,7 +303,7 @@ class LlamaModel(nn.Module):
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
L
la
maDecoderLayer
(
config
=
config
,
lambda
prefix
:
la
yer_type
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment