act.py 296 Bytes
Newer Older
1
2
import torch.nn as nn

Casper's avatar
Casper committed
3

4
5
6
7
8
class ScaledActivation(nn.Module):
    def __init__(self, module, scales):
        super().__init__()
        self.act = module
        self.scales = nn.Parameter(scales.data)
Casper's avatar
Casper committed
9

10
11
    def forward(self, x):
        return self.act(x) / self.scales.view(1, 1, -1).to(x.device)