linear.py 3.91 KB
Newer Older
jixx's avatar
init  
jixx committed
1
2
import torch
from text_generation_server.utils.import_utils import SYSTEM
jixx's avatar
jixx committed
3
4
from torch.nn import functional as F
import os
jixx's avatar
init  
jixx committed
5
6

# if SYSTEM == "rocm":
jixx's avatar
jixx committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#     ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
#         "true",
#         "1",
#     )

ROCM_USE_SKINNY_GEMM = False

    # if ROCM_USE_SKINNY_GEMM:
    #     try:
    #         # from vllm import _custom_C
    #         from vllm import _custom_ops
    #         text-generation-inference
    #     except Exception as e:
    #         raise ImportError(
    #             f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
    #         )
jixx's avatar
init  
jixx committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


class FastLinear(torch.nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(weight, requires_grad=False)
        if bias is not None:
            self.bias = torch.nn.Parameter(bias, requires_grad=False)
        else:
            self.bias = None

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(weight, bias)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight, self.bias)


class FastLinearROCm(torch.nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(weight)
        if bias is not None:
            self.bias = torch.nn.Parameter(bias)
        else:
            self.bias = None

jixx's avatar
jixx committed
64
65
66
67
68
69
70
71
        self.cu_count = torch.cuda.get_device_properties(
            device="cuda"
        ).multi_processor_count
        self.use_skinny_gemm = (
            ROCM_USE_SKINNY_GEMM
            and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
        )

jixx's avatar
init  
jixx committed
72
73
74
75
76
77
78
79
80
81
82
83
84
    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(weight, bias)

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        weight = self.weight
        bias = self.bias

jixx's avatar
jixx committed
85
86
87
88
89
        if (
            self.use_skinny_gemm
            and inp.dtype == torch.float16
            and inp.shape[-1] % 8 == 0
        ):
jixx's avatar
init  
jixx committed
90
91
92
93
94
95
96
            batched = False
            inp_shape = inp.shape

            if inp.dim() == 3:
                inp = inp.view(-1, inp_shape[-1])
                batched = True

jixx's avatar
jixx committed
97
98
99
100
101
102
103
104
105
106
            m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
            # if m > 8 and n <= 4:
            #     out = torch.empty(
            #         inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
            #     )
            #     _custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
            # elif m % 4 == 0 and n == 1 and k <= 8192:
            #     out = torch.empty(
            #         inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
            #     )
jixx's avatar
init  
jixx committed
107
108
109
110
            #     _custom_C.LLMM1(weight, inp, out, 4)
            # else:
            #     out = F.linear(inp, weight)

jixx's avatar
jixx committed
111
112
            out = F.linear(inp, weight)

jixx's avatar
init  
jixx committed
113
114
115
116
117
118
119
120
121
            if batched:
                out.view(*inp_shape[:-1], out.shape[-1])

            if bias is not None:
                out = out + bias
            return out
        return F.linear(inp, self.weight, self.bias)


jixx's avatar
jixx committed
122
123
124
125
126
def get_linear(weight, bias):
    # Weights that are loaded through methods that are not
    # quantization-aware are still bare tensors. We may want
    # to change this in the future.
    if isinstance(weight, torch.Tensor):
jixx's avatar
init  
jixx committed
127
        if SYSTEM == "rocm":
jixx's avatar
jixx committed
128
            return FastLinearROCm(weight, bias)
jixx's avatar
init  
jixx committed
129
        else:
jixx's avatar
jixx committed
130
            return FastLinear(weight, bias)
jixx's avatar
init  
jixx committed
131

jixx's avatar
jixx committed
132
    return weight.get_linear(bias)