utils.py 206 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
import sgl_kernel


def rms_norm(x, weight, eps):
    x = x.contiguous()
    orig_shape = x.shape
    x = x.view(-1, orig_shape[-1])
    x = sgl_kernel.rmsnorm(x, weight, eps).view(orig_shape)
    return x