Commit 1109480e authored by Augustin-Zidek's avatar Augustin-Zidek
Browse files

Initial release of AlphaFold.

PiperOrigin-RevId: 384954738
parents
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for data pipeline tools."""
import contextlib
import shutil
import tempfile
import time
from typing import Optional
from absl import logging
@contextlib.contextmanager
def tmpdir_manager(base_dir: Optional[str] = None):
"""Context manager that deletes a temporary directory on exit."""
tmpdir = tempfile.mkdtemp(dir=base_dir)
try:
yield tmpdir
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
@contextlib.contextmanager
def timing(msg: str):
logging.info('Started %s', msg)
tic = time.time()
yield
toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alphafold model."""
This diff is collapsed.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for all_atom."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from alphafold.model import all_atom
from alphafold.model import r3
L1_CLAMP_DISTANCE = 10
def get_identity_rigid(shape):
"""Returns identity rigid transform."""
ones = np.ones(shape)
zeros = np.zeros(shape)
rot = r3.Rots(ones, zeros, zeros,
zeros, ones, zeros,
zeros, zeros, ones)
trans = r3.Vecs(zeros, zeros, zeros)
return r3.Rigids(rot, trans)
def get_global_rigid_transform(rot_angle, translation, bcast_dims):
"""Returns rigid transform that globally rotates/translates by same amount."""
rot_angle = np.asarray(rot_angle)
translation = np.asarray(translation)
if bcast_dims:
for _ in range(bcast_dims):
rot_angle = np.expand_dims(rot_angle, 0)
translation = np.expand_dims(translation, 0)
sin_angle = np.sin(np.deg2rad(rot_angle))
cos_angle = np.cos(np.deg2rad(rot_angle))
ones = np.ones_like(sin_angle)
zeros = np.zeros_like(sin_angle)
rot = r3.Rots(ones, zeros, zeros,
zeros, cos_angle, -sin_angle,
zeros, sin_angle, cos_angle)
trans = r3.Vecs(translation[..., 0], translation[..., 1], translation[..., 2])
return r3.Rigids(rot, trans)
class AllAtomTest(parameterized.TestCase, absltest.TestCase):
@parameterized.named_parameters(
('identity', 0, [0, 0, 0]),
('rot_90', 90, [0, 0, 0]),
('trans_10', 0, [0, 0, 10]),
('rot_174_trans_1', 174, [1, 1, 1]))
def test_frame_aligned_point_error_perfect_on_global_transform(
self, rot_angle, translation):
"""Tests global transform between target and preds gives perfect score."""
# pylint: disable=bad-whitespace
target_positions = np.array(
[[ 21.182, 23.095, 19.731],
[ 22.055, 20.919, 17.294],
[ 24.599, 20.005, 15.041],
[ 25.567, 18.214, 12.166],
[ 28.063, 17.082, 10.043],
[ 28.779, 15.569, 6.985],
[ 30.581, 13.815, 4.612],
[ 29.258, 12.193, 2.296]])
# pylint: enable=bad-whitespace
global_rigid_transform = get_global_rigid_transform(
rot_angle, translation, 1)
target_positions = r3.vecs_from_tensor(target_positions)
pred_positions = r3.rigids_mul_vecs(
global_rigid_transform, target_positions)
positions_mask = np.ones(target_positions.x.shape[0])
target_frames = get_identity_rigid(10)
pred_frames = r3.rigids_mul_rigids(global_rigid_transform, target_frames)
frames_mask = np.ones(10)
fape = all_atom.frame_aligned_point_error(
pred_frames, target_frames, frames_mask, pred_positions,
target_positions, positions_mask, L1_CLAMP_DISTANCE,
L1_CLAMP_DISTANCE, epsilon=0)
self.assertAlmostEqual(fape, 0.)
@parameterized.named_parameters(
('identity',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
0.),
('shift_2.5',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[2.5, 0, 0], [7.5, 0, 0], [7.5, 0, 0]],
0.25),
('shift_5',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[5, 0, 0], [10, 0, 0], [15, 0, 0]],
0.5),
('shift_10',
[[0, 0, 0], [5, 0, 0], [10, 0, 0]],
[[10, 0, 0], [15, 0, 0], [0, 0, 0]],
1.))
def test_frame_aligned_point_error_matches_expected(
self, target_positions, pred_positions, expected_alddt):
"""Tests score matches expected."""
target_frames = get_identity_rigid(2)
pred_frames = target_frames
frames_mask = np.ones(2)
target_positions = r3.vecs_from_tensor(np.array(target_positions))
pred_positions = r3.vecs_from_tensor(np.array(pred_positions))
positions_mask = np.ones(target_positions.x.shape[0])
alddt = all_atom.frame_aligned_point_error(
pred_frames, target_frames, frames_mask, pred_positions,
target_positions, positions_mask, L1_CLAMP_DISTANCE,
L1_CLAMP_DISTANCE, epsilon=0)
self.assertAlmostEqual(alddt, expected_alddt)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of common Haiku modules for use in protein folding."""
import haiku as hk
import jax.numpy as jnp
class Linear(hk.Module):
"""Protein folding specific Linear Module.
This differs from the standard Haiku Linear in a few ways:
* It supports inputs of arbitrary rank
* Initializers are specified by strings
"""
def __init__(self,
num_output: int,
initializer: str = 'linear',
use_bias: bool = True,
bias_init: float = 0.,
name: str = 'linear'):
"""Constructs Linear Module.
Args:
num_output: number of output channels.
initializer: What initializer to use, should be one of {'linear', 'relu',
'zeros'}
use_bias: Whether to include trainable bias
bias_init: Value used to initialize bias.
name: name of module, used for name scopes.
"""
super().__init__(name=name)
self.num_output = num_output
self.initializer = initializer
self.use_bias = use_bias
self.bias_init = bias_init
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
"""Connects Module.
Args:
inputs: Tensor of shape [..., num_channel]
Returns:
output of shape [..., num_output]
"""
n_channels = int(inputs.shape[-1])
weight_shape = [n_channels, self.num_output]
if self.initializer == 'linear':
weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.)
elif self.initializer == 'relu':
weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.)
elif self.initializer == 'zeros':
weight_init = hk.initializers.Constant(0.0)
weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
weight_init)
# this is equivalent to einsum('...c,cd->...d', inputs, weights)
# but turns out to be slightly faster
inputs = jnp.swapaxes(inputs, -1, -2)
output = jnp.einsum('...cb,cd->...db', inputs, weights)
output = jnp.swapaxes(output, -1, -2)
if self.use_bias:
bias = hk.get_parameter('bias', [self.num_output], inputs.dtype,
hk.initializers.Constant(self.bias_init))
output += bias
return output
This diff is collapsed.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convenience functions for reading data."""
import io
import os
from typing import List
import haiku as hk
import numpy as np
from alphafold.model import utils
# Internal import (7716).
def casp_model_names(data_dir: str) -> List[str]:
params = os.listdir(os.path.join(data_dir, 'params'))
return [os.path.splitext(filename)[0] for filename in params]
def get_model_haiku_params(model_name: str, data_dir: str) -> hk.Params:
"""Get the Haiku parameters from a model name."""
path = os.path.join(data_dir, 'params', f'params_{model_name}.npz')
with open(path, 'rb') as f:
params = np.load(io.BytesIO(f.read()), allow_pickle=False)
return utils.flat_params_to_haiku(params)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to generate processed features."""
import copy
from typing import List, Mapping, Tuple
import ml_collections
import numpy as np
import tensorflow.compat.v1 as tf
from alphafold.model.tf import input_pipeline
from alphafold.model.tf import proteins_dataset
FeatureDict = Mapping[str, np.ndarray]
def make_data_config(
config: ml_collections.ConfigDict,
num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]:
"""Makes a data config for the input pipeline."""
cfg = copy.deepcopy(config.data)
feature_names = cfg.common.unsupervised_features
if cfg.common.use_templates:
feature_names += cfg.common.template_features
with cfg.unlocked():
cfg.eval.crop_size = num_res
return cfg, feature_names
def tf_example_to_features(tf_example: tf.train.Example,
config: ml_collections.ConfigDict,
random_seed: int = 0) -> FeatureDict:
"""Converts tf_example to numpy feature dictionary."""
num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0])
cfg, feature_names = make_data_config(config, num_res=num_res)
if 'deletion_matrix_int' in set(tf_example.features.feature):
deletion_matrix_int = (
tf_example.features.feature['deletion_matrix_int'].int64_list.value)
feat = tf.train.Feature(float_list=tf.train.FloatList(
value=map(float, deletion_matrix_int)))
tf_example.features.feature['deletion_matrix'].CopyFrom(feat)
del tf_example.features.feature['deletion_matrix_int']
tf_graph = tf.Graph()
with tf_graph.as_default(), tf.device('/device:CPU:0'):
tf.compat.v1.set_random_seed(random_seed)
tensor_dict = proteins_dataset.create_tensor_dict(
raw_data=tf_example.SerializeToString(),
features=feature_names)
processed_batch = input_pipeline.process_tensors_from_config(
tensor_dict, cfg)
tf_graph.finalize()
with tf.Session(graph=tf_graph) as sess:
features = sess.run(processed_batch)
return {k: v for k, v in features.items() if v.dtype != 'O'}
def np_example_to_features(np_example: FeatureDict,
config: ml_collections.ConfigDict,
random_seed: int = 0) -> FeatureDict:
"""Preprocesses NumPy feature dict using TF pipeline."""
np_example = dict(np_example)
num_res = int(np_example['seq_length'][0])
cfg, feature_names = make_data_config(config, num_res=num_res)
if 'deletion_matrix_int' in np_example:
np_example['deletion_matrix'] = (
np_example.pop('deletion_matrix_int').astype(np.float32))
tf_graph = tf.Graph()
with tf_graph.as_default(), tf.device('/device:CPU:0'):
tf.compat.v1.set_random_seed(random_seed)
tensor_dict = proteins_dataset.np_to_tensor_dict(
np_example=np_example, features=feature_names)
processed_batch = input_pipeline.process_tensors_from_config(
tensor_dict, cfg)
tf_graph.finalize()
with tf.Session(graph=tf_graph) as sess:
features = sess.run(processed_batch)
return {k: v for k, v in features.items() if v.dtype != 'O'}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment