cpm.py 3.64 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# coding=utf-8
# Copyright 2022 The OpenBMB team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import Optional
from typing import Tuple

import torch
import torch.nn.functional as F
from typing_extensions import TypedDict
from transformers.configuration_utils import PretrainedConfig


class CPMDragonflyConfig(PretrainedConfig):
    model_type = "cpmdragonfly"
    keys_to_ignore_at_inference = ["past_key_values"]
    attribute_map = {
        "num_key_value_heads": "num_kv_heads",
        "hidden_act": "activate_fn",
        "hidden_size": "dim_model",
        "num_attention_heads": "num_heads",
        "intermediate_size": "dim_ff",
        "num_hidden_layers": "num_layers",
        "vocab_size": "vocab_size",
        "rms_norm_eps": "eps",
        "scale_emb": "scale_emb",
        "scale_depth": "scale_depth",
        "scale": "scale",
        "attention_scale": "attention_scale"
    }

    def __init__(
        self,
        vocab_size=32000,
        dim_model=4096,
        num_heads=32,
        num_kv_heads=32,
        dim_head=128,
        dim_ff=11008,
        num_layers=32,
        dropout_p=0.0,
        activate_fn="silu",
        scale=True,
        scale_emb: float=1.,
        scale_depth: float=-1,
        dim_model_base:int=None,
        eps=1e-5,
        init_std=0.02,
        half: bool = True,
        half_type = 'bf16',
        mask_modules: Optional[List[Tuple[bool, bool]]] = None,
        use_flash_attn: bool = True,
        flash_attn_mask_shape="1d",
        flash_impl="cuda",
        base=10000,
        non_checkpointing_layers_num:int = 0,
        attention_scale=1,
        max_position_embeddings=8192,
        rope_scaling=None,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.dim_model = dim_model
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.dim_head = dim_head
        self.dim_ff = dim_ff
        self.num_layers = num_layers
        self.dropout_p = dropout_p
        self.activate_fn = activate_fn
        self.scale = scale
        self.scale_emb = scale_emb
        self.half = half
        self.half_type = half_type
        self.dim_model_base = dim_model_base
        self.scale_depth = scale_depth
        self.eps = eps
        self.init_std = init_std
        self.flash_impl = flash_impl
        self.mask_modules = mask_modules
        self.use_flash_attn = use_flash_attn
        self.flash_attn_mask_shape = flash_attn_mask_shape
        self.base = base
        self.attention_scale=attention_scale
        self.max_position_embeddings = max_position_embeddings
        self.non_checkpointing_layers_num = non_checkpointing_layers_num
        self.rope_scaling = rope_scaling
        super().__init__(architectures=["CPMDragonflyForCausalLM"])
    
    @property
    def scale_width(self,):
        if self.scale:
            return self.dim_model / self.dim_model_base
        else:
            return 1.
    
    @property
    def dtype(self, ):
        if self.half:
            if self.half_type == 'bf16':
                return torch.bfloat16
            else:
                return torch.half
        else:
            return torch.float