bloom.py 9.27 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
from typing import List, Optional, Type
5
6
7

from accelerate import init_empty_weights
from safetensors import safe_open
8
9
10
11
12
13
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    PreTrainedTokenizerBase,
)
14
15
16
17
18
19
from transformers.models.bloom.parallel_layers import (
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
)

20
21
22
23
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
24
25
26
27
28
29
30
31
32
33
34
35
    initialize_torch_distributed,
    weight_files,
)

HAS_BITS_AND_BYTES = True
try:
    import bitsandbytes as bnb
    from bitsandbytes.nn import Int8Params
except Exception as e:
    HAS_BITS_AND_BYTES = False


36
37
38
class BloomCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
39
40
41
42
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
43
44
45
46
47
48
49
50
51
    ) -> "CausalLMBatch":
        batch = super(BloomCausalLMBatch, cls).from_pb(
            pb=pb, tokenizer=tokenizer, device=device
        )
        batch.keys_head_dim_last = False
        return batch


class BLOOM(CausalLM):
52
53
54
55
56
    def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
        super(BLOOM, self).__init__(
            model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
        )

57
58
59
60
61
62
    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return BloomCausalLMBatch


class BLOOMSharded(BLOOM):
63
    def __init__(
64
        self, model_id: str, revision: Optional[str] = None, quantize: bool = False
65
    ):
66
67
68
        self.process_group, self.rank, self.world_size = initialize_torch_distributed()
        self.master = self.rank == 0
        if torch.cuda.is_available():
69
            device = torch.device(f"cuda:{self.rank}")
70
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
71
        else:
72
            device = torch.device("cpu")
73
74
            dtype = torch.float32

75
        tokenizer = AutoTokenizer.from_pretrained(
76
            model_id, revision=revision, padding_side="left", truncation_side="left"
77
        )
78
79

        config = AutoConfig.from_pretrained(
80
            model_id, revision=revision, slow_but_exact=False, tp_parallel=True
81
82
83
84
        )
        config.pad_token_id = 3

        torch.distributed.barrier(group=self.process_group)
85
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
86
87
88
89
90
91
92
93
94

        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config)

        torch.distributed.barrier(group=self.process_group)
        self.load_weights(
            model,
            filenames,
            quantize=quantize,
95
            device=device,
96
            dtype=dtype,
97
98
99
            rank=self.rank,
            world_size=self.world_size,
        )
100
        self.model = model.eval()
101
        torch.distributed.barrier(group=self.process_group)
102
        super(CausalLM, self).__init__(
103
104
105
106
107
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            decode_buffer=1,
108
        )
109
110
111

    @staticmethod
    def load_weights(
112
113
114
115
        model,
        filenames: List[str],
        quantize: bool,
        device: torch.device,
116
        dtype: torch.dtype,
117
118
        rank: int,
        world_size: int,
119
120
121
122
    ):
        parameters = dict(model.named_parameters())
        for file in filenames:
            with safe_open(
123
                file, framework="pt", device=str(device) if not quantize else "cpu"
124
125
126
127
128
129
130
131
132
133
134
            ) as f:
                for name in f.keys():
                    full_name = f"transformer.{name}"

                    module_name, param_name = full_name.rsplit(".", 1)
                    module = model.get_submodule(module_name)
                    current_tensor = parameters[full_name]

                    slice_ = f.get_slice(name)

                    if isinstance(module, TensorParallelColumnLinear):
135
136
137
138
139
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[1]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[:, start:stop]
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
                    elif isinstance(module, TensorParallelEmbedding):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        tensor = slice_[:]

                    if current_tensor.shape != tensor.shape:
                        raise ValueError(
                            f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
                        )

166
                    tensor = tensor.contiguous().to(dtype)
167
168
169
170
171
172
173
174
175
176

                    if quantize:
                        if not HAS_BITS_AND_BYTES:
                            raise ImportError(
                                "bitsandbytes is not available on your machine either because it is not installed "
                                "or you don't have a GPU.\n"
                                "You can install it with `pip install bitsandbytes`."
                            )

                        if (
177
178
179
                            type(module)
                            in [TensorParallelRowLinear, TensorParallelColumnLinear]
                            and param_name == "weight"
180
181
                        ):
                            tensor = Int8Params(
182
                                tensor,
183
184
185
186
187
188
189
190
191
192
193
194
195
                                has_fp16_weights=False,
                                requires_grad=False,
                            ).to(device)
                            state = bnb.MatmulLtState()
                            state.threshold = 6.0
                            state.has_fp16_weights = False
                            state.memory_efficient_backward = False
                            state.use_pool = True
                            state.CB = tensor.CB
                            state.SCB = tensor.SCB
                            tensor.CB = None
                            tensor.SCB = None

196
                            def replace_linear(state):
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
                                def linear(input, weight, bias):
                                    out = bnb.matmul(
                                        input,
                                        weight,
                                        state=state,
                                        threshold=state.threshold,
                                        bias=bias,
                                    )

                                    if state.CB is not None:
                                        # we converted 8-bit row major to turing/ampere format
                                        # in the first inference pass
                                        # we no longer need the row-major weight
                                        del state.CB
                                        weight.data = state.CxB

213
                                    return out
214
215
216

                                return linear

217
                            module.linear = replace_linear(state)
218
219
220
221
222
223
224
225

                        else:
                            tensor = tensor.to(device)

                    module._parameters[param_name] = tensor
                    if name == "word_embeddings.weight":
                        model.lm_head._parameters["weight"] = tensor

226
227
228
229
230
231
232
233
234
        uninitialized_parameters = []
        for n, p in model.named_parameters():
            if p.data.device == torch.device("meta"):
                uninitialized_parameters.append(n)
        if uninitialized_parameters:
            raise RuntimeError(
                f"found uninitialized parameters in model: {uninitialized_parameters}"
            )

235
236
237
    def forward(
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
    ):
238
239
240
        outputs = self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
241
            position_ids=position_ids,
242
243
244
245
246
            past_key_values=past_key_values,
            use_cache=True,
        )

        # Logits are sharded, so we need to gather them
OlivierDehaene's avatar
OlivierDehaene committed
247
248
249
        logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
        torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
        logits = torch.cat(logits, dim=2)
250

251
        return logits, outputs.past_key_values