Commit 2cd61ade authored by Michael Figurnov's avatar Michael Figurnov Committed by Copybara-Service
Browse files

Improve support for num_recycle=0.

Previously, setting num_recycle=0 in a pretrained recycling model skipped creating some Linears and LayerNorms. This meant that the the offsets of these modules were not applied leading to a degraded performance in that case.

PiperOrigin-RevId: 429075328
Change-Id: I9257f859521799f45e2deef3803c249311051225
parent 929b188a
......@@ -341,17 +341,18 @@ class AlphaFold(hk.Module):
compute_loss=compute_loss,
ensemble_representations=ensemble_representations)
if self.config.num_recycle:
emb_config = self.config.embeddings_and_evoformer
prev = {
'prev_pos': jnp.zeros(
[num_residues, residue_constants.atom_type_num, 3]),
'prev_msa_first_row': jnp.zeros(
[num_residues, emb_config.msa_channel]),
'prev_pair': jnp.zeros(
[num_residues, num_residues, emb_config.pair_channel]),
}
prev = {}
emb_config = self.config.embeddings_and_evoformer
if emb_config.recycle_pos:
prev['prev_pos'] = jnp.zeros(
[num_residues, residue_constants.atom_type_num, 3])
if emb_config.recycle_features:
prev['prev_msa_first_row'] = jnp.zeros(
[num_residues, emb_config.msa_channel])
prev['prev_pair'] = jnp.zeros(
[num_residues, num_residues, emb_config.pair_channel])
if self.config.num_recycle:
if 'num_iter_recycling' in batch:
# Training time: num_iter_recycling is in batch.
# The value for each ensemble batch is the same, so arbitrarily taking
......@@ -378,7 +379,6 @@ class AlphaFold(hk.Module):
body,
(0, prev))
else:
prev = {}
num_iter = 0
ret = do_call(prev=prev, recycle_idx=num_iter)
......@@ -1730,7 +1730,7 @@ class EmbeddingsAndEvoformer(hk.Module):
# Inject previous outputs for recycling.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6
# Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder"
if c.recycle_pos and 'prev_pos' in batch:
if c.recycle_pos:
prev_pseudo_beta = pseudo_beta_fn(
batch['aatype'], batch['prev_pos'], None)
dgram = dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos)
......@@ -1739,20 +1739,20 @@ class EmbeddingsAndEvoformer(hk.Module):
dgram)
if c.recycle_features:
if 'prev_msa_first_row' in batch:
prev_msa_first_row = hk.LayerNorm([-1],
True,
True,
name='prev_msa_first_row_norm')(
batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
if 'prev_pair' in batch:
pair_activations += hk.LayerNorm([-1],
True,
True,
name='prev_pair_norm')(
batch['prev_pair'])
prev_msa_first_row = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_msa_first_row_norm')(
batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_pair_norm')(
batch['prev_pair'])
# Relative position encoding.
# Jumper et al. (2021) Suppl. Alg. 4 "relpos"
......
......@@ -451,17 +451,18 @@ class AlphaFold(hk.Module):
is_training=is_training,
safe_key=safe_key)
if self.config.num_recycle:
emb_config = self.config.embeddings_and_evoformer
prev = {
'prev_pos':
jnp.zeros([num_res, residue_constants.atom_type_num, 3]),
'prev_msa_first_row':
jnp.zeros([num_res, emb_config.msa_channel]),
'prev_pair':
jnp.zeros([num_res, num_res, emb_config.pair_channel]),
}
prev = {}
emb_config = self.config.embeddings_and_evoformer
if emb_config.recycle_pos:
prev['prev_pos'] = jnp.zeros(
[num_res, residue_constants.atom_type_num, 3])
if emb_config.recycle_features:
prev['prev_msa_first_row'] = jnp.zeros(
[num_res, emb_config.msa_channel])
prev['prev_pair'] = jnp.zeros(
[num_res, num_res, emb_config.pair_channel])
if self.config.num_recycle:
if 'num_iter_recycling' in batch:
# Training time: num_iter_recycling is in batch.
# Value for each ensemble batch is the same, so arbitrarily taking 0-th.
......@@ -482,8 +483,6 @@ class AlphaFold(hk.Module):
return get_prev(ret), safe_key1
prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key))
else:
prev = {}
# Run extra iteration.
ret = apply_network(prev=prev, safe_key=safe_key)
......@@ -619,7 +618,7 @@ class EmbeddingsAndEvoformer(hk.Module):
mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
mask_2d = mask_2d.astype(jnp.float32)
if c.recycle_pos and 'prev_pos' in batch:
if c.recycle_pos:
prev_pseudo_beta = modules.pseudo_beta_fn(
batch['aatype'], batch['prev_pos'], None)
......@@ -630,22 +629,20 @@ class EmbeddingsAndEvoformer(hk.Module):
dgram)
if c.recycle_features:
if 'prev_msa_first_row' in batch:
prev_msa_first_row = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_msa_first_row_norm')(
batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
if 'prev_pair' in batch:
pair_activations += hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_pair_norm')(
batch['prev_pair'])
prev_msa_first_row = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_msa_first_row_norm')(
batch['prev_msa_first_row'])
msa_activations = msa_activations.at[0].add(prev_msa_first_row)
pair_activations += hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='prev_pair_norm')(
batch['prev_pair'])
if c.max_relative_idx:
pair_activations += self._relative_encoding(batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment