activations.py 386 Bytes
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
import torch.nn.functional as F

try:
    jit_fuser = torch.compile
except:
    jit_fuser = torch.jit.script


@jit_fuser
def squared_relu(x: torch.Tensor) -> torch.Tensor:
    return torch.pow(F.relu(x), 2)


@jit_fuser
def quick_gelu(x: torch.Tensor) -> torch.Tensor:
    return x * torch.sigmoid(1.702 * x)