flash_causal_lm.py 20.1 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
import torch.distributed

from torch.nn import functional as F

from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
9
from typing import Optional, Tuple, List, Type, Union, Dict
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
    NextTokenChooser,
    StoppingCriteria,
    Sampling,
)

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
32
33
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
34
35

    # Decoder values
36
37
    input_ids: List[torch.Tensor]
    position_ids: List[torch.Tensor]
38
    # cumulative sequence lengths
39
    cu_seqlens: List[int]
40
    max_seqlen: int
41
    past_key_values: Optional[Union[torch.Tensor, List[torch.Tensor]]]
42
43
44
45
46
47
48

    # All tokens
    all_input_ids: List[List[int]]
    all_input_ids_tensor: List[torch.Tensor]

    # Lengths of all generations present in the batch
    input_lengths: List[int]
49
50
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
51
52
53
54
55

    # Generation helpers
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

56
57
58
    # Constant shared tensor, ref here just so that it's accessible in concatentate()
    past_pad: Optional[torch.Tensor]

59
60
61
62
63
64
65
66
67
68
69
    def to_pb(self) -> generate_pb2.Batch:
        return generate_pb2.Batch(
            id=self.batch_id, requests=self.requests, size=len(self)
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
70
    ) -> "FlashCausalLMBatch":
71
72
73
74
75
76
        input_ids = []
        position_ids = []
        cu_seqlens = [0]
        max_seqlen = 0

        input_lengths = []
77
78
        offsets = []
        token_offsets = []
79
80
        all_input_ids = []
        all_input_ids_tensor = []
81
        requests_idx_mapping = {}
82
83
84
85
86
87
88
89

        next_token_choosers = []
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0

        # Parse batch
90
91
92
93
        for i, r in enumerate(pb.requests):
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

94
95
96
            tokenized_input = tokenizer(
                r.inputs, truncation=True, max_length=r.truncate
            )["input_ids"]
97

98
99
100
            input_length = len(tokenized_input)
            max_seqlen = max(max_seqlen, input_length)
            input_lengths.append(input_length)
101

102
103
            offsets.append(None)
            token_offsets.append(None)
104
105
106
107
108
109
            all_input_ids.append(tokenized_input)

            tokenized_input = torch.tensor(tokenized_input, device=device)
            input_ids.append(tokenized_input)

            # Position ids
110
111
112
            position_ids.append(
                torch.arange(0, input_length, dtype=torch.int32, device=device)
            )
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

            # Add cumulative lengths of all previous inputs
            cu_seqlens.append(cumulative_length + input_length)

            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
            all_input_ids_tensor.append(
                F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens))
            )

            # Update
            cumulative_length += input_length

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
132
            requests_idx_mapping=requests_idx_mapping,
133
134
135
136
137
138
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            past_key_values=None,
            input_lengths=input_lengths,
139
140
            offsets=offsets,
            token_offsets=token_offsets,
141
142
143
144
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
145
            past_pad=None,
146
147
        )

148
149
150
151
152
153
154
155
    @tracer.start_as_current_span("filter")
    def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
        if len(requests) == len(self):
            return self

156
157
        single_request = len(requests) == 1

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        # Cumulative length
        cumulative_length = 0

        # New values after filtering
        requests_idx_mapping = {}

        input_ids = []
        position_ids = []
        cu_seqlens = [0]
        max_seqlen = 0
        past_key_values = []

        all_input_ids = []
        all_input_ids_tensor = []

173
        input_lengths = []
174
175
        offsets = []
        token_offsets = []
176

177
178
179
        next_token_choosers = []
        stopping_criterias = []

180
181
182
183
184
185
186
187
188
189
190
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i

            # Get length
            request_input_length = self.input_lengths[idx]

            input_ids.append(self.input_ids[idx])
            position_ids.append(self.position_ids[idx])
            cu_seqlens.append(cumulative_length + request_input_length)
            max_seqlen = max(max_seqlen, request_input_length)
191
            if not single_request:
192
                # True index for past
193
                past_key_values.append(self.past_key_values[2 * idx])
194
195
                # Add one padding
                past_key_values.append(self.past_pad)
196
197
198
199
200
201
202
203
204
205
206
207
208

            all_input_ids.append(self.all_input_ids[idx])
            all_input_ids_tensor.append(self.all_input_ids_tensor[idx])

            input_lengths.append(request_input_length)
            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])

            next_token_choosers.append(self.next_token_choosers[idx])
            stopping_criterias.append(self.stopping_criterias[idx])

            cumulative_length += request_input_length

209
210
211
212
        if single_request:
            # Preallocate tensor for bs = 1 case
            past_key_values = torch.nn.functional.pad(
                self.past_key_values[0],
213
214
215
216
217
218
219
220
221
222
223
                (
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    stopping_criterias[0].max_new_tokens
                    - stopping_criterias[0].current_tokens,
                ),
224
225
            )

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            offsets=offsets,
            token_offsets=token_offsets,
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

251
252
        input_ids = []
        position_ids = []
253
        cu_seqlens = [0]
254
255
256
        max_seqlen = 0
        past_key_values = []

257
258
259
260
261
262
263
264
265
266
        all_input_ids = []
        all_input_ids_tensor = []

        input_lengths = []
        offsets = []
        token_offsets = []

        next_token_choosers = []
        stopping_criterias = []

267
        # Cumulative length
268
269
        cumulative_batch_size = 0
        cumulative_length = 0
270
271
272

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
273
274
275
276
277
278
279
280
281
282
283
284
285

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

            input_ids.extend(batch.input_ids)
            position_ids.extend(batch.position_ids)
            # Add cumulative lengths of all previous inputs
            cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
            max_seqlen = max(max_seqlen, batch.max_seqlen)
286

287
288
289
            if len(batch) != 1:
                past_key_values.extend(batch.past_key_values)
            else:
290
291
292
293
294
295
                # past was pre-allocated for this batch
                # We need to slice to remove the padding
                past_key_values.append(
                    batch.past_key_values[:, : batch.input_lengths[0]]
                )
                # Add one padding
296
                past_key_values.append(batch.past_pad)
297
298
299
300

            all_input_ids.extend(batch.all_input_ids)
            all_input_ids_tensor.extend(batch.all_input_ids_tensor)

301
            input_lengths.extend(batch.input_lengths)
302
303
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
304

305
306
307
308
309
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
            cumulative_length += batch.cu_seqlens[-1]
310
            cumulative_batch_size += len(batch)
311
312
313
314

        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
315
            requests_idx_mapping=requests_idx_mapping,
316
317
318
319
320
321
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
322
323
            offsets=offsets,
            token_offsets=token_offsets,
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
        )

    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
        model_cls: Type[PreTrainedModel],
        model_id: str,
        revision: Optional[str] = None,
340
341
        quantize: bool = False,
        decode_buffer: int = 3,
342
    ):
343
        self.past_pad = None
344
345
346
347
348
349
350
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        else:
            raise NotImplementedError("FlashCausalLM is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
351
            model_id, revision=revision, padding_side="left", truncation_side="left"
352
353
354
355
356
357
        )
        self.model = (
            model_cls.from_pretrained(
                model_id,
                revision=revision,
                torch_dtype=dtype,
358
                load_in_8bit=quantize,
359
360
            )
            .eval()
361
            .to(device)
362
363
364
        )

        super(FlashCausalLM, self).__init__(
365
366
367
368
369
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
            decode_buffer=decode_buffer,
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
            generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlens: torch.Tensor,
        max_s: int,
        past_key_values: Optional = None,
388
        pre_allocate_past_size: Optional[int] = None,
389
390
391
392
393
394
395
396
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
            max_s=max_s,
            past_key_values=past_key_values,
397
            pre_allocate_past_size=pre_allocate_past_size,
398
399
400
401
402
403
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
404
405
        # Shortcut when batch_size == 1
        if len(batch) == 1:
406
407
408
            input_ids = batch.input_ids[0].view(-1)
            # Slice to remove extra padding
            # past_key_values = batch.past_key_values[:, :batch.input_lengths[0]] if batch.past_key_values is not None else None
409
            past_key_values = batch.past_key_values
410
411
412
413
414
415
416
417
418
        else:
            # Concatenate tensors
            input_ids = torch.cat(batch.input_ids).view(-1)
            past_key_values = (
                torch.cat(batch.past_key_values, dim=1)
                if batch.past_key_values is not None
                else None
            )

419
420
421
422
423
424
425
426
427
428
        # if prefill and bs == 1
        if past_key_values is None and len(batch) == 1:
            # Ask to pre-allocate kv to its max size
            # == number of tokens + max_new_tokens
            pre_allocate_past_size = (
                batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
            )
        else:
            pre_allocate_past_size = None

429
430
431
432
433
434
435
436
437
        # Concatenate when prefill, torch.tensor when decode
        position_ids = (
            torch.tensor(batch.position_ids, device=self.device)
            if batch.past_key_values is not None
            else torch.cat(batch.position_ids)
        )
        cu_seqlens = torch.tensor(
            batch.cu_seqlens, device=self.device, dtype=torch.int32
        )
438
439

        out, present = self.forward(
440
            input_ids,
441
442
443
            position_ids,
            cu_seqlens,
            batch.max_seqlen,
444
            past_key_values,
445
            pre_allocate_past_size,
446
447
        )

448
449
        # Initialize past_key_values in prefill
        if batch.past_key_values is None:
450
451
            # Initialize past padding tensor
            if self.past_pad is None:
452
453
454
                self.past_pad = present.new_zeros(
                    present.shape[0], 1, *present.shape[2:]
                )
455
456
457
458
459
            # Set in batch in case it needs to be used later in concatenate()
            batch.past_pad = self.past_pad
            if len(batch) == 1:
                # Preallocate tensor for bs = 1 case
                batch.past_key_values = torch.nn.functional.pad(
460
461
                    present,
                    (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens),
462
463
                )
            else:
464
465
466
                # Add padding after each sequence
                # This will have the correct shape after the final past_key_values concatenation before the model
                # forward
467
                batch.past_key_values = [None, self.past_pad] * len(batch)
468
469
470
471
472
473

        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
474
        stopped = True
475
476
477
478
479

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
480
481
            batch.offsets,
            batch.token_offsets,
482
483
484
485
486
487
488
489
490
491
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
            batch.all_input_ids_tensor,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
492
493
            offset,
            token_offset,
494
495
496
497
498
499
500
501
502
            next_token_chooser,
            stopping_criteria,
            all_input_ids,
            all_input_ids_tensor,
        ) in enumerate(iterator):
            # Indexing metadata
            start_index = cumulative_length
            end_index = cumulative_length + input_length

503
504
            prefill = stopping_criteria.current_tokens == 0
            if prefill:
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
                # Prefill mode
                # out is of shape [cumulative_sequence_lengths, vocab_size]
                logits = out[start_index:end_index]
            else:
                # Decode mode
                # out is of shape [batch_size, vocab_size]
                logits = out[i].unsqueeze(0)

            # Select next token
            next_token_id, logprobs = next_token_chooser(
                all_input_ids_tensor[None, :input_length], logits
            )
            next_token_id_squeezed = next_token_id.squeeze()
            next_token_id_item = next_token_id_squeezed.item()

            # Append next token to all tokens
            all_input_ids.append(next_token_id_item)
            all_input_ids_tensor[input_length] = next_token_id_item

            # Generated token
            next_token_logprob = logprobs[-1, next_token_id_item]
526
527
528
529
            next_token_text, offset, token_offset = self.decode_token(
                all_input_ids,
                offset,
                token_offset,
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
                next_token_id_item,
                next_token_text,
            )

            if stop:
                # Decode generated tokens
                output_text = self.decode(
                    all_input_ids[-stopping_criteria.current_tokens :]
                )
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
                )
            else:
553
                stopped = False
554
555
556
                generated_text = None

            # Prefill
557
            if prefill:
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids_tensor[1:input_length].unsqueeze(1)
                ).squeeze(1)[:-1].tolist()
                prefill_token_ids = all_input_ids[:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_item,
                next_token_logprob,
                next_token_text,
                next_token_id_item in self.all_special_ids,
                generated_text,
            )

            generations.append(generation)
            cumulative_length += input_length
586
            new_input_length = input_length + 1
587

588
589
590
591
592
593
594
595
596
            # Update values
            batch.input_ids[i] = next_token_id
            batch.position_ids[i] = input_length
            batch.input_lengths[i] = new_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.all_input_ids[i] = all_input_ids
            batch.all_input_ids_tensor[i] = all_input_ids_tensor
            batch.max_seqlen = max(batch.max_seqlen, new_input_length)
597
            if len(batch) != 1:
598
                # Add each sequence before its padding
599
                batch.past_key_values[i * 2] = present[:, start_index:end_index]
600
601
602
603
604
            # Cumulative sum
            batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length

        # No need to return a batch if we know that all requests stopped
        return generations, batch if not stopped else None