Commit 2cd61ade authored by Michael Figurnov's avatar Michael Figurnov Committed by Copybara-Service
Browse files

Improve support for num_recycle=0.

Previously, setting num_recycle=0 in a pretrained recycling model skipped creating some Linears and LayerNorms. This meant that the the offsets of these modules were not applied leading to a degraded performance in that case.

PiperOrigin-RevId: 429075328
Change-Id: I9257f859521799f45e2deef3803c249311051225
parent 929b188a
...@@ -341,17 +341,18 @@ class AlphaFold(hk.Module): ...@@ -341,17 +341,18 @@ class AlphaFold(hk.Module):
compute_loss=compute_loss, compute_loss=compute_loss,
ensemble_representations=ensemble_representations) ensemble_representations=ensemble_representations)
if self.config.num_recycle: prev = {}
emb_config = self.config.embeddings_and_evoformer emb_config = self.config.embeddings_and_evoformer
prev = { if emb_config.recycle_pos:
'prev_pos': jnp.zeros( prev['prev_pos'] = jnp.zeros(
[num_residues, residue_constants.atom_type_num, 3]), [num_residues, residue_constants.atom_type_num, 3])
'prev_msa_first_row': jnp.zeros( if emb_config.recycle_features:
[num_residues, emb_config.msa_channel]), prev['prev_msa_first_row'] = jnp.zeros(
'prev_pair': jnp.zeros( [num_residues, emb_config.msa_channel])
[num_residues, num_residues, emb_config.pair_channel]), prev['prev_pair'] = jnp.zeros(
} [num_residues, num_residues, emb_config.pair_channel])
if self.config.num_recycle:
if 'num_iter_recycling' in batch: if 'num_iter_recycling' in batch:
# Training time: num_iter_recycling is in batch. # Training time: num_iter_recycling is in batch.
# The value for each ensemble batch is the same, so arbitrarily taking # The value for each ensemble batch is the same, so arbitrarily taking
...@@ -378,7 +379,6 @@ class AlphaFold(hk.Module): ...@@ -378,7 +379,6 @@ class AlphaFold(hk.Module):
body, body,
(0, prev)) (0, prev))
else: else:
prev = {}
num_iter = 0 num_iter = 0
ret = do_call(prev=prev, recycle_idx=num_iter) ret = do_call(prev=prev, recycle_idx=num_iter)
...@@ -1730,7 +1730,7 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -1730,7 +1730,7 @@ class EmbeddingsAndEvoformer(hk.Module):
# Inject previous outputs for recycling. # Inject previous outputs for recycling.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6
# Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder"
if c.recycle_pos and 'prev_pos' in batch: if c.recycle_pos:
prev_pseudo_beta = pseudo_beta_fn( prev_pseudo_beta = pseudo_beta_fn(
batch['aatype'], batch['prev_pos'], None) batch['aatype'], batch['prev_pos'], None)
dgram = dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos) dgram = dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos)
...@@ -1739,18 +1739,18 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -1739,18 +1739,18 @@ class EmbeddingsAndEvoformer(hk.Module):
dgram) dgram)
if c.recycle_features: if c.recycle_features:
if 'prev_msa_first_row' in batch: prev_msa_first_row = hk.LayerNorm(
prev_msa_first_row = hk.LayerNorm([-1], axis=[-1],
True, create_scale=True,
True, create_offset=True,
name='prev_msa_first_row_norm')( name='prev_msa_first_row_norm')(
batch['prev_msa_first_row']) batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row) msa_activations = msa_activations.at[0].add(prev_msa_first_row)
if 'prev_pair' in batch: pair_activations += hk.LayerNorm(
pair_activations += hk.LayerNorm([-1], axis=[-1],
True, create_scale=True,
True, create_offset=True,
name='prev_pair_norm')( name='prev_pair_norm')(
batch['prev_pair']) batch['prev_pair'])
......
...@@ -451,17 +451,18 @@ class AlphaFold(hk.Module): ...@@ -451,17 +451,18 @@ class AlphaFold(hk.Module):
is_training=is_training, is_training=is_training,
safe_key=safe_key) safe_key=safe_key)
if self.config.num_recycle: prev = {}
emb_config = self.config.embeddings_and_evoformer emb_config = self.config.embeddings_and_evoformer
prev = { if emb_config.recycle_pos:
'prev_pos': prev['prev_pos'] = jnp.zeros(
jnp.zeros([num_res, residue_constants.atom_type_num, 3]), [num_res, residue_constants.atom_type_num, 3])
'prev_msa_first_row': if emb_config.recycle_features:
jnp.zeros([num_res, emb_config.msa_channel]), prev['prev_msa_first_row'] = jnp.zeros(
'prev_pair': [num_res, emb_config.msa_channel])
jnp.zeros([num_res, num_res, emb_config.pair_channel]), prev['prev_pair'] = jnp.zeros(
} [num_res, num_res, emb_config.pair_channel])
if self.config.num_recycle:
if 'num_iter_recycling' in batch: if 'num_iter_recycling' in batch:
# Training time: num_iter_recycling is in batch. # Training time: num_iter_recycling is in batch.
# Value for each ensemble batch is the same, so arbitrarily taking 0-th. # Value for each ensemble batch is the same, so arbitrarily taking 0-th.
...@@ -482,8 +483,6 @@ class AlphaFold(hk.Module): ...@@ -482,8 +483,6 @@ class AlphaFold(hk.Module):
return get_prev(ret), safe_key1 return get_prev(ret), safe_key1
prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key)) prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key))
else:
prev = {}
# Run extra iteration. # Run extra iteration.
ret = apply_network(prev=prev, safe_key=safe_key) ret = apply_network(prev=prev, safe_key=safe_key)
...@@ -619,7 +618,7 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -619,7 +618,7 @@ class EmbeddingsAndEvoformer(hk.Module):
mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
mask_2d = mask_2d.astype(jnp.float32) mask_2d = mask_2d.astype(jnp.float32)
if c.recycle_pos and 'prev_pos' in batch: if c.recycle_pos:
prev_pseudo_beta = modules.pseudo_beta_fn( prev_pseudo_beta = modules.pseudo_beta_fn(
batch['aatype'], batch['prev_pos'], None) batch['aatype'], batch['prev_pos'], None)
...@@ -630,7 +629,6 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -630,7 +629,6 @@ class EmbeddingsAndEvoformer(hk.Module):
dgram) dgram)
if c.recycle_features: if c.recycle_features:
if 'prev_msa_first_row' in batch:
prev_msa_first_row = hk.LayerNorm( prev_msa_first_row = hk.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
...@@ -639,7 +637,6 @@ class EmbeddingsAndEvoformer(hk.Module): ...@@ -639,7 +637,6 @@ class EmbeddingsAndEvoformer(hk.Module):
batch['prev_msa_first_row']) batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row) msa_activations = msa_activations.at[0].add(prev_msa_first_row)
if 'prev_pair' in batch:
pair_activations += hk.LayerNorm( pair_activations += hk.LayerNorm(
axis=[-1], axis=[-1],
create_scale=True, create_scale=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment