causal_lm.py 27.7 KB
Newer Older
1
import torch
2
import time
3

4
from dataclasses import dataclass
5
from opentelemetry import trace
6
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
7
from typing import Optional, Tuple, List, Type, Dict
8

9
from text_generation_server.models import Model
10
from text_generation_server.utils.tokens import batch_top_tokens
11
12
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
13
    Tokens,
14
15
16
17
18
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
19

20
21
tracer = trace.get_tracer(__name__)

22
23

@dataclass
24
class CausalLMBatch(Batch):
25
26
    batch_id: int
    requests: List[generate_pb2.Request]
27
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
28
29
30
31

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
32
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
33
34
35
    past_key_values: Optional[List[Tuple]]

    # All tokens
36
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
37
38
39

    # Lengths of all generations present in the batch
    input_lengths: List[int]
40
41
    prefix_offsets: List[int]
    read_offsets: List[int]
OlivierDehaene's avatar
OlivierDehaene committed
42
43

    # Generation helpers
44
45
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
46
47
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
48
49

    # Metadata used for padding
50
    max_input_length: int
51
    padding_right_offset: int
52

53
54
55
    # Maximum number of tokens this batch will grow to
    max_tokens: int

56
57
58
    # Past metadata
    keys_head_dim_last: bool = True

59
60
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
61
            id=self.batch_id,
62
            request_ids=[r.id for r in self.requests],
63
            size=len(self),
64
            max_tokens=self.max_tokens,
65
66
67
68
        )

    @classmethod
    def from_pb(
69
70
71
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
72
        dtype: torch.dtype,
73
        device: torch.device,
74
75
76
77
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
78
        top_n_tokens = []
79
80
        prefix_offsets = []
        read_offsets = []
81
        requests_idx_mapping = {}
82
83

        # Parse batch
84
        max_truncation = 0
85
        padding_right_offset = 0
86
        max_decode_tokens = 0
87
88
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
89
            inputs.append(r.inputs)
90
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
91
92
93
94
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
95
            top_n_tokens.append(r.top_n_tokens)
96
            max_truncation = max(max_truncation, r.truncate)
97
            max_decode_tokens += stopping_criteria.max_new_tokens
98
99
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
100
101
            )

OlivierDehaene's avatar
OlivierDehaene committed
102
        tokenized_inputs = tokenizer(
103
104
105
            inputs,
            return_tensors="pt",
            padding=True,
106
            return_token_type_ids=False,
107
108
            truncation=True,
            max_length=max_truncation,
109
        ).to(device)
110
111
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
112
            prefix_offsets.append(input_len - 5)
113
            read_offsets.append(input_len)
114

115
116
117
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

118
119
120
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
121
            (pb.size, max_input_length + padding_right_offset)
122
123
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
124
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
125

126
127
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
128
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
Nicolas Patry's avatar
Nicolas Patry committed
129
130
131
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
132

133
        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
134

135
136
137
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
138
            requests_idx_mapping=requests_idx_mapping,
139
140
            input_ids=input_ids,
            attention_mask=attention_mask,
141
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
142
            past_key_values=None,
143
            all_input_ids=list(all_input_ids),
144
            input_lengths=input_lengths.tolist(),
145
146
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
147
148
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
149
150
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
151
            max_input_length=max_input_length.item(),
152
            padding_right_offset=padding_right_offset,
153
            max_tokens=max_tokens,
154
155
        )

156
    @tracer.start_as_current_span("filter")
157
158
    def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
        if len(request_ids) == 0:
159
            raise ValueError("Batch must have at least one request")
160
        if len(request_ids) == len(self):
161
162
163
164
165
166
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
167
        requests = []
168
        input_lengths = []
169
170
        prefix_offsets = []
        read_offsets = []
171
172
173
174
175
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
176
        top_n_tokens = []
177

178
        total_remaining_decode_tokens = 0
179
180
        new_padding_right_offset = 0

181
182
183
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
184
185
            keep_indices.append(idx)

186
            requests.append(self.requests[idx])
187
188
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
189
190
191
192
193
194
195
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
196
197
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
198
            top_n_tokens.append(self.top_n_tokens[idx])
199
            remaining_decode_tokens = (
200
201
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
202
203
204
205
            total_remaining_decode_tokens += remaining_decode_tokens
            new_padding_right_offset = max(
                new_padding_right_offset, remaining_decode_tokens
            )
206
207
208
209

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        position_ids = self.position_ids[keep_indices]
210
211
        self.attention_mask = self.attention_mask[
            keep_indices,
212
213
214
215
            -(self.padding_right_offset + max_input_length) : (
                self.attention_mask.shape[1] - self.padding_right_offset
            )
            + new_padding_right_offset,
216
217
        ]

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
            self.past_key_values = [list(layer) for layer in self.past_key_values]

        # Update tensors in-place to allow incremental garbage collection
        past_kv_length = max_input_length - 1
        for layer in self.past_key_values:
            past_keys, past_values = layer
            if len(past_keys.shape) == 3:
                # Force past to be of dim [self_size, num_heads, ...] for easy indexing
                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
            if self.keys_head_dim_last:
                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
            else:
                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
            del past_keys
            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
            del past_values

Nicolas Patry's avatar
Nicolas Patry committed
238
        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
239
        max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
240

241
242
243
244
245
246
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.all_input_ids = all_input_ids
        self.input_lengths = input_lengths
247
248
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
249
250
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
Nicolas Patry's avatar
Nicolas Patry committed
251
252
        self.top_n_tokens = top_n_tokens
        self.top_n_tokens_tensor = top_n_tokens_tensor
253
254
        self.max_input_length = max_input_length
        self.padding_right_offset = new_padding_right_offset
255
        self.max_tokens = max_tokens
256
257

        return self
258

259
    @classmethod
260
    @tracer.start_as_current_span("concatenate")
261
262
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
263
        total_batch_size = 0
264
        max_input_length = 0
265
266
        padding_right_offset = 0
        for batch in batches:
267
            total_batch_size += len(batch)
268
            max_input_length = max(max_input_length, batch.max_input_length)
269
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
270
271
272

        # Batch attributes
        requests = []
273
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
274
        input_lengths = []
275
276
        prefix_offsets = []
        read_offsets = []
277
278
279
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
280
        top_n_tokens = []
281
        max_tokens = 0
282

OlivierDehaene's avatar
OlivierDehaene committed
283
284
285
        # Batch tensors
        input_ids = None
        attention_mask = None
286
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
287
        past_key_values = []
Nicolas Patry's avatar
Nicolas Patry committed
288
        top_n_tokens_tensor = None
OlivierDehaene's avatar
OlivierDehaene committed
289

290
291
292
293
294
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
295
            input_lengths.extend(batch.input_lengths)
296
297
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
298
299
300
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)
Nicolas Patry's avatar
Nicolas Patry committed
301
            top_n_tokens.extend(batch.top_n_tokens)
302

303
304
305
306
307
308
309
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

310
            # Slicing end index for this batch
311
            end_index = start_index + len(batch)
312
313

            # We only concatenate batches that did at least one step
314
315
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
316

OlivierDehaene's avatar
OlivierDehaene committed
317
318
319
320
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
321
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
322
323
324
325
326
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
327
                attention_mask = batch.attention_mask.new_zeros(
328
                    (total_batch_size, max_input_length + padding_right_offset),
329
330
                )

Nicolas Patry's avatar
Nicolas Patry committed
331
332
333
334
335
336
            if top_n_tokens_tensor is None:
                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
                    total_batch_size,
                )
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor

337
            # We need to slice the attention mask to remove padding from previous steps
338
            # and to remove unused allocated space
339
            left_offset = max_input_length - batch.max_input_length
340
            batch_left_offset = (
341
                batch.attention_mask.shape[1]
342
                - batch.max_input_length
343
                - batch.padding_right_offset
344
            )
OlivierDehaene's avatar
OlivierDehaene committed
345
            attention_mask[
346
347
348
349
350
351
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
352

353
354
355
356
357
358
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

359
360
361
362
363
364
            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
            # And ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
                batch.past_key_values = [
365
366
                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
                    for layer in batch.past_key_values
367
                ]
368
            elif len(batch.past_key_values[0][0].shape) == 3:
369
370
371
372
                for layer in batch.past_key_values:
                    for k, t in enumerate(layer):
                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])

373
374
375
376
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length - batch.max_input_length
            ) * len(batch)
377

378
379
            start_index = end_index

380
381
382
383
384
385
386
387
388
        first_past_kvs = batches[0].past_key_values
        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape

        padded_past_values_shape = (
            total_batch_size,
            num_heads,
            max_input_length - 1,
            head_dim,
        )
389

390
391
392
393
394
395
396
397
398
399
        if batches[0].keys_head_dim_last:
            padded_past_keys_shape = padded_past_values_shape
        else:
            # seq_length is last for BLOOM
            padded_past_keys_shape = (
                total_batch_size,
                num_heads,
                head_dim,
                max_input_length - 1,
            )
400

401
402
403
404
405
406
407
408
409
410
411
412
413
414
        # Iterate over attention layers
        # Concatenate past key values layer by layer to allow incremental garbage collection
        for j in range(len(first_past_kvs)):
            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
            start_index = 0
            for batch in batches:
                past_keys = batch.past_key_values[j][0]
                # Clear reference to the original tensor
                batch.past_key_values[j][0] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the keys to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
415
                if batch.keys_head_dim_last:
416
417
418
                    padded_past_keys[
                        start_index:end_index, :, -past_seq_len:, :
                    ] = past_keys[:, :, -past_seq_len:, :]
419
                else:
420
421
422
423
424
425
426
427
                    # BLOOM case
                    padded_past_keys[
                        start_index:end_index, :, :, -past_seq_len:
                    ] = past_keys[:, :, :, -past_seq_len:]
                del past_keys

                start_index = end_index

428
429
430
            padded_past_values = first_past_kvs[j][1].new_zeros(
                padded_past_values_shape
            )
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
            start_index = 0
            for batch in batches:
                past_values = batch.past_key_values[j][1]
                # Clear reference to the original tensor
                batch.past_key_values[j][1] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the past values to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
                padded_past_values[
                    start_index:end_index, :, -past_seq_len:, :
                ] = past_values[:, :, -past_seq_len:, :]
                del past_values

446
                # Update values
447
448
449
                start_index = end_index

            past_key_values.append([padded_past_keys, padded_past_values])
450
451
452
453

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
454
            requests_idx_mapping=requests_idx_mapping,
455
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
456
            attention_mask=attention_mask,
457
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
458
            past_key_values=past_key_values,
459
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
460
            input_lengths=input_lengths,
461
462
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
463
464
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
465
466
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
467
            max_input_length=max_input_length,
468
            padding_right_offset=padding_right_offset,
469
            keys_head_dim_last=batches[0].keys_head_dim_last,
470
            max_tokens=max_tokens,
471
        )
472

473
474
475
    def __len__(self):
        return len(self.requests)

476
477

class CausalLM(Model):
478
479
480
481
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
482
        quantize: Optional[str] = None,
483
        dtype: Optional[torch.dtype] = None,
484
        trust_remote_code: bool = False,
485
    ):
486
487
        if torch.cuda.is_available():
            device = torch.device("cuda")
488
            dtype = torch.float16 if dtype is None else dtype
489
        else:
490
491
492
            if quantize:
                raise ValueError("quantization is not available on CPU")

493
            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
494
            dtype = torch.float32 if dtype is None else dtype
495

496
        tokenizer = AutoTokenizer.from_pretrained(
497
498
499
500
501
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
502
        )
503
        model = AutoModelForCausalLM.from_pretrained(
504
            model_id,
505
            revision=revision,
506
            torch_dtype=dtype,
507
508
509
            device_map="auto"
            if torch.cuda.is_available() and torch.cuda.device_count() > 1
            else None,
510
            load_in_8bit=quantize == "bitsandbytes",
511
            trust_remote_code=trust_remote_code,
512
        )
OlivierDehaene's avatar
OlivierDehaene committed
513
514
515
516
517
        if (
            torch.cuda.is_available()
            and torch.cuda.device_count() == 1
            and quantize != "bitsandbytes"
        ):
518
519
            model = model.cuda()

520
521
522
523
524
525
526
527
528
529
        if tokenizer.pad_token_id is None:
            if model.config.pad_token_id is not None:
                tokenizer.pad_token_id = model.config.pad_token_id
            elif model.config.eos_token_id is not None:
                tokenizer.pad_token_id = model.config.eos_token_id
            elif tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})

530
        super(CausalLM, self).__init__(
531
            model=model,
532
533
534
535
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
536
537
538
539
540
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
541

542
543
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
544
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
545
546
        )

547
    def forward(
548
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
549
550
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
551
552
553
554
555
556
557
558
559
560
561
        kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "use_cache": True,
            "return_dict": True,
        }
        if self.has_position_ids:
            kwargs["position_ids"] = position_ids

        outputs = self.model.forward(**kwargs)
562
        return outputs.logits, outputs.past_key_values
563

564
    @tracer.start_as_current_span("generate_token")
565
566
    def generate_token(
        self, batch: CausalLMBatch
567
568
    ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
569
570
571
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

572
573
        logits, past = self.forward(
            batch.input_ids,
574
            attention_mask,
575
576
            batch.position_ids,
            batch.past_key_values,
577
578
        )

579
580
        # Results
        generations: List[Generation] = []
581
        stopped = True
582

Nicolas Patry's avatar
Nicolas Patry committed
583
584
585
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens,
            batch.top_n_tokens_tensor,
586
            torch.log_softmax(logits[:, -1], -1),
Nicolas Patry's avatar
Nicolas Patry committed
587
588
        )

589
590
        start_decode = time.time_ns()

591
592
593
        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
594
            batch.input_lengths,
595
596
            batch.prefix_offsets,
            batch.read_offsets,
597
598
599
600
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
601
602
603
            batch.top_n_tokens,
            batch_top_token_ids,
            batch_top_token_logprobs,
604
605
606
607
608
609
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
610
611
            prefix_offset,
            read_offset,
612
613
614
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
615
            all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
616
617
618
            top_n_tokens,
            top_token_ids,
            top_token_logprobs,
619
620
        ) in enumerate(iterator):
            # Select next token
621
            next_token_id, logprobs = next_token_chooser(
622
                all_input_ids.view(1, -1), logits[-1:, :]
623
            )
624
625

            # Append next token to all tokens
626
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
627
628
            new_input_length = input_length + 1

629
630
631
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
632
633
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_input_ids[:, 0], prefix_offset, read_offset
634
            )
635
636

            # Evaluate stopping criteria
637
            stop, reason = stopping_criteria(
638
639
                next_token_id_squeezed,
                next_token_text,
640
            )
641

642
            if not stop:
643
                stopped = False
644

645
646
647
648
649
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
650
651
                    output_text, _, _ = self.decode_token(
                        all_input_ids[:, 0],
OlivierDehaene's avatar
OlivierDehaene committed
652
653
654
655
656
657
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
658
659
660
661
662
663
664
665
666
667
668
669
670
671
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
672
                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
673
674
675
676
677
678
679
680
681
682
683
684
                    # Remove generated token to only have prefill and add nan for first prompt token
                    prefill_logprobs = [float("nan")] + torch.log_softmax(
                        logits, -1
                    ).gather(1, all_input_ids[1:]).squeeze(1)[
                        -new_input_length:-1
                    ].tolist()
                    prefill_token_ids = all_input_ids[-new_input_length:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
685
                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
686
687
688
689
                        prefill_token_ids,
                        prefill_logprobs,
                        prefill_texts,
                        is_special=[],
690
691
692
693
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
694
695
696
697
698
699
700
701
702
                if top_n_tokens > 0:
                    toptoken_texts = self.tokenizer.batch_decode(
                        top_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    special_toptokens = [
                        token_id in self.all_special_ids for token_id in top_token_ids
                    ]
Nicolas Patry's avatar
Nicolas Patry committed
703
                    top_tokens = Tokens(
Nicolas Patry's avatar
Nicolas Patry committed
704
705
706
707
708
709
710
711
                        top_token_ids,
                        top_token_logprobs,
                        toptoken_texts,
                        special_toptokens,
                    )
                else:
                    top_tokens = None

712
713
714
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
715
                    Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
716
717
718
719
720
                        [next_token_id_squeezed],
                        [next_token_logprob],
                        [next_token_text],
                        [next_token_id_squeezed.item() in self.all_special_ids],
                    ),
721
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
722
                    top_tokens,
723
724
                )

725
                generations.append(generation)
726

727
728
729
730
            # Update values
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
731
732
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
733
734
            batch.max_input_length = max(batch.max_input_length, new_input_length)

735
        # We finished all generations in the batch; there is no next batch
736
        if stopped:
737
738
739
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
740

741
742
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
743

744
        # Update attention_mask as we added a new token to input_ids
745
746
747
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
748

749
        # Update position_ids
750
751
752
753
754
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

755
756
757
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)