metadata.py 3.73 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from dataclasses import dataclass
王敏's avatar
王敏 committed
4
from typing import  Optional
5
6
7
8
9
10
11
12
13
14
15
16
17

import numpy as np
import torch


@dataclass
class SpecDecodeMetadata:
    # [num_tokens]
    draft_token_ids: torch.Tensor
    # [batch_size]
    num_draft_tokens: list[int]
    # [batch_size]
    cu_num_draft_tokens: torch.Tensor
18
19
    # [batch_size]
    cu_num_sampled_tokens: torch.Tensor
20
21
22
23
24
25
    # [num_tokens]
    target_logits_indices: torch.Tensor
    # [batch_size]
    bonus_logits_indices: torch.Tensor
    # [num_tokens + batch_size]
    logits_indices: torch.Tensor
王敏's avatar
王敏 committed
26
27
    # [batch_size]
    spec_decode_ids: Optional[list[str]] = None
28
29
30
31
32
33
34
35
36
37
38
39

    def __post_init__(self):
        self.max_spec_len = max(self.num_draft_tokens)

    @classmethod
    def make_dummy(
        cls,
        draft_token_ids: list[list[int]],
        device: torch.device,
    ) -> "SpecDecodeMetadata":
        batch_size = len(draft_token_ids)
        num_draft_tokens = [len(ids) for ids in draft_token_ids]
40
        num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
41
42
43
        flattened_draft_token_ids = sum(draft_token_ids, [])
        num_tokens = len(flattened_draft_token_ids)

44
45
46
        draft_token_ids_tensor = torch.tensor(
            flattened_draft_token_ids, dtype=torch.int32, device=device
        )
47
        cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
48
        cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
49
50
51
52
        cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
        cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
            device
        )
53

54
55
56
57
58
59
60
        target_logits_indices = torch.zeros(
            num_tokens, dtype=torch.int32, device=device
        )
        bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device)
        logits_indices = torch.zeros(
            num_tokens + batch_size, dtype=torch.int32, device=device
        )
61
62
63
64
        return cls(
            draft_token_ids=draft_token_ids_tensor,
            num_draft_tokens=num_draft_tokens,
            cu_num_draft_tokens=cu_num_draft_tokens_tensor,
65
            cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
66
67
68
69
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
        )
luopl's avatar
luopl committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107


@dataclass
class MultiLayerEagleMetadata:
    # [batch_size]
    cached_len: torch.Tensor | None = None
    # [batch_size, layer_num]
    cached_token_ids: torch.Tensor | None = None
    # [batch_size, layer_num, hidden_size]
    cached_hidden_states: torch.Tensor | None = None
    # [batch_size, layer_num]
    cached_slot_mappings: torch.Tensor | None = None
    # [batch_size, layer_num]
    cached_positions: torch.Tensor | None = None

    @classmethod
    def make_dummy(
        cls,
        layer_num: int,
        hidden_size: int,
        device: torch.device,
    ) -> "MultiLayerEagleMetadata":
        cached_len = torch.zeros((1), dtype=torch.int64, device=device)
        cached_token_ids = torch.zeros((1, layer_num), dtype=torch.int32, device=device)
        cached_hidden_states = torch.zeros(
            (1, layer_num, hidden_size), dtype=torch.float32, device=device
        )
        cached_slot_mappings = torch.zeros(
            (1, layer_num), dtype=torch.int64, device=device
        )
        cached_positions = torch.zeros((1, layer_num), dtype=torch.int64, device=device)
        return cls(
            cached_len=cached_len,
            cached_token_ids=cached_token_ids,
            cached_hidden_states=cached_hidden_states,
            cached_slot_mappings=cached_slot_mappings,
            cached_positions=cached_positions,
        )