initializer.py 1005 Bytes
Newer Older
oahzxl's avatar
oahzxl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import math

import numpy as np
import torch.nn as nn


def glorot_uniform_af(x, gain=1.0):
    """
    initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
    In PyTorch:
    [feature_out, feature_in, n_head ...]
    In Jax:
    [... n_head, feature_in, feature_out]
    However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
    [feature_in, n_head, feature_out]

    In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
    """
    fan_in, fan_out = x.shape[-2:]
    if len(x.shape) > 2:
        receptive_field_size = np.prod(x.shape[:-2])
        fan_in *= receptive_field_size
        fan_out *= receptive_field_size
    std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
    dev = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation

    nn.init.uniform_(x, -dev, dev)

    return x