scale.py 266 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.nn as nn


class Scale(nn.Module):

    def __init__(self, scale=1.0):
        super(Scale, self).__init__()
        self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))

    def forward(self, x):
        return x * self.scale