bloom.py 9.92 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
from typing import List, Optional, Type
5
6
7

from accelerate import init_empty_weights
from safetensors import safe_open
8
9
10
11
12
13
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    PreTrainedTokenizerBase,
)
14
15
16
17
18
19
from transformers.models.bloom.parallel_layers import (
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
)

20
21
22
23
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
24
25
26
27
28
29
30
31
32
33
34
35
    initialize_torch_distributed,
    weight_files,
)

HAS_BITS_AND_BYTES = True
try:
    import bitsandbytes as bnb
    from bitsandbytes.nn import Int8Params
except Exception as e:
    HAS_BITS_AND_BYTES = False


36
37
38
class BloomCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
39
40
41
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
42
        dtype: torch.dtype,
43
        device: torch.device,
44
45
    ) -> "CausalLMBatch":
        batch = super(BloomCausalLMBatch, cls).from_pb(
46
            pb=pb, tokenizer=tokenizer, dtype=dtype, device=device
47
48
49
50
51
52
        )
        batch.keys_head_dim_last = False
        return batch


class BLOOM(CausalLM):
53
54
55
56
57
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
58
        trust_remote_code: bool = False,
59
    ):
60
        super(BLOOM, self).__init__(
61
62
63
64
            model_id=model_id,
            revision=revision,
            quantize=quantize,
            trust_remote_code=trust_remote_code,
65
66
        )

67
68
69
70
71
72
    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return BloomCausalLMBatch


class BLOOMSharded(BLOOM):
73
    def __init__(
74
75
76
77
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
78
        trust_remote_code: bool = False,
79
    ):
80
        self.process_group, rank, world_size = initialize_torch_distributed()
81
        if torch.cuda.is_available():
82
            device = torch.device(f"cuda:{rank}")
83
            dtype = torch.float16
84
        else:
85
            device = torch.device("cpu")
86
87
            dtype = torch.float32

88
        tokenizer = AutoTokenizer.from_pretrained(
89
90
91
92
93
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
94
        )
95
96

        config = AutoConfig.from_pretrained(
97
98
99
100
101
            model_id,
            revision=revision,
            slow_but_exact=False,
            tp_parallel=True,
            trust_remote_code=trust_remote_code,
102
103
104
105
        )
        config.pad_token_id = 3

        torch.distributed.barrier(group=self.process_group)
106
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
107
108

        with init_empty_weights():
109
110
111
            model = AutoModelForCausalLM.from_config(
                config, trust_remote_code=trust_remote_code
            )
112
113
114
115
116
117

        torch.distributed.barrier(group=self.process_group)
        self.load_weights(
            model,
            filenames,
            quantize=quantize,
118
            device=device,
119
            dtype=dtype,
120
121
            rank=rank,
            world_size=world_size,
122
123
        )
        torch.distributed.barrier(group=self.process_group)
124
        super(CausalLM, self).__init__(
125
            model=model,
126
127
128
129
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
130
131
            rank=rank,
            world_size=world_size,
132
        )
133
134
135

    @staticmethod
    def load_weights(
136
137
        model,
        filenames: List[str],
138
        quantize: Optional[str],
139
        device: torch.device,
140
        dtype: torch.dtype,
141
142
        rank: int,
        world_size: int,
143
144
145
146
    ):
        parameters = dict(model.named_parameters())
        for file in filenames:
            with safe_open(
147
                file, framework="pt", device=str(device) if quantize is None else "cpu"
148
149
            ) as f:
                for name in f.keys():
150
151
152
153
                    if name.startswith("transformer.") or name.startswith("lm_head."):
                        full_name = name
                    else:
                        full_name = f"transformer.{name}"
154
155
156
157
158
159
160
161

                    module_name, param_name = full_name.rsplit(".", 1)
                    module = model.get_submodule(module_name)
                    current_tensor = parameters[full_name]

                    slice_ = f.get_slice(name)

                    if isinstance(module, TensorParallelColumnLinear):
162
163
164
165
166
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
167
168
169
170
171
172
173
174
175
176
177
178
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[1]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[:, start:stop]
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
179
180
181
182
                    elif (
                        isinstance(module, TensorParallelEmbedding)
                        or name == "lm_head.weight"
                    ):
183
184
185
186
187
188
189
190
191
192
193
194
195
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        tensor = slice_[:]

                    if current_tensor.shape != tensor.shape:
                        raise ValueError(
                            f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
                        )

196
                    tensor = tensor.contiguous().to(dtype)
197

198
                    if quantize == "bitsandbytes":
199
200
201
202
203
204
205
206
                        if not HAS_BITS_AND_BYTES:
                            raise ImportError(
                                "bitsandbytes is not available on your machine either because it is not installed "
                                "or you don't have a GPU.\n"
                                "You can install it with `pip install bitsandbytes`."
                            )

                        if (
207
208
209
                            type(module)
                            in [TensorParallelRowLinear, TensorParallelColumnLinear]
                            and param_name == "weight"
210
211
                        ):
                            tensor = Int8Params(
212
                                tensor,
213
214
215
216
217
218
219
220
221
222
223
224
225
                                has_fp16_weights=False,
                                requires_grad=False,
                            ).to(device)
                            state = bnb.MatmulLtState()
                            state.threshold = 6.0
                            state.has_fp16_weights = False
                            state.memory_efficient_backward = False
                            state.use_pool = True
                            state.CB = tensor.CB
                            state.SCB = tensor.SCB
                            tensor.CB = None
                            tensor.SCB = None

226
                            def replace_linear(state):
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
                                def linear(input, weight, bias):
                                    out = bnb.matmul(
                                        input,
                                        weight,
                                        state=state,
                                        threshold=state.threshold,
                                        bias=bias,
                                    )

                                    if state.CB is not None:
                                        # we converted 8-bit row major to turing/ampere format
                                        # in the first inference pass
                                        # we no longer need the row-major weight
                                        del state.CB
                                        weight.data = state.CxB

243
                                    return out
244
245
246

                                return linear

247
                            module.linear = replace_linear(state)
248
249
                        else:
                            tensor = tensor.to(device)
OlivierDehaene's avatar
OlivierDehaene committed
250
251
252
253
254
255
                    elif quantize == "gptq":
                        raise NotImplementedError("`gptq` is not implemented for now")
                    elif quantize is None:
                        tensor = tensor.to(device)
                    else:
                        raise ValueError(f"Unexpected quantize `{quantize}`")
256
257
258
259
260

                    module._parameters[param_name] = tensor
                    if name == "word_embeddings.weight":
                        model.lm_head._parameters["weight"] = tensor

261
262
263
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ):
264
265
266
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
267
            position_ids=position_ids,
268
269
270
271
272
            past_key_values=past_key_values,
            use_cache=True,
        )

        # Logits are sharded, so we need to gather them
OlivierDehaene's avatar
OlivierDehaene committed
273
274
275
        logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
        torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
        logits = torch.cat(logits, dim=2)
276

277
        return logits, outputs.past_key_values