Commit bc3ba06e authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added switching off of column attention in evoformer when using sequence embeddings.

- Added flag `no_column_attention` in evoformer config.
- Added check in `evoformer.py` to switch off `MSAColumnAttention` when the config flag `no_column_attention` is `True`.
parent 6403401f
...@@ -514,6 +514,7 @@ config = mlc.ConfigDict( ...@@ -514,6 +514,7 @@ config = mlc.ConfigDict(
"transition_n": 4, "transition_n": 4,
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"no_column_attention": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size, "tune_chunk_size": tune_chunk_size,
......
...@@ -323,6 +323,7 @@ class EvoformerBlock(nn.Module): ...@@ -323,6 +323,7 @@ class EvoformerBlock(nn.Module):
transition_n: int, transition_n: int,
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
no_column_attention: bool,
inf: float, inf: float,
eps: float, eps: float,
): ):
...@@ -336,12 +337,14 @@ class EvoformerBlock(nn.Module): ...@@ -336,12 +337,14 @@ class EvoformerBlock(nn.Module):
inf=inf, inf=inf,
) )
self.msa_att_col = MSAColumnAttention( self.no_column_attention = no_column_attention
c_m, if self.no_column_attention == False:
c_hidden_msa_att, self.msa_att_col = MSAColumnAttention(
no_heads_msa, c_m,
inf=inf, c_hidden_msa_att,
) no_heads_msa,
inf=inf,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout) self.msa_dropout_layer = DropoutRowwise(msa_dropout)
...@@ -397,17 +400,18 @@ class EvoformerBlock(nn.Module): ...@@ -397,17 +400,18 @@ class EvoformerBlock(nn.Module):
), ),
inplace=inplace_safe, inplace=inplace_safe,
) )
m = add(m, if self.no_column_attention == False:
self.msa_att_col( m = add(m,
m, self.msa_att_col(
mask=msa_mask, m,
chunk_size=chunk_size, mask=msa_mask,
use_lma=use_lma, chunk_size=chunk_size,
use_flash=use_flash, use_lma=use_lma,
), use_flash=use_flash,
inplace=inplace_safe, ),
) inplace=inplace_safe,
)
if(not inplace_safe): if(not inplace_safe):
input_tensors = [m, input_tensors[1]] input_tensors = [m, input_tensors[1]]
...@@ -595,6 +599,7 @@ class EvoformerStack(nn.Module): ...@@ -595,6 +599,7 @@ class EvoformerStack(nn.Module):
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
blocks_per_ckpt: int, blocks_per_ckpt: int,
no_column_attention: bool,
inf: float, inf: float,
eps: float, eps: float,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
...@@ -658,6 +663,7 @@ class EvoformerStack(nn.Module): ...@@ -658,6 +663,7 @@ class EvoformerStack(nn.Module):
transition_n=transition_n, transition_n=transition_n,
msa_dropout=msa_dropout, msa_dropout=msa_dropout,
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
no_column_attention=no_column_attention,
inf=inf, inf=inf,
eps=eps, eps=eps,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment