common_modules.py 5.82 KB
Newer Older
Augustin-Zidek's avatar
Augustin-Zidek committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A collection of common Haiku modules for use in protein folding."""
16
17
18
import numbers
from typing import Union, Sequence

Augustin-Zidek's avatar
Augustin-Zidek committed
19
20
import haiku as hk
import jax.numpy as jnp
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np


# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978,
                                            dtype=np.float32)


def get_initializer_scale(initializer_name, input_shape):
  """Get Initializer for weights and scale to multiply activations by."""

  if initializer_name == 'zeros':
    w_init = hk.initializers.Constant(0.0)
  else:
    # fan-in scaling
    scale = 1.
    for channel_dim in input_shape:
      scale /= channel_dim
    if initializer_name == 'relu':
      scale *= 2

    noise_scale = scale

    stddev = np.sqrt(noise_scale)
    # Adjust stddev for truncation.
    stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR
    w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev)

  return w_init
Augustin-Zidek's avatar
Augustin-Zidek committed
50
51
52


class Linear(hk.Module):
53
  """Protein folding specific Linear module.
Augustin-Zidek's avatar
Augustin-Zidek committed
54
55

  This differs from the standard Haiku Linear in a few ways:
56
    * It supports inputs and outputs of arbitrary rank
Augustin-Zidek's avatar
Augustin-Zidek committed
57
58
59
60
    * Initializers are specified by strings
  """

  def __init__(self,
61
               num_output: Union[int, Sequence[int]],
Augustin-Zidek's avatar
Augustin-Zidek committed
62
               initializer: str = 'linear',
63
               num_input_dims: int = 1,
Augustin-Zidek's avatar
Augustin-Zidek committed
64
65
               use_bias: bool = True,
               bias_init: float = 0.,
66
               precision = None,
Augustin-Zidek's avatar
Augustin-Zidek committed
67
68
69
70
               name: str = 'linear'):
    """Constructs Linear Module.

    Args:
71
72
      num_output: Number of output channels. Can be tuple when outputting
          multiple dimensions.
Augustin-Zidek's avatar
Augustin-Zidek committed
73
74
      initializer: What initializer to use, should be one of {'linear', 'relu',
        'zeros'}
75
      num_input_dims: Number of dimensions from the end to project.
Augustin-Zidek's avatar
Augustin-Zidek committed
76
77
      use_bias: Whether to include trainable bias
      bias_init: Value used to initialize bias.
78
79
80
      precision: What precision to use for matrix multiplication, defaults
        to None.
      name: Name of module, used for name scopes.
Augustin-Zidek's avatar
Augustin-Zidek committed
81
82
    """
    super().__init__(name=name)
83
84
85
86
    if isinstance(num_output, numbers.Integral):
      self.output_shape = (num_output,)
    else:
      self.output_shape = tuple(num_output)
Augustin-Zidek's avatar
Augustin-Zidek committed
87
88
89
    self.initializer = initializer
    self.use_bias = use_bias
    self.bias_init = bias_init
90
91
92
    self.num_input_dims = num_input_dims
    self.num_output_dims = len(self.output_shape)
    self.precision = precision
Augustin-Zidek's avatar
Augustin-Zidek committed
93

94
  def __call__(self, inputs):
Augustin-Zidek's avatar
Augustin-Zidek committed
95
96
97
    """Connects Module.

    Args:
98
      inputs: Tensor with at least num_input_dims dimensions.
Augustin-Zidek's avatar
Augustin-Zidek committed
99
100

    Returns:
101
      output of shape [...] + num_output.
Augustin-Zidek's avatar
Augustin-Zidek committed
102
103
    """

104
105
106
107
108
109
110
111
    num_input_dims = self.num_input_dims

    if self.num_input_dims > 0:
      in_shape = inputs.shape[-self.num_input_dims:]
    else:
      in_shape = ()

    weight_init = get_initializer_scale(self.initializer, in_shape)
Augustin-Zidek's avatar
Augustin-Zidek committed
112

113
114
115
116
    in_letters = 'abcde'[:self.num_input_dims]
    out_letters = 'hijkl'[:self.num_output_dims]

    weight_shape = in_shape + self.output_shape
Augustin-Zidek's avatar
Augustin-Zidek committed
117
118
119
    weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
                               weight_init)

120
121
122
    equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}'

    output = jnp.einsum(equation, inputs, weights, precision=self.precision)
Augustin-Zidek's avatar
Augustin-Zidek committed
123
124

    if self.use_bias:
125
      bias = hk.get_parameter('bias', self.output_shape, inputs.dtype,
Augustin-Zidek's avatar
Augustin-Zidek committed
126
127
128
129
                              hk.initializers.Constant(self.bias_init))
      output += bias

    return output
130

Augustin Zidek's avatar
Augustin Zidek committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

class LayerNorm(hk.LayerNorm):
  """LayerNorm module.

  Equivalent to hk.LayerNorm but with different parameter shapes: they are
  always vectors rather than possibly higher-rank tensors. This makes it easier
  to change the layout whilst keep the model weight-compatible.
  """

  def __init__(self,
               axis,
               create_scale: bool,
               create_offset: bool,
               eps: float = 1e-5,
               scale_init=None,
               offset_init=None,
               use_fast_variance: bool = False,
               name=None,
               param_axis=None):
    super().__init__(
        axis=axis,
        create_scale=False,
        create_offset=False,
        eps=eps,
        scale_init=None,
        offset_init=None,
        use_fast_variance=use_fast_variance,
        name=name,
        param_axis=param_axis)
    self._temp_create_scale = create_scale
    self._temp_create_offset = create_offset

  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    is_bf16 = (x.dtype == jnp.bfloat16)
    if is_bf16:
      x = x.astype(jnp.float32)

    param_axis = self.param_axis[0] if self.param_axis else -1
    param_shape = (x.shape[param_axis],)

    param_broadcast_shape = [1] * x.ndim
    param_broadcast_shape[param_axis] = x.shape[param_axis]
    scale = None
    offset = None
    if self._temp_create_scale:
      scale = hk.get_parameter(
          'scale', param_shape, x.dtype, init=self.scale_init)
      scale = scale.reshape(param_broadcast_shape)

    if self._temp_create_offset:
      offset = hk.get_parameter(
          'offset', param_shape, x.dtype, init=self.offset_init)
      offset = offset.reshape(param_broadcast_shape)

    out = super().__call__(x, scale=scale, offset=offset)

    if is_bf16:
      out = out.astype(jnp.bfloat16)

    return out