medusa.py 1.68 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
4
import torch
from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear

OlivierDehaene's avatar
OlivierDehaene committed
5

Nicolas Patry's avatar
Nicolas Patry committed
6
7
8
9
10
11
12
13
14
@dataclass
class Output:
    logits: torch.FloatTensor = None
    speculative_logits: torch.FloatTensor = None


class ResBlock(torch.nn.Module):
    def __init__(self, config, prefix, weights):
        super().__init__()
OlivierDehaene's avatar
OlivierDehaene committed
15
16
17
        self.linear = FastLinear.load(
            config, prefix=f"{prefix}.linear", weights=weights, bias=True
        )
Nicolas Patry's avatar
Nicolas Patry committed
18
19
20
21
22
23
24
        self.act = torch.nn.SiLU()

    def forward(self, x):
        return x + self.act(self.linear(x))


class MedusaModel(torch.nn.Module):
OlivierDehaene's avatar
OlivierDehaene committed
25
    def __init__(self, config, weights, lm_head):
Nicolas Patry's avatar
Nicolas Patry committed
26
27
        super().__init__()
        self.heads = torch.nn.ModuleList(
OlivierDehaene's avatar
OlivierDehaene committed
28
29
30
31
            [
                MedusaHead(config, prefix=f"{i}", weights=weights)
                for i in range(config["medusa_num_heads"])
            ]
Nicolas Patry's avatar
Nicolas Patry committed
32
33
34
35
36
37
38
39
40
41
42
43
        )
        self.lm_head = lm_head

    def forward(self, x):
        logits = self.lm_head(x)
        speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
        return logits, speculative_logits


class MedusaHead(torch.nn.Module):
    def __init__(self, config, prefix, weights):
        super().__init__()
OlivierDehaene's avatar
OlivierDehaene committed
44
45
46
47
48
49
        self.blocks = torch.nn.ModuleList(
            [
                ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
                for i in range(config["medusa_num_layers"])
            ]
        )
Nicolas Patry's avatar
Nicolas Patry committed
50
        n = len(self.blocks)
OlivierDehaene's avatar
OlivierDehaene committed
51
52
53
        self.out = FastLinear.load(
            config, prefix=f"{prefix}.{n}", weights=weights, bias=False
        )
Nicolas Patry's avatar
Nicolas Patry committed
54
55
56
57
58
59

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = self.out(x)
        return x