initializers.py 237 Bytes
Newer Older
1
2
import torch.nn as nn

3

4
5
6
7
def GlorotOrthogonal(tensor, scale=2.0):
    if tensor is not None:
        nn.init.orthogonal_(tensor.data)
        scale /= (tensor.size(-2) + tensor.size(-1)) * tensor.var()
8
        tensor.data *= scale.sqrt()