flash_causal_lm.py 20.6 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
import torch.distributed

from torch.nn import functional as F

from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
9
from typing import Optional, Tuple, List, Type, Union, Dict
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
    NextTokenChooser,
    StoppingCriteria,
    Sampling,
)

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
32
33
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
34
35

    # Decoder values
36
37
    input_ids: List[torch.Tensor]
    position_ids: List[torch.Tensor]
38
    # cumulative sequence lengths
39
    cu_seqlens: List[int]
40
    max_seqlen: int
41
    past_key_values: Optional[Union[torch.Tensor, List[torch.Tensor]]]
42
43
44
45
46
47
48

    # All tokens
    all_input_ids: List[List[int]]
    all_input_ids_tensor: List[torch.Tensor]

    # Lengths of all generations present in the batch
    input_lengths: List[int]
49
50
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
51
52
53
54
55

    # Generation helpers
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

56
57
58
    # Constant shared tensor, ref here just so that it's accessible in concatentate()
    past_pad: Optional[torch.Tensor]

59
60
61
    # Maximum number of tokens this batch will grow to
    max_tokens: int

62
63
    def to_pb(self) -> generate_pb2.Batch:
        return generate_pb2.Batch(
64
65
66
67
            id=self.batch_id,
            requests=self.requests,
            size=len(self),
            max_tokens=self.max_tokens,
68
69
70
71
72
73
74
75
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
76
    ) -> "FlashCausalLMBatch":
77
78
79
80
81
82
        input_ids = []
        position_ids = []
        cu_seqlens = [0]
        max_seqlen = 0

        input_lengths = []
83
84
        offsets = []
        token_offsets = []
85
86
        all_input_ids = []
        all_input_ids_tensor = []
87
        requests_idx_mapping = {}
88
89
90
91
92
93
94

        next_token_choosers = []
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0

95
96
        max_tokens = 0

97
        # Parse batch
98
99
100
101
        for i, r in enumerate(pb.requests):
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

102
103
104
            tokenized_input = tokenizer(
                r.inputs, truncation=True, max_length=r.truncate
            )["input_ids"]
105

106
107
108
            input_length = len(tokenized_input)
            max_seqlen = max(max_seqlen, input_length)
            input_lengths.append(input_length)
109

110
111
            offsets.append(None)
            token_offsets.append(None)
112
113
114
115
116
117
            all_input_ids.append(tokenized_input)

            tokenized_input = torch.tensor(tokenized_input, device=device)
            input_ids.append(tokenized_input)

            # Position ids
118
119
120
            position_ids.append(
                torch.arange(0, input_length, dtype=torch.int32, device=device)
            )
121
122
123
124
125

            # Add cumulative lengths of all previous inputs
            cu_seqlens.append(cumulative_length + input_length)

            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
126

127
128
129
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
130
            max_new_tokens = stopping_criteria.max_new_tokens
131
            stopping_criterias.append(stopping_criteria)
132

133
134
135
136
137
138
            all_input_ids_tensor.append(
                F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens))
            )

            # Update
            cumulative_length += input_length
139
            max_tokens += input_length + max_new_tokens
140
141
142
143

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
144
            requests_idx_mapping=requests_idx_mapping,
145
146
147
148
149
150
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            past_key_values=None,
            input_lengths=input_lengths,
151
152
            offsets=offsets,
            token_offsets=token_offsets,
153
154
155
156
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
157
            past_pad=None,
158
            max_tokens=max_tokens,
159
160
        )

161
162
163
164
165
166
167
168
    @tracer.start_as_current_span("filter")
    def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
        if len(requests) == len(self):
            return self

169
170
        single_request = len(requests) == 1

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        # Cumulative length
        cumulative_length = 0

        # New values after filtering
        requests_idx_mapping = {}

        input_ids = []
        position_ids = []
        cu_seqlens = [0]
        max_seqlen = 0
        past_key_values = []

        all_input_ids = []
        all_input_ids_tensor = []

186
        input_lengths = []
187
188
        offsets = []
        token_offsets = []
189

190
191
192
        next_token_choosers = []
        stopping_criterias = []

193
194
        max_tokens = 0

195
196
197
198
199
200
201
202
203
204
205
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i

            # Get length
            request_input_length = self.input_lengths[idx]

            input_ids.append(self.input_ids[idx])
            position_ids.append(self.position_ids[idx])
            cu_seqlens.append(cumulative_length + request_input_length)
            max_seqlen = max(max_seqlen, request_input_length)
206
207
208
            # True index for past
            past_key_values.append(self.past_key_values[2 * idx])

209
            if not single_request:
210
211
                # Add one padding
                past_key_values.append(self.past_pad)
212
213
214
215
216
217
218
219
220

            all_input_ids.append(self.all_input_ids[idx])
            all_input_ids_tensor.append(self.all_input_ids_tensor[idx])

            input_lengths.append(request_input_length)
            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])

            next_token_choosers.append(self.next_token_choosers[idx])
221
222
223

            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
224
225

            cumulative_length += request_input_length
226
227
228
            max_tokens += request_input_length + (
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
229

230
231
232
        if single_request:
            # Preallocate tensor for bs = 1 case
            past_key_values = torch.nn.functional.pad(
233
                past_key_values[0],
234
235
236
237
238
239
240
241
242
243
244
                (
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    stopping_criterias[0].max_new_tokens
                    - stopping_criterias[0].current_tokens,
                ),
245
246
            )

247
248
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
249
            past_pad=self.past_pad,
250
251
252
253
254
255
256
257
258
259
260
261
262
263
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            offsets=offsets,
            token_offsets=token_offsets,
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
264
            max_tokens=max_tokens,
265
266
267
268
269
270
271
272
273
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

274
275
        input_ids = []
        position_ids = []
276
        cu_seqlens = [0]
277
278
279
        max_seqlen = 0
        past_key_values = []

280
281
282
283
284
285
286
287
288
289
        all_input_ids = []
        all_input_ids_tensor = []

        input_lengths = []
        offsets = []
        token_offsets = []

        next_token_choosers = []
        stopping_criterias = []

290
        # Cumulative length
291
292
        cumulative_batch_size = 0
        cumulative_length = 0
293
        max_tokens = 0
294
295
296

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
297
298
299
300
301
302
303
304
305
306
307
308
309

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

            input_ids.extend(batch.input_ids)
            position_ids.extend(batch.position_ids)
            # Add cumulative lengths of all previous inputs
            cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
            max_seqlen = max(max_seqlen, batch.max_seqlen)
310

311
312
313
            if len(batch) != 1:
                past_key_values.extend(batch.past_key_values)
            else:
314
315
316
317
318
319
                # past was pre-allocated for this batch
                # We need to slice to remove the padding
                past_key_values.append(
                    batch.past_key_values[:, : batch.input_lengths[0]]
                )
                # Add one padding
320
                past_key_values.append(batch.past_pad)
321
322
323
324

            all_input_ids.extend(batch.all_input_ids)
            all_input_ids_tensor.extend(batch.all_input_ids_tensor)

325
            input_lengths.extend(batch.input_lengths)
326
327
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
328

329
330
331
332
333
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
            cumulative_length += batch.cu_seqlens[-1]
334
            cumulative_batch_size += len(batch)
335
            max_tokens += batch.max_tokens
336
337
338

        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
339
            past_pad=batches[0].past_pad,
340
            requests=requests,
341
            requests_idx_mapping=requests_idx_mapping,
342
343
344
345
346
347
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
348
349
            offsets=offsets,
            token_offsets=token_offsets,
350
351
352
353
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
354
            max_tokens=max_tokens,
355
356
357
358
359
360
361
362
363
364
365
366
        )

    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
        model_cls: Type[PreTrainedModel],
        model_id: str,
        revision: Optional[str] = None,
367
368
        quantize: bool = False,
        decode_buffer: int = 3,
369
    ):
370
        self.past_pad = None
371
372
373
374
375
376
377
        if torch.cuda.is_available():
            device = torch.device("cuda")
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        else:
            raise NotImplementedError("FlashCausalLM is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
378
            model_id, revision=revision, padding_side="left", truncation_side="left"
379
380
381
382
383
384
        )
        self.model = (
            model_cls.from_pretrained(
                model_id,
                revision=revision,
                torch_dtype=dtype,
385
                load_in_8bit=quantize,
386
387
            )
            .eval()
388
            .to(device)
389
390
391
        )

        super(FlashCausalLM, self).__init__(
392
393
394
395
396
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
            decode_buffer=decode_buffer,
397
398
399
400
401
402
403
404
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
405
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
406
407
408
409
410
411
412
413
414
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlens: torch.Tensor,
        max_s: int,
        past_key_values: Optional = None,
415
        pre_allocate_past_size: Optional[int] = None,
416
417
418
419
420
421
422
423
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
            max_s=max_s,
            past_key_values=past_key_values,
424
            pre_allocate_past_size=pre_allocate_past_size,
425
426
427
428
429
430
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
431
432
        # Shortcut when batch_size == 1
        if len(batch) == 1:
433
            input_ids = batch.input_ids[0].view(-1)
434
            # No need to slice as flash attention will take care of it with cu_seqlens
435
            past_key_values = batch.past_key_values
436
437
438
439
440
441
442
443
444
        else:
            # Concatenate tensors
            input_ids = torch.cat(batch.input_ids).view(-1)
            past_key_values = (
                torch.cat(batch.past_key_values, dim=1)
                if batch.past_key_values is not None
                else None
            )

445
446
447
448
449
450
451
452
453
454
        # if prefill and bs == 1
        if past_key_values is None and len(batch) == 1:
            # Ask to pre-allocate kv to its max size
            # == number of tokens + max_new_tokens
            pre_allocate_past_size = (
                batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
            )
        else:
            pre_allocate_past_size = None

455
456
457
458
459
460
461
462
463
        # Concatenate when prefill, torch.tensor when decode
        position_ids = (
            torch.tensor(batch.position_ids, device=self.device)
            if batch.past_key_values is not None
            else torch.cat(batch.position_ids)
        )
        cu_seqlens = torch.tensor(
            batch.cu_seqlens, device=self.device, dtype=torch.int32
        )
464
465

        out, present = self.forward(
466
            input_ids,
467
468
469
            position_ids,
            cu_seqlens,
            batch.max_seqlen,
470
            past_key_values,
471
            pre_allocate_past_size,
472
473
        )

474
475
        # Initialize past_key_values in prefill
        if batch.past_key_values is None:
476
477
            # Initialize past padding tensor
            if self.past_pad is None:
478
479
480
                self.past_pad = present.new_zeros(
                    present.shape[0], 1, *present.shape[2:]
                )
481
482
            # Set in batch in case it needs to be used later in concatenate()
            batch.past_pad = self.past_pad
483
484
485
486
            if len(batch) == 1:
                # present is already pre-padded
                batch.past_key_values = present
            else:
487
488
489
                # Add padding after each sequence
                # This will have the correct shape after the final past_key_values concatenation before the model
                # forward
490
                batch.past_key_values = [None, self.past_pad] * len(batch)
491
492
493
494
495
496

        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
497
        stopped = True
498
499
500
501
502

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
503
504
            batch.offsets,
            batch.token_offsets,
505
506
507
508
509
510
511
512
513
514
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
            batch.all_input_ids_tensor,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
515
516
            offset,
            token_offset,
517
518
519
520
521
522
523
524
525
            next_token_chooser,
            stopping_criteria,
            all_input_ids,
            all_input_ids_tensor,
        ) in enumerate(iterator):
            # Indexing metadata
            start_index = cumulative_length
            end_index = cumulative_length + input_length

526
527
            prefill = stopping_criteria.current_tokens == 0
            if prefill:
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
                # Prefill mode
                # out is of shape [cumulative_sequence_lengths, vocab_size]
                logits = out[start_index:end_index]
            else:
                # Decode mode
                # out is of shape [batch_size, vocab_size]
                logits = out[i].unsqueeze(0)

            # Select next token
            next_token_id, logprobs = next_token_chooser(
                all_input_ids_tensor[None, :input_length], logits
            )
            next_token_id_squeezed = next_token_id.squeeze()
            next_token_id_item = next_token_id_squeezed.item()

            # Append next token to all tokens
            all_input_ids.append(next_token_id_item)
            all_input_ids_tensor[input_length] = next_token_id_item

            # Generated token
            next_token_logprob = logprobs[-1, next_token_id_item]
549
550
551
552
            next_token_text, offset, token_offset = self.decode_token(
                all_input_ids,
                offset,
                token_offset,
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
                next_token_id_item,
                next_token_text,
            )

            if stop:
                # Decode generated tokens
                output_text = self.decode(
                    all_input_ids[-stopping_criteria.current_tokens :]
                )
                # Get seed
                if isinstance(next_token_chooser.choice, Sampling):
                    seed = next_token_chooser.choice.seed
                else:
                    seed = None

                generated_text = GeneratedText(
                    output_text, stopping_criteria.current_tokens, reason, seed
                )
            else:
576
                stopped = False
577
578
579
                generated_text = None

            # Prefill
580
            if prefill:
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
                # Remove generated token to only have prefill and add nan for first prompt token
                prefill_logprobs = [float("nan")] + logprobs.gather(
                    1, all_input_ids_tensor[1:input_length].unsqueeze(1)
                ).squeeze(1)[:-1].tolist()
                prefill_token_ids = all_input_ids[:-1]
                prefill_texts = self.tokenizer.batch_decode(
                    prefill_token_ids,
                    clean_up_tokenization_spaces=False,
                    skip_special_tokens=False,
                )
                prefill_tokens = PrefillTokens(
                    prefill_token_ids, prefill_logprobs, prefill_texts
                )
            else:
                prefill_tokens = None

            generation = Generation(
                request.id,
                prefill_tokens,
                next_token_id_item,
                next_token_logprob,
                next_token_text,
                next_token_id_item in self.all_special_ids,
                generated_text,
            )

            generations.append(generation)
            cumulative_length += input_length
609
            new_input_length = input_length + 1
610

611
612
613
614
615
616
617
618
619
            # Update values
            batch.input_ids[i] = next_token_id
            batch.position_ids[i] = input_length
            batch.input_lengths[i] = new_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.all_input_ids[i] = all_input_ids
            batch.all_input_ids_tensor[i] = all_input_ids_tensor
            batch.max_seqlen = max(batch.max_seqlen, new_input_length)
620
            if len(batch) != 1:
621
                # Add each sequence before its padding
622
                batch.past_key_values[i * 2] = present[:, start_index:end_index]
623
624
625
626
627
            # Cumulative sum
            batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length

        # No need to return a batch if we know that all requests stopped
        return generations, batch if not stopped else None