qwen.py 3.38 KB
Newer Older
Qing's avatar
Qing committed
1
2
3
4
5
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
6
"""Inference-only QWen model compatible with HuggingFace weights."""
Roy's avatar
Roy committed
7
from typing import Optional
Qing's avatar
Qing committed
8

Roy's avatar
Roy committed
9
10
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
Qing's avatar
Qing committed
11

Roy's avatar
Roy committed
12
from vllm.model_executor.layers.linear import LinearMethodBase
13
from vllm.model_executor.layers.layernorm import RMSNorm
Roy's avatar
Roy committed
14
from vllm.model_executor.models.llama import LlamaForCausalLM
15
16
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
Qing's avatar
Qing committed
17
18


Roy's avatar
Roy committed
19
class QWenLMHeadModel(LlamaForCausalLM):
Qing's avatar
Qing committed
20

21
22
    def __init__(
        self,
Roy's avatar
Roy committed
23
        config: Optional[PretrainedConfig] = None,
24
        linear_method: Optional[LinearMethodBase] = None,
Roy's avatar
Roy committed
25
26
27
28
29
30
31
32
33
        lora_config: Optional[LoRAConfig] = None,
    ) -> None:
        norm = RMSNorm(config.hidden_size, config.layer_norm_epsilon)
        config.use_qkv_bias = True
        config.intermediate_size = config.intermediate_size // 2
        super().__init__(config=config,
                         linear_method=linear_method,
                         norm=norm,
                         lora_config=lora_config)
Qing's avatar
Qing committed
34

35
36
37
38
39
40
41
42
43
44
    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     load_format: str = "auto",
                     revision: Optional[str] = None):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
Roy's avatar
Roy committed
45
46
47
48
49
50
51
52
53
54
55
56
        param_weight_map = [
            ("model", "transformer"),
            (".self_attn.", ".attn."),
            (".layers.", ".h."),
            ("qkv_proj", "c_attn"),
            (".self_attn.o_proj", ".self_attn.c_proj"),
            ("norm", "ln_f"),
            ("mlp.down_proj", "mlp.c_proj"),
            ("input_layernorm", "ln_1"),
            ("post_attention_layernorm", "ln_2"),
            ("embed_tokens", "wte"),
        ]
57
        params_dict = dict(self.named_parameters())
Qing's avatar
Qing committed
58
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
59
                model_name_or_path, cache_dir, load_format, revision):
Roy's avatar
Roy committed
60
61
62
            for (param_name, weight_name) in param_weight_map:
                name = name.replace(weight_name, param_name)

Qing's avatar
Qing committed
63
64
            if "rotary_emb.inv_freq" in name:
                continue
65
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
66
67
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
68
69
70
71
72
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
73
74
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
75
                break
76
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
77
78
79
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
80
81
82
83
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)