Commit 9b18d6a9 authored by Augustin Zidek's avatar Augustin Zidek
Browse files

Release code for v2.3.0

PiperOrigin-RevId: 494507694
parent 4494af84
This diff is collapsed.
...@@ -304,9 +304,9 @@ fractionPlddtVeryHigh | `FLOAT64` | Fraction of the residues in the predi ...@@ -304,9 +304,9 @@ fractionPlddtVeryHigh | `FLOAT64` | Fraction of the residues in the predi
fractionPlddtVeryLow | `FLOAT64` | Fraction of the residues in the prediction with pLDDT less than 50 fractionPlddtVeryLow | `FLOAT64` | Fraction of the residues in the prediction with pLDDT less than 50
gene | `STRING` | The name of the gene if known, e.g. "COII" gene | `STRING` | The name of the gene if known, e.g. "COII"
geneSynonyms | `ARRAY<STRING>` | Additional synonyms for the gene geneSynonyms | `ARRAY<STRING>` | Additional synonyms for the gene
globalMetricValue | `FLOAT64` | The mean pLDDT of this prediction
isReferenceProteome | `BOOL` | Is this protein part of the reference proteome? isReferenceProteome | `BOOL` | Is this protein part of the reference proteome?
isReviewed | `BOOL` | Has this protein been reviewed, i.e. is it part of SwissProt? isReviewed | `BOOL` | Has this protein been reviewed, i.e. is it part of SwissProt?
globalMetricValue | `FLOAT64` | The mean pLDDT of this prediction
latestVersion | `INT64` | The latest AFDB version for this prediction latestVersion | `INT64` | The latest AFDB version for this prediction
modelCreatedDate | `DATE` | The date of creation for this entry, e.g. "2022-06-01" modelCreatedDate | `DATE` | The date of creation for this entry, e.g. "2022-06-01"
organismCommonNames | `ARRAY<STRING>` | List of common organism names organismCommonNames | `ARRAY<STRING>` | List of common organism names
......
...@@ -117,7 +117,7 @@ class DataPipeline: ...@@ -117,7 +117,7 @@ class DataPipeline:
uniref90_database_path: str, uniref90_database_path: str,
mgnify_database_path: str, mgnify_database_path: str,
bfd_database_path: Optional[str], bfd_database_path: Optional[str],
uniclust30_database_path: Optional[str], uniref30_database_path: Optional[str],
small_bfd_database_path: Optional[str], small_bfd_database_path: Optional[str],
template_searcher: TemplateSearcher, template_searcher: TemplateSearcher,
template_featurizer: templates.TemplateHitFeaturizer, template_featurizer: templates.TemplateHitFeaturizer,
...@@ -135,9 +135,9 @@ class DataPipeline: ...@@ -135,9 +135,9 @@ class DataPipeline:
binary_path=jackhmmer_binary_path, binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path) database_path=small_bfd_database_path)
else: else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( self.hhblits_bfd_uniref_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path, binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path]) databases=[bfd_database_path, uniref30_database_path])
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path, binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path) database_path=mgnify_database_path)
...@@ -211,14 +211,14 @@ class DataPipeline: ...@@ -211,14 +211,14 @@ class DataPipeline:
use_precomputed_msas=self.use_precomputed_msas) use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
else: else:
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
hhblits_bfd_uniclust_result = run_msa_tool( hhblits_bfd_uniref_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniclust_runner, msa_runner=self.hhblits_bfd_uniref_runner,
input_fasta_path=input_fasta_path, input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path, msa_out_path=bfd_out_path,
msa_format='a3m', msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas) use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])
templates_result = self.template_featurizer.get_templates( templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=input_sequence,
......
...@@ -128,3 +128,64 @@ class Linear(hk.Module): ...@@ -128,3 +128,64 @@ class Linear(hk.Module):
return output return output
class LayerNorm(hk.LayerNorm):
"""LayerNorm module.
Equivalent to hk.LayerNorm but with different parameter shapes: they are
always vectors rather than possibly higher-rank tensors. This makes it easier
to change the layout whilst keep the model weight-compatible.
"""
def __init__(self,
axis,
create_scale: bool,
create_offset: bool,
eps: float = 1e-5,
scale_init=None,
offset_init=None,
use_fast_variance: bool = False,
name=None,
param_axis=None):
super().__init__(
axis=axis,
create_scale=False,
create_offset=False,
eps=eps,
scale_init=None,
offset_init=None,
use_fast_variance=use_fast_variance,
name=name,
param_axis=param_axis)
self._temp_create_scale = create_scale
self._temp_create_offset = create_offset
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
is_bf16 = (x.dtype == jnp.bfloat16)
if is_bf16:
x = x.astype(jnp.float32)
param_axis = self.param_axis[0] if self.param_axis else -1
param_shape = (x.shape[param_axis],)
param_broadcast_shape = [1] * x.ndim
param_broadcast_shape[param_axis] = x.shape[param_axis]
scale = None
offset = None
if self._temp_create_scale:
scale = hk.get_parameter(
'scale', param_shape, x.dtype, init=self.scale_init)
scale = scale.reshape(param_broadcast_shape)
if self._temp_create_offset:
offset = hk.get_parameter(
'offset', param_shape, x.dtype, init=self.offset_init)
offset = offset.reshape(param_broadcast_shape)
out = super().__call__(x, scale=scale, offset=offset)
if is_bf16:
out = out.astype(jnp.bfloat16)
return out
\ No newline at end of file
...@@ -26,12 +26,12 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES ...@@ -26,12 +26,12 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def model_config(name: str) -> ml_collections.ConfigDict: def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model.""" """Get the ConfigDict of a CASP14 model."""
if 'multimer' in name:
return CONFIG_MULTIMER
if name not in CONFIG_DIFFS: if name not in CONFIG_DIFFS:
raise ValueError(f'Invalid model name {name}.') raise ValueError(f'Invalid model name {name}.')
cfg = copy.deepcopy(CONFIG) if 'multimer' in name:
cfg = copy.deepcopy(CONFIG_MULTIMER)
else:
cfg = copy.deepcopy(CONFIG)
cfg.update_from_flattened_dict(CONFIG_DIFFS[name]) cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
return cfg return cfg
...@@ -52,11 +52,11 @@ MODEL_PRESETS = { ...@@ -52,11 +52,11 @@ MODEL_PRESETS = {
'model_5_ptm', 'model_5_ptm',
), ),
'multimer': ( 'multimer': (
'model_1_multimer_v2', 'model_1_multimer_v3',
'model_2_multimer_v2', 'model_2_multimer_v3',
'model_3_multimer_v2', 'model_3_multimer_v3',
'model_4_multimer_v2', 'model_4_multimer_v3',
'model_5_multimer_v2', 'model_5_multimer_v3',
), ),
} }
MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer'] MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer']
...@@ -118,8 +118,32 @@ CONFIG_DIFFS = { ...@@ -118,8 +118,32 @@ CONFIG_DIFFS = {
}, },
'model_5_ptm': { 'model_5_ptm': {
'model.heads.predicted_aligned_error.weight': 0.1 'model.heads.predicted_aligned_error.weight': 0.1
} },
'model_1_multimer_v3': {},
'model_2_multimer_v3': {},
'model_3_multimer_v3': {},
'model_4_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
'model_5_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
}
# Key differences between multimer v1/v2 and v3, mostly due to numerical
# optimisations in the TriangleMultiplication module.
common_updates = {
'model.embeddings_and_evoformer.num_msa': 252,
'model.embeddings_and_evoformer.num_extra_msa': 1152,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.fuse_projection_weights': False,
} }
CONFIG_DIFFS.update(
{f'model_{i}_multimer': common_updates for i in range(1, 6)})
CONFIG_DIFFS.update(
{f'model_{i}_multimer_v2': common_updates for i in range(1, 6)})
CONFIG = ml_collections.ConfigDict({ CONFIG = ml_collections.ConfigDict({
'data': { 'data': {
...@@ -260,14 +284,16 @@ CONFIG = ml_collections.ConfigDict({ ...@@ -260,14 +284,16 @@ CONFIG = ml_collections.ConfigDict({
'equation': 'ikc,jkc->ijc', 'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128, 'num_intermediate_channel': 128,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': False,
}, },
'triangle_multiplication_incoming': { 'triangle_multiplication_incoming': {
'dropout_rate': 0.25, 'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc', 'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128, 'num_intermediate_channel': 128,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': False,
}, },
'pair_transition': { 'pair_transition': {
'dropout_rate': 0.0, 'dropout_rate': 0.0,
...@@ -328,14 +354,16 @@ CONFIG = ml_collections.ConfigDict({ ...@@ -328,14 +354,16 @@ CONFIG = ml_collections.ConfigDict({
'equation': 'ikc,jkc->ijc', 'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64, 'num_intermediate_channel': 64,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': False,
}, },
'triangle_multiplication_incoming': { 'triangle_multiplication_incoming': {
'dropout_rate': 0.25, 'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc', 'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64, 'num_intermediate_channel': 64,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': False,
}, },
'pair_transition': { 'pair_transition': {
'dropout_rate': 0.0, 'dropout_rate': 0.0,
...@@ -354,7 +382,7 @@ CONFIG = ml_collections.ConfigDict({ ...@@ -354,7 +382,7 @@ CONFIG = ml_collections.ConfigDict({
'multimer_mode': False, 'multimer_mode': False,
'subbatch_size': 4, 'subbatch_size': 4,
'use_remat': False, 'use_remat': False,
'zero_init': True 'zero_init': True,
}, },
'heads': { 'heads': {
'distogram': { 'distogram': {
...@@ -483,27 +511,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({ ...@@ -483,27 +511,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'gating': True, 'gating': True,
'num_head': 4, 'num_head': 4,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
}, },
'triangle_multiplication_incoming': { 'triangle_multiplication_incoming': {
'dropout_rate': 0.25, 'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc', 'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128, 'num_intermediate_channel': 128,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': True,
}, },
'triangle_multiplication_outgoing': { 'triangle_multiplication_outgoing': {
'dropout_rate': 0.25, 'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc', 'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128, 'num_intermediate_channel': 128,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': True,
} }
}, },
'extra_msa_channel': 64, 'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4, 'extra_msa_stack_num_block': 4,
'num_msa': 252, 'num_msa': 508,
'num_extra_msa': 1152, 'num_extra_msa': 2048,
'masked_msa': { 'masked_msa': {
'profile_prob': 0.1, 'profile_prob': 0.1,
'replace_fraction': 0.15, 'replace_fraction': 0.15,
...@@ -564,24 +594,28 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({ ...@@ -564,24 +594,28 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'equation': 'kjc,kic->ijc', 'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64, 'num_intermediate_channel': 64,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': True,
}, },
'triangle_multiplication_outgoing': { 'triangle_multiplication_outgoing': {
'dropout_rate': 0.25, 'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc', 'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64, 'num_intermediate_channel': 64,
'orientation': 'per_row', 'orientation': 'per_row',
'shared_dropout': True 'shared_dropout': True,
'fuse_projection_weights': True,
} }
} }
}, },
}, },
'global_config': { 'global_config': {
'bfloat16': True,
'bfloat16_output': False,
'deterministic': False, 'deterministic': False,
'multimer_mode': True, 'multimer_mode': True,
'subbatch_size': 4, 'subbatch_size': 4,
'use_remat': False, 'use_remat': False,
'zero_init': True 'zero_init': True,
}, },
'heads': { 'heads': {
'distogram': { 'distogram': {
...@@ -651,7 +685,13 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({ ...@@ -651,7 +685,13 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
} }
}, },
'num_ensemble_eval': 1, 'num_ensemble_eval': 1,
'num_recycle': 3, 'num_recycle': 20,
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `num_recycle` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
'recycle_early_stop_tolerance': 0.5,
'resample_msa_in_recycling': True 'resample_msa_in_recycling': True
} }
}) })
...@@ -331,7 +331,7 @@ class FoldIteration(hk.Module): ...@@ -331,7 +331,7 @@ class FoldIteration(hk.Module):
safe_key, *sub_keys = safe_key.split(3) safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys) sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys)) act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -353,7 +353,7 @@ class FoldIteration(hk.Module): ...@@ -353,7 +353,7 @@ class FoldIteration(hk.Module):
act = jax.nn.relu(act) act = jax.nn.relu(act)
act += input_act act += input_act
act = safe_dropout_fn(act, next(sub_keys)) act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config, ...@@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config,
c = config c = config
sequence_mask = batch['seq_mask'][:, None] sequence_mask = batch['seq_mask'][:, None]
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config, ...@@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config,
'affine': affine.to_tensor(), 'affine': affine.to_tensor(),
} }
act_2d = hk.LayerNorm( act_2d = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
......
...@@ -427,7 +427,7 @@ class FoldIteration(hk.Module): ...@@ -427,7 +427,7 @@ class FoldIteration(hk.Module):
safe_key, *sub_keys = safe_key.split(3) safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys) sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys)) act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=-1, axis=-1,
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -448,7 +448,7 @@ class FoldIteration(hk.Module): ...@@ -448,7 +448,7 @@ class FoldIteration(hk.Module):
act = jax.nn.relu(act) act = jax.nn.relu(act)
act += input_act act += input_act
act = safe_dropout_fn(act, next(sub_keys)) act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=-1, axis=-1,
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray], ...@@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
""" """
c = config c = config
sequence_mask = batch['seq_mask'][:, None] sequence_mask = batch['seq_mask'][:, None]
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')( axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')(
representations['single']) representations['single'])
...@@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray], ...@@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
rigid rigid
} }
act_2d = hk.LayerNorm( act_2d = common_modules.LayerNorm(
axis=-1, axis=-1,
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
......
...@@ -133,7 +133,7 @@ def flatten(instance): ...@@ -133,7 +133,7 @@ def flatten(instance):
inner_treedefs = [] inner_treedefs = []
num_arrays = [] num_arrays = []
for array_like in array_likes: for array_like in array_likes:
flat_array_like, inner_treedef = jax.tree_flatten(array_like) flat_array_like, inner_treedef = jax.tree_util.tree_flatten(array_like)
inner_treedefs.append(inner_treedef) inner_treedefs.append(inner_treedef)
flat_array_likes += flat_array_like flat_array_likes += flat_array_like
num_arrays.append(len(flat_array_like)) num_arrays.append(len(flat_array_like))
...@@ -206,7 +206,7 @@ class StructOfArray: ...@@ -206,7 +206,7 @@ class StructOfArray:
for num_array, inner_treedef, array_field in zip(num_arrays, for num_array, inner_treedef, array_field in zip(num_arrays,
inner_treedefs, inner_treedefs,
array_fields): array_fields):
value_dict[array_field] = jax.tree_unflatten( value_dict[array_field] = jax.tree_util.tree_unflatten(
inner_treedef, data[array_start:array_start + num_array]) inner_treedef, data[array_start:array_start + num_array])
array_start += num_array array_start += num_array
metadata_fields = get_metadata_fields(new_cls) metadata_fields = get_metadata_fields(new_cls)
......
...@@ -47,11 +47,11 @@ def _maybe_get_size(array, axis): ...@@ -47,11 +47,11 @@ def _maybe_get_size(array, axis):
def _expand_axes(axes, values, name='sharded_apply'): def _expand_axes(axes, values, name='sharded_apply'):
values_tree_def = jax.tree_flatten(values)[1] values_tree_def = jax.tree_util.tree_flatten(values)[1]
flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes) flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
# Replace None's with PROXY # Replace None's with PROXY
flat_axes = [PROXY if x is None else x for x in flat_axes] flat_axes = [PROXY if x is None else x for x in flat_axes]
return jax.tree_unflatten(values_tree_def, flat_axes) return jax.tree_util.tree_unflatten(values_tree_def, flat_axes)
def sharded_map( def sharded_map(
...@@ -126,7 +126,7 @@ def sharded_apply( ...@@ -126,7 +126,7 @@ def sharded_apply(
in_axes_ = _expand_axes(in_axes, args) in_axes_ = _expand_axes(in_axes, args)
in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_) in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_)
flat_sizes = jax.tree_flatten(in_sizes)[0] flat_sizes = jax.tree_util.tree_flatten(in_sizes)[0]
in_size = max(flat_sizes) in_size = max(flat_sizes)
assert all(i in {in_size, -1} for i in flat_sizes) assert all(i in {in_size, -1} for i in flat_sizes)
......
...@@ -501,7 +501,7 @@ class Transition(hk.Module): ...@@ -501,7 +501,7 @@ class Transition(hk.Module):
num_intermediate = int(nc * self.config.num_intermediate_factor) num_intermediate = int(nc * self.config.num_intermediate_factor)
mask = jnp.expand_dims(mask, axis=-1) mask = jnp.expand_dims(mask, axis=-1)
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -569,12 +569,15 @@ class Attention(hk.Module): ...@@ -569,12 +569,15 @@ class Attention(hk.Module):
q_weights = hk.get_parameter( q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim), 'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform()) init=glorot_uniform())
k_weights = hk.get_parameter( k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], num_head, key_dim), 'key_w', shape=(m_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform()) init=glorot_uniform())
v_weights = hk.get_parameter( v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], num_head, value_dim), 'value_w', shape=(m_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=glorot_uniform()) init=glorot_uniform())
q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5) q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
...@@ -595,10 +598,12 @@ class Attention(hk.Module): ...@@ -595,10 +598,12 @@ class Attention(hk.Module):
gating_weights = hk.get_parameter( gating_weights = hk.get_parameter(
'gating_w', 'gating_w',
shape=(q_data.shape[-1], num_head, value_dim), shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0)) init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter( gating_bias = hk.get_parameter(
'gating_b', 'gating_b',
shape=(num_head, value_dim), shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0)) init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
...@@ -610,9 +615,12 @@ class Attention(hk.Module): ...@@ -610,9 +615,12 @@ class Attention(hk.Module):
o_weights = hk.get_parameter( o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim), 'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init) init=init)
o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), o_bias = hk.get_parameter(
init=hk.initializers.Constant(0.0)) 'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
...@@ -658,12 +666,15 @@ class GlobalAttention(hk.Module): ...@@ -658,12 +666,15 @@ class GlobalAttention(hk.Module):
q_weights = hk.get_parameter( q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim), 'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform()) init=glorot_uniform())
k_weights = hk.get_parameter( k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], key_dim), 'key_w', shape=(m_data.shape[-1], key_dim),
dtype=q_data.dtype,
init=glorot_uniform()) init=glorot_uniform())
v_weights = hk.get_parameter( v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], value_dim), 'value_w', shape=(m_data.shape[-1], value_dim),
dtype=q_data.dtype,
init=glorot_uniform()) init=glorot_uniform())
v = jnp.einsum('bka,ac->bkc', m_data, v_weights) v = jnp.einsum('bka,ac->bkc', m_data, v_weights)
...@@ -684,18 +695,23 @@ class GlobalAttention(hk.Module): ...@@ -684,18 +695,23 @@ class GlobalAttention(hk.Module):
o_weights = hk.get_parameter( o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim), 'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init) init=init)
o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), o_bias = hk.get_parameter(
init=hk.initializers.Constant(0.0)) 'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
if self.config.gating: if self.config.gating:
gating_weights = hk.get_parameter( gating_weights = hk.get_parameter(
'gating_w', 'gating_w',
shape=(q_data.shape[-1], num_head, value_dim), shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0)) init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter( gating_bias = hk.get_parameter(
'gating_b', 'gating_b',
shape=(num_head, value_dim), shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0)) init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights) gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights)
...@@ -745,11 +761,11 @@ class MSARowAttentionWithPairBias(hk.Module): ...@@ -745,11 +761,11 @@ class MSARowAttentionWithPairBias(hk.Module):
bias = (1e9 * (msa_mask - 1.))[:, None, None, :] bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4 assert len(bias.shape) == 4
msa_act = hk.LayerNorm( msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')( axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act) msa_act)
pair_act = hk.LayerNorm( pair_act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -760,6 +776,7 @@ class MSARowAttentionWithPairBias(hk.Module): ...@@ -760,6 +776,7 @@ class MSARowAttentionWithPairBias(hk.Module):
weights = hk.get_parameter( weights = hk.get_parameter(
'feat_2d_weights', 'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head), shape=(pair_act.shape[-1], c.num_head),
dtype=msa_act.dtype,
init=hk.initializers.RandomNormal(stddev=init_factor)) init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
...@@ -812,7 +829,7 @@ class MSAColumnAttention(hk.Module): ...@@ -812,7 +829,7 @@ class MSAColumnAttention(hk.Module):
bias = (1e9 * (msa_mask - 1.))[:, None, None, :] bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4 assert len(bias.shape) == 4
msa_act = hk.LayerNorm( msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')( axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act) msa_act)
...@@ -867,7 +884,7 @@ class MSAColumnGlobalAttention(hk.Module): ...@@ -867,7 +884,7 @@ class MSAColumnGlobalAttention(hk.Module):
bias = (1e9 * (msa_mask - 1.))[:, None, None, :] bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4 assert len(bias.shape) == 4
msa_act = hk.LayerNorm( msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')( axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act) msa_act)
...@@ -924,7 +941,7 @@ class TriangleAttention(hk.Module): ...@@ -924,7 +941,7 @@ class TriangleAttention(hk.Module):
bias = (1e9 * (pair_mask - 1.))[:, None, None, :] bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4 assert len(bias.shape) == 4
pair_act = hk.LayerNorm( pair_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')( axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
pair_act) pair_act)
...@@ -932,6 +949,7 @@ class TriangleAttention(hk.Module): ...@@ -932,6 +949,7 @@ class TriangleAttention(hk.Module):
weights = hk.get_parameter( weights = hk.get_parameter(
'feat_2d_weights', 'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head), shape=(pair_act.shape[-1], c.num_head),
dtype=pair_act.dtype,
init=hk.initializers.RandomNormal(stddev=init_factor)) init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
...@@ -1029,7 +1047,7 @@ class PredictedLDDTHead(hk.Module): ...@@ -1029,7 +1047,7 @@ class PredictedLDDTHead(hk.Module):
""" """
act = representations['structure_module'] act = representations['structure_module']
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -1251,6 +1269,19 @@ class ExperimentallyResolvedHead(hk.Module): ...@@ -1251,6 +1269,19 @@ class ExperimentallyResolvedHead(hk.Module):
return output return output
def _layer_norm(axis=-1, name='layer_norm'):
return common_modules.LayerNorm(
axis=axis,
create_scale=True,
create_offset=True,
eps=1e-5,
use_fast_variance=True,
scale_init=hk.initializers.Constant(1.),
offset_init=hk.initializers.Constant(0.),
param_axis=axis,
name=name)
class TriangleMultiplication(hk.Module): class TriangleMultiplication(hk.Module):
"""Triangle multiplication layer ("outgoing" or "incoming"). """Triangle multiplication layer ("outgoing" or "incoming").
...@@ -1263,25 +1294,34 @@ class TriangleMultiplication(hk.Module): ...@@ -1263,25 +1294,34 @@ class TriangleMultiplication(hk.Module):
self.config = config self.config = config
self.global_config = global_config self.global_config = global_config
def __call__(self, act, mask, is_training=True): def __call__(self, left_act, left_mask, is_training=True):
"""Builds TriangleMultiplication module. """Builds TriangleMultiplication module.
Arguments: Arguments:
act: Pair activations, shape [N_res, N_res, c_z] left_act: Pair activations, shape [N_res, N_res, c_z]
mask: Pair mask, shape [N_res, N_res]. left_mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode. is_training: Whether the module is in training mode.
Returns: Returns:
Outputs, same shape/type as act. Outputs, same shape/type as left_act.
""" """
del is_training del is_training
if self.config.fuse_projection_weights:
return self._fused_triangle_multiplication(left_act, left_mask)
else:
return self._triangle_multiplication(left_act, left_mask)
@hk.transparent
def _triangle_multiplication(self, left_act, left_mask):
"""Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""
c = self.config c = self.config
gc = self.global_config gc = self.global_config
mask = mask[..., None] mask = left_mask[..., None]
act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True, act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(act) name='layer_norm_input')(left_act)
input_act = act input_act = act
left_projection = common_modules.Linear( left_projection = common_modules.Linear(
...@@ -1317,7 +1357,7 @@ class TriangleMultiplication(hk.Module): ...@@ -1317,7 +1357,7 @@ class TriangleMultiplication(hk.Module):
# b = left_proj_act and a = right_proj_act # b = left_proj_act and a = right_proj_act
act = jnp.einsum(c.equation, left_proj_act, right_proj_act) act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -1340,6 +1380,50 @@ class TriangleMultiplication(hk.Module): ...@@ -1340,6 +1380,50 @@ class TriangleMultiplication(hk.Module):
return act return act
@hk.transparent
def _fused_triangle_multiplication(self, left_act, left_mask):
"""TriangleMultiplication with fused projection weights."""
mask = left_mask[..., None]
c = self.config
gc = self.global_config
left_act = _layer_norm(axis=-1, name='left_norm_input')(left_act)
# Both left and right projections are fused into projection.
projection = common_modules.Linear(
2*c.num_intermediate_channel, name='projection')
proj_act = mask * projection(left_act)
# Both left + right gate are fused into gate_values.
gate_values = common_modules.Linear(
2 * c.num_intermediate_channel,
name='gate',
bias_init=1.,
initializer=utils.final_init(gc))(left_act)
proj_act *= jax.nn.sigmoid(gate_values)
left_proj_act = proj_act[:, :, :c.num_intermediate_channel]
right_proj_act = proj_act[:, :, c.num_intermediate_channel:]
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = _layer_norm(axis=-1, name='center_norm')(act)
output_channel = int(left_act.shape[-1])
act = common_modules.Linear(
output_channel,
initializer=utils.final_init(gc),
name='output_projection')(act)
gate_values = common_modules.Linear(
output_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='gating_linear')(left_act)
act *= jax.nn.sigmoid(gate_values)
return act
class DistogramHead(hk.Module): class DistogramHead(hk.Module):
"""Head to predict a distogram. """Head to predict a distogram.
...@@ -1446,7 +1530,7 @@ class OuterProductMean(hk.Module): ...@@ -1446,7 +1530,7 @@ class OuterProductMean(hk.Module):
c = self.config c = self.config
mask = mask[..., None] mask = mask[..., None]
act = hk.LayerNorm([-1], True, True, name='layer_norm_input')(act) act = common_modules.LayerNorm([-1], True, True, name='layer_norm_input')(act)
left_act = mask * common_modules.Linear( left_act = mask * common_modules.Linear(
c.num_outer_channel, c.num_outer_channel,
...@@ -1469,9 +1553,11 @@ class OuterProductMean(hk.Module): ...@@ -1469,9 +1553,11 @@ class OuterProductMean(hk.Module):
'output_w', 'output_w',
shape=(c.num_outer_channel, c.num_outer_channel, shape=(c.num_outer_channel, c.num_outer_channel,
self.num_output_channel), self.num_output_channel),
dtype=act.dtype,
init=init_w) init=init_w)
output_b = hk.get_parameter( output_b = hk.get_parameter(
'output_b', shape=(self.num_output_channel,), 'output_b', shape=(self.num_output_channel,),
dtype=act.dtype,
init=hk.initializers.Constant(0.0)) init=hk.initializers.Constant(0.0))
def compute_chunk(left_act): def compute_chunk(left_act):
...@@ -1738,7 +1824,7 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -1738,7 +1824,7 @@ class EmbeddingsAndEvoformer(hk.Module):
dgram) dgram)
if c.recycle_features: if c.recycle_features:
prev_msa_first_row = hk.LayerNorm( prev_msa_first_row = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -1746,7 +1832,7 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -1746,7 +1832,7 @@ class EmbeddingsAndEvoformer(hk.Module):
batch['prev_msa_first_row']) batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row) msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += hk.LayerNorm( pair_activations += common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -2020,7 +2106,7 @@ class SingleTemplateEmbedding(hk.Module): ...@@ -2020,7 +2106,7 @@ class SingleTemplateEmbedding(hk.Module):
self.config.template_pair_stack, self.global_config)( self.config.template_pair_stack, self.global_config)(
act, mask_2d, is_training) act, mask_2d, is_training)
act = hk.LayerNorm([-1], True, True, name='output_layer_norm')(act) act = common_modules.LayerNorm([-1], True, True, name='output_layer_norm')(act)
return act return act
......
...@@ -475,20 +475,51 @@ class AlphaFold(hk.Module): ...@@ -475,20 +475,51 @@ class AlphaFold(hk.Module):
# Eval mode or tests: use the maximum number of iterations. # Eval mode or tests: use the maximum number of iterations.
num_iter = c.num_recycle num_iter = c.num_recycle
def recycle_body(i, x): def distances(points):
del i """Compute all pairwise distances for a set of points."""
prev, safe_key = x return jnp.sqrt(jnp.sum((points[:, None] - points[None, :])**2,
axis=-1))
def recycle_body(x):
i, _, prev, safe_key = x
safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long
ret = apply_network(prev=prev, safe_key=safe_key2) ret = apply_network(prev=prev, safe_key=safe_key2)
return get_prev(ret), safe_key1 return i+1, prev, get_prev(ret), safe_key1
prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key)) def recycle_cond(x):
i, prev, next_in, _ = x
ca_idx = residue_constants.atom_order['CA']
sq_diff = jnp.square(distances(prev['prev_pos'][:, ca_idx, :]) -
distances(next_in['prev_pos'][:, ca_idx, :]))
mask = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
sq_diff = utils.mask_mean(mask, sq_diff)
# Early stopping criteria based on criteria used in
# AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
diff = jnp.sqrt(sq_diff + 1e-8) # avoid bad numerics giving negatives
less_than_max_recycles = (i < num_iter)
has_exceeded_tolerance = (
(i == 0) | (diff > c.recycle_early_stop_tolerance))
return less_than_max_recycles & has_exceeded_tolerance
if hk.running_init():
num_recycles, _, prev, safe_key = recycle_body(
(0, prev, prev, safe_key))
else:
num_recycles, _, prev, safe_key = hk.while_loop(
recycle_cond,
recycle_body,
(0, prev, prev, safe_key))
else:
# No recycling.
num_recycles = 0
# Run extra iteration. # Run extra iteration.
ret = apply_network(prev=prev, safe_key=safe_key) ret = apply_network(prev=prev, safe_key=safe_key)
if not return_representations: if not return_representations:
del ret['representations'] del ret['representations']
ret['num_recycles'] = num_recycles
return ret return ret
...@@ -524,11 +555,13 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -524,11 +555,13 @@ class EmbeddingsAndEvoformer(hk.Module):
Feature embedding using the features as described before. Feature embedding using the features as described before.
""" """
c = self.config c = self.config
gc = self.global_config
rel_feats = [] rel_feats = []
pos = batch['residue_index'] pos = batch['residue_index']
asym_id = batch['asym_id'] asym_id = batch['asym_id']
asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :]) asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :])
offset = pos[:, None] - pos[None, :] offset = pos[:, None] - pos[None, :]
dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32
clipped_offset = jnp.clip( clipped_offset = jnp.clip(
offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx) offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx)
...@@ -568,6 +601,7 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -568,6 +601,7 @@ class EmbeddingsAndEvoformer(hk.Module):
rel_feat = jnp.concatenate(rel_feats, axis=-1) rel_feat = jnp.concatenate(rel_feats, axis=-1)
rel_feat = rel_feat.astype(dtype)
return common_modules.Linear( return common_modules.Linear(
c.pair_channel, c.pair_channel,
name='position_activations')( name='position_activations')(
...@@ -579,6 +613,7 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -579,6 +613,7 @@ class EmbeddingsAndEvoformer(hk.Module):
gc = self.global_config gc = self.global_config
batch = dict(batch) batch = dict(batch)
dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32
if safe_key is None: if safe_key is None:
safe_key = prng.SafeKey(hk.next_rng_key()) safe_key = prng.SafeKey(hk.next_rng_key())
...@@ -587,177 +622,178 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -587,177 +622,178 @@ class EmbeddingsAndEvoformer(hk.Module):
batch['msa_profile'] = make_msa_profile(batch) batch['msa_profile'] = make_msa_profile(batch)
target_feat = jax.nn.one_hot(batch['aatype'], 21) with utils.bfloat16_context():
target_feat = jax.nn.one_hot(batch['aatype'], 21).astype(dtype)
preprocess_1d = common_modules.Linear(
c.msa_channel, name='preprocess_1d')( preprocess_1d = common_modules.Linear(
target_feat) c.msa_channel, name='preprocess_1d')(
target_feat)
safe_key, sample_key, mask_key = safe_key.split(3)
batch = sample_msa(sample_key, batch, c.num_msa) safe_key, sample_key, mask_key = safe_key.split(3)
batch = make_masked_msa(batch, mask_key, c.masked_msa) batch = sample_msa(sample_key, batch, c.num_msa)
batch = make_masked_msa(batch, mask_key, c.masked_msa)
(batch['cluster_profile'],
batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch) (batch['cluster_profile'],
batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch)
msa_feat = create_msa_feat(batch)
msa_feat = create_msa_feat(batch).astype(dtype)
preprocess_msa = common_modules.Linear(
c.msa_channel, name='preprocess_msa')( preprocess_msa = common_modules.Linear(
msa_feat) c.msa_channel, name='preprocess_msa')(
msa_feat)
msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
left_single = common_modules.Linear( left_single = common_modules.Linear(
c.pair_channel, name='left_single')( c.pair_channel, name='left_single')(
target_feat) target_feat)
right_single = common_modules.Linear( right_single = common_modules.Linear(
c.pair_channel, name='right_single')( c.pair_channel, name='right_single')(
target_feat) target_feat)
pair_activations = left_single[:, None] + right_single[None] pair_activations = left_single[:, None] + right_single[None]
mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
mask_2d = mask_2d.astype(jnp.float32) mask_2d = mask_2d.astype(dtype)
if c.recycle_pos: if c.recycle_pos:
prev_pseudo_beta = modules.pseudo_beta_fn( prev_pseudo_beta = modules.pseudo_beta_fn(
batch['aatype'], batch['prev_pos'], None) batch['aatype'], batch['prev_pos'], None)
dgram = modules.dgram_from_positions( dgram = modules.dgram_from_positions(
prev_pseudo_beta, **self.config.prev_pos) prev_pseudo_beta, **self.config.prev_pos)
pair_activations += common_modules.Linear( dgram = dgram.astype(dtype)
c.pair_channel, name='prev_pos_linear')( pair_activations += common_modules.Linear(
dgram) c.pair_channel, name='prev_pos_linear')(
dgram)
if c.recycle_features: if c.recycle_features:
prev_msa_first_row = hk.LayerNorm( prev_msa_first_row = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
name='prev_msa_first_row_norm')( name='prev_msa_first_row_norm')(
batch['prev_msa_first_row']) batch['prev_msa_first_row']).astype(dtype)
msa_activations = msa_activations.at[0].add(prev_msa_first_row) msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += hk.LayerNorm( pair_activations += common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
name='prev_pair_norm')( name='prev_pair_norm')(
batch['prev_pair']) batch['prev_pair']).astype(dtype)
if c.max_relative_idx: if c.max_relative_idx:
pair_activations += self._relative_encoding(batch) pair_activations += self._relative_encoding(batch)
if c.template.enabled: if c.template.enabled:
template_module = TemplateEmbedding(c.template, gc) template_module = TemplateEmbedding(c.template, gc)
template_batch = { template_batch = {
'template_aatype': batch['template_aatype'], 'template_aatype': batch['template_aatype'],
'template_all_atom_positions': batch['template_all_atom_positions'], 'template_all_atom_positions': batch['template_all_atom_positions'],
'template_all_atom_mask': batch['template_all_atom_mask'] 'template_all_atom_mask': batch['template_all_atom_mask']
}
# Construct a mask such that only intra-chain template features are
# computed, since all templates are for each chain individually.
multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :]
safe_key, safe_subkey = safe_key.split()
template_act = template_module(
query_embedding=pair_activations,
template_batch=template_batch,
padding_mask_2d=mask_2d,
multichain_mask_2d=multichain_mask,
is_training=is_training,
safe_key=safe_subkey)
pair_activations += template_act
# Extra MSA stack.
(extra_msa_feat,
extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa)
extra_msa_activations = common_modules.Linear(
c.extra_msa_channel,
name='extra_msa_activations')(
extra_msa_feat).astype(dtype)
extra_msa_mask = extra_msa_mask.astype(dtype)
extra_evoformer_input = {
'msa': extra_msa_activations,
'pair': pair_activations,
} }
# Construct a mask such that only intra-chain template features are extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d}
# computed, since all templates are for each chain individually.
multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :]
safe_key, safe_subkey = safe_key.split()
template_act = template_module(
query_embedding=pair_activations,
template_batch=template_batch,
padding_mask_2d=mask_2d,
multichain_mask_2d=multichain_mask,
is_training=is_training,
safe_key=safe_subkey)
pair_activations += template_act
# Extra MSA stack.
(extra_msa_feat,
extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa)
extra_msa_activations = common_modules.Linear(
c.extra_msa_channel,
name='extra_msa_activations')(
extra_msa_feat)
extra_msa_mask = extra_msa_mask.astype(jnp.float32)
extra_evoformer_input = {
'msa': extra_msa_activations,
'pair': pair_activations,
}
extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d}
extra_evoformer_iteration = modules.EvoformerIteration( extra_evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
def extra_evoformer_fn(x): def extra_evoformer_fn(x):
act, safe_key = x act, safe_key = x
safe_key, safe_subkey = safe_key.split() safe_key, safe_subkey = safe_key.split()
extra_evoformer_output = extra_evoformer_iteration( extra_evoformer_output = extra_evoformer_iteration(
activations=act, activations=act,
masks=extra_masks, masks=extra_masks,
is_training=is_training, is_training=is_training,
safe_key=safe_subkey) safe_key=safe_subkey)
return (extra_evoformer_output, safe_key) return (extra_evoformer_output, safe_key)
if gc.use_remat:
extra_evoformer_fn = hk.remat(extra_evoformer_fn)
safe_key, safe_subkey = safe_key.split()
extra_evoformer_stack = layer_stack.layer_stack(
c.extra_msa_stack_num_block)(
extra_evoformer_fn)
extra_evoformer_output, safe_key = extra_evoformer_stack(
(extra_evoformer_input, safe_subkey))
pair_activations = extra_evoformer_output['pair']
# Get the size of the MSA before potentially adding templates, so we
# can crop out the templates later.
num_msa_sequences = msa_activations.shape[0]
evoformer_input = {
'msa': msa_activations,
'pair': pair_activations,
}
evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32),
'pair': mask_2d}
if c.template.enabled:
template_features, template_masks = (
template_embedding_1d(batch=batch, num_channel=c.msa_channel))
evoformer_input['msa'] = jnp.concatenate(
[evoformer_input['msa'], template_features], axis=0)
evoformer_masks['msa'] = jnp.concatenate(
[evoformer_masks['msa'], template_masks], axis=0)
evoformer_iteration = modules.EvoformerIteration( if gc.use_remat:
c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') extra_evoformer_fn = hk.remat(extra_evoformer_fn)
def evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split() safe_key, safe_subkey = safe_key.split()
evoformer_output = evoformer_iteration( extra_evoformer_stack = layer_stack.layer_stack(
activations=act, c.extra_msa_stack_num_block)(
masks=evoformer_masks, extra_evoformer_fn)
is_training=is_training, extra_evoformer_output, safe_key = extra_evoformer_stack(
safe_key=safe_subkey) (extra_evoformer_input, safe_subkey))
return (evoformer_output, safe_key)
pair_activations = extra_evoformer_output['pair']
if gc.use_remat:
evoformer_fn = hk.remat(evoformer_fn) # Get the size of the MSA before potentially adding templates, so we
# can crop out the templates later.
num_msa_sequences = msa_activations.shape[0]
evoformer_input = {
'msa': msa_activations,
'pair': pair_activations,
}
evoformer_masks = {
'msa': batch['msa_mask'].astype(dtype),
'pair': mask_2d
}
if c.template.enabled:
template_features, template_masks = (
template_embedding_1d(
batch=batch, num_channel=c.msa_channel, global_config=gc))
evoformer_input['msa'] = jnp.concatenate(
[evoformer_input['msa'], template_features], axis=0)
evoformer_masks['msa'] = jnp.concatenate(
[evoformer_masks['msa'], template_masks], axis=0)
evoformer_iteration = modules.EvoformerIteration(
c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
def evoformer_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
evoformer_output = evoformer_iteration(
activations=act,
masks=evoformer_masks,
is_training=is_training,
safe_key=safe_subkey)
return (evoformer_output, safe_key)
if gc.use_remat:
evoformer_fn = hk.remat(evoformer_fn)
safe_key, safe_subkey = safe_key.split() safe_key, safe_subkey = safe_key.split()
evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)( evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
evoformer_fn) evoformer_fn)
def run_evoformer(evoformer_input): def run_evoformer(evoformer_input):
evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey)) evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey))
return evoformer_output return evoformer_output
evoformer_output = run_evoformer(evoformer_input) evoformer_output = run_evoformer(evoformer_input)
msa_activations = evoformer_output['msa'] msa_activations = evoformer_output['msa']
pair_activations = evoformer_output['pair'] pair_activations = evoformer_output['pair']
single_activations = common_modules.Linear( single_activations = common_modules.Linear(
c.seq_channel, name='single_activations')( c.seq_channel, name='single_activations')(
msa_activations[0]) msa_activations[0])
output.update({ output.update({
'single': 'single':
...@@ -771,6 +807,12 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -771,6 +807,12 @@ class EmbeddingsAndEvoformer(hk.Module):
msa_activations[0], msa_activations[0],
}) })
# Convert back to float32 if we're not saving memory.
if not gc.bfloat16_output:
for k, v in output.items():
if v.dtype == jnp.bfloat16:
output[k] = v.astype(jnp.float32)
return output return output
...@@ -917,6 +959,9 @@ class SingleTemplateEmbedding(hk.Module): ...@@ -917,6 +959,9 @@ class SingleTemplateEmbedding(hk.Module):
# backbone affine - i.e. in each residues local frame, what direction are # backbone affine - i.e. in each residues local frame, what direction are
# each of the other residues. # each of the other residues.
raw_atom_pos = template_all_atom_positions raw_atom_pos = template_all_atom_positions
if gc.bfloat16:
# Vec3Arrays are required to be float32
raw_atom_pos = raw_atom_pos.astype(jnp.float32)
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos) atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = folding_multimer.make_backbone_affine( rigid, backbone_mask = folding_multimer.make_backbone_affine(
...@@ -928,6 +973,10 @@ class SingleTemplateEmbedding(hk.Module): ...@@ -928,6 +973,10 @@ class SingleTemplateEmbedding(hk.Module):
unit_vector = rigid_vec.normalized() unit_vector = rigid_vec.normalized()
unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z] unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z]
if gc.bfloat16:
unit_vector = [x.astype(jnp.bfloat16) for x in unit_vector]
backbone_mask = backbone_mask.astype(jnp.bfloat16)
backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :] backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :]
backbone_mask_2d *= multichain_mask_2d backbone_mask_2d *= multichain_mask_2d
unit_vector = [x*backbone_mask_2d for x in unit_vector] unit_vector = [x*backbone_mask_2d for x in unit_vector]
...@@ -937,7 +986,7 @@ class SingleTemplateEmbedding(hk.Module): ...@@ -937,7 +986,7 @@ class SingleTemplateEmbedding(hk.Module):
to_concat.extend([(x, 0) for x in unit_vector]) to_concat.extend([(x, 0) for x in unit_vector])
to_concat.append((backbone_mask_2d, 0)) to_concat.append((backbone_mask_2d, 0))
query_embedding = hk.LayerNorm( query_embedding = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
...@@ -986,12 +1035,13 @@ class SingleTemplateEmbedding(hk.Module): ...@@ -986,12 +1035,13 @@ class SingleTemplateEmbedding(hk.Module):
template_iteration_fn) template_iteration_fn)
act, safe_key = template_stack((act, safe_subkey)) act, safe_key = template_stack((act, safe_subkey))
act = hk.LayerNorm( act = common_modules.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
create_offset=True, create_offset=True,
name='output_layer_norm')( name='output_layer_norm')(
act) act)
return act return act
...@@ -1044,21 +1094,18 @@ class TemplateEmbeddingIteration(hk.Module): ...@@ -1044,21 +1094,18 @@ class TemplateEmbeddingIteration(hk.Module):
act, act,
pair_mask, pair_mask,
safe_key=next(sub_keys)) safe_key=next(sub_keys))
act = dropout_wrapper_fn( act = dropout_wrapper_fn(
modules.TriangleAttention(c.triangle_attention_starting_node, gc, modules.TriangleAttention(c.triangle_attention_starting_node, gc,
name='triangle_attention_starting_node'), name='triangle_attention_starting_node'),
act, act,
pair_mask, pair_mask,
safe_key=next(sub_keys)) safe_key=next(sub_keys))
act = dropout_wrapper_fn( act = dropout_wrapper_fn(
modules.TriangleAttention(c.triangle_attention_ending_node, gc, modules.TriangleAttention(c.triangle_attention_ending_node, gc,
name='triangle_attention_ending_node'), name='triangle_attention_ending_node'),
act, act,
pair_mask, pair_mask,
safe_key=next(sub_keys)) safe_key=next(sub_keys))
act = dropout_wrapper_fn( act = dropout_wrapper_fn(
modules.Transition(c.pair_transition, gc, modules.Transition(c.pair_transition, gc,
name='pair_transition'), name='pair_transition'),
...@@ -1069,7 +1116,7 @@ class TemplateEmbeddingIteration(hk.Module): ...@@ -1069,7 +1116,7 @@ class TemplateEmbeddingIteration(hk.Module):
return act return act
def template_embedding_1d(batch, num_channel): def template_embedding_1d(batch, num_channel, global_config):
"""Embed templates into an (num_res, num_templates, num_channels) embedding. """Embed templates into an (num_res, num_templates, num_channels) embedding.
Args: Args:
...@@ -1080,6 +1127,7 @@ def template_embedding_1d(batch, num_channel): ...@@ -1080,6 +1127,7 @@ def template_embedding_1d(batch, num_channel):
template_all_atom_mask, (num_templates, num_residues, 37) atom mask for template_all_atom_mask, (num_templates, num_residues, 37) atom mask for
each template. each template.
num_channel: The number of channels in the output. num_channel: The number of channels in the output.
global_config: The global_config.
Returns: Returns:
An embedding of shape (num_templates, num_res, num_channels) and a mask of An embedding of shape (num_templates, num_res, num_channels) and a mask of
...@@ -1112,6 +1160,10 @@ def template_embedding_1d(batch, num_channel): ...@@ -1112,6 +1160,10 @@ def template_embedding_1d(batch, num_channel):
template_mask = chi_mask[:, :, 0] template_mask = chi_mask[:, :, 0]
if global_config.bfloat16:
template_features = template_features.astype(jnp.bfloat16)
template_mask = template_mask.astype(jnp.bfloat16)
template_activations = common_modules.Linear( template_activations = common_modules.Linear(
num_channel, num_channel,
initializer='relu', initializer='relu',
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""A collection of JAX utility functions for use in protein folding.""" """A collection of JAX utility functions for use in protein folding."""
import collections import collections
import contextlib
import functools import functools
import numbers import numbers
from typing import Mapping from typing import Mapping
...@@ -25,6 +26,27 @@ import jax.numpy as jnp ...@@ -25,6 +26,27 @@ import jax.numpy as jnp
import numpy as np import numpy as np
def bfloat16_creator(next_creator, shape, dtype, init, context):
"""Creates float32 variables when bfloat16 is requested."""
if context.original_dtype == jnp.bfloat16:
dtype = jnp.float32
return next_creator(shape, dtype, init)
def bfloat16_getter(next_getter, value, context):
"""Casts float32 to bfloat16 when bfloat16 was originally requested."""
if context.original_dtype == jnp.bfloat16:
assert value.dtype == jnp.float32
value = value.astype(jnp.bfloat16)
return next_getter(value)
@contextlib.contextmanager
def bfloat16_context():
with hk.custom_creator(bfloat16_creator), hk.custom_getter(bfloat16_getter):
yield
def final_init(config): def final_init(config):
if config.zero_init: if config.zero_init:
return 'zeros' return 'zeros'
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
"""Helper methods for the AlphaFold Colab notebook.""" """Helper methods for the AlphaFold Colab notebook."""
import enum
import json import json
from typing import Any, Mapping, Optional, Sequence, Tuple from typing import Any, Mapping, Optional, Sequence, Tuple
...@@ -23,13 +22,7 @@ from matplotlib import pyplot as plt ...@@ -23,13 +22,7 @@ from matplotlib import pyplot as plt
import numpy as np import numpy as np
@enum.unique def clean_and_validate_single_sequence(
class ModelType(enum.Enum):
MONOMER = 0
MULTIMER = 1
def clean_and_validate_sequence(
input_sequence: str, min_length: int, max_length: int) -> str: input_sequence: str, min_length: int, max_length: int) -> str:
"""Checks that the input sequence is ok and returns a clean version of it.""" """Checks that the input sequence is ok and returns a clean version of it."""
# Remove all whitespaces, tabs and end lines; upper-case. # Remove all whitespaces, tabs and end lines; upper-case.
...@@ -54,41 +47,23 @@ def clean_and_validate_sequence( ...@@ -54,41 +47,23 @@ def clean_and_validate_sequence(
return clean_sequence return clean_sequence
def validate_input( def clean_and_validate_input_sequences(
input_sequences: Sequence[str], input_sequences: Sequence[str],
min_length: int, min_sequence_length: int,
max_length: int, max_sequence_length: int) -> Sequence[str]:
max_multimer_length: int) -> Tuple[Sequence[str], ModelType]: """Validates and cleans input sequences."""
"""Validates and cleans input sequences and determines which model to use."""
sequences = [] sequences = []
for input_sequence in input_sequences: for input_sequence in input_sequences:
if input_sequence.strip(): if input_sequence.strip():
input_sequence = clean_and_validate_sequence( input_sequence = clean_and_validate_single_sequence(
input_sequence=input_sequence, input_sequence=input_sequence,
min_length=min_length, min_length=min_sequence_length,
max_length=max_length) max_length=max_sequence_length)
sequences.append(input_sequence) sequences.append(input_sequence)
if len(sequences) == 1: if sequences:
print('Using the single-chain model.') return sequences
return sequences, ModelType.MONOMER
elif len(sequences) > 1:
total_multimer_length = sum([len(seq) for seq in sequences])
if total_multimer_length > max_multimer_length:
raise ValueError(f'The total length of multimer sequences is too long: '
f'{total_multimer_length}, while the maximum is '
f'{max_multimer_length}. Please use the full AlphaFold '
f'system for long multimers.')
elif total_multimer_length > 1536:
print('WARNING: The accuracy of the system has not been fully validated '
'above 1536 residues, and you may experience long running times or '
f'run out of memory for your complex with {total_multimer_length} '
'residues.')
print(f'Using the multimer model with {len(sequences)} sequences.')
return sequences, ModelType.MULTIMER
else: else:
raise ValueError('No input amino acid sequence provided, please provide at ' raise ValueError('No input amino acid sequence provided, please provide at '
'least one sequence.') 'least one sequence.')
......
...@@ -85,7 +85,7 @@ class NotebookUtilsTest(parameterized.TestCase): ...@@ -85,7 +85,7 @@ class NotebookUtilsTest(parameterized.TestCase):
('DeepMind', 'DEEPMIND'), ('A ', 'A'), ('\tA', 'A'), (' A\t\n', 'A'), ('DeepMind', 'DEEPMIND'), ('A ', 'A'), ('\tA', 'A'), (' A\t\n', 'A'),
('ACDEFGHIKLMNPQRSTVWY', 'ACDEFGHIKLMNPQRSTVWY')) ('ACDEFGHIKLMNPQRSTVWY', 'ACDEFGHIKLMNPQRSTVWY'))
def test_clean_and_validate_sequence_ok(self, sequence, exp_clean): def test_clean_and_validate_sequence_ok(self, sequence, exp_clean):
clean = notebook_utils.clean_and_validate_sequence( clean = notebook_utils.clean_and_validate_single_sequence(
sequence, min_length=1, max_length=100) sequence, min_length=1, max_length=100)
self.assertEqual(clean, exp_clean) self.assertEqual(clean, exp_clean)
...@@ -100,35 +100,29 @@ class NotebookUtilsTest(parameterized.TestCase): ...@@ -100,35 +100,29 @@ class NotebookUtilsTest(parameterized.TestCase):
('bad_amino_acids_Z', 'ZZZZ', 'non-amino acid')) ('bad_amino_acids_Z', 'ZZZZ', 'non-amino acid'))
def test_clean_and_validate_sequence_bad(self, sequence, exp_error): def test_clean_and_validate_sequence_bad(self, sequence, exp_error):
with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'): with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'):
notebook_utils.clean_and_validate_sequence( notebook_utils.clean_and_validate_single_sequence(
sequence, min_length=4, max_length=8) sequence, min_length=4, max_length=8)
@parameterized.parameters( @parameterized.parameters(
(['A', '', '', ' ', '\t', ' \t\n', '', ''], ['A'], (['A', '', '', ' ', '\t', ' \t\n', '', ''], ['A']),
notebook_utils.ModelType.MONOMER), (['', 'A'], ['A']),
(['', 'A'], ['A'], (['A', 'C ', ''], ['A', 'C']),
notebook_utils.ModelType.MONOMER), (['', 'A', '', 'C '], ['A', 'C']))
(['A', 'C ', ''], ['A', 'C'], def test_validate_input_ok(self, input_sequences, exp_sequences):
notebook_utils.ModelType.MULTIMER), sequences = notebook_utils.clean_and_validate_input_sequences(
(['', 'A', '', 'C '], ['A', 'C'],
notebook_utils.ModelType.MULTIMER))
def test_validate_input_ok(
self, input_sequences, exp_sequences, exp_model_type):
sequences, model_type = notebook_utils.validate_input(
input_sequences=input_sequences, input_sequences=input_sequences,
min_length=1, max_length=100, max_multimer_length=100) min_sequence_length=1, max_sequence_length=100)
self.assertSequenceEqual(sequences, exp_sequences) self.assertSequenceEqual(sequences, exp_sequences)
self.assertEqual(model_type, exp_model_type)
@parameterized.named_parameters( @parameterized.named_parameters(
('no_input_sequence', ['', '\t', '\n'], 'No input amino acid sequence'), ('no_input_sequence', ['', '\t', '\n'], 'No input amino acid sequence'),
('too_long_single', ['AAAAAAAAA', 'AAAA'], 'Input sequence is too long'), ('too_long_single', ['AAAAAAAAA', 'AAAA'], 'Input sequence is too long'),
('too_long_multimer', ['AAAA', 'AAAAA'], 'The total length of multimer')) ('too_short_single', ['AAA', 'AAAA'], 'Input sequence is too short'))
def test_validate_input_bad(self, input_sequences, exp_error): def test_validate_input_bad(self, input_sequences, exp_error):
with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'): with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'):
notebook_utils.validate_input( notebook_utils.clean_and_validate_input_sequences(
input_sequences=input_sequences, input_sequences=input_sequences, min_sequence_length=4,
min_length=4, max_length=8, max_multimer_length=6) max_sequence_length=8)
def test_merge_chunked_msa_no_hits(self): def test_merge_chunked_msa_no_hits(self):
results = [ONLY_QUERY_HIT, ONLY_QUERY_HIT] results = [ONLY_QUERY_HIT, ONLY_QUERY_HIT]
......
...@@ -56,7 +56,8 @@ class AmberRelaxation(object): ...@@ -56,7 +56,8 @@ class AmberRelaxation(object):
self._use_gpu = use_gpu self._use_gpu = use_gpu
def process(self, *, def process(self, *,
prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]: prot: protein.Protein
) -> Tuple[str, Dict[str, Any], Sequence[float]]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" """Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline( out = amber_minimize.run_pipeline(
prot=prot, max_iterations=self._max_iterations, prot=prot, max_iterations=self._max_iterations,
...@@ -73,12 +74,11 @@ class AmberRelaxation(object): ...@@ -73,12 +74,11 @@ class AmberRelaxation(object):
'attempts': out['min_attempts'], 'attempts': out['min_attempts'],
'rmsd': rmsd 'rmsd': rmsd
} }
pdb_str = amber_minimize.clean_protein(prot) min_pdb = out['min_pdb']
min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors) min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
utils.assert_equal_nonterminal_atom_types( utils.assert_equal_nonterminal_atom_types(
protein.from_pdb_string(min_pdb).atom_mask, protein.from_pdb_string(min_pdb).atom_mask,
prot.atom_mask) prot.atom_mask)
violations = out['structural_violations'][ violations = out['structural_violations'][
'total_per_residue_violations_mask'] 'total_per_residue_violations_mask'].tolist()
return min_pdb, debug_data, violations return min_pdb, debug_data, violations
...@@ -82,7 +82,7 @@ class RunAmberRelaxTest(absltest.TestCase): ...@@ -82,7 +82,7 @@ class RunAmberRelaxTest(absltest.TestCase):
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0]) 0, 0, 0, 0])
# Check no violations were added. Can't check exactly due to stochasticity. # Check no violations were added. Can't check exactly due to stochasticity.
self.assertTrue(np.all(num_violations <= exp_num_violations)) self.assertTrue(np.all(np.array(num_violations) <= exp_num_violations))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -17,17 +17,6 @@ import io ...@@ -17,17 +17,6 @@ import io
from alphafold.common import residue_constants from alphafold.common import residue_constants
from Bio import PDB from Bio import PDB
import numpy as np import numpy as np
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
pdb_file = io.StringIO(pdb_str)
structure = PdbStructure(pdb_file)
topology = openmm_app.PDBFile(structure).getTopology()
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, pos, f)
return f.getvalue()
def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
......
...@@ -21,7 +21,7 @@ ARG CUDA ...@@ -21,7 +21,7 @@ ARG CUDA
# Use bash to support string substitution. # Use bash to support string substitution.
SHELL ["/bin/bash", "-o", "pipefail", "-c"] SHELL ["/bin/bash", "-o", "pipefail", "-c"]
RUN apt-get update \ RUN apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ && DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
build-essential \ build-essential \
cmake \ cmake \
...@@ -59,7 +59,7 @@ RUN conda install -qy conda==4.13.0 \ ...@@ -59,7 +59,7 @@ RUN conda install -qy conda==4.13.0 \
cudatoolkit==${CUDA_VERSION} \ cudatoolkit==${CUDA_VERSION} \
pdbfixer \ pdbfixer \
pip \ pip \
python=3.7 \ python=3.8 \
&& conda clean --all --force-pkgs-dirs --yes && conda clean --all --force-pkgs-dirs --yes
COPY . /app/alphafold COPY . /app/alphafold
...@@ -75,7 +75,7 @@ RUN pip3 install --upgrade pip --no-cache-dir \ ...@@ -75,7 +75,7 @@ RUN pip3 install --upgrade pip --no-cache-dir \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Apply OpenMM patch. # Apply OpenMM patch.
WORKDIR /opt/conda/lib/python3.7/site-packages WORKDIR /opt/conda/lib/python3.8/site-packages
RUN patch -p0 < /app/alphafold/docker/openmm.patch RUN patch -p0 < /app/alphafold/docker/openmm.patch
# Add SETUID bit to the ldconfig binary so that non-root users can run it. # Add SETUID bit to the ldconfig binary so that non-root users can run it.
......
...@@ -133,7 +133,7 @@ def main(argv): ...@@ -133,7 +133,7 @@ def main(argv):
# Path to the MGnify database for use by JackHMMER. # Path to the MGnify database for use by JackHMMER.
mgnify_database_path = os.path.join( mgnify_database_path = os.path.join(
FLAGS.data_dir, 'mgnify', 'mgy_clusters_2018_12.fa') FLAGS.data_dir, 'mgnify', 'mgy_clusters_2022_05.fa')
# Path to the BFD database for use by HHblits. # Path to the BFD database for use by HHblits.
bfd_database_path = os.path.join( bfd_database_path = os.path.join(
...@@ -144,9 +144,9 @@ def main(argv): ...@@ -144,9 +144,9 @@ def main(argv):
small_bfd_database_path = os.path.join( small_bfd_database_path = os.path.join(
FLAGS.data_dir, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta') FLAGS.data_dir, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta')
# Path to the Uniclust30 database for use by HHblits. # Path to the Uniref30 database for use by HHblits.
uniclust30_database_path = os.path.join( uniref30_database_path = os.path.join(
FLAGS.data_dir, 'uniclust30', 'uniclust30_2018_08', 'uniclust30_2018_08') FLAGS.data_dir, 'uniref30', 'UniRef30_2021_03')
# Path to the PDB70 database for use by HHsearch. # Path to the PDB70 database for use by HHsearch.
pdb70_database_path = os.path.join(FLAGS.data_dir, 'pdb70', 'pdb70') pdb70_database_path = os.path.join(FLAGS.data_dir, 'pdb70', 'pdb70')
...@@ -199,7 +199,7 @@ def main(argv): ...@@ -199,7 +199,7 @@ def main(argv):
database_paths.append(('small_bfd_database_path', small_bfd_database_path)) database_paths.append(('small_bfd_database_path', small_bfd_database_path))
else: else:
database_paths.extend([ database_paths.extend([
('uniclust30_database_path', uniclust30_database_path), ('uniref30_database_path', uniref30_database_path),
('bfd_database_path', bfd_database_path), ('bfd_database_path', bfd_database_path),
]) ])
for name, path in database_paths: for name, path in database_paths:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment