marlin.py 2.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn

try:
    import marlin
except ImportError:
    marlin = None

try:
    major, _minor = torch.cuda.get_device_capability()
    has_sm_8_0 = major >= 8
except Exception:
    has_sm_8_0 = False

MARLIN_TILE_SIZE = 16


@dataclass
class MarlinWeight:
    """
    Marlin weights.

    Attributes:
        B (torch.Tensor): int4-quantized weights packed into int32.
        s (torch.Tensor): float16 scales.
    """

    B: torch.Tensor
    s: torch.Tensor


class MarlinLinear(nn.Module):
    def __init__(
        self, *, B: torch.Tensor, s: torch.Tensor, bias: Optional[torch.Tensor]
    ):
        super().__init__()

        if not has_sm_8_0:
            raise NotImplementedError(
                "Using quantized marlin models requires CUDA capability 8.0 or later"
            )

        if marlin is None:
            raise NotImplementedError(
                "You do not seem to have marlin installed, either install it (cd server &&  make install-marlin)"
            )

        assert B.dtype == torch.int32
        assert s.dtype == torch.float16

        in_features = B.shape[0] * MARLIN_TILE_SIZE
        out_features = s.shape[1]
        assert (
            in_features % 128 == 0
        ), f"Number of input features ({in_features}) not divisable by 128"
        assert (
            out_features % 256 == 0
        ), f"Number of output features ({out_features}) not divisable by 256"

        group_size = -1 if s.shape[0] == 1 else in_features // s.shape[0]
        assert group_size in {
            -1,
            128,
        }, f"Group size must be -1 or 128, was {group_size}"

        self.register_buffer("B", B)
        self.register_buffer("s", s)
        if bias is not None:
            self.register_buffer("bias", bias)
        else:
            self.bias = None

        self.workspace = torch.zeros(
            out_features // 128 * 16, dtype=torch.int, device=B.device
        )

    def forward(self, A: torch.Tensor) -> torch.Tensor:
        assert marlin is not None
        C = torch.empty(
            A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device
        )
        marlin.mul(
            A.view((-1, A.shape[-1])),
            self.B,
            C.view((-1, C.shape[-1])),
            self.s,
            self.workspace,
        )

        if self.bias is not None:
            C += self.bias

        return C