act.py 299 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
import torch.nn as nn

class ScaledActivation(nn.Module):
    def __init__(self, module, scales):
        super().__init__()
        self.act = module
        self.scales = nn.Parameter(scales.data)
    
    def forward(self, x):
        return self.act(x) / self.scales.view(1, 1, -1).to(x.device)