Unverified Commit 45ac4ff2 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[bugfix] fix aria model and add torch.compile (#10645)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 6e9ff050
...@@ -29,7 +29,7 @@ from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP, ...@@ -29,7 +29,7 @@ from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP,
LlamaModel) LlamaModel)
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
is_pp_missing_parameter, is_pp_missing_parameter,
make_layers, maybe_prefix, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
...@@ -363,27 +363,9 @@ class AriaMoELMModel(LlamaModel): ...@@ -363,27 +363,9 @@ class AriaMoELMModel(LlamaModel):
""" """
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config,
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# FIXME: this is a hack to disable the compilation of the model
self.do_not_compile = True
self.layers = None
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MoEDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix, prefix=prefix,
), layer_type=MoEDecoderLayer)
prefix=f"{prefix}.layers",
)
# Adapted from LlamaModel.load_weights with the modification of adding # Adapted from LlamaModel.load_weights with the modification of adding
# the expert weights mapping to `stacked_params_mapping` # the expert weights mapping to `stacked_params_mapping`
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import torch import torch
from torch import nn from torch import nn
...@@ -273,7 +273,11 @@ class LlamaDecoderLayer(nn.Module): ...@@ -273,7 +273,11 @@ class LlamaDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -299,7 +303,7 @@ class LlamaModel(nn.Module): ...@@ -299,7 +303,7 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: LlamaDecoderLayer(config=config, lambda prefix: layer_type(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix), prefix=prefix),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment