q_linear.py 3.71 KB
Newer Older
1
2
3
4
5
6
import torch
import torch.nn as nn
from vllm import _custom_ops as ops


class QuantLinearInt8(nn.Module):
gushiqiao's avatar
gushiqiao committed
7
    def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
8
9
10
11
12
13
14
15
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
        self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))

        if bias:
gushiqiao's avatar
gushiqiao committed
16
            self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
17
18
19
20
21
22
23
        else:
            self.register_buffer("bias", None)

    def act_quant_func(self, x):
        input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
        return input_tensor_quant, input_tensor_scale

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    def forward(self, input_tensor):
        input_tensor = input_tensor.squeeze(0)
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight.t(),
            input_tensor_scale,
            self.weight_scale.float(),
            self.bias,
        )
        return output_tensor.unsqueeze(0)

gushiqiao's avatar
gushiqiao committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    def _apply(self, fn):
        for module in self.children():
            module._apply(fn)

        def maybe_cast(t):
            if t is not None and t.device != fn(t).device:
                return fn(t)
            return t

        self.weight = maybe_cast(self.weight)
        self.weight_scale = maybe_cast(self.weight_scale)
        self.bias = maybe_cast(self.bias)
        return self

56
57

class QuantLinearFp8(nn.Module):
gushiqiao's avatar
gushiqiao committed
58
    def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
59
60
61
62
63
64
65
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
        self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))

        if bias:
gushiqiao's avatar
gushiqiao committed
66
            self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        else:
            self.register_buffer("bias", None)

    def act_quant_func(self, x):
        input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
        return input_tensor_quant, input_tensor_scale

    def forward(self, input_tensor):
        input_tensor = input_tensor.squeeze(0)
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
83
            input_tensor_quant,
84
            self.weight.t(),
85
86
            input_tensor_scale,
            self.weight_scale.float(),
87
            self.bias,
88
        )
gushiqiao's avatar
gushiqiao committed
89

90
        return output_tensor.unsqueeze(0)
gushiqiao's avatar
gushiqiao committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104

    def _apply(self, fn):
        for module in self.children():
            module._apply(fn)

        def maybe_cast(t):
            if t is not None and t.device != fn(t).device:
                return fn(t)
            return t

        self.weight = maybe_cast(self.weight)
        self.weight_scale = maybe_cast(self.weight_scale)
        self.bias = maybe_cast(self.bias)
        return self