"CONTRIBUTING.md" did not exist on "4244d85a6888b8921aa9b369a16d180a41a5c531"
norm.py 671 Bytes
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
2
import torch
from torch import nn
3
4
5

try:
    import awq_ext  # with CUDA kernels
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
6

7
8
9
    AWQ_INSTALLED = True
except:
    AWQ_INSTALLED = False
Haotian Tang's avatar
Haotian Tang committed
10

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
11

Casper Hansen's avatar
Casper Hansen committed
12
class FasterTransformerRMSNorm(nn.Module):
Haotian Tang's avatar
Haotian Tang committed
13
14
15
16
17
18
    def __init__(self, weight, eps=1e-6):
        super().__init__()
        self.weight = weight
        self.variance_epsilon = eps

    def forward(self, x):
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
19
20
21
22
23
        assert AWQ_INSTALLED, (
            "AWQ kernels could not be loaded. "
            "Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
        )

Haotian Tang's avatar
Haotian Tang committed
24
        output = torch.empty_like(x)
25
        awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
26
27

        return output