metadata.py 2.39 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from dataclasses import dataclass
王敏's avatar
王敏 committed
4
from typing import  Optional
5
6
7
8
9
10
11
12
13
14
15
16
17

import numpy as np
import torch


@dataclass
class SpecDecodeMetadata:
    # [num_tokens]
    draft_token_ids: torch.Tensor
    # [batch_size]
    num_draft_tokens: list[int]
    # [batch_size]
    cu_num_draft_tokens: torch.Tensor
18
19
    # [batch_size]
    cu_num_sampled_tokens: torch.Tensor
20
21
22
23
24
25
    # [num_tokens]
    target_logits_indices: torch.Tensor
    # [batch_size]
    bonus_logits_indices: torch.Tensor
    # [num_tokens + batch_size]
    logits_indices: torch.Tensor
王敏's avatar
王敏 committed
26
27
    # [batch_size]
    spec_decode_ids: Optional[list[str]] = None
28
29
30
31
32
33
34
35
36
37
38
39

    def __post_init__(self):
        self.max_spec_len = max(self.num_draft_tokens)

    @classmethod
    def make_dummy(
        cls,
        draft_token_ids: list[list[int]],
        device: torch.device,
    ) -> "SpecDecodeMetadata":
        batch_size = len(draft_token_ids)
        num_draft_tokens = [len(ids) for ids in draft_token_ids]
40
        num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
41
42
43
        flattened_draft_token_ids = sum(draft_token_ids, [])
        num_tokens = len(flattened_draft_token_ids)

44
45
46
        draft_token_ids_tensor = torch.tensor(
            flattened_draft_token_ids, dtype=torch.int32, device=device
        )
47
        cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
48
        cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
49
50
51
52
        cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
        cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
            device
        )
53

54
55
56
57
58
59
60
        target_logits_indices = torch.zeros(
            num_tokens, dtype=torch.int32, device=device
        )
        bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device)
        logits_indices = torch.zeros(
            num_tokens + batch_size, dtype=torch.int32, device=device
        )
61
62
63
64
        return cls(
            draft_token_ids=draft_token_ids_tensor,
            num_draft_tokens=num_draft_tokens,
            cu_num_draft_tokens=cu_num_draft_tokens_tensor,
65
            cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
66
67
68
69
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
        )