Unverified Commit 4f2ad111 authored by Philipp Moritz's avatar Philipp Moritz Committed by GitHub
Browse files

Fix DeciLM (#2883)

parent d7afab6d
...@@ -28,6 +28,7 @@ from typing import Optional ...@@ -28,6 +28,7 @@ from typing import Optional
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
...@@ -56,10 +57,13 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -56,10 +57,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
self, self,
config: Optional[PretrainedConfig] = None, config: Optional[PretrainedConfig] = None,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer) config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer") delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config, linear_method=linear_method) super().__init__(config=config,
linear_method=linear_method,
lora_config=lora_config)
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment