speculative.py 1.1 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from typing import Tuple, Optional
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
from text_generation_server.layers.tensor_parallel import TensorParallelHead


class SpeculativeHead(torch.nn.Module):
    def __init__(self, lm_head, medusa):
        super().__init__()
        self.head = lm_head
        self.medusa = medusa

    @staticmethod
    def load(config, prefix: str, weights):
        use_medusa = config.use_medusa
        if use_medusa:
            lm_head = None
            try:
                medusa = MedusaHeadV1.load(config, prefix, weights)
            except:
                medusa = MedusaHeadV2(config, prefix, weights)
        else:
            lm_head = TensorParallelHead.load(config, prefix, weights)
            medusa = None
        return SpeculativeHead(lm_head, medusa)

    def forward(
        self, input: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        if self.medusa is not None:
            return self.medusa(input)

        assert self.head is not None
        logits = self.head(input)
        return logits, None