flash_santacoder.py 6.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.distributed

from accelerate import init_empty_weights
from opentelemetry import trace
from pathlib import Path
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List

from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
12
    FlashSantacoderForCausalLM,
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
)
from text_generation_server.utils import (
    weight_files,
    download_weights,
    weight_hub_files,
    LocalEntryNotFoundError,
)

tracer = trace.get_tracer(__name__)


class FlashSantacoder(FlashCausalLM):
    def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        else:
            raise NotImplementedError("FlashSantacoder is only available on GPU")

        if quantize:
            raise NotImplementedError("FlashSantacoder does not support quantization")

        tokenizer = AutoTokenizer.from_pretrained(
36
            model_id, revision=revision, padding_side="left", truncation_side="left"
37
38
39
        )

        config = AutoConfig.from_pretrained(
40
41
42
            model_id,
            revision=revision,
            trust_remote_code=True,  # Needed as the config is not part of Transformers
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        )

        # We do not use from_pretrained as we modified the model internal module layout
        try:
            filenames = weight_files(model_id, revision, ".bin")
        # Local files not found
        except LocalEntryNotFoundError:
            hub_files = weight_hub_files(model_id, revision, ".bin")
            filenames = download_weights(hub_files, model_id, revision)

        with init_empty_weights():
            model = FlashSantacoderForCausalLM(config)

        self.load_weights(
            model,
            filenames,
59
60
            device,
            dtype,
61
62
63
64
        )
        self.model = model.eval().to(device).to(dtype)

        super(FlashCausalLM, self).__init__(
65
            tokenizer=tokenizer, device=device, decode_buffer=1
66
67
68
69
        )

    @staticmethod
    def load_weights(
70
71
        model: FlashSantacoderForCausalLM,
        filenames: List[Path],
72
73
        device: torch.device,
        dtype: torch.dtype,
74
75
76
77
    ):
        for filename in filenames:
            state_dict = torch.load(filename, map_location="cpu")
            for key, value in state_dict.items():
78
79
                value = value.to(device).to(dtype)

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
                layer_name = ".".join(key.split(".")[:4])

                # Fused qkv
                if "q_attn.weight" in key or "kv_attn.weight" in key:
                    final_key = layer_name + ".attn.weight"
                elif "q_attn.bias" in key or "kv_attn.bias" in key:
                    final_key = layer_name + ".attn.bias"

                else:
                    final_key = key

                module_name, param_name = final_key.rsplit(".", 1)
                module = model.get_submodule(module_name)

                try:
                    current_parameter_tensor = module._parameters[param_name]
                except KeyError:
                    current_parameter_tensor = None

                if current_parameter_tensor is not None:
100
101
102
103
104
105
                    if (
                        "c_fc.weight" in key
                        or "c_proj.weight" in key
                        or "q_attn.weight" in key
                        or "kv_attn.weight" in key
                    ):
106
107
108
109
110
111
112
                        # Tranpose as we use nn.Linear instead of Conv1D
                        value = value.T

                    if current_parameter_tensor.device == torch.device("meta"):
                        # Init qkv
                        if "attn.weight" in final_key:
                            module._parameters[param_name] = value.new_empty(
113
114
115
116
117
                                (
                                    model.transformer.head_size
                                    * (model.transformer.num_heads + 2),
                                    value.shape[1],
                                )
118
119
120
                            )
                        elif "attn.bias" in final_key:
                            module._parameters[param_name] = value.new_empty(
121
122
123
124
                                (
                                    model.transformer.head_size
                                    * (model.transformer.num_heads + 2)
                                )
125
126
127
128
129
130
131
132
133
                            )

                    # Copy to correct slice
                    if "q_attn.weight" in key:
                        module._parameters[param_name][: value.shape[0]] = value
                    elif "q_attn.bias" in key:
                        module._parameters[param_name][: value.shape[0]] = value
                    elif "kv_attn.weight" in key:
                        module._parameters[param_name][
134
                            model.transformer.head_size * model.transformer.num_heads :
135
136
137
                        ] = value
                    elif "kv_attn.bias" in key:
                        module._parameters[param_name][
138
                            model.transformer.head_size * model.transformer.num_heads :
139
140
141
142
143
144
145
146
147
148
                        ] = value
                    else:
                        if current_parameter_tensor.shape != value.shape:
                            raise ValueError(
                                f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
                            )
                        module._parameters[param_name] = value
                else:
                    module._buffers[param_name] = value

149
150
                del value

151
152
153
154
155
156
157
158
        torch.cuda.empty_cache()
        model.post_load_weights()

    def decode(self, generated_ids: List[int]) -> str:
        # Do not skip special tokens as they are used for custom parsing rules of the generated text
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
        )