exl2.py 469 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from dataclasses import dataclass


@dataclass
class Exl2Weight:
    """
    Exllama2 exl2 quantized weights.
    """

    q_weight: torch.Tensor
    q_scale: torch.Tensor
    q_invperm: torch.Tensor
    q_scale_max: torch.Tensor
    q_groups: torch.Tensor

    def __post_init__(self):
        self.q_scale_max /= 256
        self.q_invperm = self.q_invperm.short()

    @property
    def device(self) -> torch.device:
        return self.q_weight.device