Commit 56d5e39c authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

Merge remote-tracking branch 'upstream/multimer' into multimer

parents 56b86074 51556d52
...@@ -13,12 +13,26 @@ ...@@ -13,12 +13,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple, Optional from typing import Tuple, Optional
from openfold.utils import all_atom_multimer
from openfold.utils.feats import (
pseudo_beta_fn,
dgram_from_positions,
build_template_angle_feat,
build_template_pair_feat,
)
from openfold.model.primitives import Linear, LayerNorm from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import add, one_hot from openfold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from openfold.utils import geometry
from openfold.utils.tensor_utils import add, one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module): class InputEmbedder(nn.Module):
...@@ -99,12 +113,13 @@ class InputEmbedder(nn.Module): ...@@ -99,12 +113,13 @@ class InputEmbedder(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
tf: batch: Dict containing
"target_feat" features of shape [*, N_res, tf_dim] "target_feat":
ri: Features of shape [*, N_res, tf_dim]
"residue_index" features of shape [*, N_res] "residue_index":
msa: Features of shape [*, N_res]
"msa_feat" features of shape [*, N_clust, N_res, msa_dim] "msa_feat":
Features of shape [*, N_clust, N_res, msa_dim]
Returns: Returns:
msa_emb: msa_emb:
[*, N_clust, N_res, C_m] MSA embedding [*, N_clust, N_res, C_m] MSA embedding
...@@ -139,6 +154,162 @@ class InputEmbedder(nn.Module): ...@@ -139,6 +154,162 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb return msa_emb, pair_emb
class InputEmbedderMultimer(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
max_relative_idx: int,
use_chain_relative: bool,
max_relative_chain: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedderMultimer, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.max_relative_idx = max_relative_idx
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if(self.use_chain_relative):
self.no_bins = (
2 * max_relative_idx + 2 +
1 +
2 * max_relative_chain + 2
)
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, batch):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = (asym_id[..., None] == asym_id[..., None, :])
offset = pos[..., None] - pos[..., None, :]
clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
)
rel_feats = []
if(self.use_chain_relative):
final_offset = torch.where(
asym_id_same,
clipped_offset,
(2 * self.max_relative_idx + 1) *
torch.ones_like(clipped_offset)
)
boundaries = torch.arange(
start=0, end=2 * self.max_relative_idx + 2, device=final_offset.device
)
rel_pos = one_hot(
final_offset,
boundaries,
)
rel_feats.append(rel_pos)
entity_id = batch["entity_id"]
entity_id_same = (entity_id[..., None] == entity_id[..., None, :])
rel_feats.append(entity_id_same[..., None])
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain,
0,
2 * max_rel_chain,
)
final_rel_chain = torch.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) *
torch.ones_like(clipped_rel_chain)
)
boundaries = torch.arange(
start=0, end=2 * max_rel_chain + 2, device=final_rel_chain.device
)
rel_chain = one_hot(
final_rel_chain,
boundaries,
)
rel_feats.append(rel_chain)
else:
boundaries = torch.arange(
start=0, end=2 * self.max_relative_idx + 1, device=clipped_offset.device
)
rel_pos = one_hot(
clipped_offset, boundaries,
)
rel_feats.append(rel_pos)
rel_feat = torch.cat(rel_feats, dim=-1).to(
self.linear_relpos.weight.dtype
)
return self.linear_relpos(rel_feat)
def forward(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module): class RecyclingEmbedder(nn.Module):
""" """
Embeds the output of an iteration of the model for recycling. Embeds the output of an iteration of the model for recycling.
...@@ -365,3 +536,345 @@ class ExtraMSAEmbedder(nn.Module): ...@@ -365,3 +536,345 @@ class ExtraMSAEmbedder(nn.Module):
x = self.linear(x) x = self.linear(x)
return x return x
class TemplateEmbedder(nn.Module):
def __init__(self, config):
super(TemplateEmbedder, self).__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(
self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True,
use_lma=False,
inplace_safe=False
):
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
if (inplace_safe):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.config.template_pair_embedder.c_t)
)
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch,
)
# [*, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
if (inplace_safe):
t_pair[..., i, :, :, :] = t
else:
pair_embeds.append(t)
del t
if (not inplace_safe):
t_pair = torch.stack(pair_embeds, dim=templ_dim)
del pair_embeds
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
del t_pair
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
use_lma=use_lma,
)
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
)
if (inplace_safe):
t *= t_mask
else:
t = t * t_mask
ret = {}
ret.update({"template_pair_embedding": t})
del t
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
batch
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
ret["template_single_embedding"] = a
return ret
class TemplatePairEmbedderMultimer(nn.Module):
def __init__(self,
c_z: int,
c_out: int,
c_dgram: int,
c_aatype: int,
):
super(TemplatePairEmbedderMultimer, self).__init__()
self.dgram_linear = Linear(c_dgram, c_out)
self.aatype_linear_1 = Linear(c_aatype, c_out)
self.aatype_linear_2 = Linear(c_aatype, c_out)
self.query_embedding_layer_norm = LayerNorm(c_z)
self.query_embedding_linear = Linear(c_z, c_out)
self.pseudo_beta_mask_linear = Linear(1, c_out)
self.x_linear = Linear(1, c_out)
self.y_linear = Linear(1, c_out)
self.z_linear = Linear(1, c_out)
self.backbone_mask_linear = Linear(1, c_out)
def forward(self,
template_dgram: torch.Tensor,
aatype_one_hot: torch.Tensor,
query_embedding: torch.Tensor,
pseudo_beta_mask: torch.Tensor,
backbone_mask: torch.Tensor,
multichain_mask_2d: torch.Tensor,
unit_vector: geometry.Vec3Array,
) -> torch.Tensor:
act = 0.
pseudo_beta_mask_2d = (
pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
)
pseudo_beta_mask_2d *= multichain_mask_2d
template_dgram *= pseudo_beta_mask_2d[..., None]
act += self.dgram_linear(template_dgram)
act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None])
aatype_one_hot = aatype_one_hot.to(template_dgram.dtype)
act += self.aatype_linear_1(aatype_one_hot[..., None, :, :])
act += self.aatype_linear_2(aatype_one_hot[..., None, :])
backbone_mask_2d = (
backbone_mask[..., None] * backbone_mask[..., None, :]
)
backbone_mask_2d *= multichain_mask_2d
x, y, z = [coord * backbone_mask_2d for coord in unit_vector]
act += self.x_linear(x[..., None])
act += self.y_linear(y[..., None])
act += self.z_linear(z[..., None])
act += self.backbone_mask_linear(backbone_mask_2d[..., None])
query_embedding = self.query_embedding_layer_norm(query_embedding)
act += self.query_embedding_linear(query_embedding)
return act
class TemplateSingleEmbedderMultimer(nn.Module):
def __init__(self,
c_in: int,
c_m: int,
):
super(TemplateSingleEmbedderMultimer, self).__init__()
self.template_single_embedder = Linear(c_in, c_m)
self.template_projector = Linear(c_m, c_m)
def forward(self,
batch,
atom_pos,
aatype_one_hot,
):
out = {}
template_chi_angles, template_chi_mask = (
all_atom_multimer.compute_chi_angles(
atom_pos,
batch["template_all_atom_mask"],
batch["template_aatype"],
)
)
template_features = torch.cat(
[
aatype_one_hot,
torch.sin(template_chi_angles) * template_chi_mask,
torch.cos(template_chi_angles) * template_chi_mask,
template_chi_mask,
],
dim=-1,
)
template_mask = template_chi_mask[..., 0]
template_activations = self.template_single_embedder(
template_features
)
template_activations = torch.nn.functional.relu(
template_activations
)
template_activations = self.template_projector(
template_activations,
)
out["template_single_embedding"] = (
template_activations
)
out["template_mask"] = template_mask
return out
class TemplateEmbedderMultimer(nn.Module):
def __init__(self, config):
super(TemplateEmbedderMultimer, self).__init__()
self.config = config
self.template_pair_embedder = TemplatePairEmbedderMultimer(
**config["template_pair_embedder"],
)
self.template_single_embedder = TemplateSingleEmbedderMultimer(
**config["template_single_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.linear_t = Linear(config.c_t, config.c_z)
def forward(self,
batch,
z,
padding_mask_2d,
templ_dim,
chunk_size,
multichain_mask_2d,
use_lma=False,
inplace_safe=False
):
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
act = 0.
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
single_template_feats["template_pseudo_beta_mask"],
)
template_dgram = dgram_from_positions(
template_positions,
inf=self.config.inf,
**self.config.distogram,
)
aatype_one_hot = torch.nn.functional.one_hot(
single_template_feats["template_aatype"], 22,
)
raw_atom_pos = single_template_feats["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
pair_act = self.template_pair_embedder(
template_dgram,
aatype_one_hot,
z,
pseudo_beta_mask,
backbone_mask,
multichain_mask_2d,
unit_vector,
)
single_template_embeds["template_pair_embedding"] = pair_act
single_template_embeds.update(
self.template_single_embedder(
single_template_feats,
atom_pos,
aatype_one_hot,
)
)
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)
# [*, N, N, C_z]
t = torch.sum(t, dim=-4) / n_templ
t = torch.nn.functional.relu(t)
t = self.linear_t(t)
template_embeds["template_pair_embedding"] = t
return template_embeds
This diff is collapsed.
...@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module): ...@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module):
if self.config.tm.enabled: if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"]) tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm( aux_out["ptm_score"] = compute_tm(
tm_logits, **self.config.tm tm_logits, **self.config.tm
) )
asym_id = outputs.get("asym_id")
if asym_id is not None:
aux_out["iptm_score"] = compute_tm(
tm_logits, asym_id=asym_id, interface=True, **self.config.tm
)
aux_out["weighted_ptm_score"] = (self.config.tm["iptm_weight"] * aux_out["iptm_score"]
+ self.config.tm["ptm_weight"] * aux_out["ptm_score"])
aux_out.update( aux_out.update(
compute_predicted_aligned_error( compute_predicted_aligned_error(
tm_logits, tm_logits,
......
This diff is collapsed.
This diff is collapsed.
...@@ -33,6 +33,8 @@ from openfold.model.triangular_attention import ( ...@@ -33,6 +33,8 @@ from openfold.model.triangular_attention import (
from openfold.model.triangular_multiplicative_update import ( from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing, TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
FusedTriangleMultiplicationOutgoing,
FusedTriangleMultiplicationIncoming
) )
from openfold.utils.checkpointing import checkpoint_blocks from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.chunk_utils import ( from openfold.utils.chunk_utils import (
...@@ -154,6 +156,8 @@ class TemplatePairStackBlock(nn.Module): ...@@ -154,6 +156,8 @@ class TemplatePairStackBlock(nn.Module):
no_heads: int, no_heads: int,
pair_transition_n: int, pair_transition_n: int,
dropout_rate: float, dropout_rate: float,
tri_mul_first: bool,
fuse_projection_weights: bool,
inf: float, inf: float,
**kwargs, **kwargs,
): ):
...@@ -166,6 +170,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -166,6 +170,7 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n = pair_transition_n self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.inf = inf self.inf = inf
self.tri_mul_first = tri_mul_first
self.dropout_row = DropoutRowwise(self.dropout_rate) self.dropout_row = DropoutRowwise(self.dropout_rate)
self.dropout_col = DropoutColumnwise(self.dropout_rate) self.dropout_col = DropoutColumnwise(self.dropout_rate)
...@@ -183,20 +188,88 @@ class TemplatePairStackBlock(nn.Module): ...@@ -183,20 +188,88 @@ class TemplatePairStackBlock(nn.Module):
inf=inf, inf=inf,
) )
self.tri_mul_out = TriangleMultiplicationOutgoing( if fuse_projection_weights:
self.c_t, self.tri_mul_out = FusedTriangleMultiplicationOutgoing(
self.c_hidden_tri_mul, self.c_t,
) self.c_hidden_tri_mul,
self.tri_mul_in = TriangleMultiplicationIncoming( )
self.c_t, self.tri_mul_in = FusedTriangleMultiplicationIncoming(
self.c_hidden_tri_mul, self.c_t,
) self.c_hidden_tri_mul,
)
else:
self.tri_mul_out = TriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
self.pair_transition = PairTransition( self.pair_transition = PairTransition(
self.c_t, self.c_t,
self.pair_transition_n, self.pair_transition_n,
) )
def tri_att_start_end(self, single, _attn_chunk_size, single_mask, use_lma, inplace_safe):
single = add(single,
self.dropout_row(
self.tri_att_start(
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace_safe,
)
single = add(single,
self.dropout_col(
self.tri_att_end(
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace_safe,
)
return single
def tri_mul_out_in(self, single, single_mask, inplace_safe):
tmu_update = self.tri_mul_out(
single,
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if (not inplace_safe):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
del tmu_update
tmu_update = self.tri_mul_in(
single,
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if (not inplace_safe):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
del tmu_update
return single
def forward(self, def forward(self,
z: torch.Tensor, z: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
...@@ -219,72 +292,37 @@ class TemplatePairStackBlock(nn.Module): ...@@ -219,72 +292,37 @@ class TemplatePairStackBlock(nn.Module):
for i in range(len(single_templates)): for i in range(len(single_templates)):
single = single_templates[i] single = single_templates[i]
single_mask = single_templates_masks[i] single_mask = single_templates_masks[i]
single = add(single,
self.dropout_row(
self.tri_att_start(
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace_safe,
)
single = add(single,
self.dropout_col(
self.tri_att_end(
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace_safe,
)
tmu_update = self.tri_mul_out( if self.tri_mul_first:
single, single = self.tri_att_start_end(single=self.tri_mul_out_in(single=single,
mask=single_mask, single_mask=single_mask,
inplace_safe=inplace_safe, inplace_safe=inplace_safe),
_add_with_inplace=True, _attn_chunk_size=_attn_chunk_size,
) single_mask=single_mask,
if(not inplace_safe): use_lma=use_lma,
single = single + self.dropout_row(tmu_update) inplace_safe=inplace_safe)
else: else:
single = tmu_update single = self.tri_mul_out_in(single=self.tri_att_start_end(single=single,
_attn_chunk_size=_attn_chunk_size,
del tmu_update single_mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe),
single_mask=single_mask,
inplace_safe=inplace_safe)
tmu_update = self.tri_mul_in(
single,
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
del tmu_update
single = add(single, single = add(single,
self.pair_transition( self.pair_transition(
single, single,
mask=single_mask if _mask_trans else None, mask=single_mask if _mask_trans else None,
chunk_size=chunk_size, chunk_size=chunk_size,
), ),
inplace_safe, inplace_safe,
) )
if(not inplace_safe): if (not inplace_safe):
single_templates[i] = single single_templates[i] = single
if(not inplace_safe): if (not inplace_safe):
z = torch.cat(single_templates, dim=-4) z = torch.cat(single_templates, dim=-4)
return z return z
...@@ -303,6 +341,8 @@ class TemplatePairStack(nn.Module): ...@@ -303,6 +341,8 @@ class TemplatePairStack(nn.Module):
no_heads, no_heads,
pair_transition_n, pair_transition_n,
dropout_rate, dropout_rate,
tri_mul_first,
fuse_projection_weights,
blocks_per_ckpt, blocks_per_ckpt,
tune_chunk_size: bool = False, tune_chunk_size: bool = False,
inf=1e9, inf=1e9,
...@@ -339,6 +379,8 @@ class TemplatePairStack(nn.Module): ...@@ -339,6 +379,8 @@ class TemplatePairStack(nn.Module):
no_heads=no_heads, no_heads=no_heads,
pair_transition_n=pair_transition_n, pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf, inf=inf,
) )
self.blocks.append(block) self.blocks.append(block)
...@@ -512,7 +554,7 @@ def embed_templates_offload( ...@@ -512,7 +554,7 @@ def embed_templates_offload(
# [*, N, C_m] # [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat) a = model.template_angle_embedder(template_angle_feat)
ret["template_angle_embedding"] = a ret["template_single_embedding"] = a
ret.update({"template_pair_embedding": t}) ret.update({"template_pair_embedding": t})
...@@ -623,7 +665,7 @@ def embed_templates_average( ...@@ -623,7 +665,7 @@ def embed_templates_average(
# [*, N, C_m] # [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat) a = model.template_angle_embedder(template_angle_feat)
ret["template_angle_embedding"] = a ret["template_single_embedding"] = a
ret.update({"template_pair_embedding": out_tensor}) ret.update({"template_pair_embedding": out_tensor})
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from openfold.utils.geometry import rigid_matrix_vector
from openfold.utils.geometry import rotation_matrix
from openfold.utils.geometry import struct_of_array
from openfold.utils.geometry import vector
Rot3Array = rotation_matrix.Rot3Array
Rigid3Array = rigid_matrix_vector.Rigid3Array
StructOfArray = struct_of_array.StructOfArray
Vec3Array = vector.Vec3Array
square_euclidean_distance = vector.square_euclidean_distance
euclidean_distance = vector.euclidean_distance
dihedral_angle = vector.dihedral_angle
dot = vector.dot
cross = vector.cross
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment