awq.py 3.89 KB
Newer Older
1
2
3
4
5
6
from typing import Optional

import torch
from torch.nn.parameter import Parameter

from vllm import quantization_ops
7
8
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
                                                       RowParallelLinear)
9
10
11
12
13


class AWQColumnParallelLinear(ColumnParallelLinear):

    def create_weights(self, dtype: torch.dtype) -> None:
chooper1's avatar
chooper1 committed
14
15
16
17
18
        assert self.input_size % self.quant_config.group_size == 0
        if self.output_size_per_partition % self.quant_config.pack_factor != 0:
            raise ValueError(
                "The tensor parallel size is not aligned with the quantized "
                "weight shape. Please use a different tensor parallel size.")
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        self.qweight = Parameter(
            torch.empty(
                self.input_size,
                self.output_size_per_partition //
                self.quant_config.pack_factor,
                device="cuda",
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        self.qzeros = Parameter(
            torch.empty(
                self.input_size // self.quant_config.group_size,
                self.output_size_per_partition //
                self.quant_config.pack_factor,
                device="cuda",
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        self.scales = Parameter(
            torch.empty(
                self.input_size // self.quant_config.group_size,
                self.output_size_per_partition,
                device="cuda",
                dtype=dtype,
            ),
            requires_grad=False,
        )

    def apply_weights(
        self,
        x: torch.Tensor,
        bias: Optional[torch.Tensor],
    ) -> torch.Tensor:
        pack_factor = self.quant_config.pack_factor
55
        out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
56
57
58
59
60
61
62
63
64
65
66
67
        reshaped_x = x.reshape(-1, x.shape[-1])
        out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
                                        self.qzeros, pack_factor)
        if bias is not None:
            out = out + bias
        return out.reshape(out_shape)


class AWQRowParallelLinear(RowParallelLinear):

    def create_weights(self, dtype: torch.dtype) -> None:
        assert self.output_size % self.quant_config.pack_factor == 0
chooper1's avatar
chooper1 committed
68
69
70
71
        if self.input_size_per_partition % self.quant_config.group_size != 0:
            raise ValueError(
                "The tensor parallel size is not aligned with the quantized "
                "weight shape. Please use a different tensor parallel size.")
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        self.qweight = Parameter(
            torch.empty(
                self.input_size_per_partition,
                self.output_size // self.quant_config.pack_factor,
                device="cuda",
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        self.qzeros = Parameter(
            torch.empty(
                self.input_size_per_partition // self.quant_config.group_size,
                self.output_size // self.quant_config.pack_factor,
                device="cuda",
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        self.scales = Parameter(
            torch.empty(
                self.input_size_per_partition // self.quant_config.group_size,
                self.output_size,
                device="cuda",
                dtype=dtype,
            ),
            requires_grad=False,
        )

    def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
        pack_factor = self.quant_config.pack_factor
102
        out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
103
104
105
106
        reshaped_x = x.reshape(-1, x.shape[-1])
        out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
                                        self.qzeros, pack_factor)
        return out.reshape(out_shape)