causal_lm.py 28.7 KB
Newer Older
1
import torch
2
import time
3

4
from dataclasses import dataclass
5
from opentelemetry import trace
6
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
7
from typing import Optional, Tuple, List, Type, Dict
8

9
from text_generation_server.models import Model
10
from text_generation_server.utils.tokens import batch_top_tokens
11
12
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
13
    Tokens,
14
15
16
17
18
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
19

20
21
tracer = trace.get_tracer(__name__)

22
23

@dataclass
24
class CausalLMBatch(Batch):
25
26
    batch_id: int
    requests: List[generate_pb2.Request]
27
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
28
29
30
31

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
32
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
33
34
35
    past_key_values: Optional[List[Tuple]]

    # All tokens
36
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
37
38
39

    # Lengths of all generations present in the batch
    input_lengths: List[int]
40
41
    prefix_offsets: List[int]
    read_offsets: List[int]
OlivierDehaene's avatar
OlivierDehaene committed
42
43

    # Generation helpers
44
45
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
46
47
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
48
49

    # Metadata used for padding
50
    max_input_length: int
51
    padding_right_offset: int
52

53
54
55
    # Maximum number of tokens this batch will grow to
    max_tokens: int

56
57
58
    # Past metadata
    keys_head_dim_last: bool = True

59
60
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
61
            id=self.batch_id,
62
            request_ids=[r.id for r in self.requests],
63
            size=len(self),
64
            max_tokens=self.max_tokens,
65
66
67
68
        )

    @classmethod
    def from_pb(
69
70
71
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
72
        dtype: torch.dtype,
73
        device: torch.device,
74
75
76
77
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
78
        top_n_tokens = []
79
80
        prefix_offsets = []
        read_offsets = []
81
        requests_idx_mapping = {}
82
83

        # Parse batch
84
        max_truncation = 0
85
        padding_right_offset = 0
86
        max_decode_tokens = 0
87
88
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
89
            inputs.append(r.inputs)
drbh's avatar
drbh committed
90
91
92
            next_token_choosers.append(
                NextTokenChooser.from_pb(r.parameters, device, tokenizer)
            )
93
94
95
96
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
97
            top_n_tokens.append(r.top_n_tokens)
98
            max_truncation = max(max_truncation, r.truncate)
99
            max_decode_tokens += stopping_criteria.max_new_tokens
100
101
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
102
103
            )

OlivierDehaene's avatar
OlivierDehaene committed
104
        tokenized_inputs = tokenizer(
105
106
107
            inputs,
            return_tensors="pt",
            padding=True,
108
            return_token_type_ids=False,
109
110
            truncation=True,
            max_length=max_truncation,
111
        ).to(device)
112
113
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
114
            prefix_offsets.append(input_len - 5)
115
            read_offsets.append(input_len)
116

117
118
119
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

120
121
122
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
123
            (pb.size, max_input_length + padding_right_offset)
124
125
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
126
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
127

128
129
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
130
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
Nicolas Patry's avatar
Nicolas Patry committed
131
132
133
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
134

135
        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
136

137
138
139
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
140
            requests_idx_mapping=requests_idx_mapping,
141
142
            input_ids=input_ids,
            attention_mask=attention_mask,
143
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
144
            past_key_values=None,
145
            all_input_ids=list(all_input_ids),
146
            input_lengths=input_lengths.tolist(),
147
148
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
149
150
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
151
152
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
153
            max_input_length=max_input_length.item(),
154
            padding_right_offset=padding_right_offset,
155
            max_tokens=max_tokens,
156
157
        )

158
    @tracer.start_as_current_span("filter")
159
160
    def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
        if len(request_ids) == 0:
161
            raise ValueError("Batch must have at least one request")
162
        if len(request_ids) == len(self):
163
164
165
166
167
168
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
169
        requests = []
170
        input_lengths = []
171
172
        prefix_offsets = []
        read_offsets = []
173
174
175
176
177
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
178
        top_n_tokens = []
179

180
        total_remaining_decode_tokens = 0
181
182
        new_padding_right_offset = 0

183
184
185
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
186
187
            keep_indices.append(idx)

188
            requests.append(self.requests[idx])
189
190
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
191
192
193
194
195
196
197
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
198
199
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
200
            top_n_tokens.append(self.top_n_tokens[idx])
201
            remaining_decode_tokens = (
202
203
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
204
205
206
207
            total_remaining_decode_tokens += remaining_decode_tokens
            new_padding_right_offset = max(
                new_padding_right_offset, remaining_decode_tokens
            )
208
209
210
211

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        position_ids = self.position_ids[keep_indices]
212
213
        self.attention_mask = self.attention_mask[
            keep_indices,
214
215
216
217
            -(self.padding_right_offset + max_input_length) : (
                self.attention_mask.shape[1] - self.padding_right_offset
            )
            + new_padding_right_offset,
218
219
        ]

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
            self.past_key_values = [list(layer) for layer in self.past_key_values]

        # Update tensors in-place to allow incremental garbage collection
        past_kv_length = max_input_length - 1
        for layer in self.past_key_values:
            past_keys, past_values = layer
            if len(past_keys.shape) == 3:
                # Force past to be of dim [self_size, num_heads, ...] for easy indexing
                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
            if self.keys_head_dim_last:
                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
            else:
                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
            del past_keys
            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
            del past_values

Nicolas Patry's avatar
Nicolas Patry committed
240
        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
241
        max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
242

243
244
245
246
247
248
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.all_input_ids = all_input_ids
        self.input_lengths = input_lengths
249
250
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
251
252
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
Nicolas Patry's avatar
Nicolas Patry committed
253
254
        self.top_n_tokens = top_n_tokens
        self.top_n_tokens_tensor = top_n_tokens_tensor
255
256
        self.max_input_length = max_input_length
        self.padding_right_offset = new_padding_right_offset
257
        self.max_tokens = max_tokens
258
259

        return self
260

261
    @classmethod
262
    @tracer.start_as_current_span("concatenate")
263
264
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
265
        total_batch_size = 0
266
        max_input_length = 0
267
268
        padding_right_offset = 0
        for batch in batches:
269
            total_batch_size += len(batch)
270
            max_input_length = max(max_input_length, batch.max_input_length)
271
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
272
273
274

        # Batch attributes
        requests = []
275
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
276
        input_lengths = []
277
278
        prefix_offsets = []
        read_offsets = []
279
280
281
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
282
        top_n_tokens = []
283
        max_tokens = 0
284

OlivierDehaene's avatar
OlivierDehaene committed
285
286
287
        # Batch tensors
        input_ids = None
        attention_mask = None
288
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
289
        past_key_values = []
Nicolas Patry's avatar
Nicolas Patry committed
290
        top_n_tokens_tensor = None
OlivierDehaene's avatar
OlivierDehaene committed
291

292
293
294
295
296
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
297
            input_lengths.extend(batch.input_lengths)
298
299
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
300
301
302
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)
Nicolas Patry's avatar
Nicolas Patry committed
303
            top_n_tokens.extend(batch.top_n_tokens)
304

305
306
307
308
309
310
311
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

312
            # Slicing end index for this batch
313
            end_index = start_index + len(batch)
314
315

            # We only concatenate batches that did at least one step
316
317
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
318

OlivierDehaene's avatar
OlivierDehaene committed
319
320
321
322
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
323
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
324
325
326
327
328
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
329
                attention_mask = batch.attention_mask.new_zeros(
330
                    (total_batch_size, max_input_length + padding_right_offset),
331
332
                )

Nicolas Patry's avatar
Nicolas Patry committed
333
334
335
336
337
338
            if top_n_tokens_tensor is None:
                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
                    total_batch_size,
                )
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor

339
            # We need to slice the attention mask to remove padding from previous steps
340
            # and to remove unused allocated space
341
            left_offset = max_input_length - batch.max_input_length
342
            batch_left_offset = (
343
                batch.attention_mask.shape[1]
344
                - batch.max_input_length
345
                - batch.padding_right_offset
346
            )
OlivierDehaene's avatar
OlivierDehaene committed
347
            attention_mask[
348
349
350
351
352
353
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
354

355
356
357
358
359
360
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

361
362
363
364
365
366
            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
            # And ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
                batch.past_key_values = [
367
368
                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
                    for layer in batch.past_key_values
369
                ]
370
            elif len(batch.past_key_values[0][0].shape) == 3:
371
372
373
374
                for layer in batch.past_key_values:
                    for k, t in enumerate(layer):
                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])

375
376
377
378
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length - batch.max_input_length
            ) * len(batch)
379

380
381
            start_index = end_index

382
383
384
385
386
387
388
389
390
        first_past_kvs = batches[0].past_key_values
        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape

        padded_past_values_shape = (
            total_batch_size,
            num_heads,
            max_input_length - 1,
            head_dim,
        )
391

392
393
394
395
396
397
398
399
400
401
        if batches[0].keys_head_dim_last:
            padded_past_keys_shape = padded_past_values_shape
        else:
            # seq_length is last for BLOOM
            padded_past_keys_shape = (
                total_batch_size,
                num_heads,
                head_dim,
                max_input_length - 1,
            )
402

403
404
405
406
407
408
409
410
411
412
413
414
415
416
        # Iterate over attention layers
        # Concatenate past key values layer by layer to allow incremental garbage collection
        for j in range(len(first_past_kvs)):
            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
            start_index = 0
            for batch in batches:
                past_keys = batch.past_key_values[j][0]
                # Clear reference to the original tensor
                batch.past_key_values[j][0] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the keys to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
417
                if batch.keys_head_dim_last:
drbh's avatar
drbh committed
418
419
420
                    padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
                        past_keys[:, :, -past_seq_len:, :]
                    )
421
                else:
422
                    # BLOOM case
drbh's avatar
drbh committed
423
424
425
                    padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
                        past_keys[:, :, :, -past_seq_len:]
                    )
426
427
428
429
                del past_keys

                start_index = end_index

430
431
432
            padded_past_values = first_past_kvs[j][1].new_zeros(
                padded_past_values_shape
            )
433
434
435
436
437
438
439
440
441
442
            start_index = 0
            for batch in batches:
                past_values = batch.past_key_values[j][1]
                # Clear reference to the original tensor
                batch.past_key_values[j][1] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the past values to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
drbh's avatar
drbh committed
443
444
445
                padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
                    past_values[:, :, -past_seq_len:, :]
                )
446
447
                del past_values

448
                # Update values
449
450
451
                start_index = end_index

            past_key_values.append([padded_past_keys, padded_past_values])
452
453
454
455

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
456
            requests_idx_mapping=requests_idx_mapping,
457
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
458
            attention_mask=attention_mask,
459
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
460
            past_key_values=past_key_values,
461
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
462
            input_lengths=input_lengths,
463
464
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
465
466
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
467
468
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
469
            max_input_length=max_input_length,
470
            padding_right_offset=padding_right_offset,
471
            keys_head_dim_last=batches[0].keys_head_dim_last,
472
            max_tokens=max_tokens,
473
        )
474

475
476
477
    def __len__(self):
        return len(self.requests)

478
479

class CausalLM(Model):
480
481
482
483
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
484
        quantize: Optional[str] = None,
485
        use_medusa: Optional[str] = None,
486
        dtype: Optional[torch.dtype] = None,
487
        trust_remote_code: bool = False,
488
    ):
489
490
        if torch.cuda.is_available():
            device = torch.device("cuda")
491
            dtype = torch.float16 if dtype is None else dtype
492
        else:
493
494
495
            if quantize:
                raise ValueError("quantization is not available on CPU")

496
            device = torch.device("cpu")
Wang, Yi's avatar
Wang, Yi committed
497
            dtype = torch.float32 if dtype is None else dtype
498

499
        tokenizer = AutoTokenizer.from_pretrained(
500
501
502
503
504
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
505
        )
506
        model = AutoModelForCausalLM.from_pretrained(
507
            model_id,
508
            revision=revision,
509
            torch_dtype=dtype,
drbh's avatar
drbh committed
510
511
512
513
514
            device_map=(
                "auto"
                if torch.cuda.is_available() and torch.cuda.device_count() > 1
                else None
            ),
515
            load_in_8bit=quantize == "bitsandbytes",
516
            trust_remote_code=trust_remote_code,
517
        )
OlivierDehaene's avatar
OlivierDehaene committed
518
519
520
521
522
        if (
            torch.cuda.is_available()
            and torch.cuda.device_count() == 1
            and quantize != "bitsandbytes"
        ):
523
524
            model = model.cuda()

525
526
527
528
529
530
531
532
533
534
        if tokenizer.pad_token_id is None:
            if model.config.pad_token_id is not None:
                tokenizer.pad_token_id = model.config.pad_token_id
            elif model.config.eos_token_id is not None:
                tokenizer.pad_token_id = model.config.eos_token_id
            elif tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})

535
        super(CausalLM, self).__init__(
536
            model=model,
537
538
539
540
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
541
542
543
544
545
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
546

547
548
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
549
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
550
551
        )

552
    def forward(
553
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
554
555
556
    ) -> Tuple[
        torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
    ]:
557
        # Model Forward
558
559
560
561
562
563
564
565
566
567
568
        kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "use_cache": True,
            "return_dict": True,
        }
        if self.has_position_ids:
            kwargs["position_ids"] = position_ids

        outputs = self.model.forward(**kwargs)
569
570
571
572
573
        if isinstance(outputs, tuple):
            outputs, speculative_logits = outputs
        else:
            speculative_logits = None
        return outputs.logits, speculative_logits, outputs.past_key_values
574

575
    @tracer.start_as_current_span("generate_token")
576
577
    def generate_token(
        self, batch: CausalLMBatch
578
579
    ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
580
581
582
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

583
        logits, speculative_logits, past = self.forward(
584
            batch.input_ids,
585
            attention_mask,
586
587
            batch.position_ids,
            batch.past_key_values,
588
589
        )

590
591
        # Results
        generations: List[Generation] = []
592
        stopped = True
593

Nicolas Patry's avatar
Nicolas Patry committed
594
595
        # Speculation is not active for causal
        accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
Nicolas Patry's avatar
Nicolas Patry committed
596
597
598
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens,
            batch.top_n_tokens_tensor,
599
            torch.log_softmax(logits[:, -1], -1),
Nicolas Patry's avatar
Nicolas Patry committed
600
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
601
602
        )

603
604
        start_decode = time.time_ns()

605
606
607
        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
608
            batch.input_lengths,
609
610
            batch.prefix_offsets,
            batch.read_offsets,
611
612
613
614
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
615
616
617
            batch.top_n_tokens,
            batch_top_token_ids,
            batch_top_token_logprobs,
618
619
620
621
622
623
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
624
625
            prefix_offset,
            read_offset,
626
627
628
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
629
            all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
630
631
632
            top_n_tokens,
            top_token_ids,
            top_token_logprobs,
633
634
        ) in enumerate(iterator):
            # Select next token
635
            next_token_id, logprobs = next_token_chooser(
636
                all_input_ids.view(1, -1), logits[-1:, :]
637
            )
638
639

            # Append next token to all tokens
640
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
641
642
            new_input_length = input_length + 1

643
644
645
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
646
647
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_input_ids[:, 0], prefix_offset, read_offset
648
            )
649
650

            # Evaluate stopping criteria
651
            stop, reason = stopping_criteria(
652
653
                next_token_id_squeezed,
                next_token_text,
654
            )
655

656
            if not stop:
657
                stopped = False
658

659
660
661
662
663
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
664
665
                    output_text, _, _ = self.decode_token(
                        all_input_ids[:, 0],
OlivierDehaene's avatar
OlivierDehaene committed
666
667
668
669
670
671
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
672
673
674
675
676
677
678
679
680
681
682
683
684
685
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
686
                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
687
688
689
690
691
692
693
694
695
696
697
698
                    # Remove generated token to only have prefill and add nan for first prompt token
                    prefill_logprobs = [float("nan")] + torch.log_softmax(
                        logits, -1
                    ).gather(1, all_input_ids[1:]).squeeze(1)[
                        -new_input_length:-1
                    ].tolist()
                    prefill_token_ids = all_input_ids[-new_input_length:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
699
                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
700
701
702
703
                        prefill_token_ids,
                        prefill_logprobs,
                        prefill_texts,
                        is_special=[],
704
705
706
707
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
708
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
709
                    all_top_tokens = []
drbh's avatar
drbh committed
710
                    for top_token_ids, top_token_logprobs in zip(
711
712
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
713
714
715
716
717
718
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
719
720
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
721
722
723
724
725
726
727
728
729
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
730
731
732
                else:
                    top_tokens = None

733
734
735
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
736
                    Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
737
738
739
740
741
                        [next_token_id_squeezed],
                        [next_token_logprob],
                        [next_token_text],
                        [next_token_id_squeezed.item() in self.all_special_ids],
                    ),
742
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
743
                    top_tokens,
744
745
                )

746
                generations.append(generation)
747

748
            # Update values
drbh's avatar
drbh committed
749
750
751
            batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
                next_token_id_squeezed.item()
            )
752
753
754
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
755
756
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
757
758
            batch.max_input_length = max(batch.max_input_length, new_input_length)

759
        # We finished all generations in the batch; there is no next batch
760
        if stopped:
761
762
763
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
764

765
766
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
767

768
        # Update attention_mask as we added a new token to input_ids
769
770
771
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
772

773
        # Update position_ids
774
775
776
777
778
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

779
780
781
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)