bloom.py 9.88 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
from typing import List, Optional, Type
5
6
7

from accelerate import init_empty_weights
from safetensors import safe_open
8
9
10
11
12
13
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    PreTrainedTokenizerBase,
)
14
15
16
17
18
19
from transformers.models.bloom.parallel_layers import (
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
)

20
21
22
23
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
24
25
26
27
28
29
30
31
32
33
34
35
    initialize_torch_distributed,
    weight_files,
)

HAS_BITS_AND_BYTES = True
try:
    import bitsandbytes as bnb
    from bitsandbytes.nn import Int8Params
except Exception as e:
    HAS_BITS_AND_BYTES = False


36
37
38
class BloomCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
39
40
41
42
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
43
44
45
46
47
48
49
50
51
    ) -> "CausalLMBatch":
        batch = super(BloomCausalLMBatch, cls).from_pb(
            pb=pb, tokenizer=tokenizer, device=device
        )
        batch.keys_head_dim_last = False
        return batch


class BLOOM(CausalLM):
52
53
54
55
56
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
57
        trust_remote_code: bool = False,
58
    ):
59
        super(BLOOM, self).__init__(
60
61
62
63
            model_id=model_id,
            revision=revision,
            quantize=quantize,
            trust_remote_code=trust_remote_code,
64
65
        )

66
67
68
69
70
71
    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return BloomCausalLMBatch


class BLOOMSharded(BLOOM):
72
    def __init__(
73
74
75
76
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
77
        trust_remote_code: bool = False,
78
    ):
79
        self.process_group, rank, world_size = initialize_torch_distributed()
80
        if torch.cuda.is_available():
81
            device = torch.device(f"cuda:{rank}")
82
            dtype = torch.float16
83
        else:
84
            device = torch.device("cpu")
85
86
            dtype = torch.float32

87
        tokenizer = AutoTokenizer.from_pretrained(
88
89
90
91
92
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
93
        )
94
95

        config = AutoConfig.from_pretrained(
96
97
98
99
100
            model_id,
            revision=revision,
            slow_but_exact=False,
            tp_parallel=True,
            trust_remote_code=trust_remote_code,
101
102
103
104
        )
        config.pad_token_id = 3

        torch.distributed.barrier(group=self.process_group)
105
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
106
107

        with init_empty_weights():
108
109
110
            model = AutoModelForCausalLM.from_config(
                config, trust_remote_code=trust_remote_code
            )
111
112
113
114
115
116

        torch.distributed.barrier(group=self.process_group)
        self.load_weights(
            model,
            filenames,
            quantize=quantize,
117
            device=device,
118
            dtype=dtype,
119
120
            rank=rank,
            world_size=world_size,
121
122
        )
        torch.distributed.barrier(group=self.process_group)
123
        super(CausalLM, self).__init__(
124
            model=model,
125
126
127
128
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
129
130
            rank=rank,
            world_size=world_size,
131
        )
132
133
134

    @staticmethod
    def load_weights(
135
136
        model,
        filenames: List[str],
137
        quantize: Optional[str],
138
        device: torch.device,
139
        dtype: torch.dtype,
140
141
        rank: int,
        world_size: int,
142
143
144
145
    ):
        parameters = dict(model.named_parameters())
        for file in filenames:
            with safe_open(
146
                file, framework="pt", device=str(device) if quantize is None else "cpu"
147
148
            ) as f:
                for name in f.keys():
149
150
151
152
                    if name.startswith("transformer.") or name.startswith("lm_head."):
                        full_name = name
                    else:
                        full_name = f"transformer.{name}"
153
154
155
156
157
158
159
160

                    module_name, param_name = full_name.rsplit(".", 1)
                    module = model.get_submodule(module_name)
                    current_tensor = parameters[full_name]

                    slice_ = f.get_slice(name)

                    if isinstance(module, TensorParallelColumnLinear):
161
162
163
164
165
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
166
167
168
169
170
171
172
173
174
175
176
177
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[1]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[:, start:stop]
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
178
179
180
181
                    elif (
                        isinstance(module, TensorParallelEmbedding)
                        or name == "lm_head.weight"
                    ):
182
183
184
185
186
187
188
189
190
191
192
193
194
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        tensor = slice_[:]

                    if current_tensor.shape != tensor.shape:
                        raise ValueError(
                            f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
                        )

195
                    tensor = tensor.contiguous().to(dtype)
196

197
                    if quantize == "bitsandbytes":
198
199
200
201
202
203
204
205
                        if not HAS_BITS_AND_BYTES:
                            raise ImportError(
                                "bitsandbytes is not available on your machine either because it is not installed "
                                "or you don't have a GPU.\n"
                                "You can install it with `pip install bitsandbytes`."
                            )

                        if (
206
207
208
                            type(module)
                            in [TensorParallelRowLinear, TensorParallelColumnLinear]
                            and param_name == "weight"
209
210
                        ):
                            tensor = Int8Params(
211
                                tensor,
212
213
214
215
216
217
218
219
220
221
222
223
224
                                has_fp16_weights=False,
                                requires_grad=False,
                            ).to(device)
                            state = bnb.MatmulLtState()
                            state.threshold = 6.0
                            state.has_fp16_weights = False
                            state.memory_efficient_backward = False
                            state.use_pool = True
                            state.CB = tensor.CB
                            state.SCB = tensor.SCB
                            tensor.CB = None
                            tensor.SCB = None

225
                            def replace_linear(state):
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
                                def linear(input, weight, bias):
                                    out = bnb.matmul(
                                        input,
                                        weight,
                                        state=state,
                                        threshold=state.threshold,
                                        bias=bias,
                                    )

                                    if state.CB is not None:
                                        # we converted 8-bit row major to turing/ampere format
                                        # in the first inference pass
                                        # we no longer need the row-major weight
                                        del state.CB
                                        weight.data = state.CxB

242
                                    return out
243
244
245

                                return linear

246
                            module.linear = replace_linear(state)
247
248
249
250
251
                        elif quantize == "gptq":
                            raise NotImplementedError(
                                "`gptq` is not implemented for now"
                            )
                        elif quantize is None:
252
                            tensor = tensor.to(device)
253
254
                        else:
                            raise ValueError(f"Unexpected quantize `{quantize}`")
255
256
257
258
259

                    module._parameters[param_name] = tensor
                    if name == "word_embeddings.weight":
                        model.lm_head._parameters["weight"] = tensor

260
261
262
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ):
263
264
265
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
266
            position_ids=position_ids,
267
268
269
270
271
            past_key_values=past_key_values,
            use_cache=True,
        )

        # Logits are sharded, so we need to gather them
OlivierDehaene's avatar
OlivierDehaene committed
272
273
274
        logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
        torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
        logits = torch.cat(logits, dim=2)
275

276
        return logits, outputs.past_key_values