causal_lm.py 31.7 KB
Newer Older
1
import torch
2
import time
3
import torch.distributed
4

5
from dataclasses import dataclass
6
from opentelemetry import trace
7
8
9
10
11
12
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    PreTrainedTokenizerBase,
)
13
from typing import Optional, Tuple, List, Type, Dict
14

15
16
17
18
19
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
20
from text_generation_server.models import Model
Daniël de Kok's avatar
Daniël de Kok committed
21
from text_generation_server.utils.chunks import concat_text_chunks
22
from text_generation_server.utils.import_utils import SYSTEM
23
from text_generation_server.utils.tokens import batch_top_tokens
24
25
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
26
    Tokens,
27
28
29
30
31
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
32

33
34
tracer = trace.get_tracer(__name__)

35
36

@dataclass
37
class CausalLMBatch(Batch):
38
39
    batch_id: int
    requests: List[generate_pb2.Request]
40
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
41
42
43
44

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
45
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
46
47
48
    past_key_values: Optional[List[Tuple]]

    # All tokens
49
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
50
51
52

    # Lengths of all generations present in the batch
    input_lengths: List[int]
53
54
    prefix_offsets: List[int]
    read_offsets: List[int]
OlivierDehaene's avatar
OlivierDehaene committed
55
56

    # Generation helpers
57
58
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
59
60
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
61
62

    # Metadata used for padding
63
    max_input_length: int
64
    padding_right_offset: int
65

66
67
68
    # Maximum number of tokens this batch will grow to
    max_tokens: int

69
70
71
    # Past metadata
    keys_head_dim_last: bool = True

72
73
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
74
            id=self.batch_id,
75
            request_ids=[r.id for r in self.requests],
76
            size=len(self),
77
            max_tokens=self.max_tokens,
78
79
80
81
        )

    @classmethod
    def from_pb(
82
83
84
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
85
        dtype: torch.dtype,
86
        device: torch.device,
87
88
89
90
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
91
        top_n_tokens = []
92
93
        prefix_offsets = []
        read_offsets = []
94
        requests_idx_mapping = {}
95
96

        # Parse batch
97
        max_truncation = 0
98
        padding_right_offset = 0
99
        max_decode_tokens = 0
100
101
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
Daniël de Kok's avatar
Daniël de Kok committed
102
103
            inputs.append(concat_text_chunks(r.input_chunks.chunks))

drbh's avatar
drbh committed
104
105
106
            next_token_choosers.append(
                NextTokenChooser.from_pb(r.parameters, device, tokenizer)
            )
107
108
109
110
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
111
            top_n_tokens.append(r.top_n_tokens)
112
            max_truncation = max(max_truncation, r.truncate)
113
            max_decode_tokens += stopping_criteria.max_new_tokens
114
115
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
116
117
            )

OlivierDehaene's avatar
OlivierDehaene committed
118
        tokenized_inputs = tokenizer(
119
120
121
            inputs,
            return_tensors="pt",
            padding=True,
122
            return_token_type_ids=False,
123
124
            truncation=True,
            max_length=max_truncation,
125
        ).to(device)
126
127
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
128
            prefix_offsets.append(input_len - 5)
129
            read_offsets.append(input_len)
130

131
132
133
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

134
135
136
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
137
            (pb.size, max_input_length + padding_right_offset)
138
139
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
140
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
141

142
143
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
144
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
Nicolas Patry's avatar
Nicolas Patry committed
145
146
147
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
148

149
        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
150

151
152
153
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
154
            requests_idx_mapping=requests_idx_mapping,
155
156
            input_ids=input_ids,
            attention_mask=attention_mask,
157
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
158
            past_key_values=None,
159
            all_input_ids=list(all_input_ids),
160
            input_lengths=input_lengths.tolist(),
161
162
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
163
164
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
165
166
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
167
            max_input_length=max_input_length.item(),
168
            padding_right_offset=padding_right_offset,
169
            max_tokens=max_tokens,
170
171
        )

172
    @tracer.start_as_current_span("filter")
173
174
    def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
        if len(request_ids) == 0:
175
            raise ValueError("Batch must have at least one request")
176
        if len(request_ids) == len(self):
177
178
179
180
181
182
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
183
        requests = []
184
        input_lengths = []
185
186
        prefix_offsets = []
        read_offsets = []
187
188
189
190
191
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
192
        top_n_tokens = []
193

194
        total_remaining_decode_tokens = 0
195
196
        new_padding_right_offset = 0

197
198
199
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
200
201
            keep_indices.append(idx)

202
            requests.append(self.requests[idx])
203
204
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
205
206
207
208
209
210
211
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
212
213
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
214
            top_n_tokens.append(self.top_n_tokens[idx])
215
            remaining_decode_tokens = (
216
217
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
218
219
220
221
            total_remaining_decode_tokens += remaining_decode_tokens
            new_padding_right_offset = max(
                new_padding_right_offset, remaining_decode_tokens
            )
222
223
224
225

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        position_ids = self.position_ids[keep_indices]
226
227
        self.attention_mask = self.attention_mask[
            keep_indices,
228
229
230
231
            -(self.padding_right_offset + max_input_length) : (
                self.attention_mask.shape[1] - self.padding_right_offset
            )
            + new_padding_right_offset,
232
233
        ]

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
            self.past_key_values = [list(layer) for layer in self.past_key_values]

        # Update tensors in-place to allow incremental garbage collection
        past_kv_length = max_input_length - 1
        for layer in self.past_key_values:
            past_keys, past_values = layer
            if len(past_keys.shape) == 3:
                # Force past to be of dim [self_size, num_heads, ...] for easy indexing
                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
            if self.keys_head_dim_last:
                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
            else:
                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
            del past_keys
            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
            del past_values

Nicolas Patry's avatar
Nicolas Patry committed
254
        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
255
        max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
256

257
258
259
260
261
262
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.all_input_ids = all_input_ids
        self.input_lengths = input_lengths
263
264
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
265
266
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
Nicolas Patry's avatar
Nicolas Patry committed
267
268
        self.top_n_tokens = top_n_tokens
        self.top_n_tokens_tensor = top_n_tokens_tensor
269
270
        self.max_input_length = max_input_length
        self.padding_right_offset = new_padding_right_offset
271
        self.max_tokens = max_tokens
272
273

        return self
274

275
    @classmethod
276
    @tracer.start_as_current_span("concatenate")
277
278
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
279
        total_batch_size = 0
280
        max_input_length = 0
281
282
        padding_right_offset = 0
        for batch in batches:
283
            total_batch_size += len(batch)
284
            max_input_length = max(max_input_length, batch.max_input_length)
285
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
286
287
288

        # Batch attributes
        requests = []
289
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
290
        input_lengths = []
291
292
        prefix_offsets = []
        read_offsets = []
293
294
295
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
296
        top_n_tokens = []
297
        max_tokens = 0
298

OlivierDehaene's avatar
OlivierDehaene committed
299
300
301
        # Batch tensors
        input_ids = None
        attention_mask = None
302
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
303
        past_key_values = []
Nicolas Patry's avatar
Nicolas Patry committed
304
        top_n_tokens_tensor = None
OlivierDehaene's avatar
OlivierDehaene committed
305

306
307
308
309
310
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
311
            input_lengths.extend(batch.input_lengths)
312
313
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
314
315
316
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)
Nicolas Patry's avatar
Nicolas Patry committed
317
            top_n_tokens.extend(batch.top_n_tokens)
318

319
320
321
322
323
324
325
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

326
            # Slicing end index for this batch
327
            end_index = start_index + len(batch)
328
329

            # We only concatenate batches that did at least one step
330
331
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
332

OlivierDehaene's avatar
OlivierDehaene committed
333
334
335
336
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
337
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
338
339
340
341
342
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
343
                attention_mask = batch.attention_mask.new_zeros(
344
                    (total_batch_size, max_input_length + padding_right_offset),
345
346
                )

Nicolas Patry's avatar
Nicolas Patry committed
347
348
349
350
351
352
            if top_n_tokens_tensor is None:
                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
                    total_batch_size,
                )
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor

353
            # We need to slice the attention mask to remove padding from previous steps
354
            # and to remove unused allocated space
355
            left_offset = max_input_length - batch.max_input_length
356
            batch_left_offset = (
357
                batch.attention_mask.shape[1]
358
                - batch.max_input_length
359
                - batch.padding_right_offset
360
            )
OlivierDehaene's avatar
OlivierDehaene committed
361
            attention_mask[
362
363
364
365
366
367
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
368

369
370
371
372
373
374
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

375
376
377
378
379
380
            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
            # And ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
                batch.past_key_values = [
381
382
                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
                    for layer in batch.past_key_values
383
                ]
384
            elif len(batch.past_key_values[0][0].shape) == 3:
385
386
387
388
                for layer in batch.past_key_values:
                    for k, t in enumerate(layer):
                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])

389
390
391
392
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length - batch.max_input_length
            ) * len(batch)
393

394
395
            start_index = end_index

396
397
398
399
400
401
402
403
404
        first_past_kvs = batches[0].past_key_values
        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape

        padded_past_values_shape = (
            total_batch_size,
            num_heads,
            max_input_length - 1,
            head_dim,
        )
405

406
407
408
409
410
411
412
413
414
415
        if batches[0].keys_head_dim_last:
            padded_past_keys_shape = padded_past_values_shape
        else:
            # seq_length is last for BLOOM
            padded_past_keys_shape = (
                total_batch_size,
                num_heads,
                head_dim,
                max_input_length - 1,
            )
416

417
418
419
420
421
422
423
424
425
426
427
428
429
430
        # Iterate over attention layers
        # Concatenate past key values layer by layer to allow incremental garbage collection
        for j in range(len(first_past_kvs)):
            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
            start_index = 0
            for batch in batches:
                past_keys = batch.past_key_values[j][0]
                # Clear reference to the original tensor
                batch.past_key_values[j][0] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the keys to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
431
                if batch.keys_head_dim_last:
drbh's avatar
drbh committed
432
433
434
                    padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
                        past_keys[:, :, -past_seq_len:, :]
                    )
435
                else:
436
                    # BLOOM case
drbh's avatar
drbh committed
437
438
439
                    padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
                        past_keys[:, :, :, -past_seq_len:]
                    )
440
441
442
443
                del past_keys

                start_index = end_index

444
445
446
            padded_past_values = first_past_kvs[j][1].new_zeros(
                padded_past_values_shape
            )
447
448
449
450
451
452
453
454
455
456
            start_index = 0
            for batch in batches:
                past_values = batch.past_key_values[j][1]
                # Clear reference to the original tensor
                batch.past_key_values[j][1] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the past values to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
drbh's avatar
drbh committed
457
458
459
                padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
                    past_values[:, :, -past_seq_len:, :]
                )
460
461
                del past_values

462
                # Update values
463
464
465
                start_index = end_index

            past_key_values.append([padded_past_keys, padded_past_values])
466
467
468
469

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
470
            requests_idx_mapping=requests_idx_mapping,
471
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
472
            attention_mask=attention_mask,
473
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
474
            past_key_values=past_key_values,
475
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
476
            input_lengths=input_lengths,
477
478
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
479
480
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
481
482
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
483
            max_input_length=max_input_length,
484
            padding_right_offset=padding_right_offset,
485
            keys_head_dim_last=batches[0].keys_head_dim_last,
486
            max_tokens=max_tokens,
487
        )
488

489
490
491
    def __len__(self):
        return len(self.requests)

492

493
494
495
496
497
@dataclass
class CausalLMBatchKeysLast(Batch):
    keys_head_dim_last: bool = False


498
class CausalLM(Model):
499
500
501
    def __init__(
        self,
        model_id: str,
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        default_dtype=torch.float16,
        trust_remote_code: bool = False,
        tokenizer_class=AutoTokenizer,
        config_class=AutoConfig,
        batch_class=CausalLMBatch,
    ):
        self.batch_class = batch_class
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
                dtype = default_dtype if dtype is None else dtype
            else:
                device = torch.device("cpu")
                # Float16 doesn't exist on target.
                dtype = torch.bfloat16 if dtype is None else dtype
        else:
            device = torch.device("cpu")
            dtype = torch.float32 if dtype is None else dtype

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )

        config = config_class.from_pretrained(
            model_id,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )
        config.quantize = quantize
        config.speculator = speculator
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = config.pad_token_id

        torch.distributed.barrier(group=self.process_group)
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
            filenames, device=device, dtype=dtype, process_group=self.process_group
        )
        if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
            weights._set_gptq_params(model_id, revision)

556
557
        prefix = ""
        model = model_class(prefix, config, weights)
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574

        torch.distributed.barrier(group=self.process_group)
        super().__init__(
            model_id=model_id,
            model=model,
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
            rank=rank,
            world_size=world_size,
        )

    @classmethod
    def fallback(
        cls,
        model_id: str,
575
        revision: Optional[str] = None,
576
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
577
        speculator: Optional[str] = None,
578
        dtype: Optional[torch.dtype] = None,
579
        trust_remote_code: bool = False,
580
    ):
Nicolas Patry's avatar
Nicolas Patry committed
581
582
        if speculator:
            raise RuntimeError("Speculator decoding is not enabled for AutoModel")
OlivierDehaene's avatar
OlivierDehaene committed
583

584
585
        if torch.cuda.is_available():
            device = torch.device("cuda")
586
            dtype = torch.float16 if dtype is None else dtype
587
        else:
588
589
590
            if quantize:
                raise ValueError("quantization is not available on CPU")

591
            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
592
            dtype = torch.float32 if dtype is None else dtype
593

594
        tokenizer = AutoTokenizer.from_pretrained(
595
596
597
598
599
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
600
        )
601
        model = AutoModelForCausalLM.from_pretrained(
602
            model_id,
603
            revision=revision,
604
            torch_dtype=dtype,
drbh's avatar
drbh committed
605
606
607
608
609
            device_map=(
                "auto"
                if torch.cuda.is_available() and torch.cuda.device_count() > 1
                else None
            ),
610
            load_in_8bit=quantize == "bitsandbytes",
611
            trust_remote_code=trust_remote_code,
612
        )
OlivierDehaene's avatar
OlivierDehaene committed
613
614
615
616
617
        if (
            torch.cuda.is_available()
            and torch.cuda.device_count() == 1
            and quantize != "bitsandbytes"
        ):
618
619
            model = model.cuda()

620
621
622
623
624
625
626
627
628
629
        if tokenizer.pad_token_id is None:
            if model.config.pad_token_id is not None:
                tokenizer.pad_token_id = model.config.pad_token_id
            elif model.config.eos_token_id is not None:
                tokenizer.pad_token_id = model.config.eos_token_id
            elif tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})

630
631
632
633
634
635
        self = cls.__new__(
            cls,
        )
        self.batch_class = CausalLMBatch
        super().__init__(
            self,
drbh's avatar
drbh committed
636
            model_id=model_id,
637
            model=model,
638
639
640
641
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
642
        )
643
        return self
644
645
646

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
647
        return self.batch_class
648

649
    def forward(
650
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
651
652
653
    ) -> Tuple[
        torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
    ]:
654
        # Model Forward
655
656
657
658
659
660
661
662
663
664
665
        kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "use_cache": True,
            "return_dict": True,
        }
        if self.has_position_ids:
            kwargs["position_ids"] = position_ids

        outputs = self.model.forward(**kwargs)
666
667
668
669
670
        if isinstance(outputs, tuple):
            outputs, speculative_logits = outputs
        else:
            speculative_logits = None
        return outputs.logits, speculative_logits, outputs.past_key_values
671

672
    @tracer.start_as_current_span("generate_token")
673
674
    def generate_token(
        self, batch: CausalLMBatch
675
676
    ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
677
678
679
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

680
        logits, speculative_logits, past = self.forward(
681
            batch.input_ids,
682
            attention_mask,
683
684
            batch.position_ids,
            batch.past_key_values,
685
686
        )

687
688
        # Results
        generations: List[Generation] = []
689
        stopped = True
690

Nicolas Patry's avatar
Nicolas Patry committed
691
692
        # Speculation is not active for causal
        accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
Nicolas Patry's avatar
Nicolas Patry committed
693
694
695
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens,
            batch.top_n_tokens_tensor,
696
            torch.log_softmax(logits[:, -1], -1),
Nicolas Patry's avatar
Nicolas Patry committed
697
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
698
699
        )

700
701
        start_decode = time.time_ns()

702
703
704
        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
705
            batch.input_lengths,
706
707
            batch.prefix_offsets,
            batch.read_offsets,
708
709
710
711
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
712
713
714
            batch.top_n_tokens,
            batch_top_token_ids,
            batch_top_token_logprobs,
715
716
717
718
719
720
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
721
722
            prefix_offset,
            read_offset,
723
724
725
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
726
            all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
727
728
729
            top_n_tokens,
            top_token_ids,
            top_token_logprobs,
730
731
        ) in enumerate(iterator):
            # Select next token
732
            next_token_id, logprobs = next_token_chooser(
733
                all_input_ids.view(1, -1), logits[-1:, :]
734
            )
735
736

            # Append next token to all tokens
737
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
738
739
            new_input_length = input_length + 1

740
741
742
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
743
744
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_input_ids[:, 0], prefix_offset, read_offset
745
            )
746
747

            # Evaluate stopping criteria
748
            stop, reason = stopping_criteria(
749
750
                next_token_id_squeezed,
                next_token_text,
751
            )
752

753
            if not stop:
754
                stopped = False
755

756
757
758
759
760
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
761
762
                    output_text, _, _ = self.decode_token(
                        all_input_ids[:, 0],
OlivierDehaene's avatar
OlivierDehaene committed
763
764
765
766
767
768
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
769
770
771
772
773
774
775
776
777
778
779
780
781
782
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
783
                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
784
785
786
787
788
789
790
791
792
793
794
795
                    # Remove generated token to only have prefill and add nan for first prompt token
                    prefill_logprobs = [float("nan")] + torch.log_softmax(
                        logits, -1
                    ).gather(1, all_input_ids[1:]).squeeze(1)[
                        -new_input_length:-1
                    ].tolist()
                    prefill_token_ids = all_input_ids[-new_input_length:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
796
                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
797
798
799
800
                        prefill_token_ids,
                        prefill_logprobs,
                        prefill_texts,
                        is_special=[],
801
802
803
804
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
805
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
806
                    all_top_tokens = []
drbh's avatar
drbh committed
807
                    for top_token_ids, top_token_logprobs in zip(
808
809
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
810
811
812
813
814
815
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
816
817
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
818
819
820
821
822
823
824
825
826
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
827
828
829
                else:
                    top_tokens = None

830
831
832
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
833
                    Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
834
835
836
837
838
                        [next_token_id_squeezed],
                        [next_token_logprob],
                        [next_token_text],
                        [next_token_id_squeezed.item() in self.all_special_ids],
                    ),
839
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
840
                    top_tokens,
841
842
                )

843
                generations.append(generation)
844

845
            # Update values
drbh's avatar
drbh committed
846
847
848
            batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
                next_token_id_squeezed.item()
            )
849
850
851
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
852
853
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
854
855
            batch.max_input_length = max(batch.max_input_length, new_input_length)

856
        # We finished all generations in the batch; there is no next batch
857
        if stopped:
858
859
860
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
861

862
863
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
864

865
        # Update attention_mask as we added a new token to input_ids
866
867
868
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
869

870
        # Update position_ids
871
872
873
874
875
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

876
877
878
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)