Commit 40325b18 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added embedder for handling single-sequence embeddings.

- Added a `PreembeddingEmbedder` for embedding single-sequence (NUM_RESIDUE, ...) shaped embeddings as input.
parent 60d0b15a
...@@ -139,6 +139,100 @@ class InputEmbedder(nn.Module): ...@@ -139,6 +139,100 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb return msa_emb, pair_emb
class PreembeddingEmbedder(nn.Module):
"""
Embeds the sequence pre-embedding passed to the model and the target_feat features.
"""
def __init__(
self,
tf_dim: int,
preembedding_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
End channel dimension of the incoming target features
preembedding_dim:
End channel dimension of the incoming embeddings
c_z:
Pair embedding dimension
c_m:
Single-Seq embedding dimension
relpos_k:
Window size used in relative position encoding
"""
super(PreembeddingEmbedder, self).__init__()
self.tf_dim = tf_dim
self.preembedding_dim = preembedding_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_preemb_m = Linear(self.preembedding_dim, c_m)
self.linear_preemb_z_i = Linear(self.preembedding_dim, c_z)
self.linear_preemb_z_j = Linear(self.preembedding_dim, c_z)
# Relative Positional Encoding
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, ri: torch.Tensor):
"""
Computes relative positional encodings
Args:
ri:
"residue_index" feature of shape [*, N]
Returns:
Relative positional encoding of protein using the
residue_index feature
"""
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
d = d[..., None] - reshaped_bins
d = torch.abs(d)
d = torch.argmin(d, dim=-1)
d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
d = d.to(ri.dtype)
return self.linear_relpos(d)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
preemb: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
)
preemb_emb = self.linear_preemb_m(preemb[..., None, :, :]) + tf_m
preemb_emb_i = self.linear_preemb_z_i(preemb)
preemb_emb_j = self.linear_preemb_z_j(preemb)
pair_emb = self.relpos(ri.type(preemb_emb_i.dtype))
pair_emb = add(pair_emb,
preemb_emb_i[..., None, :],
inplace=inplace_safe)
pair_emb = add(pair_emb,
preemb_emb_j[..., None, :, :],
inplace=inplace_safe)
return preemb_emb, pair_emb
class RecyclingEmbedder(nn.Module): class RecyclingEmbedder(nn.Module):
""" """
Embeds the output of an iteration of the model for recycling. Embeds the output of an iteration of the model for recycling.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment