Commit eb93322b authored by mashun1's avatar mashun1
Browse files

dtk24.04.1

parents
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code for constructing the model."""
from typing import Any, Mapping, Optional, Union
from absl import logging
from alphafold.common import confidence
from alphafold.model import features
from alphafold.model import modules
from alphafold.model import modules_multimer
import haiku as hk
import jax
import ml_collections
import numpy as np
import tensorflow.compat.v1 as tf
import tree
def get_confidence_metrics(
prediction_result: Mapping[str, Any],
multimer_mode: bool) -> Mapping[str, Any]:
"""Post processes prediction_result to get confidence metrics."""
confidence_metrics = {}
confidence_metrics['plddt'] = confidence.compute_plddt(
prediction_result['predicted_lddt']['logits'])
if 'predicted_aligned_error' in prediction_result:
confidence_metrics.update(confidence.compute_predicted_aligned_error(
logits=prediction_result['predicted_aligned_error']['logits'],
breaks=prediction_result['predicted_aligned_error']['breaks']))
confidence_metrics['ptm'] = confidence.predicted_tm_score(
logits=prediction_result['predicted_aligned_error']['logits'],
breaks=prediction_result['predicted_aligned_error']['breaks'],
asym_id=None)
if multimer_mode:
# Compute the ipTM only for the multimer model.
confidence_metrics['iptm'] = confidence.predicted_tm_score(
logits=prediction_result['predicted_aligned_error']['logits'],
breaks=prediction_result['predicted_aligned_error']['breaks'],
asym_id=prediction_result['predicted_aligned_error']['asym_id'],
interface=True)
confidence_metrics['ranking_confidence'] = (
0.8 * confidence_metrics['iptm'] + 0.2 * confidence_metrics['ptm'])
if not multimer_mode:
# Monomer models use mean pLDDT for model ranking.
confidence_metrics['ranking_confidence'] = np.mean(
confidence_metrics['plddt'])
return confidence_metrics
class RunModel:
"""Container for JAX model."""
def __init__(self,
config: ml_collections.ConfigDict,
params: Optional[Mapping[str, Mapping[str, jax.Array]]] = None):
self.config = config
self.params = params
self.multimer_mode = config.model.global_config.multimer_mode
if self.multimer_mode:
def _forward_fn(batch):
model = modules_multimer.AlphaFold(self.config.model)
return model(
batch,
is_training=False)
else:
def _forward_fn(batch):
model = modules.AlphaFold(self.config.model)
return model(
batch,
is_training=False,
compute_loss=False,
ensemble_representations=True)
self.apply = jax.jit(hk.transform(_forward_fn).apply)
self.init = jax.jit(hk.transform(_forward_fn).init)
def init_params(self, feat: features.FeatureDict, random_seed: int = 0):
"""Initializes the model parameters.
If none were provided when this class was instantiated then the parameters
are randomly initialized.
Args:
feat: A dictionary of NumPy feature arrays as output by
RunModel.process_features.
random_seed: A random seed to use to initialize the parameters if none
were set when this class was initialized.
"""
if not self.params:
# Init params randomly.
rng = jax.random.PRNGKey(random_seed)
self.params = hk.data_structures.to_mutable_dict(
self.init(rng, feat))
logging.warning('Initialized parameters randomly')
def process_features(
self,
raw_features: Union[tf.train.Example, features.FeatureDict],
random_seed: int) -> features.FeatureDict:
"""Processes features to prepare for feeding them into the model.
Args:
raw_features: The output of the data pipeline either as a dict of NumPy
arrays or as a tf.train.Example.
random_seed: The random seed to use when processing the features.
Returns:
A dict of NumPy feature arrays suitable for feeding into the model.
"""
if self.multimer_mode:
return raw_features
# Single-chain mode.
if isinstance(raw_features, dict):
return features.np_example_to_features(
np_example=raw_features,
config=self.config,
random_seed=random_seed)
else:
return features.tf_example_to_features(
tf_example=raw_features,
config=self.config,
random_seed=random_seed)
def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct:
self.init_params(feat)
logging.info('Running eval_shape with shape(feat) = %s',
tree.map_structure(lambda x: x.shape, feat))
shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat)
logging.info('Output shape was %s', shape)
return shape
def predict(self,
feat: features.FeatureDict,
random_seed: int,
) -> Mapping[str, Any]:
"""Makes a prediction by inferencing the model on the provided features.
Args:
feat: A dictionary of NumPy feature arrays as output by
RunModel.process_features.
random_seed: The random seed to use when running the model. In the
multimer model this controls the MSA sampling.
Returns:
A dictionary of model outputs.
"""
self.init_params(feat)
logging.info('Running predict with shape(feat) = %s',
tree.map_structure(lambda x: x.shape, feat))
result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)
# This block is to ensure benchmark timings are accurate. Some blocking is
# already happening when computing get_confidence_metrics, and this ensures
# all outputs are blocked on.
jax.tree_map(lambda x: x.block_until_ready(), result)
result.update(
get_confidence_metrics(result, multimer_mode=self.multimer_mode))
logging.info('Output shape was %s',
tree.map_structure(lambda x: x.shape, result))
return result
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modules and code used in the core part of AlphaFold.
The structure generation code is in 'folding.py'.
"""
import functools
from alphafold.common import residue_constants
from alphafold.model import all_atom
from alphafold.model import common_modules
from alphafold.model import folding
from alphafold.model import layer_stack
from alphafold.model import lddt
from alphafold.model import mapping
from alphafold.model import prng
from alphafold.model import quat_affine
from alphafold.model import utils
import haiku as hk
import jax
import jax.numpy as jnp
_SOFTMAX_MASK = -1e9
def softmax_cross_entropy(logits, labels):
"""Computes softmax cross entropy given logits and one-hot class labels."""
loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
return jnp.asarray(loss)
def sigmoid_cross_entropy(logits, labels):
"""Computes sigmoid cross entropy given logits and multiple class labels."""
log_p = jax.nn.log_sigmoid(logits)
# log(1 - sigmoid(x)) = log_sigmoid(-x), the latter is more numerically stable
log_not_p = jax.nn.log_sigmoid(-logits)
loss = -labels * log_p - (1. - labels) * log_not_p
return jnp.asarray(loss)
def apply_dropout(*, tensor, safe_key, rate, is_training, broadcast_dim=None):
"""Applies dropout to a tensor."""
if is_training and rate != 0.0:
shape = list(tensor.shape)
if broadcast_dim is not None:
shape[broadcast_dim] = 1
keep_rate = 1.0 - rate
keep = jax.random.bernoulli(safe_key.get(), keep_rate, shape=shape)
return keep * tensor / keep_rate
else:
return tensor
def dropout_wrapper(module,
input_act,
mask,
safe_key,
global_config,
output_act=None,
is_training=True,
**kwargs):
"""Applies module + dropout + residual update."""
if output_act is None:
output_act = input_act
gc = global_config
residual = module(input_act, mask, is_training=is_training, **kwargs)
dropout_rate = 0.0 if gc.deterministic else module.config.dropout_rate
# Will override `is_training` to True if want to use dropout.
should_apply_dropout = True if gc.eval_dropout else is_training
if module.config.shared_dropout:
if module.config.orientation == 'per_row':
broadcast_dim = 0
else:
broadcast_dim = 1
else:
broadcast_dim = None
residual = apply_dropout(tensor=residual,
safe_key=safe_key,
rate=dropout_rate,
is_training=should_apply_dropout,
broadcast_dim=broadcast_dim)
new_act = output_act + residual
return new_act
def create_extra_msa_feature(batch):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Arguments:
batch: a dictionary with the following keys:
* 'extra_msa': [N_extra_seq, N_res] MSA that wasn't selected as a cluster
centre. Note, that this is not one-hot encoded.
* 'extra_has_deletion': [N_extra_seq, N_res] Whether there is a deletion to
the left of each position in the extra MSA.
* 'extra_deletion_value': [N_extra_seq, N_res] The number of deletions to
the left of each position in the extra MSA.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
msa_1hot = jax.nn.one_hot(batch['extra_msa'], 23)
msa_feat = [msa_1hot,
jnp.expand_dims(batch['extra_has_deletion'], axis=-1),
jnp.expand_dims(batch['extra_deletion_value'], axis=-1)]
return jnp.concatenate(msa_feat, axis=-1)
class AlphaFoldIteration(hk.Module):
"""A single recycling iteration of AlphaFold architecture.
Computes ensembled (averaged) representations from the provided features.
These representations are then passed to the various heads
that have been requested by the configuration file. Each head also returns a
loss which is combined as a weighted sum to produce the total loss.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22
"""
def __init__(self, config, global_config, name='alphafold_iteration'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
ensembled_batch,
non_ensembled_batch,
is_training,
compute_loss=False,
ensemble_representations=False,
return_representations=False):
num_ensemble = jnp.asarray(ensembled_batch['seq_length'].shape[0])
if not ensemble_representations:
assert ensembled_batch['seq_length'].shape[0] == 1
def slice_batch(i):
b = {k: v[i] for k, v in ensembled_batch.items()}
b.update(non_ensembled_batch)
return b
# Compute representations for each batch element and average.
evoformer_module = EmbeddingsAndEvoformer(
self.config.embeddings_and_evoformer, self.global_config)
batch0 = slice_batch(0)
representations = evoformer_module(batch0, is_training)
# MSA representations are not ensembled so
# we don't pass tensor into the loop.
msa_representation = representations['msa']
del representations['msa']
# Average the representations (except MSA) over the batch dimension.
if ensemble_representations:
def body(x):
"""Add one element to the representations ensemble."""
i, current_representations = x
feats = slice_batch(i)
representations_update = evoformer_module(
feats, is_training)
new_representations = {}
for k in current_representations:
new_representations[k] = (
current_representations[k] + representations_update[k])
return i+1, new_representations
if hk.running_init():
# When initializing the Haiku module, run one iteration of the
# while_loop to initialize the Haiku modules used in `body`.
_, representations = body((1, representations))
else:
_, representations = hk.while_loop(
lambda x: x[0] < num_ensemble,
body,
(1, representations))
for k in representations:
if k != 'msa':
representations[k] /= num_ensemble.astype(representations[k].dtype)
representations['msa'] = msa_representation
batch = batch0 # We are not ensembled from here on.
heads = {}
for head_name, head_config in sorted(self.config.heads.items()):
if not head_config.weight:
continue # Do not instantiate zero-weight heads.
head_factory = {
'masked_msa': MaskedMsaHead,
'distogram': DistogramHead,
'structure_module': functools.partial(
folding.StructureModule, compute_loss=compute_loss),
'predicted_lddt': PredictedLDDTHead,
'predicted_aligned_error': PredictedAlignedErrorHead,
'experimentally_resolved': ExperimentallyResolvedHead,
}[head_name]
heads[head_name] = (head_config,
head_factory(head_config, self.global_config))
total_loss = 0.
ret = {}
ret['representations'] = representations
def loss(module, head_config, ret, name, filter_ret=True):
if filter_ret:
value = ret[name]
else:
value = ret
loss_output = module.loss(value, batch)
ret[name].update(loss_output)
loss = head_config.weight * ret[name]['loss']
return loss
for name, (head_config, module) in heads.items():
# Skip PredictedLDDTHead and PredictedAlignedErrorHead until
# StructureModule is executed.
if name in ('predicted_lddt', 'predicted_aligned_error'):
continue
else:
ret[name] = module(representations, batch, is_training)
if 'representations' in ret[name]:
# Extra representations from the head. Used by the structure module
# to provide activations for the PredictedLDDTHead.
representations.update(ret[name].pop('representations'))
if compute_loss:
total_loss += loss(module, head_config, ret, name)
if self.config.heads.get('predicted_lddt.weight', 0.0):
# Add PredictedLDDTHead after StructureModule executes.
name = 'predicted_lddt'
# Feed all previous results to give access to structure_module result.
head_config, module = heads[name]
ret[name] = module(representations, batch, is_training)
if compute_loss:
total_loss += loss(module, head_config, ret, name, filter_ret=False)
if ('predicted_aligned_error' in self.config.heads
and self.config.heads.get('predicted_aligned_error.weight', 0.0)):
# Add PredictedAlignedErrorHead after StructureModule executes.
name = 'predicted_aligned_error'
# Feed all previous results to give access to structure_module result.
head_config, module = heads[name]
ret[name] = module(representations, batch, is_training)
if compute_loss:
total_loss += loss(module, head_config, ret, name, filter_ret=False)
if compute_loss:
return ret, total_loss
else:
return ret
class AlphaFold(hk.Module):
"""AlphaFold model with recycling.
Jumper et al. (2021) Suppl. Alg. 2 "Inference"
"""
def __init__(self, config, name='alphafold'):
super().__init__(name=name)
self.config = config
self.global_config = config.global_config
def __call__(
self,
batch,
is_training,
compute_loss=False,
ensemble_representations=False,
return_representations=False):
"""Run the AlphaFold model.
Arguments:
batch: Dictionary with inputs to the AlphaFold model.
is_training: Whether the system is in training or inference mode.
compute_loss: Whether to compute losses (requires extra features
to be present in the batch and knowing the true structure).
ensemble_representations: Whether to use ensembling of representations.
return_representations: Whether to also return the intermediate
representations.
Returns:
When compute_loss is True:
a tuple of loss and output of AlphaFoldIteration.
When compute_loss is False:
just output of AlphaFoldIteration.
The output of AlphaFoldIteration is a nested dictionary containing
predictions from the various heads.
"""
impl = AlphaFoldIteration(self.config, self.global_config)
batch_size, num_residues = batch['aatype'].shape
def get_prev(ret):
new_prev = {
'prev_pos':
ret['structure_module']['final_atom_positions'],
'prev_msa_first_row': ret['representations']['msa_first_row'],
'prev_pair': ret['representations']['pair'],
}
return jax.tree_map(jax.lax.stop_gradient, new_prev)
def do_call(prev,
recycle_idx,
compute_loss=compute_loss):
if self.config.resample_msa_in_recycling:
num_ensemble = batch_size // (self.config.num_recycle + 1)
def slice_recycle_idx(x):
start = recycle_idx * num_ensemble
size = num_ensemble
return jax.lax.dynamic_slice_in_dim(x, start, size, axis=0)
ensembled_batch = jax.tree_map(slice_recycle_idx, batch)
else:
num_ensemble = batch_size
ensembled_batch = batch
non_ensembled_batch = jax.tree_map(lambda x: x, prev)
return impl(
ensembled_batch=ensembled_batch,
non_ensembled_batch=non_ensembled_batch,
is_training=is_training,
compute_loss=compute_loss,
ensemble_representations=ensemble_representations)
prev = {}
emb_config = self.config.embeddings_and_evoformer
if emb_config.recycle_pos:
prev['prev_pos'] = jnp.zeros(
[num_residues, residue_constants.atom_type_num, 3])
if emb_config.recycle_features:
prev['prev_msa_first_row'] = jnp.zeros(
[num_residues, emb_config.msa_channel])
prev['prev_pair'] = jnp.zeros(
[num_residues, num_residues, emb_config.pair_channel])
if self.config.num_recycle:
if 'num_iter_recycling' in batch:
# Training time: num_iter_recycling is in batch.
# The value for each ensemble batch is the same, so arbitrarily taking
# 0-th.
num_iter = batch['num_iter_recycling'][0]
# Add insurance that we will not run more
# recyclings than the model is configured to run.
num_iter = jnp.minimum(num_iter, self.config.num_recycle)
else:
# Eval mode or tests: use the maximum number of iterations.
num_iter = self.config.num_recycle
body = lambda x: (x[0] + 1, # pylint: disable=g-long-lambda
get_prev(do_call(x[1], recycle_idx=x[0],
compute_loss=False)))
if hk.running_init():
# When initializing the Haiku module, run one iteration of the
# while_loop to initialize the Haiku modules used in `body`.
_, prev = body((0, prev))
else:
_, prev = hk.while_loop(
lambda x: x[0] < num_iter,
body,
(0, prev))
else:
num_iter = 0
ret = do_call(prev=prev, recycle_idx=num_iter)
if compute_loss:
ret = ret[0], [ret[1]]
if not return_representations:
del (ret[0] if compute_loss else ret)['representations'] # pytype: disable=unsupported-operands
return ret
class TemplatePairStack(hk.Module):
"""Pair stack for the templates.
Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack"
"""
def __init__(self, config, global_config, name='template_pair_stack'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, pair_act, pair_mask, is_training, safe_key=None):
"""Builds TemplatePairStack module.
Arguments:
pair_act: Pair activations for single template, shape [N_res, N_res, c_t].
pair_mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
safe_key: Safe key object encapsulating the random number generation key.
Returns:
Updated pair_act, shape [N_res, N_res, c_t].
"""
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
gc = self.global_config
c = self.config
if not c.num_block:
return pair_act
def block(x):
"""One block of the template pair stack."""
pair_act, safe_key = x
dropout_wrapper_fn = functools.partial(
dropout_wrapper, is_training=is_training, global_config=gc)
safe_key, *sub_keys = safe_key.split(6)
sub_keys = iter(sub_keys)
pair_act = dropout_wrapper_fn(
TriangleAttention(c.triangle_attention_starting_node, gc,
name='triangle_attention_starting_node'),
pair_act,
pair_mask,
next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleAttention(c.triangle_attention_ending_node, gc,
name='triangle_attention_ending_node'),
pair_act,
pair_mask,
next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
name='triangle_multiplication_outgoing'),
pair_act,
pair_mask,
next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleMultiplication(c.triangle_multiplication_incoming, gc,
name='triangle_multiplication_incoming'),
pair_act,
pair_mask,
next(sub_keys))
pair_act = dropout_wrapper_fn(
Transition(c.pair_transition, gc, name='pair_transition'),
pair_act,
pair_mask,
next(sub_keys))
return pair_act, safe_key
if gc.use_remat:
block = hk.remat(block)
res_stack = layer_stack.layer_stack(c.num_block)(block)
pair_act, safe_key = res_stack((pair_act, safe_key))
return pair_act
class Transition(hk.Module):
"""Transition layer.
Jumper et al. (2021) Suppl. Alg. 9 "MSATransition"
Jumper et al. (2021) Suppl. Alg. 15 "PairTransition"
"""
def __init__(self, config, global_config, name='transition_block'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, act, mask, is_training=True):
"""Builds Transition module.
Arguments:
act: A tensor of queries of size [batch_size, N_res, N_channel].
mask: A tensor denoting the mask of size [batch_size, N_res].
is_training: Whether the module is in training mode.
Returns:
A float32 tensor of size [batch_size, N_res, N_channel].
"""
_, _, nc = act.shape
num_intermediate = int(nc * self.config.num_intermediate_factor)
mask = jnp.expand_dims(mask, axis=-1)
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='input_layer_norm')(
act)
transition_module = hk.Sequential([
common_modules.Linear(
num_intermediate,
initializer='relu',
name='transition1'), jax.nn.relu,
common_modules.Linear(
nc,
initializer=utils.final_init(self.global_config),
name='transition2')
])
act = mapping.inference_subbatch(
transition_module,
self.global_config.subbatch_size,
batched_args=[act],
nonbatched_args=[],
low_memory=not is_training)
return act
def glorot_uniform():
return hk.initializers.VarianceScaling(scale=1.0,
mode='fan_avg',
distribution='uniform')
class Attention(hk.Module):
"""Multihead attention."""
def __init__(self, config, global_config, output_dim, name='attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.output_dim = output_dim
def __call__(self, q_data, m_data, mask, nonbatched_bias=None):
"""Builds Attention module.
Arguments:
q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].
m_data: A tensor of memories from which the keys and values are
projected, shape [batch_size, N_keys, m_channels].
mask: A mask for the attention, shape [batch_size, N_queries, N_keys].
nonbatched_bias: Shared bias, shape [N_queries, N_keys].
Returns:
A float32 tensor of shape [batch_size, N_queries, output_dim].
"""
# Sensible default for when the config keys are missing
key_dim = self.config.get('key_dim', int(q_data.shape[-1]))
value_dim = self.config.get('value_dim', int(m_data.shape[-1]))
num_head = self.config.num_head
assert key_dim % num_head == 0
assert value_dim % num_head == 0
key_dim = key_dim // num_head
value_dim = value_dim // num_head
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=glorot_uniform())
q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)
v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)
if nonbatched_bias is not None:
logits += jnp.expand_dims(nonbatched_bias, axis=0)
logits = jnp.where(mask, logits, _SOFTMAX_MASK)
weights = utils.stable_softmax(logits)
weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)
if self.global_config.zero_init:
init = hk.initializers.Constant(0.0)
else:
init = glorot_uniform()
if self.config.gating:
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
gating_weights) + gating_bias
gate_values = jax.nn.sigmoid(gate_values)
weighted_avg *= gate_values
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init)
o_bias = hk.get_parameter(
'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
return output
class GlobalAttention(hk.Module):
"""Global attention.
Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" lines 2-7
"""
def __init__(self, config, global_config, output_dim, name='attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.output_dim = output_dim
def __call__(self, q_data, m_data, q_mask):
"""Builds GlobalAttention module.
Arguments:
q_data: A tensor of queries with size [batch_size, N_queries,
q_channels]
m_data: A tensor of memories from which the keys and values
projected. Size [batch_size, N_keys, m_channels]
q_mask: A binary mask for q_data with zeros in the padded sequence
elements and ones otherwise. Size [batch_size, N_queries, q_channels]
(or broadcastable to this shape).
Returns:
A float32 tensor of size [batch_size, N_queries, output_dim].
"""
# Sensible default for when the config keys are missing
key_dim = self.config.get('key_dim', int(q_data.shape[-1]))
value_dim = self.config.get('value_dim', int(m_data.shape[-1]))
num_head = self.config.num_head
assert key_dim % num_head == 0
assert value_dim % num_head == 0
key_dim = key_dim // num_head
value_dim = value_dim // num_head
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], value_dim),
dtype=q_data.dtype,
init=glorot_uniform())
v = jnp.einsum('bka,ac->bkc', m_data, v_weights)
q_avg = utils.mask_mean(q_mask, q_data, axis=1)
q = jnp.einsum('ba,ahc->bhc', q_avg, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ac->bkc', m_data, k_weights)
bias = q_mask[:, None, :, 0]
logits = jnp.einsum('bhc,bkc->bhk', q, k)
logits = jnp.where(bias, logits, _SOFTMAX_MASK)
weights = utils.stable_softmax(logits)
weighted_avg = jnp.einsum('bhk,bkc->bhc', weights, v)
if self.global_config.zero_init:
init = hk.initializers.Constant(0.0)
else:
init = glorot_uniform()
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init)
o_bias = hk.get_parameter(
'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
if self.config.gating:
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights)
gate_values = jax.nn.sigmoid(gate_values + gating_bias)
weighted_avg = weighted_avg[:, None] * gate_values
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
else:
output = jnp.einsum('bhc,hco->bo', weighted_avg, o_weights) + o_bias
output = output[:, None]
return output
class MSARowAttentionWithPairBias(hk.Module):
"""MSA per-row attention biased by the pair representation.
Jumper et al. (2021) Suppl. Alg. 7 "MSARowAttentionWithPairBias"
"""
def __init__(self, config, global_config,
name='msa_row_attention_with_pair_bias'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
msa_act,
msa_mask,
pair_act,
is_training=False):
"""Builds MSARowAttentionWithPairBias module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
pair_act: [N_res, N_res, c_z] pair representation.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m].
"""
c = self.config
assert len(msa_act.shape) == 3
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_row'
mask = msa_mask[:, None, None, :]
assert len(mask.shape) == 4
msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
pair_act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='feat_2d_norm')(
pair_act)
init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1]))
weights = hk.get_parameter(
'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head),
dtype=msa_act.dtype,
init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
attn_mod = Attention(
c, self.global_config, msa_act.shape[-1])
msa_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, mask],
nonbatched_args=[nonbatched_bias],
low_memory=not is_training)
return msa_act
class MSAColumnAttention(hk.Module):
"""MSA per-column attention.
Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention"
"""
def __init__(self, config, global_config, name='msa_column_attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
msa_act,
msa_mask,
is_training=False):
"""Builds MSAColumnAttention module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m]
"""
c = self.config
assert len(msa_act.shape) == 3
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_column'
msa_act = jnp.swapaxes(msa_act, -2, -3)
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
mask = msa_mask[:, None, None, :]
assert len(mask.shape) == 4
msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
attn_mod = Attention(
c, self.global_config, msa_act.shape[-1])
msa_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, mask],
nonbatched_args=[],
low_memory=not is_training)
msa_act = jnp.swapaxes(msa_act, -2, -3)
return msa_act
class MSAColumnGlobalAttention(hk.Module):
"""MSA per-column global attention.
Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention"
"""
def __init__(self, config, global_config, name='msa_column_global_attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
msa_act,
msa_mask,
is_training=False):
"""Builds MSAColumnGlobalAttention module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m].
"""
c = self.config
assert len(msa_act.shape) == 3
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_column'
msa_act = jnp.swapaxes(msa_act, -2, -3)
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
attn_mod = GlobalAttention(
c, self.global_config, msa_act.shape[-1],
name='attention')
# [N_seq, N_res, 1]
msa_mask = jnp.expand_dims(msa_mask, axis=-1)
msa_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, msa_mask],
nonbatched_args=[],
low_memory=not is_training)
msa_act = jnp.swapaxes(msa_act, -2, -3)
return msa_act
class TriangleAttention(hk.Module):
"""Triangle Attention.
Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode"
Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode"
"""
def __init__(self, config, global_config, name='triangle_attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, pair_act, pair_mask, is_training=False):
"""Builds TriangleAttention module.
Arguments:
pair_act: [N_res, N_res, c_z] pair activations tensor
pair_mask: [N_res, N_res] mask of non-padded regions in the tensor.
is_training: Whether the module is in training mode.
Returns:
Update to pair_act, shape [N_res, N_res, c_z].
"""
c = self.config
assert len(pair_act.shape) == 3
assert len(pair_mask.shape) == 2
assert c.orientation in ['per_row', 'per_column']
if c.orientation == 'per_column':
pair_act = jnp.swapaxes(pair_act, -2, -3)
pair_mask = jnp.swapaxes(pair_mask, -1, -2)
mask = pair_mask[:, None, None, :]
assert len(mask.shape) == 4
pair_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
pair_act)
init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1]))
weights = hk.get_parameter(
'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head),
dtype=pair_act.dtype,
init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
attn_mod = Attention(
c, self.global_config, pair_act.shape[-1])
pair_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[pair_act, pair_act, mask],
nonbatched_args=[nonbatched_bias],
low_memory=not is_training)
if c.orientation == 'per_column':
pair_act = jnp.swapaxes(pair_act, -2, -3)
return pair_act
class MaskedMsaHead(hk.Module):
"""Head to predict MSA at the masked locations.
The MaskedMsaHead employs a BERT-style objective to reconstruct a masked
version of the full MSA, based on a linear projection of
the MSA representation.
Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction"
"""
def __init__(self, config, global_config, name='masked_msa_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
if global_config.multimer_mode:
self.num_output = len(residue_constants.restypes_with_x_and_gap)
else:
self.num_output = config.num_output
def __call__(self, representations, batch, is_training):
"""Builds MaskedMsaHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'msa': MSA representation, shape [N_seq, N_res, c_m].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* 'logits': logits of shape [N_seq, N_res, N_aatype] with
(unnormalized) log probabilies of predicted aatype at position.
"""
del batch
logits = common_modules.Linear(
self.num_output,
initializer=utils.final_init(self.global_config),
name='logits')(
representations['msa'])
return dict(logits=logits)
def loss(self, value, batch):
errors = softmax_cross_entropy(
labels=jax.nn.one_hot(batch['true_msa'], num_classes=self.num_output),
logits=value['logits'])
loss = (jnp.sum(errors * batch['bert_mask'], axis=(-2, -1)) /
(1e-8 + jnp.sum(batch['bert_mask'], axis=(-2, -1))))
return {'loss': loss}
class PredictedLDDTHead(hk.Module):
"""Head to predict the per-residue LDDT to be used as a confidence measure.
Jumper et al. (2021) Suppl. Sec. 1.9.6 "Model confidence prediction (pLDDT)"
Jumper et al. (2021) Suppl. Alg. 29 "predictPerResidueLDDT_Ca"
"""
def __init__(self, config, global_config, name='predicted_lddt_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds PredictedLDDTHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'structure_module': Single representation from the structure module,
shape [N_res, c_s].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing :
* 'logits': logits of shape [N_res, N_bins] with
(unnormalized) log probabilies of binned predicted lDDT.
"""
act = representations['structure_module']
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='input_layer_norm')(
act)
act = common_modules.Linear(
self.config.num_channels,
initializer='relu',
name='act_0')(
act)
act = jax.nn.relu(act)
act = common_modules.Linear(
self.config.num_channels,
initializer='relu',
name='act_1')(
act)
act = jax.nn.relu(act)
logits = common_modules.Linear(
self.config.num_bins,
initializer=utils.final_init(self.global_config),
name='logits')(
act)
# Shape (batch_size, num_res, num_bins)
return dict(logits=logits)
def loss(self, value, batch):
# Shape (num_res, 37, 3)
pred_all_atom_pos = value['structure_module']['final_atom_positions']
# Shape (num_res, 37, 3)
true_all_atom_pos = batch['all_atom_positions']
# Shape (num_res, 37)
all_atom_mask = batch['all_atom_mask']
# Shape (num_res,)
lddt_ca = lddt.lddt(
# Shape (batch_size, num_res, 3)
predicted_points=pred_all_atom_pos[None, :, 1, :],
# Shape (batch_size, num_res, 3)
true_points=true_all_atom_pos[None, :, 1, :],
# Shape (batch_size, num_res, 1)
true_points_mask=all_atom_mask[None, :, 1:2].astype(jnp.float32),
cutoff=15.,
per_residue=True)
lddt_ca = jax.lax.stop_gradient(lddt_ca)
num_bins = self.config.num_bins
bin_index = jnp.floor(lddt_ca * num_bins).astype(jnp.int32)
# protect against out of range for lddt_ca == 1
bin_index = jnp.minimum(bin_index, num_bins - 1)
lddt_ca_one_hot = jax.nn.one_hot(bin_index, num_classes=num_bins)
# Shape (num_res, num_channel)
logits = value['predicted_lddt']['logits']
errors = softmax_cross_entropy(labels=lddt_ca_one_hot, logits=logits)
# Shape (num_res,)
mask_ca = all_atom_mask[:, residue_constants.atom_order['CA']]
mask_ca = mask_ca.astype(jnp.float32)
loss = jnp.sum(errors * mask_ca) / (jnp.sum(mask_ca) + 1e-8)
if self.config.filter_by_resolution:
# NMR & distillation have resolution = 0
loss *= ((batch['resolution'] >= self.config.min_resolution)
& (batch['resolution'] <= self.config.max_resolution)).astype(
jnp.float32)
output = {'loss': loss}
return output
class PredictedAlignedErrorHead(hk.Module):
"""Head to predict the distance errors in the backbone alignment frames.
Can be used to compute predicted TM-Score.
Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction"
"""
def __init__(self, config, global_config,
name='predicted_aligned_error_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds PredictedAlignedErrorHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'pair': pair representation, shape [N_res, N_res, c_z].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* logits: logits for aligned error, shape [N_res, N_res, N_bins].
* bin_breaks: array containing bin breaks, shape [N_bins - 1].
"""
act = representations['pair']
# Shape (num_res, num_res, num_bins)
logits = common_modules.Linear(
self.config.num_bins,
initializer=utils.final_init(self.global_config),
name='logits')(act)
# Shape (num_bins,)
breaks = jnp.linspace(
0., self.config.max_error_bin, self.config.num_bins - 1)
return dict(logits=logits, breaks=breaks)
def loss(self, value, batch):
# Shape (num_res, 7)
predicted_affine = quat_affine.QuatAffine.from_tensor(
value['structure_module']['final_affines'])
# Shape (num_res, 7)
true_affine = quat_affine.QuatAffine.from_tensor(
batch['backbone_affine_tensor'])
# Shape (num_res)
mask = batch['backbone_affine_mask']
# Shape (num_res, num_res)
square_mask = mask[:, None] * mask[None, :]
num_bins = self.config.num_bins
# (1, num_bins - 1)
breaks = value['predicted_aligned_error']['breaks']
# (1, num_bins)
logits = value['predicted_aligned_error']['logits']
# Compute the squared error for each alignment.
def _local_frame_points(affine):
points = [jnp.expand_dims(x, axis=-2) for x in affine.translation]
return affine.invert_point(points, extra_dims=1)
error_dist2_xyz = [
jnp.square(a - b)
for a, b in zip(_local_frame_points(predicted_affine),
_local_frame_points(true_affine))]
error_dist2 = sum(error_dist2_xyz)
# Shape (num_res, num_res)
# First num_res are alignment frames, second num_res are the residues.
error_dist2 = jax.lax.stop_gradient(error_dist2)
sq_breaks = jnp.square(breaks)
true_bins = jnp.sum((
error_dist2[..., None] > sq_breaks).astype(jnp.int32), axis=-1)
errors = softmax_cross_entropy(
labels=jax.nn.one_hot(true_bins, num_bins, axis=-1), logits=logits)
loss = (jnp.sum(errors * square_mask, axis=(-2, -1)) /
(1e-8 + jnp.sum(square_mask, axis=(-2, -1))))
if self.config.filter_by_resolution:
# NMR & distillation have resolution = 0
loss *= ((batch['resolution'] >= self.config.min_resolution)
& (batch['resolution'] <= self.config.max_resolution)).astype(
jnp.float32)
output = {'loss': loss}
return output
class ExperimentallyResolvedHead(hk.Module):
"""Predicts if an atom is experimentally resolved in a high-res structure.
Only trained on high-resolution X-ray crystals & cryo-EM.
Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction'
"""
def __init__(self, config, global_config,
name='experimentally_resolved_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds ExperimentallyResolvedHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'single': Single representation, shape [N_res, c_s].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* 'logits': logits of shape [N_res, 37],
log probability that an atom is resolved in atom37 representation,
can be converted to probability by applying sigmoid.
"""
logits = common_modules.Linear(
37, # atom_exists.shape[-1]
initializer=utils.final_init(self.global_config),
name='logits')(representations['single'])
return dict(logits=logits)
def loss(self, value, batch):
logits = value['logits']
assert len(logits.shape) == 2
# Does the atom appear in the amino acid?
atom_exists = batch['atom37_atom_exists']
# Is the atom resolved in the experiment? Subset of atom_exists,
# *except for OXT*
all_atom_mask = batch['all_atom_mask'].astype(jnp.float32)
xent = sigmoid_cross_entropy(labels=all_atom_mask, logits=logits)
loss = jnp.sum(xent * atom_exists) / (1e-8 + jnp.sum(atom_exists))
if self.config.filter_by_resolution:
# NMR & distillation examples have resolution = 0.
loss *= ((batch['resolution'] >= self.config.min_resolution)
& (batch['resolution'] <= self.config.max_resolution)).astype(
jnp.float32)
output = {'loss': loss}
return output
def _layer_norm(axis=-1, name='layer_norm'):
return common_modules.LayerNorm(
axis=axis,
create_scale=True,
create_offset=True,
eps=1e-5,
use_fast_variance=True,
scale_init=hk.initializers.Constant(1.),
offset_init=hk.initializers.Constant(0.),
param_axis=axis,
name=name)
class TriangleMultiplication(hk.Module):
"""Triangle multiplication layer ("outgoing" or "incoming").
Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing"
Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming"
"""
def __init__(self, config, global_config, name='triangle_multiplication'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, left_act, left_mask, is_training=True):
"""Builds TriangleMultiplication module.
Arguments:
left_act: Pair activations, shape [N_res, N_res, c_z]
left_mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
Returns:
Outputs, same shape/type as left_act.
"""
del is_training
if self.config.fuse_projection_weights:
return self._fused_triangle_multiplication(left_act, left_mask)
else:
return self._triangle_multiplication(left_act, left_mask)
@hk.transparent
def _triangle_multiplication(self, left_act, left_mask):
"""Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""
c = self.config
gc = self.global_config
mask = left_mask[..., None]
act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(left_act)
input_act = act
left_projection = common_modules.Linear(
c.num_intermediate_channel,
name='left_projection')
left_proj_act = mask * left_projection(act)
right_projection = common_modules.Linear(
c.num_intermediate_channel,
name='right_projection')
right_proj_act = mask * right_projection(act)
left_gate_values = jax.nn.sigmoid(common_modules.Linear(
c.num_intermediate_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='left_gate')(act))
right_gate_values = jax.nn.sigmoid(common_modules.Linear(
c.num_intermediate_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='right_gate')(act))
left_proj_act *= left_gate_values
right_proj_act *= right_gate_values
# "Outgoing" edges equation: 'ikc,jkc->ijc'
# "Incoming" edges equation: 'kjc,kic->ijc'
# Note on the Suppl. Alg. 11 & 12 notation:
# For the "outgoing" edges, a = left_proj_act and b = right_proj_act
# For the "incoming" edges, it's swapped:
# b = left_proj_act and a = right_proj_act
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='center_layer_norm')(
act)
output_channel = int(input_act.shape[-1])
act = common_modules.Linear(
output_channel,
initializer=utils.final_init(gc),
name='output_projection')(act)
gate_values = jax.nn.sigmoid(common_modules.Linear(
output_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='gating_linear')(input_act))
act *= gate_values
return act
@hk.transparent
def _fused_triangle_multiplication(self, left_act, left_mask):
"""TriangleMultiplication with fused projection weights."""
mask = left_mask[..., None]
c = self.config
gc = self.global_config
left_act = _layer_norm(axis=-1, name='left_norm_input')(left_act)
# Both left and right projections are fused into projection.
projection = common_modules.Linear(
2*c.num_intermediate_channel, name='projection')
proj_act = mask * projection(left_act)
# Both left + right gate are fused into gate_values.
gate_values = common_modules.Linear(
2 * c.num_intermediate_channel,
name='gate',
bias_init=1.,
initializer=utils.final_init(gc))(left_act)
proj_act *= jax.nn.sigmoid(gate_values)
left_proj_act = proj_act[:, :, :c.num_intermediate_channel]
right_proj_act = proj_act[:, :, c.num_intermediate_channel:]
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = _layer_norm(axis=-1, name='center_norm')(act)
output_channel = int(left_act.shape[-1])
act = common_modules.Linear(
output_channel,
initializer=utils.final_init(gc),
name='output_projection')(act)
gate_values = common_modules.Linear(
output_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='gating_linear')(left_act)
act *= jax.nn.sigmoid(gate_values)
return act
class DistogramHead(hk.Module):
"""Head to predict a distogram.
Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction"
"""
def __init__(self, config, global_config, name='distogram_head'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, representations, batch, is_training):
"""Builds DistogramHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'pair': pair representation, shape [N_res, N_res, c_z].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* logits: logits for distogram, shape [N_res, N_res, N_bins].
* bin_breaks: array containing bin breaks, shape [N_bins - 1,].
"""
half_logits = common_modules.Linear(
self.config.num_bins,
initializer=utils.final_init(self.global_config),
name='half_logits')(
representations['pair'])
logits = half_logits + jnp.swapaxes(half_logits, -2, -3)
breaks = jnp.linspace(self.config.first_break, self.config.last_break,
self.config.num_bins - 1)
return dict(logits=logits, bin_edges=breaks)
def loss(self, value, batch):
return _distogram_log_loss(value['logits'], value['bin_edges'],
batch, self.config.num_bins)
def _distogram_log_loss(logits, bin_edges, batch, num_bins):
"""Log loss of a distogram."""
assert len(logits.shape) == 3
positions = batch['pseudo_beta']
mask = batch['pseudo_beta_mask']
assert positions.shape[-1] == 3
sq_breaks = jnp.square(bin_edges)
dist2 = jnp.sum(
jnp.square(
jnp.expand_dims(positions, axis=-2) -
jnp.expand_dims(positions, axis=-3)),
axis=-1,
keepdims=True)
true_bins = jnp.sum(dist2 > sq_breaks, axis=-1)
errors = softmax_cross_entropy(
labels=jax.nn.one_hot(true_bins, num_bins), logits=logits)
square_mask = jnp.expand_dims(mask, axis=-2) * jnp.expand_dims(mask, axis=-1)
avg_error = (
jnp.sum(errors * square_mask, axis=(-2, -1)) /
(1e-6 + jnp.sum(square_mask, axis=(-2, -1))))
dist2 = dist2[..., 0]
return dict(loss=avg_error, true_dist=jnp.sqrt(1e-6 + dist2))
class OuterProductMean(hk.Module):
"""Computes mean outer product.
Jumper et al. (2021) Suppl. Alg. 10 "OuterProductMean"
"""
def __init__(self,
config,
global_config,
num_output_channel,
name='outer_product_mean'):
super().__init__(name=name)
self.global_config = global_config
self.config = config
self.num_output_channel = num_output_channel
def __call__(self, act, mask, is_training=True):
"""Builds OuterProductMean module.
Arguments:
act: MSA representation, shape [N_seq, N_res, c_m].
mask: MSA mask, shape [N_seq, N_res].
is_training: Whether the module is in training mode.
Returns:
Update to pair representation, shape [N_res, N_res, c_z].
"""
gc = self.global_config
c = self.config
mask = mask[..., None]
act = common_modules.LayerNorm([-1], True, True, name='layer_norm_input')(act)
left_act = mask * common_modules.Linear(
c.num_outer_channel,
initializer='linear',
name='left_projection')(
act)
right_act = mask * common_modules.Linear(
c.num_outer_channel,
initializer='linear',
name='right_projection')(
act)
if gc.zero_init:
init_w = hk.initializers.Constant(0.0)
else:
init_w = hk.initializers.VarianceScaling(scale=2., mode='fan_in')
output_w = hk.get_parameter(
'output_w',
shape=(c.num_outer_channel, c.num_outer_channel,
self.num_output_channel),
dtype=act.dtype,
init=init_w)
output_b = hk.get_parameter(
'output_b', shape=(self.num_output_channel,),
dtype=act.dtype,
init=hk.initializers.Constant(0.0))
def compute_chunk(left_act):
# This is equivalent to
#
# act = jnp.einsum('abc,ade->dceb', left_act, right_act)
# act = jnp.einsum('dceb,cef->bdf', act, output_w) + output_b
#
# but faster.
left_act = jnp.transpose(left_act, [0, 2, 1])
act = jnp.einsum('acb,ade->dceb', left_act, right_act)
act = jnp.einsum('dceb,cef->dbf', act, output_w) + output_b
return jnp.transpose(act, [1, 0, 2])
act = mapping.inference_subbatch(
compute_chunk,
c.chunk_size,
batched_args=[left_act],
nonbatched_args=[],
low_memory=True,
input_subbatch_dim=1,
output_subbatch_dim=0)
epsilon = 1e-3
norm = jnp.einsum('abc,adc->bdc', mask, mask)
act /= epsilon + norm
return act
def dgram_from_positions(positions, num_bins, min_bin, max_bin):
"""Compute distogram from amino acid positions.
Arguments:
positions: [N_res, 3] Position coordinates.
num_bins: The number of bins in the distogram.
min_bin: The left edge of the first bin.
max_bin: The left edge of the final bin. The final bin catches
everything larger than `max_bin`.
Returns:
Distogram with the specified number of bins.
"""
def squared_difference(x, y):
return jnp.square(x - y)
lower_breaks = jnp.linspace(min_bin, max_bin, num_bins)
lower_breaks = jnp.square(lower_breaks)
upper_breaks = jnp.concatenate([lower_breaks[1:],
jnp.array([1e8], dtype=jnp.float32)], axis=-1)
dist2 = jnp.sum(
squared_difference(
jnp.expand_dims(positions, axis=-2),
jnp.expand_dims(positions, axis=-3)),
axis=-1, keepdims=True)
dgram = ((dist2 > lower_breaks).astype(jnp.float32) *
(dist2 < upper_breaks).astype(jnp.float32))
return dgram
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
"""Create pseudo beta features."""
is_gly = jnp.equal(aatype, residue_constants.restype_order['G'])
ca_idx = residue_constants.atom_order['CA']
cb_idx = residue_constants.atom_order['CB']
pseudo_beta = jnp.where(
jnp.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :])
if all_atom_masks is not None:
pseudo_beta_mask = jnp.where(
is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
pseudo_beta_mask = pseudo_beta_mask.astype(jnp.float32)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
class EvoformerIteration(hk.Module):
"""Single iteration (block) of Evoformer stack.
Jumper et al. (2021) Suppl. Alg. 6 "EvoformerStack" lines 2-10
"""
def __init__(self, config, global_config, is_extra_msa,
name='evoformer_iteration'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.is_extra_msa = is_extra_msa
def __call__(self, activations, masks, is_training=True, safe_key=None):
"""Builds EvoformerIteration module.
Arguments:
activations: Dictionary containing activations:
* 'msa': MSA activations, shape [N_seq, N_res, c_m].
* 'pair': pair activations, shape [N_res, N_res, c_z].
masks: Dictionary of masks:
* 'msa': MSA mask, shape [N_seq, N_res].
* 'pair': pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
safe_key: prng.SafeKey encapsulating rng key.
Returns:
Outputs, same shape/type as act.
"""
c = self.config
gc = self.global_config
msa_act, pair_act = activations['msa'], activations['pair']
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
msa_mask, pair_mask = masks['msa'], masks['pair']
dropout_wrapper_fn = functools.partial(
dropout_wrapper,
is_training=is_training,
global_config=gc)
safe_key, *sub_keys = safe_key.split(10)
sub_keys = iter(sub_keys)
outer_module = OuterProductMean(
config=c.outer_product_mean,
global_config=self.global_config,
num_output_channel=int(pair_act.shape[-1]),
name='outer_product_mean')
if c.outer_product_mean.first:
pair_act = dropout_wrapper_fn(
outer_module,
msa_act,
msa_mask,
safe_key=next(sub_keys),
output_act=pair_act)
msa_act = dropout_wrapper_fn(
MSARowAttentionWithPairBias(
c.msa_row_attention_with_pair_bias, gc,
name='msa_row_attention_with_pair_bias'),
msa_act,
msa_mask,
safe_key=next(sub_keys),
pair_act=pair_act)
if not self.is_extra_msa:
attn_mod = MSAColumnAttention(
c.msa_column_attention, gc, name='msa_column_attention')
else:
attn_mod = MSAColumnGlobalAttention(
c.msa_column_attention, gc, name='msa_column_global_attention')
msa_act = dropout_wrapper_fn(
attn_mod,
msa_act,
msa_mask,
safe_key=next(sub_keys))
msa_act = dropout_wrapper_fn(
Transition(c.msa_transition, gc, name='msa_transition'),
msa_act,
msa_mask,
safe_key=next(sub_keys))
if not c.outer_product_mean.first:
pair_act = dropout_wrapper_fn(
outer_module,
msa_act,
msa_mask,
safe_key=next(sub_keys),
output_act=pair_act)
pair_act = dropout_wrapper_fn(
TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
name='triangle_multiplication_outgoing'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleMultiplication(c.triangle_multiplication_incoming, gc,
name='triangle_multiplication_incoming'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleAttention(c.triangle_attention_starting_node, gc,
name='triangle_attention_starting_node'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
TriangleAttention(c.triangle_attention_ending_node, gc,
name='triangle_attention_ending_node'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
pair_act = dropout_wrapper_fn(
Transition(c.pair_transition, gc, name='pair_transition'),
pair_act,
pair_mask,
safe_key=next(sub_keys))
return {'msa': msa_act, 'pair': pair_act}
class EmbeddingsAndEvoformer(hk.Module):
"""Embeds the input data and runs Evoformer.
Produces the MSA, single and pair representations.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5-18
"""
def __init__(self, config, global_config, name='evoformer'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, batch, is_training, safe_key=None):
c = self.config
gc = self.global_config
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
# Embed clustered MSA.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5
# Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder"
preprocess_1d = common_modules.Linear(
c.msa_channel, name='preprocess_1d')(
batch['target_feat'])
preprocess_msa = common_modules.Linear(
c.msa_channel, name='preprocess_msa')(
batch['msa_feat'])
msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
left_single = common_modules.Linear(
c.pair_channel, name='left_single')(
batch['target_feat'])
right_single = common_modules.Linear(
c.pair_channel, name='right_single')(
batch['target_feat'])
pair_activations = left_single[:, None] + right_single[None]
mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
# Inject previous outputs for recycling.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6
# Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder"
if c.recycle_pos:
prev_pseudo_beta = pseudo_beta_fn(
batch['aatype'], batch['prev_pos'], None)
dgram = dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos)
pair_activations += common_modules.Linear(
c.pair_channel, name='prev_pos_linear')(
dgram)
if c.recycle_features:
prev_msa_first_row = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_msa_first_row_norm')(
batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_pair_norm')(
batch['prev_pair'])
# Relative position encoding.
# Jumper et al. (2021) Suppl. Alg. 4 "relpos"
# Jumper et al. (2021) Suppl. Alg. 5 "one_hot"
if c.max_relative_feature:
# Add one-hot-encoded clipped residue distances to the pair activations.
pos = batch['residue_index']
offset = pos[:, None] - pos[None, :]
rel_pos = jax.nn.one_hot(
jnp.clip(
offset + c.max_relative_feature,
a_min=0,
a_max=2 * c.max_relative_feature),
2 * c.max_relative_feature + 1)
pair_activations += common_modules.Linear(
c.pair_channel, name='pair_activiations')(
rel_pos)
# Embed templates into the pair activations.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13
if c.template.enabled:
template_batch = {k: batch[k] for k in batch if k.startswith('template_')}
template_pair_representation = TemplateEmbedding(c.template, gc)(
pair_activations,
template_batch,
mask_2d,
is_training=is_training)
pair_activations += template_pair_representation
# Embed extra MSA features.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16
extra_msa_feat = create_extra_msa_feature(batch)
extra_msa_activations = common_modules.Linear(
c.extra_msa_channel,
name='extra_msa_activations')(
extra_msa_feat)
# Extra MSA Stack.
# Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack"
extra_msa_stack_input = {
'msa': extra_msa_activations,
'pair': pair_activations,
}
extra_msa_stack_iteration = EvoformerIteration(
c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
def extra_msa_stack_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
extra_evoformer_output = extra_msa_stack_iteration(
activations=act,
masks={
'msa': batch['extra_msa_mask'],
'pair': mask_2d
},
is_training=is_training,
safe_key=safe_subkey)
return (extra_evoformer_output, safe_key)
if gc.use_remat:
extra_msa_stack_fn = hk.remat(extra_msa_stack_fn)
extra_msa_stack = layer_stack.layer_stack(
c.extra_msa_stack_num_block)(
extra_msa_stack_fn)
extra_msa_output, safe_key = extra_msa_stack(
(extra_msa_stack_input, safe_key))
pair_activations = extra_msa_output['pair']
evoformer_input = {
'msa': msa_activations,
'pair': pair_activations,
}
evoformer_masks = {'msa': batch['msa_mask'], 'pair': mask_2d}
# Append num_templ rows to msa_activations with template embeddings.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8
if c.template.enabled and c.template.embed_torsion_angles:
num_templ, num_res = batch['template_aatype'].shape
# Embed the templates aatypes.
aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1)
# Embed the templates aatype, torsion angles and masks.
# Shape (templates, residues, msa_channels)
ret = all_atom.atom37_to_torsion_angles(
aatype=batch['template_aatype'],
all_atom_pos=batch['template_all_atom_positions'],
all_atom_mask=batch['template_all_atom_masks'],
# Ensure consistent behaviour during testing:
placeholder_for_undefined=not gc.zero_init)
template_features = jnp.concatenate([
aatype_one_hot,
jnp.reshape(
ret['torsion_angles_sin_cos'], [num_templ, num_res, 14]),
jnp.reshape(
ret['alt_torsion_angles_sin_cos'], [num_templ, num_res, 14]),
ret['torsion_angles_mask']], axis=-1)
template_activations = common_modules.Linear(
c.msa_channel,
initializer='relu',
name='template_single_embedding')(
template_features)
template_activations = jax.nn.relu(template_activations)
template_activations = common_modules.Linear(
c.msa_channel,
initializer='relu',
name='template_projection')(
template_activations)
# Concatenate the templates to the msa.
evoformer_input['msa'] = jnp.concatenate(
[evoformer_input['msa'], template_activations], axis=0)
# Concatenate templates masks to the msa masks.
# Use mask from the psi angle, as it only depends on the backbone atoms
# from a single residue.
torsion_angle_mask = ret['torsion_angles_mask'][:, :, 2]
torsion_angle_mask = torsion_angle_mask.astype(
evoformer_masks['msa'].dtype)
evoformer_masks['msa'] = jnp.concatenate(
[evoformer_masks['msa'], torsion_angle_mask], axis=0)
# Main trunk of the network
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18
evoformer_iteration = EvoformerIteration(
c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
def evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
evoformer_output = evoformer_iteration(
activations=act,
masks=evoformer_masks,
is_training=is_training,
safe_key=safe_subkey)
return (evoformer_output, safe_key)
if gc.use_remat:
evoformer_fn = hk.remat(evoformer_fn)
evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
evoformer_fn)
evoformer_output, safe_key = evoformer_stack(
(evoformer_input, safe_key))
msa_activations = evoformer_output['msa']
pair_activations = evoformer_output['pair']
single_activations = common_modules.Linear(
c.seq_channel, name='single_activations')(
msa_activations[0])
num_sequences = batch['msa_feat'].shape[0]
output = {
'single': single_activations,
'pair': pair_activations,
# Crop away template rows such that they are not used in MaskedMsaHead.
'msa': msa_activations[:num_sequences, :, :],
'msa_first_row': msa_activations[0],
}
return output
class SingleTemplateEmbedding(hk.Module):
"""Embeds a single template.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9+11
"""
def __init__(self, config, global_config, name='single_template_embedding'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, query_embedding, batch, mask_2d, is_training):
"""Build the single template embedding.
Arguments:
query_embedding: Query pair representation, shape [N_res, N_res, c_z].
batch: A batch of template features (note the template dimension has been
stripped out as this module only runs over a single template).
mask_2d: Padding mask (Note: this doesn't care if a template exists,
unlike the template_pseudo_beta_mask).
is_training: Whether the module is in training mode.
Returns:
A template embedding [N_res, N_res, c_z].
"""
assert mask_2d.dtype == query_embedding.dtype
dtype = query_embedding.dtype
num_res = batch['template_aatype'].shape[0]
num_channels = (self.config.template_pair_stack
.triangle_attention_ending_node.value_dim)
template_mask = batch['template_pseudo_beta_mask']
template_mask_2d = template_mask[:, None] * template_mask[None, :]
template_mask_2d = template_mask_2d.astype(dtype)
template_dgram = dgram_from_positions(batch['template_pseudo_beta'],
**self.config.dgram_features)
template_dgram = template_dgram.astype(dtype)
to_concat = [template_dgram, template_mask_2d[:, :, None]]
aatype = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1, dtype=dtype)
to_concat.append(jnp.tile(aatype[None, :, :], [num_res, 1, 1]))
to_concat.append(jnp.tile(aatype[:, None, :], [1, num_res, 1]))
n, ca, c = [residue_constants.atom_order[a] for a in ('N', 'CA', 'C')]
rot, trans = quat_affine.make_transform_from_reference(
n_xyz=batch['template_all_atom_positions'][:, n],
ca_xyz=batch['template_all_atom_positions'][:, ca],
c_xyz=batch['template_all_atom_positions'][:, c])
affines = quat_affine.QuatAffine(
quaternion=quat_affine.rot_to_quat(rot, unstack_inputs=True),
translation=trans,
rotation=rot,
unstack_inputs=True)
points = [jnp.expand_dims(x, axis=-2) for x in affines.translation]
affine_vec = affines.invert_point(points, extra_dims=1)
inv_distance_scalar = jax.lax.rsqrt(
1e-6 + sum([jnp.square(x) for x in affine_vec]))
# Backbone affine mask: whether the residue has C, CA, N
# (the template mask defined above only considers pseudo CB).
template_mask = (
batch['template_all_atom_masks'][..., n] *
batch['template_all_atom_masks'][..., ca] *
batch['template_all_atom_masks'][..., c])
template_mask_2d = template_mask[:, None] * template_mask[None, :]
inv_distance_scalar *= template_mask_2d.astype(inv_distance_scalar.dtype)
unit_vector = [(x * inv_distance_scalar)[..., None] for x in affine_vec]
unit_vector = [x.astype(dtype) for x in unit_vector]
template_mask_2d = template_mask_2d.astype(dtype)
if not self.config.use_template_unit_vector:
unit_vector = [jnp.zeros_like(x) for x in unit_vector]
to_concat.extend(unit_vector)
to_concat.append(template_mask_2d[..., None])
act = jnp.concatenate(to_concat, axis=-1)
# Mask out non-template regions so we don't get arbitrary values in the
# distogram for these regions.
act *= template_mask_2d[..., None]
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 9
act = common_modules.Linear(
num_channels,
initializer='relu',
name='embedding2d')(
act)
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 11
act = TemplatePairStack(
self.config.template_pair_stack, self.global_config)(
act, mask_2d, is_training)
act = common_modules.LayerNorm([-1], True, True, name='output_layer_norm')(act)
return act
class TemplateEmbedding(hk.Module):
"""Embeds a set of templates.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12
Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention"
"""
def __init__(self, config, global_config, name='template_embedding'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, query_embedding, template_batch, mask_2d, is_training):
"""Build TemplateEmbedding module.
Arguments:
query_embedding: Query pair representation, shape [N_res, N_res, c_z].
template_batch: A batch of template features.
mask_2d: Padding mask (Note: this doesn't care if a template exists,
unlike the template_pseudo_beta_mask).
is_training: Whether the module is in training mode.
Returns:
A template embedding [N_res, N_res, c_z].
"""
num_templates = template_batch['template_mask'].shape[0]
num_channels = (self.config.template_pair_stack
.triangle_attention_ending_node.value_dim)
num_res = query_embedding.shape[0]
dtype = query_embedding.dtype
template_mask = template_batch['template_mask']
template_mask = template_mask.astype(dtype)
query_num_channels = query_embedding.shape[-1]
# Make sure the weights are shared across templates by constructing the
# embedder here.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12
template_embedder = SingleTemplateEmbedding(self.config, self.global_config)
def map_fn(batch):
return template_embedder(query_embedding, batch, mask_2d, is_training)
template_pair_representation = mapping.sharded_map(map_fn, in_axes=0)(
template_batch)
# Cross attend from the query to the templates along the residue
# dimension by flattening everything else into the batch dimension.
# Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention"
flat_query = jnp.reshape(query_embedding,
[num_res * num_res, 1, query_num_channels])
flat_templates = jnp.reshape(
jnp.transpose(template_pair_representation, [1, 2, 0, 3]),
[num_res * num_res, num_templates, num_channels])
mask = template_mask[None, None, None, :]
template_pointwise_attention_module = Attention(
self.config.attention, self.global_config, query_num_channels)
nonbatched_args = [mask]
batched_args = [flat_query, flat_templates]
embedding = mapping.inference_subbatch(
template_pointwise_attention_module,
self.config.subbatch_size,
batched_args=batched_args,
nonbatched_args=nonbatched_args,
low_memory=not is_training)
embedding = jnp.reshape(embedding,
[num_res, num_res, query_num_channels])
# No gradients if no templates.
embedding *= (jnp.sum(template_mask) > 0.).astype(embedding.dtype)
return embedding
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core modules, which have been refactored in AlphaFold-Multimer.
The main difference is that MSA sampling pipeline is moved inside the JAX model
for easier implementation of recycling and ensembling.
Lower-level modules up to EvoformerIteration are reused from modules.py.
"""
import functools
from typing import Sequence
from alphafold.common import residue_constants
from alphafold.model import all_atom_multimer
from alphafold.model import common_modules
from alphafold.model import folding_multimer
from alphafold.model import geometry
from alphafold.model import layer_stack
from alphafold.model import modules
from alphafold.model import prng
from alphafold.model import utils
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
def reduce_fn(x, mode):
if mode == 'none' or mode is None:
return jnp.asarray(x)
elif mode == 'sum':
return jnp.asarray(x).sum()
elif mode == 'mean':
return jnp.mean(jnp.asarray(x))
else:
raise ValueError('Unsupported reduction option.')
def gumbel_noise(key: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray:
"""Generate Gumbel Noise of given Shape.
This generates samples from Gumbel(0, 1).
Args:
key: Jax random number key.
shape: Shape of noise to return.
Returns:
Gumbel noise of given shape.
"""
epsilon = 1e-6
uniform = utils.padding_consistent_rng(jax.random.uniform)
uniform_noise = uniform(
key, shape=shape, dtype=jnp.float32, minval=0., maxval=1.)
gumbel = -jnp.log(-jnp.log(uniform_noise + epsilon) + epsilon)
return gumbel
def gumbel_max_sample(key: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
"""Samples from a probability distribution given by 'logits'.
This uses Gumbel-max trick to implement the sampling in an efficient manner.
Args:
key: prng key.
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z = gumbel_noise(key, logits.shape)
return jax.nn.one_hot(
jnp.argmax(logits + z, axis=-1),
logits.shape[-1],
dtype=logits.dtype)
def gumbel_argsort_sample_idx(key: jnp.ndarray,
logits: jnp.ndarray) -> jnp.ndarray:
"""Samples with replacement from a distribution given by 'logits'.
This uses Gumbel trick to implement the sampling an efficient manner. For a
distribution over k items this samples k times without replacement, so this
is effectively sampling a random permutation with probabilities over the
permutations derived from the logprobs.
Args:
key: prng key.
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z = gumbel_noise(key, logits.shape)
# This construction is equivalent to jnp.argsort, but using a non stable sort,
# since stable sort's aren't supported by jax2tf.
axis = len(logits.shape) - 1
iota = jax.lax.broadcasted_iota(jnp.int64, logits.shape, axis)
_, perm = jax.lax.sort_key_val(
logits + z, iota, dimension=-1, is_stable=False)
return perm[::-1]
def make_masked_msa(batch, key, config, epsilon=1e-6):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa = jnp.array([0.05] * 20 + [0., 0.], dtype=jnp.float32)
categorical_probs = (
config.uniform_prob * random_aa +
config.profile_prob * batch['msa_profile'] +
config.same_prob * jax.nn.one_hot(batch['msa'], 22))
# Put all remaining probability on [MASK] which is a new column.
pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
pad_shapes[-1][1] = 1
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
assert mask_prob >= 0.
categorical_probs = jnp.pad(
categorical_probs, pad_shapes, constant_values=mask_prob)
sh = batch['msa'].shape
key, mask_subkey, gumbel_subkey = key.split(3)
uniform = utils.padding_consistent_rng(jax.random.uniform)
mask_position = uniform(mask_subkey.get(), sh) < config.replace_fraction
mask_position *= batch['msa_mask']
logits = jnp.log(categorical_probs + epsilon)
bert_msa = gumbel_max_sample(gumbel_subkey.get(), logits)
bert_msa = jnp.where(mask_position,
jnp.argmax(bert_msa, axis=-1), batch['msa'])
bert_msa *= batch['msa_mask']
# Mix real and masked MSA.
if 'bert_mask' in batch:
batch['bert_mask'] *= mask_position.astype(jnp.float32)
else:
batch['bert_mask'] = mask_position.astype(jnp.float32)
batch['true_msa'] = batch['msa']
batch['msa'] = bert_msa
return batch
def nearest_neighbor_clusters(batch, gap_agreement_weight=0.):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask.
weights = jnp.array(
[1.] * 21 + [gap_agreement_weight] + [0.], dtype=jnp.float32)
msa_mask = batch['msa_mask']
msa_one_hot = jax.nn.one_hot(batch['msa'], 23)
extra_mask = batch['extra_msa_mask']
extra_one_hot = jax.nn.one_hot(batch['extra_msa'], 23)
msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot
extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot
agreement = jnp.einsum('mrc, nrc->nm', extra_one_hot_masked,
weights * msa_one_hot_masked)
cluster_assignment = jax.nn.softmax(1e3 * agreement, axis=0)
cluster_assignment *= jnp.einsum('mr, nr->mn', msa_mask, extra_mask)
cluster_count = jnp.sum(cluster_assignment, axis=-1)
cluster_count += 1. # We always include the sequence itself.
msa_sum = jnp.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked)
msa_sum += msa_one_hot_masked
cluster_profile = msa_sum / cluster_count[:, None, None]
extra_deletion_matrix = batch['extra_deletion_matrix']
deletion_matrix = batch['deletion_matrix']
del_sum = jnp.einsum('nm, mc->nc', cluster_assignment,
extra_mask * extra_deletion_matrix)
del_sum += deletion_matrix # Original sequence.
cluster_deletion_mean = del_sum / cluster_count[:, None]
return cluster_profile, cluster_deletion_mean
def create_msa_feat(batch):
"""Create and concatenate MSA features."""
msa_1hot = jax.nn.one_hot(batch['msa'], 23)
deletion_matrix = batch['deletion_matrix']
has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None]
deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None]
deletion_mean_value = (jnp.arctan(batch['cluster_deletion_mean'] / 3.) *
(2. / jnp.pi))[..., None]
msa_feat = [
msa_1hot,
has_deletion,
deletion_value,
batch['cluster_profile'],
deletion_mean_value
]
return jnp.concatenate(msa_feat, axis=-1)
def create_extra_msa_feature(batch, num_extra_msa):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Args:
batch: a dictionary with the following keys:
* 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
centre. Note - This isn't one-hotted.
* 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
position.
num_extra_msa: Number of extra msa to use.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
extra_msa = batch['extra_msa'][:num_extra_msa]
deletion_matrix = batch['extra_deletion_matrix'][:num_extra_msa]
msa_1hot = jax.nn.one_hot(extra_msa, 23)
has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None]
deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None]
extra_msa_mask = batch['extra_msa_mask'][:num_extra_msa]
return jnp.concatenate([msa_1hot, has_deletion, deletion_value],
axis=-1), extra_msa_mask
def sample_msa(key, batch, max_seq):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
key: safe key for random number generation.
batch: batch to sample msa from.
max_seq: number of sequences to sample.
Returns:
Protein with sampled msa.
"""
# Sample uniformly among sequences with at least one non-masked position.
logits = (jnp.clip(jnp.sum(batch['msa_mask'], axis=-1), 0., 1.) - 1.) * 1e6
# The cluster_bias_mask can be used to preserve the first row (target
# sequence) for each chain, for example.
if 'cluster_bias_mask' not in batch:
cluster_bias_mask = jnp.pad(
jnp.zeros(batch['msa'].shape[0] - 1), (1, 0), constant_values=1.)
else:
cluster_bias_mask = batch['cluster_bias_mask']
logits += cluster_bias_mask * 1e6
index_order = gumbel_argsort_sample_idx(key.get(), logits)
sel_idx = index_order[:max_seq]
extra_idx = index_order[max_seq:]
for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']:
if k in batch:
batch['extra_' + k] = batch[k][extra_idx]
batch[k] = batch[k][sel_idx]
return batch
def make_msa_profile(batch):
"""Compute the MSA profile."""
# Compute the profile for every residue (over all MSA sequences).
return utils.mask_mean(
batch['msa_mask'][:, :, None], jax.nn.one_hot(batch['msa'], 22), axis=0)
class AlphaFoldIteration(hk.Module):
"""A single recycling iteration of AlphaFold architecture.
Computes ensembled (averaged) representations from the provided features.
These representations are then passed to the various heads
that have been requested by the configuration file.
"""
def __init__(self, config, global_config, name='alphafold_iteration'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
batch,
is_training,
return_representations=False,
safe_key=None):
if is_training:
num_ensemble = np.asarray(self.config.num_ensemble_train)
else:
num_ensemble = np.asarray(self.config.num_ensemble_eval)
# Compute representations for each MSA sample and average.
embedding_module = EmbeddingsAndEvoformer(
self.config.embeddings_and_evoformer, self.global_config)
repr_shape = hk.eval_shape(
lambda: embedding_module(batch, is_training))
representations = {
k: jnp.zeros(v.shape, v.dtype) for (k, v) in repr_shape.items()
}
def ensemble_body(x, unused_y):
"""Add into representations ensemble."""
del unused_y
representations, safe_key = x
safe_key, safe_subkey = safe_key.split()
representations_update = embedding_module(
batch, is_training, safe_key=safe_subkey)
for k in representations:
if k not in {'msa', 'true_msa', 'bert_mask'}:
representations[k] += representations_update[k] * (
1. / num_ensemble).astype(representations[k].dtype)
else:
representations[k] = representations_update[k]
return (representations, safe_key), None
(representations, _), _ = hk.scan(
ensemble_body, (representations, safe_key), None, length=num_ensemble)
self.representations = representations
self.batch = batch
self.heads = {}
for head_name, head_config in sorted(self.config.heads.items()):
if not head_config.weight:
continue # Do not instantiate zero-weight heads.
head_factory = {
'masked_msa':
modules.MaskedMsaHead,
'distogram':
modules.DistogramHead,
'structure_module':
folding_multimer.StructureModule,
'predicted_aligned_error':
modules.PredictedAlignedErrorHead,
'predicted_lddt':
modules.PredictedLDDTHead,
'experimentally_resolved':
modules.ExperimentallyResolvedHead,
}[head_name]
self.heads[head_name] = (head_config,
head_factory(head_config, self.global_config))
structure_module_output = None
if 'entity_id' in batch and 'all_atom_positions' in batch:
_, fold_module = self.heads['structure_module']
structure_module_output = fold_module(representations, batch, is_training)
ret = {}
ret['representations'] = representations
for name, (head_config, module) in self.heads.items():
if name == 'structure_module' and structure_module_output is not None:
ret[name] = structure_module_output
representations['structure_module'] = structure_module_output.pop('act')
# Skip confidence heads until StructureModule is executed.
elif name in {'predicted_lddt', 'predicted_aligned_error',
'experimentally_resolved'}:
continue
else:
ret[name] = module(representations, batch, is_training)
# Add confidence heads after StructureModule is executed.
if self.config.heads.get('predicted_lddt.weight', 0.0):
name = 'predicted_lddt'
head_config, module = self.heads[name]
ret[name] = module(representations, batch, is_training)
if self.config.heads.experimentally_resolved.weight:
name = 'experimentally_resolved'
head_config, module = self.heads[name]
ret[name] = module(representations, batch, is_training)
if self.config.heads.get('predicted_aligned_error.weight', 0.0):
name = 'predicted_aligned_error'
head_config, module = self.heads[name]
ret[name] = module(representations, batch, is_training)
# Will be used for ipTM computation.
ret[name]['asym_id'] = batch['asym_id']
return ret
class AlphaFold(hk.Module):
"""AlphaFold-Multimer model with recycling.
"""
def __init__(self, config, name='alphafold'):
super().__init__(name=name)
self.config = config
self.global_config = config.global_config
def __call__(
self,
batch,
is_training,
return_representations=False,
safe_key=None):
c = self.config
impl = AlphaFoldIteration(c, self.global_config)
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
elif isinstance(safe_key, jnp.ndarray):
safe_key = prng.SafeKey(safe_key)
assert isinstance(batch, dict)
num_res = batch['aatype'].shape[0]
def get_prev(ret):
new_prev = {
'prev_pos':
ret['structure_module']['final_atom_positions'],
'prev_msa_first_row': ret['representations']['msa_first_row'],
'prev_pair': ret['representations']['pair'],
}
return jax.tree_map(jax.lax.stop_gradient, new_prev)
def apply_network(prev, safe_key):
recycled_batch = {**batch, **prev}
return impl(
batch=recycled_batch,
is_training=is_training,
safe_key=safe_key)
prev = {}
emb_config = self.config.embeddings_and_evoformer
if emb_config.recycle_pos:
prev['prev_pos'] = jnp.zeros(
[num_res, residue_constants.atom_type_num, 3])
if emb_config.recycle_features:
prev['prev_msa_first_row'] = jnp.zeros(
[num_res, emb_config.msa_channel])
prev['prev_pair'] = jnp.zeros(
[num_res, num_res, emb_config.pair_channel])
if self.config.num_recycle:
if 'num_iter_recycling' in batch:
# Training time: num_iter_recycling is in batch.
# Value for each ensemble batch is the same, so arbitrarily taking 0-th.
num_iter = batch['num_iter_recycling'][0]
# Add insurance that even when ensembling, we will not run more
# recyclings than the model is configured to run.
num_iter = jnp.minimum(num_iter, c.num_recycle)
else:
# Eval mode or tests: use the maximum number of iterations.
num_iter = c.num_recycle
def distances(points):
"""Compute all pairwise distances for a set of points."""
return jnp.sqrt(jnp.sum((points[:, None] - points[None, :])**2,
axis=-1))
def recycle_body(x):
i, _, prev, safe_key = x
safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long
ret = apply_network(prev=prev, safe_key=safe_key2)
return i+1, prev, get_prev(ret), safe_key1
def recycle_cond(x):
i, prev, next_in, _ = x
ca_idx = residue_constants.atom_order['CA']
sq_diff = jnp.square(distances(prev['prev_pos'][:, ca_idx, :]) -
distances(next_in['prev_pos'][:, ca_idx, :]))
mask = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
sq_diff = utils.mask_mean(mask, sq_diff)
# Early stopping criteria based on criteria used in
# AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
diff = jnp.sqrt(sq_diff + 1e-8) # avoid bad numerics giving negatives
less_than_max_recycles = (i < num_iter)
has_exceeded_tolerance = (
(i == 0) | (diff > c.recycle_early_stop_tolerance))
return less_than_max_recycles & has_exceeded_tolerance
if hk.running_init():
num_recycles, _, prev, safe_key = recycle_body(
(0, prev, prev, safe_key))
else:
num_recycles, _, prev, safe_key = hk.while_loop(
recycle_cond,
recycle_body,
(0, prev, prev, safe_key))
else:
# No recycling.
num_recycles = 0
# Run extra iteration.
ret = apply_network(prev=prev, safe_key=safe_key)
if not return_representations:
del ret['representations']
ret['num_recycles'] = num_recycles
return ret
class EmbeddingsAndEvoformer(hk.Module):
"""Embeds the input data and runs Evoformer.
Produces the MSA, single and pair representations.
"""
def __init__(self, config, global_config, name='evoformer'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def _relative_encoding(self, batch):
"""Add relative position encodings.
For position (i, j), the value is (i-j) clipped to [-k, k] and one-hotted.
When not using 'use_chain_relative' the residue indices are used as is, e.g.
for heteromers relative positions will be computed using the positions in
the corresponding chains.
When using 'use_chain_relative' we add an extra bin that denotes
'different chain'. Furthermore we also provide the relative chain index
(i.e. sym_id) clipped and one-hotted to the network. And an extra feature
which denotes whether they belong to the same chain type, i.e. it's 0 if
they are in different heteromer chains and 1 otherwise.
Args:
batch: batch.
Returns:
Feature embedding using the features as described before.
"""
c = self.config
gc = self.global_config
rel_feats = []
pos = batch['residue_index']
asym_id = batch['asym_id']
asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :])
offset = pos[:, None] - pos[None, :]
dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32
clipped_offset = jnp.clip(
offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx)
if c.use_chain_relative:
final_offset = jnp.where(asym_id_same, clipped_offset,
(2 * c.max_relative_idx + 1) *
jnp.ones_like(clipped_offset))
rel_pos = jax.nn.one_hot(final_offset, 2 * c.max_relative_idx + 2)
rel_feats.append(rel_pos)
entity_id = batch['entity_id']
entity_id_same = jnp.equal(entity_id[:, None], entity_id[None, :])
rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None])
sym_id = batch['sym_id']
rel_sym_id = sym_id[:, None] - sym_id[None, :]
max_rel_chain = c.max_relative_chain
clipped_rel_chain = jnp.clip(
rel_sym_id + max_rel_chain, a_min=0, a_max=2 * max_rel_chain)
final_rel_chain = jnp.where(entity_id_same, clipped_rel_chain,
(2 * max_rel_chain + 1) *
jnp.ones_like(clipped_rel_chain))
rel_chain = jax.nn.one_hot(final_rel_chain, 2 * c.max_relative_chain + 2)
rel_feats.append(rel_chain)
else:
rel_pos = jax.nn.one_hot(clipped_offset, 2 * c.max_relative_idx + 1)
rel_feats.append(rel_pos)
rel_feat = jnp.concatenate(rel_feats, axis=-1)
rel_feat = rel_feat.astype(dtype)
return common_modules.Linear(
c.pair_channel,
name='position_activations')(
rel_feat)
def __call__(self, batch, is_training, safe_key=None):
c = self.config
gc = self.global_config
batch = dict(batch)
dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
output = {}
batch['msa_profile'] = make_msa_profile(batch)
with utils.bfloat16_context():
target_feat = jax.nn.one_hot(batch['aatype'], 21).astype(dtype)
preprocess_1d = common_modules.Linear(
c.msa_channel, name='preprocess_1d')(
target_feat)
safe_key, sample_key, mask_key = safe_key.split(3)
batch = sample_msa(sample_key, batch, c.num_msa)
batch = make_masked_msa(batch, mask_key, c.masked_msa)
(batch['cluster_profile'],
batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch)
msa_feat = create_msa_feat(batch).astype(dtype)
preprocess_msa = common_modules.Linear(
c.msa_channel, name='preprocess_msa')(
msa_feat)
msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
left_single = common_modules.Linear(
c.pair_channel, name='left_single')(
target_feat)
right_single = common_modules.Linear(
c.pair_channel, name='right_single')(
target_feat)
pair_activations = left_single[:, None] + right_single[None]
mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
mask_2d = mask_2d.astype(dtype)
if c.recycle_pos:
prev_pseudo_beta = modules.pseudo_beta_fn(
batch['aatype'], batch['prev_pos'], None)
dgram = modules.dgram_from_positions(
prev_pseudo_beta, **self.config.prev_pos)
dgram = dgram.astype(dtype)
pair_activations += common_modules.Linear(
c.pair_channel, name='prev_pos_linear')(
dgram)
if c.recycle_features:
prev_msa_first_row = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_msa_first_row_norm')(
batch['prev_msa_first_row']).astype(dtype)
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_pair_norm')(
batch['prev_pair']).astype(dtype)
if c.max_relative_idx:
pair_activations += self._relative_encoding(batch)
if c.template.enabled:
template_module = TemplateEmbedding(c.template, gc)
template_batch = {
'template_aatype': batch['template_aatype'],
'template_all_atom_positions': batch['template_all_atom_positions'],
'template_all_atom_mask': batch['template_all_atom_mask']
}
# Construct a mask such that only intra-chain template features are
# computed, since all templates are for each chain individually.
multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :]
safe_key, safe_subkey = safe_key.split()
template_act = template_module(
query_embedding=pair_activations,
template_batch=template_batch,
padding_mask_2d=mask_2d,
multichain_mask_2d=multichain_mask,
is_training=is_training,
safe_key=safe_subkey)
pair_activations += template_act
# Extra MSA stack.
(extra_msa_feat,
extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa)
extra_msa_activations = common_modules.Linear(
c.extra_msa_channel,
name='extra_msa_activations')(
extra_msa_feat).astype(dtype)
extra_msa_mask = extra_msa_mask.astype(dtype)
extra_evoformer_input = {
'msa': extra_msa_activations,
'pair': pair_activations,
}
extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d}
extra_evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
def extra_evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
extra_evoformer_output = extra_evoformer_iteration(
activations=act,
masks=extra_masks,
is_training=is_training,
safe_key=safe_subkey)
return (extra_evoformer_output, safe_key)
if gc.use_remat:
extra_evoformer_fn = hk.remat(extra_evoformer_fn)
safe_key, safe_subkey = safe_key.split()
extra_evoformer_stack = layer_stack.layer_stack(
c.extra_msa_stack_num_block)(
extra_evoformer_fn)
extra_evoformer_output, safe_key = extra_evoformer_stack(
(extra_evoformer_input, safe_subkey))
pair_activations = extra_evoformer_output['pair']
# Get the size of the MSA before potentially adding templates, so we
# can crop out the templates later.
num_msa_sequences = msa_activations.shape[0]
evoformer_input = {
'msa': msa_activations,
'pair': pair_activations,
}
evoformer_masks = {
'msa': batch['msa_mask'].astype(dtype),
'pair': mask_2d
}
if c.template.enabled:
template_features, template_masks = (
template_embedding_1d(
batch=batch, num_channel=c.msa_channel, global_config=gc))
evoformer_input['msa'] = jnp.concatenate(
[evoformer_input['msa'], template_features], axis=0)
evoformer_masks['msa'] = jnp.concatenate(
[evoformer_masks['msa'], template_masks], axis=0)
evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
def evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
evoformer_output = evoformer_iteration(
activations=act,
masks=evoformer_masks,
is_training=is_training,
safe_key=safe_subkey)
return (evoformer_output, safe_key)
if gc.use_remat:
evoformer_fn = hk.remat(evoformer_fn)
safe_key, safe_subkey = safe_key.split()
evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
evoformer_fn)
def run_evoformer(evoformer_input):
evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey))
return evoformer_output
evoformer_output = run_evoformer(evoformer_input)
msa_activations = evoformer_output['msa']
pair_activations = evoformer_output['pair']
single_activations = common_modules.Linear(
c.seq_channel, name='single_activations')(
msa_activations[0])
output.update({
'single':
single_activations,
'pair':
pair_activations,
# Crop away template rows such that they are not used in MaskedMsaHead.
'msa':
msa_activations[:num_msa_sequences, :, :],
'msa_first_row':
msa_activations[0],
})
# Convert back to float32 if we're not saving memory.
if not gc.bfloat16_output:
for k, v in output.items():
if v.dtype == jnp.bfloat16:
output[k] = v.astype(jnp.float32)
return output
class TemplateEmbedding(hk.Module):
"""Embed a set of templates."""
def __init__(self, config, global_config, name='template_embedding'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, query_embedding, template_batch, padding_mask_2d,
multichain_mask_2d, is_training,
safe_key=None):
"""Generate an embedding for a set of templates.
Args:
query_embedding: [num_res, num_res, num_channel] a query tensor that will
be used to attend over the templates to remove the num_templates
dimension.
template_batch: A dictionary containing:
`template_aatype`: [num_templates, num_res] aatype for each template.
`template_all_atom_positions`: [num_templates, num_res, 37, 3] atom
positions for all templates.
`template_all_atom_mask`: [num_templates, num_res, 37] mask for each
template.
padding_mask_2d: [num_res, num_res] Pair mask for attention operations.
multichain_mask_2d: [num_res, num_res] Mask indicating which residue pairs
are intra-chain, used to mask out residue distance based features
between chains.
is_training: bool indicating where we are running in training mode.
safe_key: random key generator.
Returns:
An embedding of size [num_res, num_res, num_channels]
"""
c = self.config
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
num_templates = template_batch['template_aatype'].shape[0]
num_res, _, query_num_channels = query_embedding.shape
# Embed each template separately.
template_embedder = SingleTemplateEmbedding(self.config, self.global_config)
def partial_template_embedder(template_aatype,
template_all_atom_positions,
template_all_atom_mask,
unsafe_key):
safe_key = prng.SafeKey(unsafe_key)
return template_embedder(query_embedding,
template_aatype,
template_all_atom_positions,
template_all_atom_mask,
padding_mask_2d,
multichain_mask_2d,
is_training,
safe_key)
safe_key, unsafe_key = safe_key.split()
unsafe_keys = jax.random.split(unsafe_key._key, num_templates)
def scan_fn(carry, x):
return carry + partial_template_embedder(*x), None
scan_init = jnp.zeros((num_res, num_res, c.num_channels),
dtype=query_embedding.dtype)
summed_template_embeddings, _ = hk.scan(
scan_fn, scan_init,
(template_batch['template_aatype'],
template_batch['template_all_atom_positions'],
template_batch['template_all_atom_mask'], unsafe_keys))
embedding = summed_template_embeddings / num_templates
embedding = jax.nn.relu(embedding)
embedding = common_modules.Linear(
query_num_channels,
initializer='relu',
name='output_linear')(embedding)
return embedding
class SingleTemplateEmbedding(hk.Module):
"""Embed a single template."""
def __init__(self, config, global_config, name='single_template_embedding'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, query_embedding, template_aatype,
template_all_atom_positions, template_all_atom_mask,
padding_mask_2d, multichain_mask_2d, is_training,
safe_key):
"""Build the single template embedding graph.
Args:
query_embedding: (num_res, num_res, num_channels) - embedding of the
query sequence/msa.
template_aatype: [num_res] aatype for each template.
template_all_atom_positions: [num_res, 37, 3] atom positions for all
templates.
template_all_atom_mask: [num_res, 37] mask for each template.
padding_mask_2d: Padding mask (Note: this doesn't care if a template
exists, unlike the template_pseudo_beta_mask).
multichain_mask_2d: A mask indicating intra-chain residue pairs, used
to mask out between chain distances/features when templates are for
single chains.
is_training: Are we in training mode.
safe_key: Random key generator.
Returns:
A template embedding (num_res, num_res, num_channels).
"""
gc = self.global_config
c = self.config
assert padding_mask_2d.dtype == query_embedding.dtype
dtype = query_embedding.dtype
num_channels = self.config.num_channels
def construct_input(query_embedding, template_aatype,
template_all_atom_positions, template_all_atom_mask,
multichain_mask_2d):
# Compute distogram feature for the template.
template_positions, pseudo_beta_mask = modules.pseudo_beta_fn(
template_aatype, template_all_atom_positions, template_all_atom_mask)
pseudo_beta_mask_2d = (pseudo_beta_mask[:, None] *
pseudo_beta_mask[None, :])
pseudo_beta_mask_2d *= multichain_mask_2d
template_dgram = modules.dgram_from_positions(
template_positions, **self.config.dgram_features)
template_dgram *= pseudo_beta_mask_2d[..., None]
template_dgram = template_dgram.astype(dtype)
pseudo_beta_mask_2d = pseudo_beta_mask_2d.astype(dtype)
to_concat = [(template_dgram, 1), (pseudo_beta_mask_2d, 0)]
aatype = jax.nn.one_hot(template_aatype, 22, axis=-1, dtype=dtype)
to_concat.append((aatype[None, :, :], 1))
to_concat.append((aatype[:, None, :], 1))
# Compute a feature representing the normalized vector between each
# backbone affine - i.e. in each residues local frame, what direction are
# each of the other residues.
raw_atom_pos = template_all_atom_positions
if gc.bfloat16:
# Vec3Arrays are required to be float32
raw_atom_pos = raw_atom_pos.astype(jnp.float32)
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = folding_multimer.make_backbone_affine(
atom_pos,
template_all_atom_mask,
template_aatype)
points = rigid.translation
rigid_vec = rigid[:, None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z]
if gc.bfloat16:
unit_vector = [x.astype(jnp.bfloat16) for x in unit_vector]
backbone_mask = backbone_mask.astype(jnp.bfloat16)
backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :]
backbone_mask_2d *= multichain_mask_2d
unit_vector = [x*backbone_mask_2d for x in unit_vector]
# Note that the backbone_mask takes into account C, CA and N (unlike
# pseudo beta mask which just needs CB) so we add both masks as features.
to_concat.extend([(x, 0) for x in unit_vector])
to_concat.append((backbone_mask_2d, 0))
query_embedding = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='query_embedding_norm')(
query_embedding)
# Allow the template embedder to see the query embedding. Note this
# contains the position relative feature, so this is how the network knows
# which residues are next to each other.
to_concat.append((query_embedding, 1))
act = 0
for i, (x, n_input_dims) in enumerate(to_concat):
act += common_modules.Linear(
num_channels,
num_input_dims=n_input_dims,
initializer='relu',
name=f'template_pair_embedding_{i}')(x)
return act
act = construct_input(query_embedding, template_aatype,
template_all_atom_positions, template_all_atom_mask,
multichain_mask_2d)
template_iteration = TemplateEmbeddingIteration(
c.template_pair_stack, gc, name='template_embedding_iteration')
def template_iteration_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
act = template_iteration(
act=act,
pair_mask=padding_mask_2d,
is_training=is_training,
safe_key=safe_subkey)
return (act, safe_key)
if gc.use_remat:
template_iteration_fn = hk.remat(template_iteration_fn)
safe_key, safe_subkey = safe_key.split()
template_stack = layer_stack.layer_stack(
c.template_pair_stack.num_block)(
template_iteration_fn)
act, safe_key = template_stack((act, safe_subkey))
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='output_layer_norm')(
act)
return act
class TemplateEmbeddingIteration(hk.Module):
"""Single Iteration of Template Embedding."""
def __init__(self, config, global_config,
name='template_embedding_iteration'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, act, pair_mask, is_training=True,
safe_key=None):
"""Build a single iteration of the template embedder.
Args:
act: [num_res, num_res, num_channel] Input pairwise activations.
pair_mask: [num_res, num_res] padding mask.
is_training: Whether to run in training mode.
safe_key: Safe pseudo-random generator key.
Returns:
[num_res, num_res, num_channel] tensor of activations.
"""
c = self.config
gc = self.global_config
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
dropout_wrapper_fn = functools.partial(
modules.dropout_wrapper,
is_training=is_training,
global_config=gc)
safe_key, *sub_keys = safe_key.split(20)
sub_keys = iter(sub_keys)
act = dropout_wrapper_fn(
modules.TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
name='triangle_multiplication_outgoing'),
act,
pair_mask,
safe_key=next(sub_keys))
act = dropout_wrapper_fn(
modules.TriangleMultiplication(c.triangle_multiplication_incoming, gc,
name='triangle_multiplication_incoming'),
act,
pair_mask,
safe_key=next(sub_keys))
act = dropout_wrapper_fn(
modules.TriangleAttention(c.triangle_attention_starting_node, gc,
name='triangle_attention_starting_node'),
act,
pair_mask,
safe_key=next(sub_keys))
act = dropout_wrapper_fn(
modules.TriangleAttention(c.triangle_attention_ending_node, gc,
name='triangle_attention_ending_node'),
act,
pair_mask,
safe_key=next(sub_keys))
act = dropout_wrapper_fn(
modules.Transition(c.pair_transition, gc,
name='pair_transition'),
act,
pair_mask,
safe_key=next(sub_keys))
return act
def template_embedding_1d(batch, num_channel, global_config):
"""Embed templates into an (num_res, num_templates, num_channels) embedding.
Args:
batch: A batch containing:
template_aatype, (num_templates, num_res) aatype for the templates.
template_all_atom_positions, (num_templates, num_residues, 37, 3) atom
positions for the templates.
template_all_atom_mask, (num_templates, num_residues, 37) atom mask for
each template.
num_channel: The number of channels in the output.
global_config: The global_config.
Returns:
An embedding of shape (num_templates, num_res, num_channels) and a mask of
shape (num_templates, num_res).
"""
# Embed the templates aatypes.
aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1)
num_templates = batch['template_aatype'].shape[0]
all_chi_angles = []
all_chi_masks = []
for i in range(num_templates):
atom_pos = geometry.Vec3Array.from_array(
batch['template_all_atom_positions'][i, :, :, :])
template_chi_angles, template_chi_mask = all_atom_multimer.compute_chi_angles(
atom_pos,
batch['template_all_atom_mask'][i, :, :],
batch['template_aatype'][i, :])
all_chi_angles.append(template_chi_angles)
all_chi_masks.append(template_chi_mask)
chi_angles = jnp.stack(all_chi_angles, axis=0)
chi_mask = jnp.stack(all_chi_masks, axis=0)
template_features = jnp.concatenate([
aatype_one_hot,
jnp.sin(chi_angles) * chi_mask,
jnp.cos(chi_angles) * chi_mask,
chi_mask], axis=-1)
template_mask = chi_mask[:, :, 0]
if global_config.bfloat16:
template_features = template_features.astype(jnp.bfloat16)
template_mask = template_mask.astype(jnp.bfloat16)
template_activations = common_modules.Linear(
num_channel,
initializer='relu',
name='template_single_embedding')(
template_features)
template_activations = jax.nn.relu(template_activations)
template_activations = common_modules.Linear(
num_channel,
initializer='relu',
name='template_projection')(
template_activations)
return template_activations, template_mask
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of utilities surrounding PRNG usage in protein folding."""
import haiku as hk
import jax
def safe_dropout(*, tensor, safe_key, rate, is_deterministic, is_training):
if is_training and rate != 0.0 and not is_deterministic:
return hk.dropout(safe_key.get(), rate, tensor)
else:
return tensor
class SafeKey:
"""Safety wrapper for PRNG keys."""
def __init__(self, key):
self._key = key
self._used = False
def _assert_not_used(self):
if self._used:
raise RuntimeError('Random key has been used previously.')
def get(self):
self._assert_not_used()
self._used = True
return self._key
def split(self, num_keys=2):
self._assert_not_used()
self._used = True
new_keys = jax.random.split(self._key, num_keys)
return jax.tree_map(SafeKey, tuple(new_keys))
def duplicate(self, num_keys=2):
self._assert_not_used()
self._used = True
return tuple(SafeKey(self._key) for _ in range(num_keys))
def _safe_key_flatten(safe_key):
# Flatten transfers "ownership" to the tree
return (safe_key._key,), safe_key._used # pylint: disable=protected-access
def _safe_key_unflatten(aux_data, children):
ret = SafeKey(children[0])
ret._used = aux_data # pylint: disable=protected-access
return ret
jax.tree_util.register_pytree_node(
SafeKey, _safe_key_flatten, _safe_key_unflatten)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for prng."""
from absl.testing import absltest
from alphafold.model import prng
import jax
class PrngTest(absltest.TestCase):
def test_key_reuse(self):
init_key = jax.random.PRNGKey(42)
safe_key = prng.SafeKey(init_key)
_, safe_key = safe_key.split()
raw_key = safe_key.get()
self.assertFalse((raw_key == init_key).all())
with self.assertRaises(RuntimeError):
safe_key.get()
with self.assertRaises(RuntimeError):
safe_key.split()
with self.assertRaises(RuntimeError):
safe_key.duplicate()
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quaternion geometry modules.
This introduces a representation of coordinate frames that is based around a
‘QuatAffine’ object. This object describes an array of coordinate frames.
It consists of vectors corresponding to the
origin of the frames as well as orientations which are stored in two
ways, as unit quaternions as well as a rotation matrices.
The rotation matrices are derived from the unit quaternions and the two are kept
in sync.
For an explanation of the relation between unit quaternions and rotations see
https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
This representation is used in the model for the backbone frames.
One important thing to note here, is that while we update both representations
the jit compiler is going to ensure that only the parts that are
actually used are executed.
"""
import functools
from typing import Tuple
import jax
import jax.numpy as jnp
import numpy as np
# pylint: disable=bad-whitespace
QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32)
QUAT_TO_ROT[0, 0] = [[ 1, 0, 0], [ 0, 1, 0], [ 0, 0, 1]] # rr
QUAT_TO_ROT[1, 1] = [[ 1, 0, 0], [ 0,-1, 0], [ 0, 0,-1]] # ii
QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [ 0, 1, 0], [ 0, 0,-1]] # jj
QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [ 0,-1, 0], [ 0, 0, 1]] # kk
QUAT_TO_ROT[1, 2] = [[ 0, 2, 0], [ 2, 0, 0], [ 0, 0, 0]] # ij
QUAT_TO_ROT[1, 3] = [[ 0, 0, 2], [ 0, 0, 0], [ 2, 0, 0]] # ik
QUAT_TO_ROT[2, 3] = [[ 0, 0, 0], [ 0, 0, 2], [ 0, 2, 0]] # jk
QUAT_TO_ROT[0, 1] = [[ 0, 0, 0], [ 0, 0,-2], [ 0, 2, 0]] # ir
QUAT_TO_ROT[0, 2] = [[ 0, 0, 2], [ 0, 0, 0], [-2, 0, 0]] # jr
QUAT_TO_ROT[0, 3] = [[ 0,-2, 0], [ 2, 0, 0], [ 0, 0, 0]] # kr
QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32)
QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
[ 0,-1, 0, 0],
[ 0, 0,-1, 0],
[ 0, 0, 0,-1]]
QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
[ 1, 0, 0, 0],
[ 0, 0, 0, 1],
[ 0, 0,-1, 0]]
QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
[ 0, 0, 0,-1],
[ 1, 0, 0, 0],
[ 0, 1, 0, 0]]
QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
[ 0, 0, 1, 0],
[ 0,-1, 0, 0],
[ 1, 0, 0, 0]]
QUAT_MULTIPLY_BY_VEC = QUAT_MULTIPLY[:, 1:, :]
# pylint: enable=bad-whitespace
def rot_to_quat(rot, unstack_inputs=False):
"""Convert rotation matrix to quaternion.
Note that this function calls self_adjoint_eig which is extremely expensive on
the GPU. If at all possible, this function should run on the CPU.
Args:
rot: rotation matrix (see below for format).
unstack_inputs: If true, rotation matrix should be shape (..., 3, 3)
otherwise the rotation matrix should be a list of lists of tensors.
Returns:
Quaternion as (..., 4) tensor.
"""
if unstack_inputs:
rot = [jnp.moveaxis(x, -1, 0) for x in jnp.moveaxis(rot, -2, 0)]
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
# pylint: disable=bad-whitespace
k = [[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
[ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
[ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
[ yx - xy, xz + zx, yz + zy, zz - xx - yy,]]
# pylint: enable=bad-whitespace
k = (1./3.) * jnp.stack([jnp.stack(x, axis=-1) for x in k],
axis=-2)
# Get eigenvalues in non-decreasing order and associated.
_, qs = jnp.linalg.eigh(k)
return qs[..., -1]
def rot_list_to_tensor(rot_list):
"""Convert list of lists to rotation tensor."""
return jnp.stack(
[jnp.stack(rot_list[0], axis=-1),
jnp.stack(rot_list[1], axis=-1),
jnp.stack(rot_list[2], axis=-1)],
axis=-2)
def vec_list_to_tensor(vec_list):
"""Convert list to vector tensor."""
return jnp.stack(vec_list, axis=-1)
def quat_to_rot(normalized_quat):
"""Convert a normalized quaternion to a rotation matrix."""
rot_tensor = jnp.sum(
np.reshape(QUAT_TO_ROT, (4, 4, 9)) *
normalized_quat[..., :, None, None] *
normalized_quat[..., None, :, None],
axis=(-3, -2))
rot = jnp.moveaxis(rot_tensor, -1, 0) # Unstack.
return [[rot[0], rot[1], rot[2]],
[rot[3], rot[4], rot[5]],
[rot[6], rot[7], rot[8]]]
def quat_multiply_by_vec(quat, vec):
"""Multiply a quaternion by a pure-vector quaternion."""
return jnp.sum(
QUAT_MULTIPLY_BY_VEC *
quat[..., :, None, None] *
vec[..., None, :, None],
axis=(-3, -2))
def quat_multiply(quat1, quat2):
"""Multiply a quaternion by another quaternion."""
return jnp.sum(
QUAT_MULTIPLY *
quat1[..., :, None, None] *
quat2[..., None, :, None],
axis=(-3, -2))
def apply_rot_to_vec(rot, vec, unstack=False):
"""Multiply rotation matrix by a vector."""
if unstack:
x, y, z = [vec[:, i] for i in range(3)]
else:
x, y, z = vec
return [rot[0][0] * x + rot[0][1] * y + rot[0][2] * z,
rot[1][0] * x + rot[1][1] * y + rot[1][2] * z,
rot[2][0] * x + rot[2][1] * y + rot[2][2] * z]
def apply_inverse_rot_to_vec(rot, vec):
"""Multiply the inverse of a rotation matrix by a vector."""
# Inverse rotation is just transpose
return [rot[0][0] * vec[0] + rot[1][0] * vec[1] + rot[2][0] * vec[2],
rot[0][1] * vec[0] + rot[1][1] * vec[1] + rot[2][1] * vec[2],
rot[0][2] * vec[0] + rot[1][2] * vec[1] + rot[2][2] * vec[2]]
class QuatAffine(object):
"""Affine transformation represented by quaternion and vector."""
def __init__(self, quaternion, translation, rotation=None, normalize=True,
unstack_inputs=False):
"""Initialize from quaternion and translation.
Args:
quaternion: Rotation represented by a quaternion, to be applied
before translation. Must be a unit quaternion unless normalize==True.
translation: Translation represented as a vector.
rotation: Same rotation as the quaternion, represented as a (..., 3, 3)
tensor. If None, rotation will be calculated from the quaternion.
normalize: If True, l2 normalize the quaternion on input.
unstack_inputs: If True, translation is a vector with last component 3
"""
if quaternion is not None:
assert quaternion.shape[-1] == 4
if unstack_inputs:
if rotation is not None:
rotation = [jnp.moveaxis(x, -1, 0) # Unstack.
for x in jnp.moveaxis(rotation, -2, 0)] # Unstack.
translation = jnp.moveaxis(translation, -1, 0) # Unstack.
if normalize and quaternion is not None:
quaternion = quaternion / jnp.linalg.norm(quaternion, axis=-1,
keepdims=True)
if rotation is None:
rotation = quat_to_rot(quaternion)
self.quaternion = quaternion
self.rotation = [list(row) for row in rotation]
self.translation = list(translation)
assert all(len(row) == 3 for row in self.rotation)
assert len(self.translation) == 3
def to_tensor(self):
return jnp.concatenate(
[self.quaternion] +
[jnp.expand_dims(x, axis=-1) for x in self.translation],
axis=-1)
def apply_tensor_fn(self, tensor_fn):
"""Return a new QuatAffine with tensor_fn applied (e.g. stop_gradient)."""
return QuatAffine(
tensor_fn(self.quaternion),
[tensor_fn(x) for x in self.translation],
rotation=[[tensor_fn(x) for x in row] for row in self.rotation],
normalize=False)
def apply_rotation_tensor_fn(self, tensor_fn):
"""Return a new QuatAffine with tensor_fn applied to the rotation part."""
return QuatAffine(
tensor_fn(self.quaternion),
[x for x in self.translation],
rotation=[[tensor_fn(x) for x in row] for row in self.rotation],
normalize=False)
def scale_translation(self, position_scale):
"""Return a new quat affine with a different scale for translation."""
return QuatAffine(
self.quaternion,
[x * position_scale for x in self.translation],
rotation=[[x for x in row] for row in self.rotation],
normalize=False)
@classmethod
def from_tensor(cls, tensor, normalize=False):
quaternion, tx, ty, tz = jnp.split(tensor, [4, 5, 6], axis=-1)
return cls(quaternion,
[tx[..., 0], ty[..., 0], tz[..., 0]],
normalize=normalize)
def pre_compose(self, update):
"""Return a new QuatAffine which applies the transformation update first.
Args:
update: Length-6 vector. 3-vector of x, y, and z such that the quaternion
update is (1, x, y, z) and zero for the 3-vector is the identity
quaternion. 3-vector for translation concatenated.
Returns:
New QuatAffine object.
"""
vector_quaternion_update, x, y, z = jnp.split(update, [3, 4, 5], axis=-1)
trans_update = [jnp.squeeze(x, axis=-1),
jnp.squeeze(y, axis=-1),
jnp.squeeze(z, axis=-1)]
new_quaternion = (self.quaternion +
quat_multiply_by_vec(self.quaternion,
vector_quaternion_update))
trans_update = apply_rot_to_vec(self.rotation, trans_update)
new_translation = [
self.translation[0] + trans_update[0],
self.translation[1] + trans_update[1],
self.translation[2] + trans_update[2]]
return QuatAffine(new_quaternion, new_translation)
def apply_to_point(self, point, extra_dims=0):
"""Apply affine to a point.
Args:
point: List of 3 tensors to apply affine.
extra_dims: Number of dimensions at the end of the transformed_point
shape that are not present in the rotation and translation. The most
common use is rotation N points at once with extra_dims=1 for use in a
network.
Returns:
Transformed point after applying affine.
"""
rotation = self.rotation
translation = self.translation
for _ in range(extra_dims):
expand_fn = functools.partial(jnp.expand_dims, axis=-1)
rotation = jax.tree_map(expand_fn, rotation)
translation = jax.tree_map(expand_fn, translation)
rot_point = apply_rot_to_vec(rotation, point)
return [
rot_point[0] + translation[0],
rot_point[1] + translation[1],
rot_point[2] + translation[2]]
def invert_point(self, transformed_point, extra_dims=0):
"""Apply inverse of transformation to a point.
Args:
transformed_point: List of 3 tensors to apply affine
extra_dims: Number of dimensions at the end of the transformed_point
shape that are not present in the rotation and translation. The most
common use is rotation N points at once with extra_dims=1 for use in a
network.
Returns:
Transformed point after applying affine.
"""
rotation = self.rotation
translation = self.translation
for _ in range(extra_dims):
expand_fn = functools.partial(jnp.expand_dims, axis=-1)
rotation = jax.tree_map(expand_fn, rotation)
translation = jax.tree_map(expand_fn, translation)
rot_point = [
transformed_point[0] - translation[0],
transformed_point[1] - translation[1],
transformed_point[2] - translation[2]]
return apply_inverse_rot_to_vec(rotation, rot_point)
def __repr__(self):
return 'QuatAffine(%r, %r)' % (self.quaternion, self.translation)
def _multiply(a, b):
return jnp.stack([
jnp.array([a[0][0]*b[0][0] + a[0][1]*b[1][0] + a[0][2]*b[2][0],
a[0][0]*b[0][1] + a[0][1]*b[1][1] + a[0][2]*b[2][1],
a[0][0]*b[0][2] + a[0][1]*b[1][2] + a[0][2]*b[2][2]]),
jnp.array([a[1][0]*b[0][0] + a[1][1]*b[1][0] + a[1][2]*b[2][0],
a[1][0]*b[0][1] + a[1][1]*b[1][1] + a[1][2]*b[2][1],
a[1][0]*b[0][2] + a[1][1]*b[1][2] + a[1][2]*b[2][2]]),
jnp.array([a[2][0]*b[0][0] + a[2][1]*b[1][0] + a[2][2]*b[2][0],
a[2][0]*b[0][1] + a[2][1]*b[1][1] + a[2][2]*b[2][1],
a[2][0]*b[0][2] + a[2][1]*b[1][2] + a[2][2]*b[2][2]])])
def make_canonical_transform(
n_xyz: jnp.ndarray,
ca_xyz: jnp.ndarray,
c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Returns translation and rotation matrices to canonicalize residue atoms.
Note that this method does not take care of symmetries. If you provide the
atom positions in the non-standard way, the N atom will end up not at
[-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
need to take care of such cases in your code.
Args:
n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
Returns:
A tuple (translation, rotation) where:
translation is an array of shape [batch, 3] defining the translation.
rotation is an array of shape [batch, 3, 3] defining the rotation.
After applying the translation and rotation to all atoms in a residue:
* All atoms will be shifted so that CA is at the origin,
* All atoms will be rotated so that C is at the x-axis,
* All atoms will be shifted so that N is in the xy plane.
"""
assert len(n_xyz.shape) == 2, n_xyz.shape
assert n_xyz.shape[-1] == 3, n_xyz.shape
assert n_xyz.shape == ca_xyz.shape == c_xyz.shape, (
n_xyz.shape, ca_xyz.shape, c_xyz.shape)
# Place CA at the origin.
translation = -ca_xyz
n_xyz = n_xyz + translation
c_xyz = c_xyz + translation
# Place C on the x-axis.
c_x, c_y, c_z = [c_xyz[:, i] for i in range(3)]
# Rotate by angle c1 in the x-y plane (around the z-axis).
sin_c1 = -c_y / jnp.sqrt(1e-20 + c_x**2 + c_y**2)
cos_c1 = c_x / jnp.sqrt(1e-20 + c_x**2 + c_y**2)
zeros = jnp.zeros_like(sin_c1)
ones = jnp.ones_like(sin_c1)
# pylint: disable=bad-whitespace
c1_rot_matrix = jnp.stack([jnp.array([cos_c1, -sin_c1, zeros]),
jnp.array([sin_c1, cos_c1, zeros]),
jnp.array([zeros, zeros, ones])])
# Rotate by angle c2 in the x-z plane (around the y-axis).
sin_c2 = c_z / jnp.sqrt(1e-20 + c_x**2 + c_y**2 + c_z**2)
cos_c2 = jnp.sqrt(c_x**2 + c_y**2) / jnp.sqrt(
1e-20 + c_x**2 + c_y**2 + c_z**2)
c2_rot_matrix = jnp.stack([jnp.array([cos_c2, zeros, sin_c2]),
jnp.array([zeros, ones, zeros]),
jnp.array([-sin_c2, zeros, cos_c2])])
c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix)
n_xyz = jnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T
# Place N in the x-y plane.
_, n_y, n_z = [n_xyz[:, i] for i in range(3)]
# Rotate by angle alpha in the y-z plane (around the x-axis).
sin_n = -n_z / jnp.sqrt(1e-20 + n_y**2 + n_z**2)
cos_n = n_y / jnp.sqrt(1e-20 + n_y**2 + n_z**2)
n_rot_matrix = jnp.stack([jnp.array([ones, zeros, zeros]),
jnp.array([zeros, cos_n, -sin_n]),
jnp.array([zeros, sin_n, cos_n])])
# pylint: enable=bad-whitespace
return (translation,
jnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1]))
def make_transform_from_reference(
n_xyz: jnp.ndarray,
ca_xyz: jnp.ndarray,
c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Returns rotation and translation matrices to convert from reference.
Note that this method does not take care of symmetries. If you provide the
atom positions in the non-standard way, the N atom will end up not at
[-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
need to take care of such cases in your code.
Args:
n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
Returns:
A tuple (rotation, translation) where:
rotation is an array of shape [batch, 3, 3] defining the rotation.
translation is an array of shape [batch, 3] defining the translation.
After applying the translation and rotation to the reference backbone,
the coordinates will approximately equal to the input coordinates.
The order of translation and rotation differs from make_canonical_transform
because the rotation from this function should be applied before the
translation, unlike make_canonical_transform.
"""
translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz)
return np.transpose(rotation, (0, 2, 1)), -translation
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for quat_affine."""
from absl import logging
from absl.testing import absltest
from alphafold.model import quat_affine
import jax
import jax.numpy as jnp
import numpy as np
VERBOSE = False
np.set_printoptions(precision=3, suppress=True)
r2t = quat_affine.rot_list_to_tensor
v2t = quat_affine.vec_list_to_tensor
q2r = lambda q: r2t(quat_affine.quat_to_rot(q))
class QuatAffineTest(absltest.TestCase):
def _assert_check(self, to_check, tol=1e-5):
for k, (correct, generated) in to_check.items():
if VERBOSE:
logging.info(k)
logging.info('Correct %s', correct)
logging.info('Predicted %s', generated)
self.assertLess(np.max(np.abs(correct - generated)), tol)
def test_conversion(self):
quat = jnp.array([-2., 5., -1., 4.])
rotation = jnp.array([
[0.26087, 0.130435, 0.956522],
[-0.565217, -0.782609, 0.26087],
[0.782609, -0.608696, -0.130435]])
translation = jnp.array([1., -3., 4.])
point = jnp.array([0.7, 3.2, -2.9])
a = quat_affine.QuatAffine(quat, translation, unstack_inputs=True)
true_new_point = jnp.matmul(rotation, point[:, None])[:, 0] + translation
self._assert_check({
'rot': (rotation, r2t(a.rotation)),
'trans': (translation, v2t(a.translation)),
'point': (true_new_point,
v2t(a.apply_to_point(jnp.moveaxis(point, -1, 0)))),
# Because of the double cover, we must be careful and compare rotations
'quat': (q2r(a.quaternion),
q2r(quat_affine.rot_to_quat(a.rotation))),
})
def test_double_cover(self):
"""Test that -q is the same rotation as q."""
rng = jax.random.PRNGKey(42)
keys = jax.random.split(rng)
q = jax.random.normal(keys[0], (2, 4))
trans = jax.random.normal(keys[1], (2, 3))
a1 = quat_affine.QuatAffine(q, trans, unstack_inputs=True)
a2 = quat_affine.QuatAffine(-q, trans, unstack_inputs=True)
self._assert_check({
'rot': (r2t(a1.rotation),
r2t(a2.rotation)),
'trans': (v2t(a1.translation),
v2t(a2.translation)),
})
def test_homomorphism(self):
rng = jax.random.PRNGKey(42)
keys = jax.random.split(rng, 4)
vec_q1 = jax.random.normal(keys[0], (2, 3))
q1 = jnp.concatenate([
jnp.ones_like(vec_q1)[:, :1],
vec_q1], axis=-1)
q2 = jax.random.normal(keys[1], (2, 4))
t1 = jax.random.normal(keys[2], (2, 3))
t2 = jax.random.normal(keys[3], (2, 3))
a1 = quat_affine.QuatAffine(q1, t1, unstack_inputs=True)
a2 = quat_affine.QuatAffine(q2, t2, unstack_inputs=True)
a21 = a2.pre_compose(jnp.concatenate([vec_q1, t1], axis=-1))
rng, key = jax.random.split(rng)
x = jax.random.normal(key, (2, 3))
new_x = a21.apply_to_point(jnp.moveaxis(x, -1, 0))
new_x_apply2 = a2.apply_to_point(a1.apply_to_point(jnp.moveaxis(x, -1, 0)))
self._assert_check({
'quat': (q2r(quat_affine.quat_multiply(a2.quaternion, a1.quaternion)),
q2r(a21.quaternion)),
'rot': (jnp.matmul(r2t(a2.rotation), r2t(a1.rotation)),
r2t(a21.rotation)),
'point': (v2t(new_x_apply2),
v2t(new_x)),
'inverse': (x, v2t(a21.invert_point(new_x))),
})
def test_batching(self):
"""Test that affine applies batchwise."""
rng = jax.random.PRNGKey(42)
keys = jax.random.split(rng, 3)
q = jax.random.uniform(keys[0], (5, 2, 4))
t = jax.random.uniform(keys[1], (2, 3))
x = jax.random.uniform(keys[2], (5, 1, 3))
a = quat_affine.QuatAffine(q, t, unstack_inputs=True)
y = v2t(a.apply_to_point(jnp.moveaxis(x, -1, 0)))
y_list = []
for i in range(5):
for j in range(2):
a_local = quat_affine.QuatAffine(q[i, j], t[j],
unstack_inputs=True)
y_local = v2t(a_local.apply_to_point(jnp.moveaxis(x[i, 0], -1, 0)))
y_list.append(y_local)
y_combine = jnp.reshape(jnp.stack(y_list, axis=0), (5, 2, 3))
self._assert_check({
'batch': (y_combine, y),
'quat': (q2r(a.quaternion),
q2r(quat_affine.rot_to_quat(a.rotation))),
})
def assertAllClose(self, a, b, rtol=1e-06, atol=1e-06):
self.assertTrue(np.allclose(a, b, rtol=rtol, atol=atol))
def assertAllEqual(self, a, b):
self.assertTrue(np.all(np.array(a) == np.array(b)))
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformations for 3D coordinates.
This Module contains objects for representing Vectors (Vecs), Rotation Matrices
(Rots) and proper Rigid transformation (Rigids). These are represented as
named tuples with arrays for each entry, for example a set of
[N, M] points would be represented as a Vecs object with arrays of shape [N, M]
for x, y and z.
This is being done to improve readability by making it very clear what objects
are geometric objects rather than relying on comments and array shapes.
Another reason for this is to avoid using matrix
multiplication primitives like matmul or einsum, on modern accelerator hardware
these can end up on specialized cores such as tensor cores on GPU or the MXU on
cloud TPUs, this often involves lower computational precision which can be
problematic for coordinate geometry. Also these cores are typically optimized
for larger matrices than 3 dimensional, this code is written to avoid any
unintended use of these cores on both GPUs and TPUs.
"""
import collections
from typing import List
from alphafold.model import quat_affine
import jax.numpy as jnp
import tree
# Array of 3-component vectors, stored as individual array for
# each component.
Vecs = collections.namedtuple('Vecs', ['x', 'y', 'z'])
# Array of 3x3 rotation matrices, stored as individual array for
# each component.
Rots = collections.namedtuple('Rots', ['xx', 'xy', 'xz',
'yx', 'yy', 'yz',
'zx', 'zy', 'zz'])
# Array of rigid 3D transformations, stored as array of rotations and
# array of translations.
Rigids = collections.namedtuple('Rigids', ['rot', 'trans'])
def squared_difference(x, y):
return jnp.square(x - y)
def invert_rigids(r: Rigids) -> Rigids:
"""Computes group inverse of rigid transformations 'r'."""
inv_rots = invert_rots(r.rot)
t = rots_mul_vecs(inv_rots, r.trans)
inv_trans = Vecs(-t.x, -t.y, -t.z)
return Rigids(inv_rots, inv_trans)
def invert_rots(m: Rots) -> Rots:
"""Computes inverse of rotations 'm'."""
return Rots(m.xx, m.yx, m.zx,
m.xy, m.yy, m.zy,
m.xz, m.yz, m.zz)
def rigids_from_3_points(
point_on_neg_x_axis: Vecs, # shape (...)
origin: Vecs, # shape (...)
point_on_xy_plane: Vecs, # shape (...)
) -> Rigids: # shape (...)
"""Create Rigids from 3 points.
Jumper et al. (2021) Suppl. Alg. 21 "rigidFrom3Points"
This creates a set of rigid transformations from 3 points by Gram Schmidt
orthogonalization.
Args:
point_on_neg_x_axis: Vecs corresponding to points on the negative x axis
origin: Origin of resulting rigid transformations
point_on_xy_plane: Vecs corresponding to points in the xy plane
Returns:
Rigid transformations from global frame to local frames derived from
the input points.
"""
m = rots_from_two_vecs(
e0_unnormalized=vecs_sub(origin, point_on_neg_x_axis),
e1_unnormalized=vecs_sub(point_on_xy_plane, origin))
return Rigids(rot=m, trans=origin)
def rigids_from_list(l: List[jnp.ndarray]) -> Rigids:
"""Converts flat list of arrays to rigid transformations."""
assert len(l) == 12
return Rigids(Rots(*(l[:9])), Vecs(*(l[9:])))
def rigids_from_quataffine(a: quat_affine.QuatAffine) -> Rigids:
"""Converts QuatAffine object to the corresponding Rigids object."""
return Rigids(Rots(*tree.flatten(a.rotation)),
Vecs(*a.translation))
def rigids_from_tensor4x4(
m: jnp.ndarray # shape (..., 4, 4)
) -> Rigids: # shape (...)
"""Construct Rigids object from an 4x4 array.
Here the 4x4 is representing the transformation in homogeneous coordinates.
Args:
m: Array representing transformations in homogeneous coordinates.
Returns:
Rigids object corresponding to transformations m
"""
assert m.shape[-1] == 4
assert m.shape[-2] == 4
return Rigids(
Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2],
m[..., 1, 0], m[..., 1, 1], m[..., 1, 2],
m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]),
Vecs(m[..., 0, 3], m[..., 1, 3], m[..., 2, 3]))
def rigids_from_tensor_flat9(
m: jnp.ndarray # shape (..., 9)
) -> Rigids: # shape (...)
"""Flat9 encoding: first two columns of rotation matrix + translation."""
assert m.shape[-1] == 9
e0 = Vecs(m[..., 0], m[..., 1], m[..., 2])
e1 = Vecs(m[..., 3], m[..., 4], m[..., 5])
trans = Vecs(m[..., 6], m[..., 7], m[..., 8])
return Rigids(rot=rots_from_two_vecs(e0, e1),
trans=trans)
def rigids_from_tensor_flat12(
m: jnp.ndarray # shape (..., 12)
) -> Rigids: # shape (...)
"""Flat12 encoding: rotation matrix (9 floats) + translation (3 floats)."""
assert m.shape[-1] == 12
x = jnp.moveaxis(m, -1, 0) # Unstack
return Rigids(Rots(*x[:9]), Vecs(*x[9:]))
def rigids_mul_rigids(a: Rigids, b: Rigids) -> Rigids:
"""Group composition of Rigids 'a' and 'b'."""
return Rigids(
rots_mul_rots(a.rot, b.rot),
vecs_add(a.trans, rots_mul_vecs(a.rot, b.trans)))
def rigids_mul_rots(r: Rigids, m: Rots) -> Rigids:
"""Compose rigid transformations 'r' with rotations 'm'."""
return Rigids(rots_mul_rots(r.rot, m), r.trans)
def rigids_mul_vecs(r: Rigids, v: Vecs) -> Vecs:
"""Apply rigid transforms 'r' to points 'v'."""
return vecs_add(rots_mul_vecs(r.rot, v), r.trans)
def rigids_to_list(r: Rigids) -> List[jnp.ndarray]:
"""Turn Rigids into flat list, inverse of 'rigids_from_list'."""
return list(r.rot) + list(r.trans)
def rigids_to_quataffine(r: Rigids) -> quat_affine.QuatAffine:
"""Convert Rigids r into QuatAffine, inverse of 'rigids_from_quataffine'."""
return quat_affine.QuatAffine(
quaternion=None,
rotation=[[r.rot.xx, r.rot.xy, r.rot.xz],
[r.rot.yx, r.rot.yy, r.rot.yz],
[r.rot.zx, r.rot.zy, r.rot.zz]],
translation=[r.trans.x, r.trans.y, r.trans.z])
def rigids_to_tensor_flat9(
r: Rigids # shape (...)
) -> jnp.ndarray: # shape (..., 9)
"""Flat9 encoding: first two columns of rotation matrix + translation."""
return jnp.stack(
[r.rot.xx, r.rot.yx, r.rot.zx, r.rot.xy, r.rot.yy, r.rot.zy]
+ list(r.trans), axis=-1)
def rigids_to_tensor_flat12(
r: Rigids # shape (...)
) -> jnp.ndarray: # shape (..., 12)
"""Flat12 encoding: rotation matrix (9 floats) + translation (3 floats)."""
return jnp.stack(list(r.rot) + list(r.trans), axis=-1)
def rots_from_tensor3x3(
m: jnp.ndarray, # shape (..., 3, 3)
) -> Rots: # shape (...)
"""Convert rotations represented as (3, 3) array to Rots."""
assert m.shape[-1] == 3
assert m.shape[-2] == 3
return Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2],
m[..., 1, 0], m[..., 1, 1], m[..., 1, 2],
m[..., 2, 0], m[..., 2, 1], m[..., 2, 2])
def rots_from_two_vecs(e0_unnormalized: Vecs, e1_unnormalized: Vecs) -> Rots:
"""Create rotation matrices from unnormalized vectors for the x and y-axes.
This creates a rotation matrix from two vectors using Gram-Schmidt
orthogonalization.
Args:
e0_unnormalized: vectors lying along x-axis of resulting rotation
e1_unnormalized: vectors lying in xy-plane of resulting rotation
Returns:
Rotations resulting from Gram-Schmidt procedure.
"""
# Normalize the unit vector for the x-axis, e0.
e0 = vecs_robust_normalize(e0_unnormalized)
# make e1 perpendicular to e0.
c = vecs_dot_vecs(e1_unnormalized, e0)
e1 = Vecs(e1_unnormalized.x - c * e0.x,
e1_unnormalized.y - c * e0.y,
e1_unnormalized.z - c * e0.z)
e1 = vecs_robust_normalize(e1)
# Compute e2 as cross product of e0 and e1.
e2 = vecs_cross_vecs(e0, e1)
return Rots(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
def rots_mul_rots(a: Rots, b: Rots) -> Rots:
"""Composition of rotations 'a' and 'b'."""
c0 = rots_mul_vecs(a, Vecs(b.xx, b.yx, b.zx))
c1 = rots_mul_vecs(a, Vecs(b.xy, b.yy, b.zy))
c2 = rots_mul_vecs(a, Vecs(b.xz, b.yz, b.zz))
return Rots(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
def rots_mul_vecs(m: Rots, v: Vecs) -> Vecs:
"""Apply rotations 'm' to vectors 'v'."""
return Vecs(m.xx * v.x + m.xy * v.y + m.xz * v.z,
m.yx * v.x + m.yy * v.y + m.yz * v.z,
m.zx * v.x + m.zy * v.y + m.zz * v.z)
def vecs_add(v1: Vecs, v2: Vecs) -> Vecs:
"""Add two vectors 'v1' and 'v2'."""
return Vecs(v1.x + v2.x, v1.y + v2.y, v1.z + v2.z)
def vecs_dot_vecs(v1: Vecs, v2: Vecs) -> jnp.ndarray:
"""Dot product of vectors 'v1' and 'v2'."""
return v1.x * v2.x + v1.y * v2.y + v1.z * v2.z
def vecs_cross_vecs(v1: Vecs, v2: Vecs) -> Vecs:
"""Cross product of vectors 'v1' and 'v2'."""
return Vecs(v1.y * v2.z - v1.z * v2.y,
v1.z * v2.x - v1.x * v2.z,
v1.x * v2.y - v1.y * v2.x)
def vecs_from_tensor(x: jnp.ndarray # shape (..., 3)
) -> Vecs: # shape (...)
"""Converts from tensor of shape (3,) to Vecs."""
num_components = x.shape[-1]
assert num_components == 3
return Vecs(x[..., 0], x[..., 1], x[..., 2])
def vecs_robust_normalize(v: Vecs, epsilon: float = 1e-8) -> Vecs:
"""Normalizes vectors 'v'.
Args:
v: vectors to be normalized.
epsilon: small regularizer added to squared norm before taking square root.
Returns:
normalized vectors
"""
norms = vecs_robust_norm(v, epsilon)
return Vecs(v.x / norms, v.y / norms, v.z / norms)
def vecs_robust_norm(v: Vecs, epsilon: float = 1e-8) -> jnp.ndarray:
"""Computes norm of vectors 'v'.
Args:
v: vectors to be normalized.
epsilon: small regularizer added to squared norm before taking square root.
Returns:
norm of 'v'
"""
return jnp.sqrt(jnp.square(v.x) + jnp.square(v.y) + jnp.square(v.z) + epsilon)
def vecs_sub(v1: Vecs, v2: Vecs) -> Vecs:
"""Computes v1 - v2."""
return Vecs(v1.x - v2.x, v1.y - v2.y, v1.z - v2.z)
def vecs_squared_distance(v1: Vecs, v2: Vecs) -> jnp.ndarray:
"""Computes squared euclidean difference between 'v1' and 'v2'."""
return (squared_difference(v1.x, v2.x) +
squared_difference(v1.y, v2.y) +
squared_difference(v1.z, v2.z))
def vecs_to_tensor(v: Vecs # shape (...)
) -> jnp.ndarray: # shape(..., 3)
"""Converts 'v' to tensor with shape 3, inverse of 'vecs_from_tensor'."""
return jnp.stack([v.x, v.y, v.z], axis=-1)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alphafold model TensorFlow code."""
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data for AlphaFold."""
from alphafold.common import residue_constants
from alphafold.model.tf import shape_helpers
from alphafold.model.tf import shape_placeholders
from alphafold.model.tf import utils
import numpy as np
import tensorflow.compat.v1 as tf
# Pylint gets confused by the curry1 decorator because it changes the number
# of arguments to the function.
# pylint:disable=no-value-for-parameter
NUM_RES = shape_placeholders.NUM_RES
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def cast_64bit_ints(protein):
for k, v in protein.items():
if v.dtype == tf.int64:
protein[k] = tf.cast(v, tf.int32)
return protein
_MSA_FEATURE_NAMES = [
'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask',
'true_msa'
]
def make_seq_mask(protein):
protein['seq_mask'] = tf.ones(
shape_helpers.shape_list(protein['aatype']), dtype=tf.float32)
return protein
def make_template_mask(protein):
protein['template_mask'] = tf.ones(
shape_helpers.shape_list(protein['template_domain_names']),
dtype=tf.float32)
return protein
def curry1(f):
"""Supply all arguments but the first."""
def fc(*args, **kwargs):
return lambda x: f(x, *args, **kwargs)
return fc
@curry1
def add_distillation_flag(protein, distillation):
protein['is_distillation'] = tf.constant(float(distillation),
shape=[],
dtype=tf.float32)
return protein
def make_all_atom_aatype(protein):
protein['all_atom_aatype'] = protein['aatype']
return protein
def fix_templates_aatype(protein):
"""Fixes aatype encoding of templates."""
# Map one-hot to indices.
protein['template_aatype'] = tf.argmax(
protein['template_aatype'], output_type=tf.int32, axis=-1)
# Map hhsearch-aatype to our aatype.
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = tf.constant(new_order_list, dtype=tf.int32)
protein['template_aatype'] = tf.gather(params=new_order,
indices=protein['template_aatype'])
return protein
def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = tf.constant(new_order_list, dtype=protein['msa'].dtype)
protein['msa'] = tf.gather(new_order, protein['msa'], axis=0)
perm_matrix = np.zeros((22, 22), dtype=np.float32)
perm_matrix[range(len(new_order_list)), new_order_list] = 1.
for k in protein:
if 'profile' in k: # Include both hhblits and psiblast profiles
num_dim = protein[k].shape.as_list()[-1]
assert num_dim in [20, 21, 22], (
'num_dim for %s out of expected range: %s' % (k, num_dim))
protein[k] = tf.tensordot(protein[k], perm_matrix[:num_dim, :num_dim], 1)
return protein
def squeeze_features(protein):
"""Remove singleton and repeated dimensions in protein features."""
protein['aatype'] = tf.argmax(
protein['aatype'], axis=-1, output_type=tf.int32)
for k in [
'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence',
'superfamily', 'deletion_matrix', 'resolution',
'between_segment_residues', 'residue_index', 'template_all_atom_masks']:
if k in protein:
final_dim = shape_helpers.shape_list(protein[k])[-1]
if isinstance(final_dim, int) and final_dim == 1:
protein[k] = tf.squeeze(protein[k], axis=-1)
for k in ['seq_length', 'num_alignments']:
if k in protein:
protein[k] = protein[k][0] # Remove fake sequence dimension
return protein
def make_random_crop_to_size_seed(protein):
"""Random seed for cropping residues and templates."""
protein['random_crop_to_size_seed'] = utils.make_random_seed()
return protein
@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
"""Replace a proportion of the MSA with 'X'."""
msa_mask = (tf.random.uniform(shape_helpers.shape_list(protein['msa'])) <
replace_proportion)
x_idx = 20
gap_idx = 21
msa_mask = tf.logical_and(msa_mask, protein['msa'] != gap_idx)
protein['msa'] = tf.where(msa_mask,
tf.ones_like(protein['msa']) * x_idx,
protein['msa'])
aatype_mask = (
tf.random.uniform(shape_helpers.shape_list(protein['aatype'])) <
replace_proportion)
protein['aatype'] = tf.where(aatype_mask,
tf.ones_like(protein['aatype']) * x_idx,
protein['aatype'])
return protein
@curry1
def sample_msa(protein, max_seq, keep_extra):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
protein: batch to sample msa from.
max_seq: number of sequences to sample.
keep_extra: When True sequences not sampled are put into fields starting
with 'extra_*'.
Returns:
Protein with sampled msa.
"""
num_seq = tf.shape(protein['msa'])[0]
shuffled = tf.random_shuffle(tf.range(1, num_seq))
index_order = tf.concat([[0], shuffled], axis=0)
num_sel = tf.minimum(max_seq, num_seq)
sel_seq, not_sel_seq = tf.split(index_order, [num_sel, num_seq - num_sel])
for k in _MSA_FEATURE_NAMES:
if k in protein:
if keep_extra:
protein['extra_' + k] = tf.gather(protein[k], not_sel_seq)
protein[k] = tf.gather(protein[k], sel_seq)
return protein
@curry1
def crop_extra_msa(protein, max_extra_msa):
"""MSA features are cropped so only `max_extra_msa` sequences are kept."""
num_seq = tf.shape(protein['extra_msa'])[0]
num_sel = tf.minimum(max_extra_msa, num_seq)
select_indices = tf.random_shuffle(tf.range(0, num_seq))[:num_sel]
for k in _MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
protein['extra_' + k] = tf.gather(protein['extra_' + k], select_indices)
return protein
def delete_extra_msa(protein):
for k in _MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
del protein['extra_' + k]
return protein
@curry1
def block_delete_msa(protein, config):
"""Sample MSA by deleting contiguous blocks.
Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion"
Arguments:
protein: batch dict containing the msa
config: ConfigDict with parameters
Returns:
updated protein
"""
num_seq = shape_helpers.shape_list(protein['msa'])[0]
block_num_seq = tf.cast(
tf.floor(tf.cast(num_seq, tf.float32) * config.msa_fraction_per_block),
tf.int32)
if config.randomize_num_blocks:
nb = tf.random.uniform([], 0, config.num_blocks + 1, dtype=tf.int32)
else:
nb = config.num_blocks
del_block_starts = tf.random.uniform([nb], 0, num_seq, dtype=tf.int32)
del_blocks = del_block_starts[:, None] + tf.range(block_num_seq)
del_blocks = tf.clip_by_value(del_blocks, 0, num_seq - 1)
del_indices = tf.unique(tf.sort(tf.reshape(del_blocks, [-1])))[0]
# Make sure we keep the original sequence
sparse_diff = tf.sets.difference(tf.range(1, num_seq)[None],
del_indices[None])
keep_indices = tf.squeeze(tf.sparse.to_dense(sparse_diff), 0)
keep_indices = tf.concat([[0], keep_indices], axis=0)
for k in _MSA_FEATURE_NAMES:
if k in protein:
protein[k] = tf.gather(protein[k], keep_indices)
return protein
@curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask
weights = tf.concat([
tf.ones(21),
gap_agreement_weight * tf.ones(1),
np.zeros(1)], 0)
# Make agreement score as weighted Hamming distance
sample_one_hot = (protein['msa_mask'][:, :, None] *
tf.one_hot(protein['msa'], 23))
extra_one_hot = (protein['extra_msa_mask'][:, :, None] *
tf.one_hot(protein['extra_msa'], 23))
num_seq, num_res, _ = shape_helpers.shape_list(sample_one_hot)
extra_num_seq, _, _ = shape_helpers.shape_list(extra_one_hot)
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
agreement = tf.matmul(
tf.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
tf.reshape(sample_one_hot * weights, [num_seq, num_res * 23]),
transpose_b=True)
# Assign each sequence in the extra sequences to the closest MSA sample
protein['extra_cluster_assignment'] = tf.argmax(
agreement, axis=1, output_type=tf.int32)
return protein
@curry1
def summarize_clusters(protein):
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq = shape_helpers.shape_list(protein['msa'])[0]
def csum(x):
return tf.math.unsorted_segment_sum(
x, protein['extra_cluster_assignment'], num_seq)
mask = protein['extra_msa_mask']
mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center
msa_sum = csum(mask[:, :, None] * tf.one_hot(protein['extra_msa'], 23))
msa_sum += tf.one_hot(protein['msa'], 23) # Original sequence
protein['cluster_profile'] = msa_sum / mask_counts[:, :, None]
del msa_sum
del_sum = csum(mask * protein['extra_deletion_matrix'])
del_sum += protein['deletion_matrix'] # Original sequence
protein['cluster_deletion_mean'] = del_sum / mask_counts
del del_sum
return protein
def make_msa_mask(protein):
"""Mask features are all ones, but will later be zero-padded."""
protein['msa_mask'] = tf.ones(
shape_helpers.shape_list(protein['msa']), dtype=tf.float32)
protein['msa_row_mask'] = tf.ones(
shape_helpers.shape_list(protein['msa'])[0], dtype=tf.float32)
return protein
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
"""Create pseudo beta features."""
is_gly = tf.equal(aatype, residue_constants.restype_order['G'])
ca_idx = residue_constants.atom_order['CA']
cb_idx = residue_constants.atom_order['CB']
pseudo_beta = tf.where(
tf.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :])
if all_atom_masks is not None:
pseudo_beta_mask = tf.where(
is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
pseudo_beta_mask = tf.cast(pseudo_beta_mask, tf.float32)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
@curry1
def make_pseudo_beta(protein, prefix=''):
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert prefix in ['', 'template_']
protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = (
pseudo_beta_fn(
protein['template_aatype' if prefix else 'all_atom_aatype'],
protein[prefix + 'all_atom_positions'],
protein['template_all_atom_masks' if prefix else 'all_atom_mask']))
return protein
@curry1
def add_constant_field(protein, key, value):
protein[key] = tf.convert_to_tensor(value)
return protein
def shaped_categorical(probs, epsilon=1e-10):
ds = shape_helpers.shape_list(probs)
num_classes = ds[-1]
counts = tf.random.categorical(
tf.reshape(tf.log(probs + epsilon), [-1, num_classes]),
1,
dtype=tf.int32)
return tf.reshape(counts, ds[:-1])
def make_hhblits_profile(protein):
"""Compute the HHblits MSA profile if not already present."""
if 'hhblits_profile' in protein:
return protein
# Compute the profile for every residue (over all MSA sequences).
protein['hhblits_profile'] = tf.reduce_mean(
tf.one_hot(protein['msa'], 22), axis=0)
return protein
@curry1
def make_masked_msa(protein, config, replace_fraction):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly
random_aa = tf.constant([0.05] * 20 + [0., 0.], dtype=tf.float32)
categorical_probs = (
config.uniform_prob * random_aa +
config.profile_prob * protein['hhblits_profile'] +
config.same_prob * tf.one_hot(protein['msa'], 22))
# Put all remaining probability on [MASK] which is a new column
pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
pad_shapes[-1][1] = 1
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
assert mask_prob >= 0.
categorical_probs = tf.pad(
categorical_probs, pad_shapes, constant_values=mask_prob)
sh = shape_helpers.shape_list(protein['msa'])
mask_position = tf.random.uniform(sh) < replace_fraction
bert_msa = shaped_categorical(categorical_probs)
bert_msa = tf.where(mask_position, bert_msa, protein['msa'])
# Mix real and masked MSA
protein['bert_mask'] = tf.cast(mask_position, tf.float32)
protein['true_msa'] = protein['msa']
protein['msa'] = bert_msa
return protein
@curry1
def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size,
num_res, num_templates=0):
"""Guess at the MSA and sequence dimensions to make fixed size."""
pad_size_map = {
NUM_RES: num_res,
NUM_MSA_SEQ: msa_cluster_size,
NUM_EXTRA_SEQ: extra_msa_size,
NUM_TEMPLATES: num_templates,
}
for k, v in protein.items():
# Don't transfer this to the accelerator.
if k == 'extra_cluster_assignment':
continue
shape = v.shape.as_list()
schema = shape_schema[k]
assert len(shape) == len(schema), (
f'Rank mismatch between shape and shape schema for {k}: '
f'{shape} vs {schema}')
pad_size = [
pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
]
padding = [(0, p - tf.shape(v)[i]) for i, p in enumerate(pad_size)]
if padding:
protein[k] = tf.pad(
v, padding, name=f'pad_to_fixed_{k}')
protein[k].set_shape(pad_size)
return protein
@curry1
def make_msa_feat(protein):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping
# for compatibility with domain datasets.
has_break = tf.clip_by_value(
tf.cast(protein['between_segment_residues'], tf.float32),
0, 1)
aatype_1hot = tf.one_hot(protein['aatype'], 21, axis=-1)
target_feat = [
tf.expand_dims(has_break, axis=-1),
aatype_1hot, # Everyone gets the original sequence.
]
msa_1hot = tf.one_hot(protein['msa'], 23, axis=-1)
has_deletion = tf.clip_by_value(protein['deletion_matrix'], 0., 1.)
deletion_value = tf.atan(protein['deletion_matrix'] / 3.) * (2. / np.pi)
msa_feat = [
msa_1hot,
tf.expand_dims(has_deletion, axis=-1),
tf.expand_dims(deletion_value, axis=-1),
]
if 'cluster_profile' in protein:
deletion_mean_value = (
tf.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi))
msa_feat.extend([
protein['cluster_profile'],
tf.expand_dims(deletion_mean_value, axis=-1),
])
if 'extra_deletion_matrix' in protein:
protein['extra_has_deletion'] = tf.clip_by_value(
protein['extra_deletion_matrix'], 0., 1.)
protein['extra_deletion_value'] = tf.atan(
protein['extra_deletion_matrix'] / 3.) * (2. / np.pi)
protein['msa_feat'] = tf.concat(msa_feat, axis=-1)
protein['target_feat'] = tf.concat(target_feat, axis=-1)
return protein
@curry1
def select_feat(protein, feature_list):
return {k: v for k, v in protein.items() if k in feature_list}
@curry1
def crop_templates(protein, max_templates):
for k, v in protein.items():
if k.startswith('template_'):
protein[k] = v[:max_templates]
return protein
@curry1
def random_crop_to_size(protein, crop_size, max_templates, shape_schema,
subsample_templates=False):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length = protein['seq_length']
if 'template_mask' in protein:
num_templates = tf.cast(
shape_helpers.shape_list(protein['template_mask'])[0], tf.int32)
else:
num_templates = tf.constant(0, dtype=tf.int32)
num_res_crop_size = tf.math.minimum(seq_length, crop_size)
# Ensures that the cropping of residues and templates happens in the same way
# across ensembling iterations.
# Do not use for randomness that should vary in ensembling.
seed_maker = utils.SeedMaker(initial_seed=protein['random_crop_to_size_seed'])
if subsample_templates:
templates_crop_start = tf.random.stateless_uniform(
shape=(), minval=0, maxval=num_templates + 1, dtype=tf.int32,
seed=seed_maker())
else:
templates_crop_start = 0
num_templates_crop_size = tf.math.minimum(
num_templates - templates_crop_start, max_templates)
num_res_crop_start = tf.random.stateless_uniform(
shape=(), minval=0, maxval=seq_length - num_res_crop_size + 1,
dtype=tf.int32, seed=seed_maker())
templates_select_indices = tf.argsort(tf.random.stateless_uniform(
[num_templates], seed=seed_maker()))
for k, v in protein.items():
if k not in shape_schema or (
'template' not in k and NUM_RES not in shape_schema[k]):
continue
# randomly permute the templates before cropping them.
if k.startswith('template') and subsample_templates:
v = tf.gather(v, templates_select_indices)
crop_sizes = []
crop_starts = []
for i, (dim_size, dim) in enumerate(zip(shape_schema[k],
shape_helpers.shape_list(v))):
is_num_res = (dim_size == NUM_RES)
if i == 0 and k.startswith('template'):
crop_size = num_templates_crop_size
crop_start = templates_crop_start
else:
crop_start = num_res_crop_start if is_num_res else 0
crop_size = (num_res_crop_size if is_num_res else
(-1 if dim is None else dim))
crop_sizes.append(crop_size)
crop_starts.append(crop_start)
protein[k] = tf.slice(v, crop_starts, crop_sizes)
protein['seq_length'] = num_res_crop_size
return protein
def make_atom14_masks(protein):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0)
for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
])
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = tf.gather(restype_atom14_to_atom37,
protein['aatype'])
residx_atom14_mask = tf.gather(restype_atom14_mask,
protein['aatype'])
protein['atom14_atom_exists'] = residx_atom14_mask
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37
# create the gather indices for mapping back
residx_atom37_to_atom14 = tf.gather(restype_atom37_to_atom14,
protein['aatype'])
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = tf.gather(restype_atom37_mask,
protein['aatype'])
protein['atom37_atom_exists'] = residx_atom37_mask
return protein
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature pre-processing input pipeline for AlphaFold."""
from alphafold.model.tf import data_transforms
from alphafold.model.tf import shape_placeholders
import tensorflow.compat.v1 as tf
import tree
# Pylint gets confused by the curry1 decorator because it changes the number
# of arguments to the function.
# pylint:disable=no-value-for-parameter
NUM_RES = shape_placeholders.NUM_RES
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def nonensembled_map_fns(data_config):
"""Input pipeline functions which are not ensembled."""
common_cfg = data_config.common
map_fns = [
data_transforms.correct_msa_restypes,
data_transforms.add_distillation_flag(False),
data_transforms.cast_64bit_ints,
data_transforms.squeeze_features,
# Keep to not disrupt RNG.
data_transforms.randomly_replace_msa_with_unknown(0.0),
data_transforms.make_seq_mask,
data_transforms.make_msa_mask,
# Compute the HHblits profile if it's not set. This has to be run before
# sampling the MSA.
data_transforms.make_hhblits_profile,
data_transforms.make_random_crop_to_size_seed,
]
if common_cfg.use_templates:
map_fns.extend([
data_transforms.fix_templates_aatype,
data_transforms.make_template_mask,
data_transforms.make_pseudo_beta('template_')
])
map_fns.extend([
data_transforms.make_atom14_masks,
])
return map_fns
def ensembled_map_fns(data_config):
"""Input pipeline functions that can be ensembled and averaged."""
common_cfg = data_config.common
eval_cfg = data_config.eval
map_fns = []
if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates
else:
pad_msa_clusters = eval_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa
map_fns.append(
data_transforms.sample_msa(
max_msa_clusters,
keep_extra=True))
if 'masked_msa' in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
map_fns.append(
data_transforms.make_masked_msa(common_cfg.masked_msa,
eval_cfg.masked_msa_replace_fraction))
if common_cfg.msa_cluster_features:
map_fns.append(data_transforms.nearest_neighbor_clusters())
map_fns.append(data_transforms.summarize_clusters())
# Crop after creating the cluster profiles.
if max_extra_msa:
map_fns.append(data_transforms.crop_extra_msa(max_extra_msa))
else:
map_fns.append(data_transforms.delete_extra_msa)
map_fns.append(data_transforms.make_msa_feat())
crop_feats = dict(eval_cfg.feat)
if eval_cfg.fixed_size:
map_fns.append(data_transforms.select_feat(list(crop_feats)))
map_fns.append(data_transforms.random_crop_to_size(
eval_cfg.crop_size,
eval_cfg.max_templates,
crop_feats,
eval_cfg.subsample_templates))
map_fns.append(data_transforms.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
eval_cfg.crop_size,
eval_cfg.max_templates))
else:
map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates))
return map_fns
def process_tensors_from_config(tensors, data_config):
"""Apply filters and maps to an existing dataset, based on the config."""
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_map_fns(data_config)
fn = compose(fns)
d['ensemble_index'] = i
return fn(d)
eval_cfg = data_config.eval
tensors = compose(
nonensembled_map_fns(
data_config))(
tensors)
tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0))
num_ensemble = eval_cfg.num_ensemble
if data_config.common.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step.
num_ensemble *= data_config.common.num_recycle + 1
if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1:
fn_output_signature = tree.map_structure(
tf.TensorSpec.from_tensor, tensors_0)
tensors = tf.map_fn(
lambda x: wrap_ensemble_fn(tensors, x),
tf.range(num_ensemble),
parallel_iterations=1,
fn_output_signature=fn_output_signature)
else:
tensors = tree.map_structure(lambda x: x[None],
tensors_0)
return tensors
@data_transforms.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains descriptions of various protein features."""
import enum
from typing import Dict, Optional, Sequence, Tuple, Union
from alphafold.common import residue_constants
import tensorflow.compat.v1 as tf
# Type aliases.
FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]]
class FeatureType(enum.Enum):
ZERO_DIM = 0 # Shape [x]
ONE_DIM = 1 # Shape [num_res, x]
TWO_DIM = 2 # Shape [num_res, num_res, x]
MSA = 3 # Shape [msa_length, num_res, x]
# Placeholder values that will be replaced with their true value at runtime.
NUM_RES = "num residues placeholder"
NUM_SEQ = "length msa placeholder"
NUM_TEMPLATES = "num templates placeholder"
# Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders
# to be replaced with the number of residues and the number of sequences in the
# multiple sequence alignment, respectively.
FEATURES = {
#### Static features of a protein sequence ####
"aatype": (tf.float32, [NUM_RES, 21]),
"between_segment_residues": (tf.int64, [NUM_RES, 1]),
"deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]),
"domain_name": (tf.string, [1]),
"msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]),
"num_alignments": (tf.int64, [NUM_RES, 1]),
"residue_index": (tf.int64, [NUM_RES, 1]),
"seq_length": (tf.int64, [NUM_RES, 1]),
"sequence": (tf.string, [1]),
"all_atom_positions": (tf.float32,
[NUM_RES, residue_constants.atom_type_num, 3]),
"all_atom_mask": (tf.int64, [NUM_RES, residue_constants.atom_type_num]),
"resolution": (tf.float32, [1]),
"template_domain_names": (tf.string, [NUM_TEMPLATES]),
"template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]),
"template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]),
"template_all_atom_positions": (tf.float32, [
NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3
]),
"template_all_atom_masks": (tf.float32, [
NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1
]),
}
FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()}
FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()}
def register_feature(name: str,
type_: tf.dtypes.DType,
shape_: Tuple[Union[str, int]]):
"""Register extra features used in custom datasets."""
FEATURES[name] = (type_, shape_)
FEATURE_TYPES[name] = type_
FEATURE_SIZES[name] = shape_
def shape(feature_name: str,
num_residues: int,
msa_length: int,
num_templates: Optional[int] = None,
features: Optional[FeaturesMetadata] = None):
"""Get the shape for the given feature name.
This is near identical to _get_tf_shape_no_placeholders() but with 2
differences:
* This method does not calculate a single placeholder from the total number of
elements (eg given <NUM_RES, 3> and size := 12, this won't deduce NUM_RES
must be 4)
* This method will work with tensors
Args:
feature_name: String identifier for the feature. If the feature name ends
with "_unnormalized", this suffix is stripped off.
num_residues: The number of residues in the current domain - some elements
of the shape can be dynamic and will be replaced by this value.
msa_length: The number of sequences in the multiple sequence alignment, some
elements of the shape can be dynamic and will be replaced by this value.
If the number of alignments is unknown / not read, please pass None for
msa_length.
num_templates (optional): The number of templates in this tfexample.
features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES.
Returns:
List of ints representation the tensor size.
Raises:
ValueError: If a feature is requested but no concrete placeholder value is
given.
"""
features = features or FEATURES
if feature_name.endswith("_unnormalized"):
feature_name = feature_name[:-13]
unused_dtype, raw_sizes = features[feature_name]
replacements = {NUM_RES: num_residues,
NUM_SEQ: msa_length}
if num_templates is not None:
replacements[NUM_TEMPLATES] = num_templates
sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes]
for dimension in sizes:
if isinstance(dimension, str):
raise ValueError("Could not parse %s (shape: %s) with values: %s" % (
feature_name, raw_sizes, replacements))
return sizes
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for protein_features."""
import uuid
from absl.testing import absltest
from absl.testing import parameterized
from alphafold.model.tf import protein_features
import tensorflow.compat.v1 as tf
def _random_bytes():
return str(uuid.uuid4()).encode('utf-8')
class FeaturesTest(parameterized.TestCase, tf.test.TestCase):
def setUp(self):
super().setUp()
tf.disable_v2_behavior()
def testFeatureNames(self):
self.assertEqual(len(protein_features.FEATURE_SIZES),
len(protein_features.FEATURE_TYPES))
sorted_size_names = sorted(protein_features.FEATURE_SIZES.keys())
sorted_type_names = sorted(protein_features.FEATURE_TYPES.keys())
for i, size_name in enumerate(sorted_size_names):
self.assertEqual(size_name, sorted_type_names[i])
def testReplacement(self):
for name in protein_features.FEATURE_SIZES.keys():
sizes = protein_features.shape(name,
num_residues=12,
msa_length=24,
num_templates=3)
for x in sizes:
self.assertEqual(type(x), int)
self.assertGreater(x, 0)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Datasets consisting of proteins."""
from typing import Dict, Mapping, Optional, Sequence
from alphafold.model.tf import protein_features
import numpy as np
import tensorflow.compat.v1 as tf
TensorDict = Dict[str, tf.Tensor]
def parse_tfexample(
raw_data: bytes,
features: protein_features.FeaturesMetadata,
key: Optional[str] = None) -> Dict[str, tf.train.Feature]:
"""Read a single TF Example proto and return a subset of its features.
Args:
raw_data: A serialized tf.Example proto.
features: A dictionary of features, mapping string feature names to a tuple
(dtype, shape). This dictionary should be a subset of
protein_features.FEATURES (or the dictionary itself for all features).
key: Optional string with the SSTable key of that tf.Example. This will be
added into features as a 'key' but only if requested in features.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
feature_map = {
k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True)
for k, v in features.items()
}
parsed_features = tf.io.parse_single_example(raw_data, feature_map)
reshaped_features = parse_reshape_logic(parsed_features, features, key=key)
return reshaped_features
def _first(tensor: tf.Tensor) -> tf.Tensor:
"""Returns the 1st element - the input can be a tensor or a scalar."""
return tf.reshape(tensor, shape=(-1,))[0]
def parse_reshape_logic(
parsed_features: TensorDict,
features: protein_features.FeaturesMetadata,
key: Optional[str] = None) -> TensorDict:
"""Transforms parsed serial features to the correct shape."""
# Find out what is the number of sequences and the number of alignments.
num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32)
if "num_alignments" in parsed_features:
num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32)
else:
num_msa = 0
if "template_domain_names" in parsed_features:
num_templates = tf.cast(
tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32)
else:
num_templates = 0
if key is not None and "key" in features:
parsed_features["key"] = [key] # Expand dims from () to (1,).
# Reshape the tensors according to the sequence length and num alignments.
for k, v in parsed_features.items():
new_shape = protein_features.shape(
feature_name=k,
num_residues=num_residues,
msa_length=num_msa,
num_templates=num_templates,
features=features)
new_shape_size = tf.constant(1, dtype=tf.int32)
for dim in new_shape:
new_shape_size *= tf.cast(dim, tf.int32)
assert_equal = tf.assert_equal(
tf.size(v), new_shape_size,
name="assert_%s_shape_correct" % k,
message="The size of feature %s (%s) could not be reshaped "
"into %s" % (k, tf.size(v), new_shape))
if "template" not in k:
# Make sure the feature we are reshaping is not empty.
assert_non_empty = tf.assert_greater(
tf.size(v), 0, name="assert_%s_non_empty" % k,
message="The feature %s is not set in the tf.Example. Either do not "
"request the feature or use a tf.Example that has the "
"feature set." % k)
with tf.control_dependencies([assert_non_empty, assert_equal]):
parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
else:
with tf.control_dependencies([assert_equal]):
parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
return parsed_features
def _make_features_metadata(
feature_names: Sequence[str]) -> protein_features.FeaturesMetadata:
"""Makes a feature name to type and shape mapping from a list of names."""
# Make sure these features are always read.
required_features = ["aatype", "sequence", "seq_length"]
feature_names = list(set(feature_names) | set(required_features))
features_metadata = {name: protein_features.FEATURES[name]
for name in feature_names}
return features_metadata
def create_tensor_dict(
raw_data: bytes,
features: Sequence[str],
key: Optional[str] = None,
) -> TensorDict:
"""Creates a dictionary of tensor features.
Args:
raw_data: A serialized tf.Example proto.
features: A list of strings of feature names to be returned in the dataset.
key: Optional string with the SSTable key of that tf.Example. This will be
added into features as a 'key' but only if requested in features.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
features_metadata = _make_features_metadata(features)
return parse_tfexample(raw_data, features_metadata, key)
def np_to_tensor_dict(
np_example: Mapping[str, np.ndarray],
features: Sequence[str],
) -> TensorDict:
"""Creates dict of tensors from a dict of NumPy arrays.
Args:
np_example: A dict of NumPy feature arrays.
features: A list of strings of feature names to be returned in the dataset.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
features_metadata = _make_features_metadata(features)
tensor_dict = {k: tf.constant(v) for k, v in np_example.items()
if k in features_metadata}
# Ensures shapes are as expected. Needed for setting size of empty features
# e.g. when no template hits were found.
tensor_dict = parse_reshape_logic(tensor_dict, features_metadata)
return tensor_dict
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with shapes of TensorFlow tensors."""
import tensorflow.compat.v1 as tf
def shape_list(x):
"""Return list of dimensions of a tensor, statically where possible.
Like `x.shape.as_list()` but with tensors instead of `None`s.
Args:
x: A tensor.
Returns:
A list with length equal to the rank of the tensor. The n-th element of the
list is an integer when that dimension is statically known otherwise it is
the n-th element of `tf.shape(x)`.
"""
x = tf.convert_to_tensor(x)
# If unknown rank, return dynamic shape
if x.get_shape().dims is None:
return tf.shape(x)
static = x.get_shape().as_list()
shape = tf.shape(x)
ret = []
for i in range(len(static)):
dim = static[i]
if dim is None:
dim = shape[i]
ret.append(dim)
return ret
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for shape_helpers."""
from alphafold.model.tf import shape_helpers
import numpy as np
import tensorflow.compat.v1 as tf
class ShapeTest(tf.test.TestCase):
def setUp(self):
super().setUp()
tf.disable_v2_behavior()
def test_shape_list(self):
"""Test that shape_list can allow for reshaping to dynamic shapes."""
a = tf.zeros([10, 4, 4, 2])
p = tf.placeholder(tf.float32, shape=[None, None, 1, 4, 4])
shape_dyn = shape_helpers.shape_list(p)[:2] + [4, 4]
b = tf.reshape(a, shape_dyn)
with self.session() as sess:
out = sess.run(b, feed_dict={p: np.ones((20, 1, 1, 4, 4))})
self.assertAllEqual(out.shape, (20, 1, 4, 4))
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Placeholder values for run-time varying dimension sizes."""
NUM_RES = 'num residues placeholder'
NUM_MSA_SEQ = 'msa placeholder'
NUM_EXTRA_SEQ = 'extra msa placeholder'
NUM_TEMPLATES = 'num templates placeholder'
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for various components."""
import tensorflow.compat.v1 as tf
def tf_combine_mask(*masks):
"""Take the intersection of float-valued masks."""
ret = 1
for m in masks:
ret *= m
return ret
class SeedMaker(object):
"""Return unique seeds."""
def __init__(self, initial_seed=0):
self.next_seed = initial_seed
def __call__(self):
i = self.next_seed
self.next_seed += 1
return i
seed_maker = SeedMaker()
def make_random_seed():
return tf.random.uniform([2],
tf.int32.min,
tf.int32.max,
tf.int32,
seed=seed_maker())
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of JAX utility functions for use in protein folding."""
import collections
import contextlib
import functools
import numbers
from typing import Mapping
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
def stable_softmax(logits: jax.Array) -> jax.Array:
"""Numerically stable softmax for (potential) bfloat 16."""
if logits.dtype == jnp.float32:
output = jax.nn.softmax(logits)
elif logits.dtype == jnp.bfloat16:
# Need to explicitly do softmax in float32 to avoid numerical issues
# with large negatives. Large negatives can occur if trying to mask
# by adding on large negative logits so that things softmax to zero.
output = jax.nn.softmax(logits.astype(jnp.float32)).astype(jnp.bfloat16)
else:
raise ValueError(f'Unexpected input dtype {logits.dtype}')
return output
def bfloat16_creator(next_creator, shape, dtype, init, context):
"""Creates float32 variables when bfloat16 is requested."""
if context.original_dtype == jnp.bfloat16:
dtype = jnp.float32
return next_creator(shape, dtype, init)
def bfloat16_getter(next_getter, value, context):
"""Casts float32 to bfloat16 when bfloat16 was originally requested."""
if context.original_dtype == jnp.bfloat16:
assert value.dtype == jnp.float32
value = value.astype(jnp.bfloat16)
return next_getter(value)
@contextlib.contextmanager
def bfloat16_context():
with hk.custom_creator(bfloat16_creator), hk.custom_getter(bfloat16_getter):
yield
def final_init(config):
if config.zero_init:
return 'zeros'
else:
return 'linear'
def batched_gather(params, indices, axis=0, batch_dims=0):
"""Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`."""
take_fn = lambda p, i: jnp.take(p, i, axis=axis, mode='clip')
for _ in range(batch_dims):
take_fn = jax.vmap(take_fn)
return take_fn(params, indices)
def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
"""Masked mean."""
if drop_mask_channel:
mask = mask[..., 0]
mask_shape = mask.shape
value_shape = value.shape
assert len(mask_shape) == len(value_shape)
if isinstance(axis, numbers.Integral):
axis = [axis]
elif axis is None:
axis = list(range(len(mask_shape)))
assert isinstance(axis, collections.abc.Iterable), (
'axis needs to be either an iterable, integer or "None"')
broadcast_factor = 1.
for axis_ in axis:
value_size = value_shape[axis_]
mask_size = mask_shape[axis_]
if mask_size == 1:
broadcast_factor *= value_size
else:
assert mask_size == value_size
return (jnp.sum(mask * value, axis=axis) /
(jnp.sum(mask, axis=axis) * broadcast_factor + eps))
def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params:
"""Convert a dictionary of NumPy arrays to Haiku parameters."""
hk_params = {}
for path, array in params.items():
scope, name = path.split('//')
if scope not in hk_params:
hk_params[scope] = {}
hk_params[scope][name] = jnp.array(array)
return hk_params
def padding_consistent_rng(f):
"""Modify any element-wise random function to be consistent with padding.
Normally if you take a function like jax.random.normal and generate an array,
say of size (10,10), you will get a different set of random numbers to if you
add padding and take the first (10,10) sub-array.
This function makes a random function that is consistent regardless of the
amount of padding added.
Note: The padding-consistent function is likely to be slower to compile and
run than the function it is wrapping, but these slowdowns are likely to be
negligible in a large network.
Args:
f: Any element-wise function that takes (PRNG key, shape) as the first 2
arguments.
Returns:
An equivalent function to f, that is now consistent for different amounts of
padding.
"""
def grid_keys(key, shape):
"""Generate a grid of rng keys that is consistent with different padding.
Generate random keys such that the keys will be identical, regardless of
how much padding is added to any dimension.
Args:
key: A PRNG key.
shape: The shape of the output array of keys that will be generated.
Returns:
An array of shape `shape` consisting of random keys.
"""
if not shape:
return key
new_keys = jax.vmap(functools.partial(jax.random.fold_in, key))(
jnp.arange(shape[0]))
return jax.vmap(functools.partial(grid_keys, shape=shape[1:]))(new_keys)
def inner(key, shape, **kwargs):
keys = grid_keys(key, shape)
signature = (
'()->()'
if jax.dtypes.issubdtype(keys.dtype, jax.dtypes.prng_key)
else '(2)->()'
)
return jnp.vectorize(
functools.partial(f, shape=(), **kwargs), signature=signature
)(keys)
return inner
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AlphaFold Colab notebook."""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment