norm.py 490 Bytes
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
2
import torch
from torch import nn
3
4
5
6
7
8

try:
    import awq_ext  # with CUDA kernels
    AWQ_INSTALLED = True
except:
    AWQ_INSTALLED = False
Haotian Tang's avatar
Haotian Tang committed
9

Casper Hansen's avatar
Casper Hansen committed
10
class FasterTransformerRMSNorm(nn.Module):
Haotian Tang's avatar
Haotian Tang committed
11
12
13
14
15
16
17
    def __init__(self, weight, eps=1e-6):
        super().__init__()
        self.weight = weight
        self.variance_epsilon = eps

    def forward(self, x):
        output = torch.empty_like(x)
18
        awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
Haotian Tang's avatar
Haotian Tang committed
19
        return output