Commit 9b18d6a9 authored by Augustin Zidek's avatar Augustin Zidek
Browse files

Release code for v2.3.0

PiperOrigin-RevId: 494507694
parent 4494af84
This diff is collapsed.
......@@ -304,9 +304,9 @@ fractionPlddtVeryHigh | `FLOAT64` | Fraction of the residues in the predi
fractionPlddtVeryLow | `FLOAT64` | Fraction of the residues in the prediction with pLDDT less than 50
gene | `STRING` | The name of the gene if known, e.g. "COII"
geneSynonyms | `ARRAY<STRING>` | Additional synonyms for the gene
globalMetricValue | `FLOAT64` | The mean pLDDT of this prediction
isReferenceProteome | `BOOL` | Is this protein part of the reference proteome?
isReviewed | `BOOL` | Has this protein been reviewed, i.e. is it part of SwissProt?
globalMetricValue | `FLOAT64` | The mean pLDDT of this prediction
latestVersion | `INT64` | The latest AFDB version for this prediction
modelCreatedDate | `DATE` | The date of creation for this entry, e.g. "2022-06-01"
organismCommonNames | `ARRAY<STRING>` | List of common organism names
......
......@@ -117,7 +117,7 @@ class DataPipeline:
uniref90_database_path: str,
mgnify_database_path: str,
bfd_database_path: Optional[str],
uniclust30_database_path: Optional[str],
uniref30_database_path: Optional[str],
small_bfd_database_path: Optional[str],
template_searcher: TemplateSearcher,
template_featurizer: templates.TemplateHitFeaturizer,
......@@ -135,9 +135,9 @@ class DataPipeline:
binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path)
else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
self.hhblits_bfd_uniref_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path])
databases=[bfd_database_path, uniref30_database_path])
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path)
......@@ -211,14 +211,14 @@ class DataPipeline:
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
else:
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
hhblits_bfd_uniclust_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniclust_runner,
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
hhblits_bfd_uniref_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniref_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
......
......@@ -128,3 +128,64 @@ class Linear(hk.Module):
return output
class LayerNorm(hk.LayerNorm):
"""LayerNorm module.
Equivalent to hk.LayerNorm but with different parameter shapes: they are
always vectors rather than possibly higher-rank tensors. This makes it easier
to change the layout whilst keep the model weight-compatible.
"""
def __init__(self,
axis,
create_scale: bool,
create_offset: bool,
eps: float = 1e-5,
scale_init=None,
offset_init=None,
use_fast_variance: bool = False,
name=None,
param_axis=None):
super().__init__(
axis=axis,
create_scale=False,
create_offset=False,
eps=eps,
scale_init=None,
offset_init=None,
use_fast_variance=use_fast_variance,
name=name,
param_axis=param_axis)
self._temp_create_scale = create_scale
self._temp_create_offset = create_offset
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
is_bf16 = (x.dtype == jnp.bfloat16)
if is_bf16:
x = x.astype(jnp.float32)
param_axis = self.param_axis[0] if self.param_axis else -1
param_shape = (x.shape[param_axis],)
param_broadcast_shape = [1] * x.ndim
param_broadcast_shape[param_axis] = x.shape[param_axis]
scale = None
offset = None
if self._temp_create_scale:
scale = hk.get_parameter(
'scale', param_shape, x.dtype, init=self.scale_init)
scale = scale.reshape(param_broadcast_shape)
if self._temp_create_offset:
offset = hk.get_parameter(
'offset', param_shape, x.dtype, init=self.offset_init)
offset = offset.reshape(param_broadcast_shape)
out = super().__call__(x, scale=scale, offset=offset)
if is_bf16:
out = out.astype(jnp.bfloat16)
return out
\ No newline at end of file
......@@ -26,12 +26,12 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model."""
if 'multimer' in name:
return CONFIG_MULTIMER
if name not in CONFIG_DIFFS:
raise ValueError(f'Invalid model name {name}.')
cfg = copy.deepcopy(CONFIG)
if 'multimer' in name:
cfg = copy.deepcopy(CONFIG_MULTIMER)
else:
cfg = copy.deepcopy(CONFIG)
cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
return cfg
......@@ -52,11 +52,11 @@ MODEL_PRESETS = {
'model_5_ptm',
),
'multimer': (
'model_1_multimer_v2',
'model_2_multimer_v2',
'model_3_multimer_v2',
'model_4_multimer_v2',
'model_5_multimer_v2',
'model_1_multimer_v3',
'model_2_multimer_v3',
'model_3_multimer_v3',
'model_4_multimer_v3',
'model_5_multimer_v3',
),
}
MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer']
......@@ -118,8 +118,32 @@ CONFIG_DIFFS = {
},
'model_5_ptm': {
'model.heads.predicted_aligned_error.weight': 0.1
}
},
'model_1_multimer_v3': {},
'model_2_multimer_v3': {},
'model_3_multimer_v3': {},
'model_4_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
'model_5_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
}
# Key differences between multimer v1/v2 and v3, mostly due to numerical
# optimisations in the TriangleMultiplication module.
common_updates = {
'model.embeddings_and_evoformer.num_msa': 252,
'model.embeddings_and_evoformer.num_extra_msa': 1152,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.fuse_projection_weights': False,
}
CONFIG_DIFFS.update(
{f'model_{i}_multimer': common_updates for i in range(1, 6)})
CONFIG_DIFFS.update(
{f'model_{i}_multimer_v2': common_updates for i in range(1, 6)})
CONFIG = ml_collections.ConfigDict({
'data': {
......@@ -260,14 +284,16 @@ CONFIG = ml_collections.ConfigDict({
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'pair_transition': {
'dropout_rate': 0.0,
......@@ -328,14 +354,16 @@ CONFIG = ml_collections.ConfigDict({
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'pair_transition': {
'dropout_rate': 0.0,
......@@ -354,7 +382,7 @@ CONFIG = ml_collections.ConfigDict({
'multimer_mode': False,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
'zero_init': True,
},
'heads': {
'distogram': {
......@@ -483,27 +511,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
}
},
'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4,
'num_msa': 252,
'num_extra_msa': 1152,
'num_msa': 508,
'num_extra_msa': 2048,
'masked_msa': {
'profile_prob': 0.1,
'replace_fraction': 0.15,
......@@ -564,24 +594,28 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
}
}
},
},
'global_config': {
'bfloat16': True,
'bfloat16_output': False,
'deterministic': False,
'multimer_mode': True,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
'zero_init': True,
},
'heads': {
'distogram': {
......@@ -651,7 +685,13 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
}
},
'num_ensemble_eval': 1,
'num_recycle': 3,
'num_recycle': 20,
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `num_recycle` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
'recycle_early_stop_tolerance': 0.5,
'resample_msa_in_recycling': True
}
})
......@@ -331,7 +331,7 @@ class FoldIteration(hk.Module):
safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -353,7 +353,7 @@ class FoldIteration(hk.Module):
act = jax.nn.relu(act)
act += input_act
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config,
c = config
sequence_mask = batch['seq_mask'][:, None]
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config,
'affine': affine.to_tensor(),
}
act_2d = hk.LayerNorm(
act_2d = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......
......@@ -427,7 +427,7 @@ class FoldIteration(hk.Module):
safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
......@@ -448,7 +448,7 @@ class FoldIteration(hk.Module):
act = jax.nn.relu(act)
act += input_act
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
......@@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
"""
c = config
sequence_mask = batch['seq_mask'][:, None]
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')(
representations['single'])
......@@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
rigid
}
act_2d = hk.LayerNorm(
act_2d = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
......
......@@ -133,7 +133,7 @@ def flatten(instance):
inner_treedefs = []
num_arrays = []
for array_like in array_likes:
flat_array_like, inner_treedef = jax.tree_flatten(array_like)
flat_array_like, inner_treedef = jax.tree_util.tree_flatten(array_like)
inner_treedefs.append(inner_treedef)
flat_array_likes += flat_array_like
num_arrays.append(len(flat_array_like))
......@@ -206,7 +206,7 @@ class StructOfArray:
for num_array, inner_treedef, array_field in zip(num_arrays,
inner_treedefs,
array_fields):
value_dict[array_field] = jax.tree_unflatten(
value_dict[array_field] = jax.tree_util.tree_unflatten(
inner_treedef, data[array_start:array_start + num_array])
array_start += num_array
metadata_fields = get_metadata_fields(new_cls)
......
......@@ -47,11 +47,11 @@ def _maybe_get_size(array, axis):
def _expand_axes(axes, values, name='sharded_apply'):
values_tree_def = jax.tree_flatten(values)[1]
values_tree_def = jax.tree_util.tree_flatten(values)[1]
flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
# Replace None's with PROXY
flat_axes = [PROXY if x is None else x for x in flat_axes]
return jax.tree_unflatten(values_tree_def, flat_axes)
return jax.tree_util.tree_unflatten(values_tree_def, flat_axes)
def sharded_map(
......@@ -126,7 +126,7 @@ def sharded_apply(
in_axes_ = _expand_axes(in_axes, args)
in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_)
flat_sizes = jax.tree_flatten(in_sizes)[0]
flat_sizes = jax.tree_util.tree_flatten(in_sizes)[0]
in_size = max(flat_sizes)
assert all(i in {in_size, -1} for i in flat_sizes)
......
......@@ -501,7 +501,7 @@ class Transition(hk.Module):
num_intermediate = int(nc * self.config.num_intermediate_factor)
mask = jnp.expand_dims(mask, axis=-1)
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -569,12 +569,15 @@ class Attention(hk.Module):
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=glorot_uniform())
q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
......@@ -595,10 +598,12 @@ class Attention(hk.Module):
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
......@@ -610,9 +615,12 @@ class Attention(hk.Module):
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init)
o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),
init=hk.initializers.Constant(0.0))
o_bias = hk.get_parameter(
'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
......@@ -658,12 +666,15 @@ class GlobalAttention(hk.Module):
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], value_dim),
dtype=q_data.dtype,
init=glorot_uniform())
v = jnp.einsum('bka,ac->bkc', m_data, v_weights)
......@@ -684,18 +695,23 @@ class GlobalAttention(hk.Module):
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init)
o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),
init=hk.initializers.Constant(0.0))
o_bias = hk.get_parameter(
'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
if self.config.gating:
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights)
......@@ -745,11 +761,11 @@ class MSARowAttentionWithPairBias(hk.Module):
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
pair_act = hk.LayerNorm(
pair_act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -760,6 +776,7 @@ class MSARowAttentionWithPairBias(hk.Module):
weights = hk.get_parameter(
'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head),
dtype=msa_act.dtype,
init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
......@@ -812,7 +829,7 @@ class MSAColumnAttention(hk.Module):
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
......@@ -867,7 +884,7 @@ class MSAColumnGlobalAttention(hk.Module):
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
......@@ -924,7 +941,7 @@ class TriangleAttention(hk.Module):
bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
pair_act = hk.LayerNorm(
pair_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
pair_act)
......@@ -932,6 +949,7 @@ class TriangleAttention(hk.Module):
weights = hk.get_parameter(
'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head),
dtype=pair_act.dtype,
init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
......@@ -1029,7 +1047,7 @@ class PredictedLDDTHead(hk.Module):
"""
act = representations['structure_module']
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -1251,6 +1269,19 @@ class ExperimentallyResolvedHead(hk.Module):
return output
def _layer_norm(axis=-1, name='layer_norm'):
return common_modules.LayerNorm(
axis=axis,
create_scale=True,
create_offset=True,
eps=1e-5,
use_fast_variance=True,
scale_init=hk.initializers.Constant(1.),
offset_init=hk.initializers.Constant(0.),
param_axis=axis,
name=name)
class TriangleMultiplication(hk.Module):
"""Triangle multiplication layer ("outgoing" or "incoming").
......@@ -1263,25 +1294,34 @@ class TriangleMultiplication(hk.Module):
self.config = config
self.global_config = global_config
def __call__(self, act, mask, is_training=True):
def __call__(self, left_act, left_mask, is_training=True):
"""Builds TriangleMultiplication module.
Arguments:
act: Pair activations, shape [N_res, N_res, c_z]
mask: Pair mask, shape [N_res, N_res].
left_act: Pair activations, shape [N_res, N_res, c_z]
left_mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
Returns:
Outputs, same shape/type as act.
Outputs, same shape/type as left_act.
"""
del is_training
if self.config.fuse_projection_weights:
return self._fused_triangle_multiplication(left_act, left_mask)
else:
return self._triangle_multiplication(left_act, left_mask)
@hk.transparent
def _triangle_multiplication(self, left_act, left_mask):
"""Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""
c = self.config
gc = self.global_config
mask = mask[..., None]
mask = left_mask[..., None]
act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(act)
act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(left_act)
input_act = act
left_projection = common_modules.Linear(
......@@ -1317,7 +1357,7 @@ class TriangleMultiplication(hk.Module):
# b = left_proj_act and a = right_proj_act
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -1340,6 +1380,50 @@ class TriangleMultiplication(hk.Module):
return act
@hk.transparent
def _fused_triangle_multiplication(self, left_act, left_mask):
"""TriangleMultiplication with fused projection weights."""
mask = left_mask[..., None]
c = self.config
gc = self.global_config
left_act = _layer_norm(axis=-1, name='left_norm_input')(left_act)
# Both left and right projections are fused into projection.
projection = common_modules.Linear(
2*c.num_intermediate_channel, name='projection')
proj_act = mask * projection(left_act)
# Both left + right gate are fused into gate_values.
gate_values = common_modules.Linear(
2 * c.num_intermediate_channel,
name='gate',
bias_init=1.,
initializer=utils.final_init(gc))(left_act)
proj_act *= jax.nn.sigmoid(gate_values)
left_proj_act = proj_act[:, :, :c.num_intermediate_channel]
right_proj_act = proj_act[:, :, c.num_intermediate_channel:]
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = _layer_norm(axis=-1, name='center_norm')(act)
output_channel = int(left_act.shape[-1])
act = common_modules.Linear(
output_channel,
initializer=utils.final_init(gc),
name='output_projection')(act)
gate_values = common_modules.Linear(
output_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='gating_linear')(left_act)
act *= jax.nn.sigmoid(gate_values)
return act
class DistogramHead(hk.Module):
"""Head to predict a distogram.
......@@ -1446,7 +1530,7 @@ class OuterProductMean(hk.Module):
c = self.config
mask = mask[..., None]
act = hk.LayerNorm([-1], True, True, name='layer_norm_input')(act)
act = common_modules.LayerNorm([-1], True, True, name='layer_norm_input')(act)
left_act = mask * common_modules.Linear(
c.num_outer_channel,
......@@ -1469,9 +1553,11 @@ class OuterProductMean(hk.Module):
'output_w',
shape=(c.num_outer_channel, c.num_outer_channel,
self.num_output_channel),
dtype=act.dtype,
init=init_w)
output_b = hk.get_parameter(
'output_b', shape=(self.num_output_channel,),
dtype=act.dtype,
init=hk.initializers.Constant(0.0))
def compute_chunk(left_act):
......@@ -1738,7 +1824,7 @@ class EmbeddingsAndEvoformer(hk.Module):
dgram)
if c.recycle_features:
prev_msa_first_row = hk.LayerNorm(
prev_msa_first_row = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -1746,7 +1832,7 @@ class EmbeddingsAndEvoformer(hk.Module):
batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += hk.LayerNorm(
pair_activations += common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -2020,7 +2106,7 @@ class SingleTemplateEmbedding(hk.Module):
self.config.template_pair_stack, self.global_config)(
act, mask_2d, is_training)
act = hk.LayerNorm([-1], True, True, name='output_layer_norm')(act)
act = common_modules.LayerNorm([-1], True, True, name='output_layer_norm')(act)
return act
......
......@@ -475,20 +475,51 @@ class AlphaFold(hk.Module):
# Eval mode or tests: use the maximum number of iterations.
num_iter = c.num_recycle
def recycle_body(i, x):
del i
prev, safe_key = x
def distances(points):
"""Compute all pairwise distances for a set of points."""
return jnp.sqrt(jnp.sum((points[:, None] - points[None, :])**2,
axis=-1))
def recycle_body(x):
i, _, prev, safe_key = x
safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long
ret = apply_network(prev=prev, safe_key=safe_key2)
return get_prev(ret), safe_key1
prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key))
return i+1, prev, get_prev(ret), safe_key1
def recycle_cond(x):
i, prev, next_in, _ = x
ca_idx = residue_constants.atom_order['CA']
sq_diff = jnp.square(distances(prev['prev_pos'][:, ca_idx, :]) -
distances(next_in['prev_pos'][:, ca_idx, :]))
mask = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
sq_diff = utils.mask_mean(mask, sq_diff)
# Early stopping criteria based on criteria used in
# AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
diff = jnp.sqrt(sq_diff + 1e-8) # avoid bad numerics giving negatives
less_than_max_recycles = (i < num_iter)
has_exceeded_tolerance = (
(i == 0) | (diff > c.recycle_early_stop_tolerance))
return less_than_max_recycles & has_exceeded_tolerance
if hk.running_init():
num_recycles, _, prev, safe_key = recycle_body(
(0, prev, prev, safe_key))
else:
num_recycles, _, prev, safe_key = hk.while_loop(
recycle_cond,
recycle_body,
(0, prev, prev, safe_key))
else:
# No recycling.
num_recycles = 0
# Run extra iteration.
ret = apply_network(prev=prev, safe_key=safe_key)
if not return_representations:
del ret['representations']
ret['num_recycles'] = num_recycles
return ret
......@@ -524,11 +555,13 @@ class EmbeddingsAndEvoformer(hk.Module):
Feature embedding using the features as described before.
"""
c = self.config
gc = self.global_config
rel_feats = []
pos = batch['residue_index']
asym_id = batch['asym_id']
asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :])
offset = pos[:, None] - pos[None, :]
dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32
clipped_offset = jnp.clip(
offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx)
......@@ -568,6 +601,7 @@ class EmbeddingsAndEvoformer(hk.Module):
rel_feat = jnp.concatenate(rel_feats, axis=-1)
rel_feat = rel_feat.astype(dtype)
return common_modules.Linear(
c.pair_channel,
name='position_activations')(
......@@ -579,6 +613,7 @@ class EmbeddingsAndEvoformer(hk.Module):
gc = self.global_config
batch = dict(batch)
dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32
if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key())
......@@ -587,177 +622,178 @@ class EmbeddingsAndEvoformer(hk.Module):
batch['msa_profile'] = make_msa_profile(batch)
target_feat = jax.nn.one_hot(batch['aatype'], 21)
preprocess_1d = common_modules.Linear(
c.msa_channel, name='preprocess_1d')(
target_feat)
safe_key, sample_key, mask_key = safe_key.split(3)
batch = sample_msa(sample_key, batch, c.num_msa)
batch = make_masked_msa(batch, mask_key, c.masked_msa)
(batch['cluster_profile'],
batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch)
msa_feat = create_msa_feat(batch)
preprocess_msa = common_modules.Linear(
c.msa_channel, name='preprocess_msa')(
msa_feat)
msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
left_single = common_modules.Linear(
c.pair_channel, name='left_single')(
target_feat)
right_single = common_modules.Linear(
c.pair_channel, name='right_single')(
target_feat)
pair_activations = left_single[:, None] + right_single[None]
mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
mask_2d = mask_2d.astype(jnp.float32)
if c.recycle_pos:
prev_pseudo_beta = modules.pseudo_beta_fn(
batch['aatype'], batch['prev_pos'], None)
dgram = modules.dgram_from_positions(
prev_pseudo_beta, **self.config.prev_pos)
pair_activations += common_modules.Linear(
c.pair_channel, name='prev_pos_linear')(
dgram)
if c.recycle_features:
prev_msa_first_row = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_msa_first_row_norm')(
batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_pair_norm')(
batch['prev_pair'])
if c.max_relative_idx:
pair_activations += self._relative_encoding(batch)
if c.template.enabled:
template_module = TemplateEmbedding(c.template, gc)
template_batch = {
'template_aatype': batch['template_aatype'],
'template_all_atom_positions': batch['template_all_atom_positions'],
'template_all_atom_mask': batch['template_all_atom_mask']
with utils.bfloat16_context():
target_feat = jax.nn.one_hot(batch['aatype'], 21).astype(dtype)
preprocess_1d = common_modules.Linear(
c.msa_channel, name='preprocess_1d')(
target_feat)
safe_key, sample_key, mask_key = safe_key.split(3)
batch = sample_msa(sample_key, batch, c.num_msa)
batch = make_masked_msa(batch, mask_key, c.masked_msa)
(batch['cluster_profile'],
batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch)
msa_feat = create_msa_feat(batch).astype(dtype)
preprocess_msa = common_modules.Linear(
c.msa_channel, name='preprocess_msa')(
msa_feat)
msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
left_single = common_modules.Linear(
c.pair_channel, name='left_single')(
target_feat)
right_single = common_modules.Linear(
c.pair_channel, name='right_single')(
target_feat)
pair_activations = left_single[:, None] + right_single[None]
mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
mask_2d = mask_2d.astype(dtype)
if c.recycle_pos:
prev_pseudo_beta = modules.pseudo_beta_fn(
batch['aatype'], batch['prev_pos'], None)
dgram = modules.dgram_from_positions(
prev_pseudo_beta, **self.config.prev_pos)
dgram = dgram.astype(dtype)
pair_activations += common_modules.Linear(
c.pair_channel, name='prev_pos_linear')(
dgram)
if c.recycle_features:
prev_msa_first_row = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_msa_first_row_norm')(
batch['prev_msa_first_row']).astype(dtype)
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_pair_norm')(
batch['prev_pair']).astype(dtype)
if c.max_relative_idx:
pair_activations += self._relative_encoding(batch)
if c.template.enabled:
template_module = TemplateEmbedding(c.template, gc)
template_batch = {
'template_aatype': batch['template_aatype'],
'template_all_atom_positions': batch['template_all_atom_positions'],
'template_all_atom_mask': batch['template_all_atom_mask']
}
# Construct a mask such that only intra-chain template features are
# computed, since all templates are for each chain individually.
multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :]
safe_key, safe_subkey = safe_key.split()
template_act = template_module(
query_embedding=pair_activations,
template_batch=template_batch,
padding_mask_2d=mask_2d,
multichain_mask_2d=multichain_mask,
is_training=is_training,
safe_key=safe_subkey)
pair_activations += template_act
# Extra MSA stack.
(extra_msa_feat,
extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa)
extra_msa_activations = common_modules.Linear(
c.extra_msa_channel,
name='extra_msa_activations')(
extra_msa_feat).astype(dtype)
extra_msa_mask = extra_msa_mask.astype(dtype)
extra_evoformer_input = {
'msa': extra_msa_activations,
'pair': pair_activations,
}
# Construct a mask such that only intra-chain template features are
# computed, since all templates are for each chain individually.
multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :]
safe_key, safe_subkey = safe_key.split()
template_act = template_module(
query_embedding=pair_activations,
template_batch=template_batch,
padding_mask_2d=mask_2d,
multichain_mask_2d=multichain_mask,
is_training=is_training,
safe_key=safe_subkey)
pair_activations += template_act
# Extra MSA stack.
(extra_msa_feat,
extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa)
extra_msa_activations = common_modules.Linear(
c.extra_msa_channel,
name='extra_msa_activations')(
extra_msa_feat)
extra_msa_mask = extra_msa_mask.astype(jnp.float32)
extra_evoformer_input = {
'msa': extra_msa_activations,
'pair': pair_activations,
}
extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d}
extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d}
extra_evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
extra_evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
def extra_evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
extra_evoformer_output = extra_evoformer_iteration(
activations=act,
masks=extra_masks,
is_training=is_training,
safe_key=safe_subkey)
return (extra_evoformer_output, safe_key)
if gc.use_remat:
extra_evoformer_fn = hk.remat(extra_evoformer_fn)
safe_key, safe_subkey = safe_key.split()
extra_evoformer_stack = layer_stack.layer_stack(
c.extra_msa_stack_num_block)(
extra_evoformer_fn)
extra_evoformer_output, safe_key = extra_evoformer_stack(
(extra_evoformer_input, safe_subkey))
pair_activations = extra_evoformer_output['pair']
# Get the size of the MSA before potentially adding templates, so we
# can crop out the templates later.
num_msa_sequences = msa_activations.shape[0]
evoformer_input = {
'msa': msa_activations,
'pair': pair_activations,
}
evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32),
'pair': mask_2d}
if c.template.enabled:
template_features, template_masks = (
template_embedding_1d(batch=batch, num_channel=c.msa_channel))
evoformer_input['msa'] = jnp.concatenate(
[evoformer_input['msa'], template_features], axis=0)
evoformer_masks['msa'] = jnp.concatenate(
[evoformer_masks['msa'], template_masks], axis=0)
def extra_evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
extra_evoformer_output = extra_evoformer_iteration(
activations=act,
masks=extra_masks,
is_training=is_training,
safe_key=safe_subkey)
return (extra_evoformer_output, safe_key)
evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
if gc.use_remat:
extra_evoformer_fn = hk.remat(extra_evoformer_fn)
def evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
evoformer_output = evoformer_iteration(
activations=act,
masks=evoformer_masks,
is_training=is_training,
safe_key=safe_subkey)
return (evoformer_output, safe_key)
if gc.use_remat:
evoformer_fn = hk.remat(evoformer_fn)
extra_evoformer_stack = layer_stack.layer_stack(
c.extra_msa_stack_num_block)(
extra_evoformer_fn)
extra_evoformer_output, safe_key = extra_evoformer_stack(
(extra_evoformer_input, safe_subkey))
pair_activations = extra_evoformer_output['pair']
# Get the size of the MSA before potentially adding templates, so we
# can crop out the templates later.
num_msa_sequences = msa_activations.shape[0]
evoformer_input = {
'msa': msa_activations,
'pair': pair_activations,
}
evoformer_masks = {
'msa': batch['msa_mask'].astype(dtype),
'pair': mask_2d
}
if c.template.enabled:
template_features, template_masks = (
template_embedding_1d(
batch=batch, num_channel=c.msa_channel, global_config=gc))
evoformer_input['msa'] = jnp.concatenate(
[evoformer_input['msa'], template_features], axis=0)
evoformer_masks['msa'] = jnp.concatenate(
[evoformer_masks['msa'], template_masks], axis=0)
evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
def evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
evoformer_output = evoformer_iteration(
activations=act,
masks=evoformer_masks,
is_training=is_training,
safe_key=safe_subkey)
return (evoformer_output, safe_key)
if gc.use_remat:
evoformer_fn = hk.remat(evoformer_fn)
safe_key, safe_subkey = safe_key.split()
evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
evoformer_fn)
safe_key, safe_subkey = safe_key.split()
evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
evoformer_fn)
def run_evoformer(evoformer_input):
evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey))
return evoformer_output
def run_evoformer(evoformer_input):
evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey))
return evoformer_output
evoformer_output = run_evoformer(evoformer_input)
evoformer_output = run_evoformer(evoformer_input)
msa_activations = evoformer_output['msa']
pair_activations = evoformer_output['pair']
msa_activations = evoformer_output['msa']
pair_activations = evoformer_output['pair']
single_activations = common_modules.Linear(
c.seq_channel, name='single_activations')(
msa_activations[0])
single_activations = common_modules.Linear(
c.seq_channel, name='single_activations')(
msa_activations[0])
output.update({
'single':
......@@ -771,6 +807,12 @@ class EmbeddingsAndEvoformer(hk.Module):
msa_activations[0],
})
# Convert back to float32 if we're not saving memory.
if not gc.bfloat16_output:
for k, v in output.items():
if v.dtype == jnp.bfloat16:
output[k] = v.astype(jnp.float32)
return output
......@@ -917,6 +959,9 @@ class SingleTemplateEmbedding(hk.Module):
# backbone affine - i.e. in each residues local frame, what direction are
# each of the other residues.
raw_atom_pos = template_all_atom_positions
if gc.bfloat16:
# Vec3Arrays are required to be float32
raw_atom_pos = raw_atom_pos.astype(jnp.float32)
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = folding_multimer.make_backbone_affine(
......@@ -928,6 +973,10 @@ class SingleTemplateEmbedding(hk.Module):
unit_vector = rigid_vec.normalized()
unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z]
if gc.bfloat16:
unit_vector = [x.astype(jnp.bfloat16) for x in unit_vector]
backbone_mask = backbone_mask.astype(jnp.bfloat16)
backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :]
backbone_mask_2d *= multichain_mask_2d
unit_vector = [x*backbone_mask_2d for x in unit_vector]
......@@ -937,7 +986,7 @@ class SingleTemplateEmbedding(hk.Module):
to_concat.extend([(x, 0) for x in unit_vector])
to_concat.append((backbone_mask_2d, 0))
query_embedding = hk.LayerNorm(
query_embedding = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
......@@ -986,12 +1035,13 @@ class SingleTemplateEmbedding(hk.Module):
template_iteration_fn)
act, safe_key = template_stack((act, safe_subkey))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='output_layer_norm')(
act)
return act
......@@ -1044,21 +1094,18 @@ class TemplateEmbeddingIteration(hk.Module):
act,
pair_mask,
safe_key=next(sub_keys))
act = dropout_wrapper_fn(
modules.TriangleAttention(c.triangle_attention_starting_node, gc,
name='triangle_attention_starting_node'),
act,
pair_mask,
safe_key=next(sub_keys))
act = dropout_wrapper_fn(
modules.TriangleAttention(c.triangle_attention_ending_node, gc,
name='triangle_attention_ending_node'),
act,
pair_mask,
safe_key=next(sub_keys))
act = dropout_wrapper_fn(
modules.Transition(c.pair_transition, gc,
name='pair_transition'),
......@@ -1069,7 +1116,7 @@ class TemplateEmbeddingIteration(hk.Module):
return act
def template_embedding_1d(batch, num_channel):
def template_embedding_1d(batch, num_channel, global_config):
"""Embed templates into an (num_res, num_templates, num_channels) embedding.
Args:
......@@ -1080,6 +1127,7 @@ def template_embedding_1d(batch, num_channel):
template_all_atom_mask, (num_templates, num_residues, 37) atom mask for
each template.
num_channel: The number of channels in the output.
global_config: The global_config.
Returns:
An embedding of shape (num_templates, num_res, num_channels) and a mask of
......@@ -1112,6 +1160,10 @@ def template_embedding_1d(batch, num_channel):
template_mask = chi_mask[:, :, 0]
if global_config.bfloat16:
template_features = template_features.astype(jnp.bfloat16)
template_mask = template_mask.astype(jnp.bfloat16)
template_activations = common_modules.Linear(
num_channel,
initializer='relu',
......
......@@ -15,6 +15,7 @@
"""A collection of JAX utility functions for use in protein folding."""
import collections
import contextlib
import functools
import numbers
from typing import Mapping
......@@ -25,6 +26,27 @@ import jax.numpy as jnp
import numpy as np
def bfloat16_creator(next_creator, shape, dtype, init, context):
"""Creates float32 variables when bfloat16 is requested."""
if context.original_dtype == jnp.bfloat16:
dtype = jnp.float32
return next_creator(shape, dtype, init)
def bfloat16_getter(next_getter, value, context):
"""Casts float32 to bfloat16 when bfloat16 was originally requested."""
if context.original_dtype == jnp.bfloat16:
assert value.dtype == jnp.float32
value = value.astype(jnp.bfloat16)
return next_getter(value)
@contextlib.contextmanager
def bfloat16_context():
with hk.custom_creator(bfloat16_creator), hk.custom_getter(bfloat16_getter):
yield
def final_init(config):
if config.zero_init:
return 'zeros'
......
......@@ -13,7 +13,6 @@
# limitations under the License.
"""Helper methods for the AlphaFold Colab notebook."""
import enum
import json
from typing import Any, Mapping, Optional, Sequence, Tuple
......@@ -23,13 +22,7 @@ from matplotlib import pyplot as plt
import numpy as np
@enum.unique
class ModelType(enum.Enum):
MONOMER = 0
MULTIMER = 1
def clean_and_validate_sequence(
def clean_and_validate_single_sequence(
input_sequence: str, min_length: int, max_length: int) -> str:
"""Checks that the input sequence is ok and returns a clean version of it."""
# Remove all whitespaces, tabs and end lines; upper-case.
......@@ -54,41 +47,23 @@ def clean_and_validate_sequence(
return clean_sequence
def validate_input(
def clean_and_validate_input_sequences(
input_sequences: Sequence[str],
min_length: int,
max_length: int,
max_multimer_length: int) -> Tuple[Sequence[str], ModelType]:
"""Validates and cleans input sequences and determines which model to use."""
min_sequence_length: int,
max_sequence_length: int) -> Sequence[str]:
"""Validates and cleans input sequences."""
sequences = []
for input_sequence in input_sequences:
if input_sequence.strip():
input_sequence = clean_and_validate_sequence(
input_sequence = clean_and_validate_single_sequence(
input_sequence=input_sequence,
min_length=min_length,
max_length=max_length)
min_length=min_sequence_length,
max_length=max_sequence_length)
sequences.append(input_sequence)
if len(sequences) == 1:
print('Using the single-chain model.')
return sequences, ModelType.MONOMER
elif len(sequences) > 1:
total_multimer_length = sum([len(seq) for seq in sequences])
if total_multimer_length > max_multimer_length:
raise ValueError(f'The total length of multimer sequences is too long: '
f'{total_multimer_length}, while the maximum is '
f'{max_multimer_length}. Please use the full AlphaFold '
f'system for long multimers.')
elif total_multimer_length > 1536:
print('WARNING: The accuracy of the system has not been fully validated '
'above 1536 residues, and you may experience long running times or '
f'run out of memory for your complex with {total_multimer_length} '
'residues.')
print(f'Using the multimer model with {len(sequences)} sequences.')
return sequences, ModelType.MULTIMER
if sequences:
return sequences
else:
raise ValueError('No input amino acid sequence provided, please provide at '
'least one sequence.')
......
......@@ -85,7 +85,7 @@ class NotebookUtilsTest(parameterized.TestCase):
('DeepMind', 'DEEPMIND'), ('A ', 'A'), ('\tA', 'A'), (' A\t\n', 'A'),
('ACDEFGHIKLMNPQRSTVWY', 'ACDEFGHIKLMNPQRSTVWY'))
def test_clean_and_validate_sequence_ok(self, sequence, exp_clean):
clean = notebook_utils.clean_and_validate_sequence(
clean = notebook_utils.clean_and_validate_single_sequence(
sequence, min_length=1, max_length=100)
self.assertEqual(clean, exp_clean)
......@@ -100,35 +100,29 @@ class NotebookUtilsTest(parameterized.TestCase):
('bad_amino_acids_Z', 'ZZZZ', 'non-amino acid'))
def test_clean_and_validate_sequence_bad(self, sequence, exp_error):
with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'):
notebook_utils.clean_and_validate_sequence(
notebook_utils.clean_and_validate_single_sequence(
sequence, min_length=4, max_length=8)
@parameterized.parameters(
(['A', '', '', ' ', '\t', ' \t\n', '', ''], ['A'],
notebook_utils.ModelType.MONOMER),
(['', 'A'], ['A'],
notebook_utils.ModelType.MONOMER),
(['A', 'C ', ''], ['A', 'C'],
notebook_utils.ModelType.MULTIMER),
(['', 'A', '', 'C '], ['A', 'C'],
notebook_utils.ModelType.MULTIMER))
def test_validate_input_ok(
self, input_sequences, exp_sequences, exp_model_type):
sequences, model_type = notebook_utils.validate_input(
(['A', '', '', ' ', '\t', ' \t\n', '', ''], ['A']),
(['', 'A'], ['A']),
(['A', 'C ', ''], ['A', 'C']),
(['', 'A', '', 'C '], ['A', 'C']))
def test_validate_input_ok(self, input_sequences, exp_sequences):
sequences = notebook_utils.clean_and_validate_input_sequences(
input_sequences=input_sequences,
min_length=1, max_length=100, max_multimer_length=100)
min_sequence_length=1, max_sequence_length=100)
self.assertSequenceEqual(sequences, exp_sequences)
self.assertEqual(model_type, exp_model_type)
@parameterized.named_parameters(
('no_input_sequence', ['', '\t', '\n'], 'No input amino acid sequence'),
('too_long_single', ['AAAAAAAAA', 'AAAA'], 'Input sequence is too long'),
('too_long_multimer', ['AAAA', 'AAAAA'], 'The total length of multimer'))
('too_short_single', ['AAA', 'AAAA'], 'Input sequence is too short'))
def test_validate_input_bad(self, input_sequences, exp_error):
with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'):
notebook_utils.validate_input(
input_sequences=input_sequences,
min_length=4, max_length=8, max_multimer_length=6)
notebook_utils.clean_and_validate_input_sequences(
input_sequences=input_sequences, min_sequence_length=4,
max_sequence_length=8)
def test_merge_chunked_msa_no_hits(self):
results = [ONLY_QUERY_HIT, ONLY_QUERY_HIT]
......
......@@ -56,7 +56,8 @@ class AmberRelaxation(object):
self._use_gpu = use_gpu
def process(self, *,
prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]:
prot: protein.Protein
) -> Tuple[str, Dict[str, Any], Sequence[float]]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline(
prot=prot, max_iterations=self._max_iterations,
......@@ -73,12 +74,11 @@ class AmberRelaxation(object):
'attempts': out['min_attempts'],
'rmsd': rmsd
}
pdb_str = amber_minimize.clean_protein(prot)
min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
min_pdb = out['min_pdb']
min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
utils.assert_equal_nonterminal_atom_types(
protein.from_pdb_string(min_pdb).atom_mask,
prot.atom_mask)
violations = out['structural_violations'][
'total_per_residue_violations_mask']
'total_per_residue_violations_mask'].tolist()
return min_pdb, debug_data, violations
......@@ -82,7 +82,7 @@ class RunAmberRelaxTest(absltest.TestCase):
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0])
# Check no violations were added. Can't check exactly due to stochasticity.
self.assertTrue(np.all(num_violations <= exp_num_violations))
self.assertTrue(np.all(np.array(num_violations) <= exp_num_violations))
if __name__ == '__main__':
......
......@@ -17,17 +17,6 @@ import io
from alphafold.common import residue_constants
from Bio import PDB
import numpy as np
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
pdb_file = io.StringIO(pdb_str)
structure = PdbStructure(pdb_file)
topology = openmm_app.PDBFile(structure).getTopology()
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, pos, f)
return f.getvalue()
def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
......
......@@ -21,7 +21,7 @@ ARG CUDA
# Use bash to support string substitution.
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
RUN apt-get update \
RUN apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
build-essential \
cmake \
......@@ -59,7 +59,7 @@ RUN conda install -qy conda==4.13.0 \
cudatoolkit==${CUDA_VERSION} \
pdbfixer \
pip \
python=3.7 \
python=3.8 \
&& conda clean --all --force-pkgs-dirs --yes
COPY . /app/alphafold
......@@ -75,7 +75,7 @@ RUN pip3 install --upgrade pip --no-cache-dir \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Apply OpenMM patch.
WORKDIR /opt/conda/lib/python3.7/site-packages
WORKDIR /opt/conda/lib/python3.8/site-packages
RUN patch -p0 < /app/alphafold/docker/openmm.patch
# Add SETUID bit to the ldconfig binary so that non-root users can run it.
......
......@@ -133,7 +133,7 @@ def main(argv):
# Path to the MGnify database for use by JackHMMER.
mgnify_database_path = os.path.join(
FLAGS.data_dir, 'mgnify', 'mgy_clusters_2018_12.fa')
FLAGS.data_dir, 'mgnify', 'mgy_clusters_2022_05.fa')
# Path to the BFD database for use by HHblits.
bfd_database_path = os.path.join(
......@@ -144,9 +144,9 @@ def main(argv):
small_bfd_database_path = os.path.join(
FLAGS.data_dir, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta')
# Path to the Uniclust30 database for use by HHblits.
uniclust30_database_path = os.path.join(
FLAGS.data_dir, 'uniclust30', 'uniclust30_2018_08', 'uniclust30_2018_08')
# Path to the Uniref30 database for use by HHblits.
uniref30_database_path = os.path.join(
FLAGS.data_dir, 'uniref30', 'UniRef30_2021_03')
# Path to the PDB70 database for use by HHsearch.
pdb70_database_path = os.path.join(FLAGS.data_dir, 'pdb70', 'pdb70')
......@@ -199,7 +199,7 @@ def main(argv):
database_paths.append(('small_bfd_database_path', small_bfd_database_path))
else:
database_paths.extend([
('uniclust30_database_path', uniclust30_database_path),
('uniref30_database_path', uniref30_database_path),
('bfd_database_path', bfd_database_path),
])
for name, path in database_paths:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment