kernel.py 690 Bytes
Newer Older
oahzxl's avatar
oahzxl committed
1
2
3
4
5
6
7
8
9
10
import torch
import torch.nn.functional as F


def bias_sigmod_ele(y, bias, z):
    return torch.sigmoid(y + bias) * z


def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
                     residual: torch.Tensor, prob: float) -> torch.Tensor:
oahzxl's avatar
oahzxl committed
11
    out = (x + bias) * F.dropout(dropmask, p=prob, training=False)
oahzxl's avatar
oahzxl committed
12
13
14
15
16
17
18
19
    out = residual + out
    return out


def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
                              dropout_mask: torch.Tensor, Z_raw: torch.Tensor,
                              prob: float) -> torch.Tensor:
    return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b))