linear.py 3.76 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
import torch
from text_generation_server.utils.import_utils import SYSTEM
3
from torch.nn import functional as F
4
import os
Nicolas Patry's avatar
Nicolas Patry committed
5

xuxzh1's avatar
xuxzh1 committed
6
7
8
9
10
# if SYSTEM == "rocm":
#     ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
#         "true",
#         "1",
#     )
xuxzh1's avatar
xuxzh1 committed
11
ROCM_USE_SKINNY_GEMM=False
xuxzh1's avatar
xuxzh1 committed
12
13
14
15
16
17
18
#     if ROCM_USE_SKINNY_GEMM:
#         try:
#             from vllm import _custom_C
#         except Exception as e:
#             raise ImportError(
#                 f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
#             )
fxmarty's avatar
fxmarty committed
19

Nicolas Patry's avatar
Nicolas Patry committed
20
21
22
23
24
25
26
27

class FastLinear(torch.nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
drbh's avatar
drbh committed
28
        self.weight = torch.nn.Parameter(weight, requires_grad=False)
Nicolas Patry's avatar
Nicolas Patry committed
29
        if bias is not None:
drbh's avatar
drbh committed
30
            self.bias = torch.nn.Parameter(bias, requires_grad=False)
Nicolas Patry's avatar
Nicolas Patry committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        else:
            self.bias = None

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(weight, bias)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight, self.bias)


fxmarty's avatar
fxmarty committed
47
48
49
50
51
52
53
54
55
56
57
58
59
class FastLinearROCm(torch.nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(weight)
        if bias is not None:
            self.bias = torch.nn.Parameter(bias)
        else:
            self.bias = None

60
61
62
63
64
65
66
67
        self.cu_count = torch.cuda.get_device_properties(
            device="cuda"
        ).multi_processor_count
        self.use_skinny_gemm = (
            ROCM_USE_SKINNY_GEMM
            and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
        )

fxmarty's avatar
fxmarty committed
68
69
70
71
72
73
74
75
76
77
78
79
80
    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(weight, bias)

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        weight = self.weight
        bias = self.bias

81
82
83
84
85
        if (
            self.use_skinny_gemm
            and inp.dtype == torch.float16
            and inp.shape[-1] % 8 == 0
        ):
fxmarty's avatar
fxmarty committed
86
87
88
89
90
91
92
            batched = False
            inp_shape = inp.shape

            if inp.dim() == 3:
                inp = inp.view(-1, inp_shape[-1])
                batched = True

93
94
95
96
97
98
99
100
101
102
            m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
            if m > 8 and n <= 4:
                out = torch.empty(
                    inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
                )
                _custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
            elif m % 4 == 0 and n == 1 and k <= 8192:
                out = torch.empty(
                    inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
                )
fxmarty's avatar
fxmarty committed
103
104
105
106
107
108
109
110
111
112
113
114
115
                _custom_C.LLMM1(weight, inp, out, 4)
            else:
                out = F.linear(inp, weight)

            if batched:
                out.view(*inp_shape[:-1], out.shape[-1])

            if bias is not None:
                out = out + bias
            return out
        return F.linear(inp, self.weight, self.bias)


116
117
118
119
120
def get_linear(weight, bias):
    # Weights that are loaded through methods that are not
    # quantization-aware are still bare tensors. We may want
    # to change this in the future.
    if isinstance(weight, torch.Tensor):
fxmarty's avatar
fxmarty committed
121
        if SYSTEM == "rocm":
122
            return FastLinearROCm(weight, bias)
Nicolas Patry's avatar
Nicolas Patry committed
123
        else:
124
            return FastLinear(weight, bias)
Nicolas Patry's avatar
Nicolas Patry committed
125

126
    return weight.get_linear(bias)