glu_activations.py 1.29 KB
Newer Older
hepj987's avatar
hepj987 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from torch import nn
from torch.nn import functional as F

from megatron import logging
from megatron.model.utils import log_debug_usage

logger = logging.get_logger(__name__)

class _GLUBaseModule(nn.Module):
    def __init__(self, activation_fn):
        super().__init__()
        self.activation_fn = activation_fn

    def forward(self, x):
        # dim=-1 breaks in jit for pt<1.10
        x1, x2 = x.chunk(2, dim=(x.ndim - 1))
        return x1 * self.activation_fn(x2)


class LiGLU(_GLUBaseModule):
    def __init__(self):
        super().__init__(nn.Identity())


class GEGLU(_GLUBaseModule):
    def __init__(self):
        super().__init__(F.gelu)


class ReGLU(_GLUBaseModule):
    def __init__(self):
        super().__init__(F.relu)


class SwiGLU(_GLUBaseModule):
    def __init__(self):
        super().__init__(F.silu)


liglu = log_debug_usage(logger, "Using GLU activation: LiGLU.")(torch.jit.script(LiGLU()))
geglu = log_debug_usage(logger, "Using GLU activation: GELU.")(torch.jit.script(GEGLU()))
reglu = log_debug_usage(logger, "Using GLU activation: ReGLU.")(torch.jit.script(ReGLU()))
swiglu = log_debug_usage(logger, "Using GLU activation: SwiGLU.")(torch.jit.script(SwiGLU()))


GLU_ACTIVATIONS = {
    "geglu": geglu,
    "liglu": liglu,
    "reglu": reglu,
    "swiglu": swiglu,
}