eetq.py 1.29 KB
Newer Older
jixx's avatar
jixx committed
1
2
from dataclasses import dataclass

jixx's avatar
init  
jixx committed
3
4
import torch
from EETQ import quant_weights, w8_a16_gemm
jixx's avatar
jixx committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from text_generation_server.utils.weights import UnquantizedWeight


@dataclass
class EETQWeight(UnquantizedWeight):
    weight: torch.Tensor

    def get_linear(self, bias: torch.Tensor):
        try:
            from text_generation_server.layers.eetq import EETQLinear

            return EETQLinear(self.weight, bias)
        except ImportError:
            raise ImportError(
                "Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
            )
jixx's avatar
init  
jixx committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43


class EETQLinear(torch.nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
        device = weight.device
        if weight.dtype != torch.float16:
            weight = weight.to(dtype=torch.float16)
        weight = torch.t(weight).contiguous().cpu()
        weight, scale = quant_weights(weight, torch.int8, False)

        self.weight = weight.cuda(device)
        self.scale = scale.cuda(device)
        self.bias = bias.cuda(device) if bias is not None else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = w8_a16_gemm(input, self.weight, self.scale)
        output = output + self.bias if self.bias is not None else output
        return output