"pcdet/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "883fc156518bb5f4c6d9120635ea998f3169a04d"
Commit d8ee9c5f authored by Christina Floristean's avatar Christina Floristean
Browse files

All non-cuda tests passing for monomer/multimer. Tri mul/attn and OPM order switched.

parent 260db67f
.DS_Store
*.DS_Store
**/.DS_Store
.idea/
**/__pycache__
*.pyc
build/
dist/
*.egg-info/
openfold/resources
**/stereo_chemical_props.txt
**/sample_feats.pickle
......@@ -331,6 +331,7 @@ config = mlc.ConfigDict(
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": False,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
......@@ -367,6 +368,7 @@ config = mlc.ConfigDict(
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": False,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
......@@ -388,6 +390,7 @@ config = mlc.ConfigDict(
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": False,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
......@@ -546,6 +549,7 @@ multimer_model_config_update = {
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": True,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
......@@ -555,6 +559,53 @@ multimer_model_config_update = {
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
"use_unit_vector": True
},
"extra_msa": {
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
},
"extra_msa_stack": {
"c_m": c_e,
"c_z": c_z,
"c_hidden_msa_att": 8,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
},
"evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
"eps": eps, # 1e-10,
},
"heads": {
"lddt": {
......
......@@ -93,7 +93,7 @@ def fix_templates_aatype(protein):
# Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
new_order_list, dtype=torch.int64, device=protein["aatype"].device,
new_order_list, dtype=torch.int64, device=protein["template_aatype"].device,
).expand(num_templates, -1)
protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"]
......@@ -669,8 +669,8 @@ def make_atom14_masks(protein):
def make_atom14_masks_np(batch):
batch = tree_map(
lambda n: torch.tensor(n, device=batch["aatype"].device),
batch,
lambda n: torch.tensor(n, device="cpu"),
batch,
np.ndarray
)
out = make_atom14_masks(batch)
......
......@@ -1048,7 +1048,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
for i in idx:
# We got all the templates we wanted, stop processing hits.
if len(already_seen) >= self.max_hits:
if len(already_seen) >= self._max_hits:
break
hit = filtered[i]
......@@ -1088,16 +1088,29 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
for k in template_features:
template_features[k].append(result.features[k])
for name in template_features:
if num_hits > 0:
if already_seen:
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
else:
# Make sure the feature has correct dtype even if empty.
template_features[name] = np.array(
[], dtype=TEMPLATE_FEATURES[name]
)
else:
num_res = len(query_sequence)
# Construct a default template with all zeros.
template_features = {
"template_aatype": np.zeros(
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32
),
"template_all_atom_masks": np.zeros(
(1, num_res, residue_constants.atom_type_num), np.float32
),
"template_all_atom_positions": np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_sum_probs": np.array([0], dtype=np.float32),
}
return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings
......
......@@ -216,10 +216,12 @@ class InputEmbedderMultimer(nn.Module):
(2 * self.max_relative_idx + 1) *
torch.ones_like(clipped_offset)
)
rel_pos = torch.nn.functional.one_hot(
boundaries = torch.arange(
start=0, end=2 * self.max_relative_idx + 2, device=final_offset.device
)
rel_pos = one_hot(
final_offset,
2 * self.max_relative_idx + 2,
boundaries,
)
rel_feats.append(rel_pos)
......@@ -245,15 +247,21 @@ class InputEmbedderMultimer(nn.Module):
torch.ones_like(clipped_rel_chain)
)
rel_chain = torch.nn.functional.one_hot(
boundaries = torch.arange(
start=0, end=2 * max_rel_chain + 2, device=final_rel_chain.device
)
rel_chain = one_hot(
final_rel_chain,
2 * max_rel_chain + 2,
boundaries,
)
rel_feats.append(rel_chain)
else:
rel_pos = torch.nn.functional.one_hot(
clipped_offset, 2 * self.max_relative_idx + 1,
boundaries = torch.arange(
start=0, end=2 * self.max_relative_idx + 1, device=clipped_offset.device
)
rel_pos = one_hot(
clipped_offset, boundaries,
)
rel_feats.append(rel_pos)
......@@ -471,102 +479,6 @@ class TemplatePairEmbedder(nn.Module):
return x
class TemplateEmbedder(nn.Module):
def __init__(
self,
config,
):
super().__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(
self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True,
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {}
if self.config.embed_angles:
ret["template_pair_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t})
return ret
class ExtraMSAEmbedder(nn.Module):
"""
Embeds unclustered MSA sequences.
......@@ -625,12 +537,13 @@ class TemplateEmbedder(nn.Module):
**config["template_pointwise_attention"],
)
def forward(self,
batch,
def forward(
self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True
):
# Embed the templates one at a time (with a poor man's vmap)
......@@ -706,7 +619,7 @@ class TemplatePairEmbedderMultimer(nn.Module):
c_dgram: int,
c_aatype: int,
):
super().__init__()
super(TemplatePairEmbedderMultimer, self).__init__()
self.dgram_linear = Linear(c_dgram, c_out)
self.aatype_linear_1 = Linear(c_aatype, c_out)
......@@ -765,7 +678,7 @@ class TemplateSingleEmbedderMultimer(nn.Module):
c_in: int,
c_m: int,
):
super().__init__()
super(TemplateSingleEmbedderMultimer, self).__init__()
self.template_single_embedder = Linear(c_in, c_m)
self.template_projector = Linear(c_m, c_m)
......
......@@ -117,34 +117,19 @@ class MSATransition(nn.Module):
return m
class EvoformerBlockCore(nn.Module):
class PairStack(nn.Module):
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
pair_dropout: float,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
eps: float
):
super(EvoformerBlockCore, self).__init__()
self.msa_transition = MSATransition(
c_m=c_m,
n=transition_n,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
super(PairStack, self).__init__()
self.tri_mul_out = TriangleMultiplicationOutgoing(
c_z,
......@@ -178,25 +163,15 @@ class EvoformerBlockCore(nn.Module):
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size
)
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(
......@@ -209,7 +184,7 @@ class EvoformerBlockCore(nn.Module):
z, mask=pair_trans_mask, chunk_size=chunk_size
)
return m, z
return z
class EvoformerBlock(nn.Module):
......@@ -225,11 +200,14 @@ class EvoformerBlock(nn.Module):
transition_n: int,
msa_dropout: float,
pair_dropout: float,
opm_first: bool,
inf: float,
eps: float,
):
super(EvoformerBlock, self).__init__()
self.opm_first = opm_first
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
......@@ -247,18 +225,26 @@ class EvoformerBlock(nn.Module):
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
self.msa_transition = MSATransition(
c_m=c_m,
n=transition_n,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
self.pair_stack = PairStack(
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
eps=eps
)
def forward(self,
......@@ -269,17 +255,34 @@ class EvoformerBlock(nn.Module):
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask = msa_mask if _mask_trans else None
if self.opm_first:
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size
)
if not self.opm_first:
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
z = self.pair_stack(
z,
pair_mask=pair_mask,
chunk_size=chunk_size,
)
return m, z
......@@ -304,12 +307,14 @@ class ExtraMSABlock(nn.Module):
transition_n: int,
msa_dropout: float,
pair_dropout: float,
opm_first: bool,
inf: float,
eps: float,
ckpt: bool,
):
super(ExtraMSABlock, self).__init__()
self.opm_first = opm_first
self.ckpt = ckpt
self.msa_att_row = MSARowAttentionWithPairBias(
......@@ -330,13 +335,21 @@ class ExtraMSABlock(nn.Module):
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
self.msa_transition = MSATransition(
c_m=c_m,
n=transition_n,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
self.pair_stack = PairStack(
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
......@@ -361,6 +374,11 @@ class ExtraMSABlock(nn.Module):
m1 += m2
return m1
if self.opm_first:
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
m = add(m, self.msa_dropout_layer(
self.msa_att_row(
......@@ -377,8 +395,17 @@ class ExtraMSABlock(nn.Module):
def fn(m, z):
m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size))
m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
m = add(m, self.msa_transition(
m, mask=msa_mask, chunk_size=chunk_size
))
if not self.opm_first:
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
z = self.pair_stack(
z, pair_mask=pair_mask, chunk_size=chunk_size
)
return m, z
......@@ -414,6 +441,7 @@ class EvoformerStack(nn.Module):
transition_n: int,
msa_dropout: float,
pair_dropout: float,
opm_first: bool,
blocks_per_ckpt: int,
inf: float,
eps: float,
......@@ -475,6 +503,7 @@ class EvoformerStack(nn.Module):
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
opm_first=opm_first,
inf=inf,
eps=eps,
)
......@@ -555,6 +584,7 @@ class ExtraMSAStack(nn.Module):
transition_n: int,
msa_dropout: float,
pair_dropout: float,
opm_first: bool,
inf: float,
eps: float,
ckpt: bool,
......@@ -581,6 +611,7 @@ class ExtraMSAStack(nn.Module):
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
opm_first=opm_first,
inf=inf,
eps=eps,
ckpt=ckpt if chunk_msa_attn else False,
......
......@@ -169,8 +169,8 @@ class PointProjection(nn.Module):
def forward(self,
activations: torch.Tensor,
rigids: Rigid3Array,
) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array]]:
rigids: Union[Rigid, Rigid3Array],
) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array], torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO: Needs to run in high precision during training
points_local = self.linear(activations)
points_local = points_local.reshape(
......@@ -181,8 +181,9 @@ class PointProjection(nn.Module):
points_local = torch.split(
points_local, points_local.shape[-1] // 3, dim=-1
)
points_local = Vec3Array(*points_local)
points_global = rigids[..., None, None].apply_to_point(points_local)
points_local = torch.stack(points_local, dim=-1)
points_global = rigids[..., None, None].apply(points_local)
if(self.return_local_points):
return points_global, points_local
......@@ -285,7 +286,7 @@ class InvariantPointAttention(nn.Module):
self,
s: torch.Tensor,
z: torch.Tensor,
r: Rigid,
r: Union[Rigid, Rigid3Array],
mask: torch.Tensor,
) -> torch.Tensor:
"""
......@@ -340,9 +341,6 @@ class InvariantPointAttention(nn.Module):
k, v = torch.split(kv, self.c_hidden, dim=-1)
kv_pts = self.linear_kv_points(s, r)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
# [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split(
......@@ -364,10 +362,16 @@ class InvariantPointAttention(nn.Module):
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :]
# [*, N_res, N_res, H, P_q]
pt_att = sum([c**2 for c in pt_att])
if self.is_multimer:
pt_att = q_pts.unsqueeze(-3) - k_pts.unsqueeze(-4)
# [*, N_res, N_res, H, P_q]
pt_att = sum([c ** 2 for c in pt_att])
else:
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att ** 2
pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
)
......@@ -399,20 +403,42 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, N_res, H, P_v]
o_pt = v_pts * permute_final_dims(a, (1, 2, 0)).unsqueeze(-1)
o_pt = o_pt.sum(dim=-3)
if self.is_multimer:
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, N_res, H, P_v]
o_pt = v_pts[..., None, :, :, :] * permute_final_dims(a, (1, 2, 0)).unsqueeze(-1)
o_pt = o_pt.sum(dim=-3)
# [*, N_res, H, P_v]
o_pt = r[..., None, None].apply_inverse_to_point(o_pt)
# [*, N_res, H, P_v]
o_pt = r[..., None, None].apply_inverse_to_point(o_pt)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(self.eps)
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(self.eps)
else:
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
o_pt = torch.unbind(o_pt, dim=-1)
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
......@@ -617,7 +643,10 @@ class StructureModule(nn.Module):
self.dropout_rate,
)
self.bb_update = QuatRigid(self.c_s, full_quat=False)
if self.is_multimer:
self.bb_update = QuatRigid(self.c_s, full_quat=False)
else:
self.bb_update = BackboneUpdate(self.c_s)
self.angle_resnet = AngleResnet(
self.c_s,
......
......@@ -141,6 +141,7 @@ class TemplatePairStackBlock(nn.Module):
no_heads: int,
pair_transition_n: int,
dropout_rate: float,
tri_mul_first: bool,
inf: float,
**kwargs,
):
......@@ -153,6 +154,7 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate
self.inf = inf
self.tri_mul_first = tri_mul_first
self.dropout_row = DropoutRowwise(self.dropout_rate)
self.dropout_col = DropoutColumnwise(self.dropout_rate)
......@@ -184,6 +186,38 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n,
)
def tri_att_start_end(self, single, single_mask, chunk_size):
single = single + self.dropout_row(
self.tri_att_start(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
single = single + self.dropout_col(
self.tri_att_end(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
return single
def tri_mul_out_in(self, single, single_mask):
single = single + self.dropout_row(
self.tri_mul_out(
single,
mask=single_mask
)
)
single = single + self.dropout_row(
self.tri_mul_in(
single,
mask=single_mask
)
)
return single
def forward(self,
z: torch.Tensor,
mask: torch.Tensor,
......@@ -200,32 +234,17 @@ class TemplatePairStackBlock(nn.Module):
single = single_templates[i]
single_mask = single_templates_masks[i]
single = single + self.dropout_row(
self.tri_att_start(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
single = single + self.dropout_col(
self.tri_att_end(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
single = single + self.dropout_row(
self.tri_mul_out(
single,
mask=single_mask
)
)
single = single + self.dropout_row(
self.tri_mul_in(
single,
mask=single_mask
)
)
if self.tri_mul_first:
single = self.tri_att_start_end(single=self.tri_mul_out_in(single=single,
single_mask=single_mask),
single_mask=single_mask,
chunk_size=chunk_size)
else:
single = self.tri_mul_out_in(single=self.tri_att_start_end(single=single,
single_mask=single_mask,
chunk_size=chunk_size),
single_mask=single_mask)
single = single + self.pair_transition(
single,
mask=single_mask if _mask_trans else None,
......@@ -252,6 +271,7 @@ class TemplatePairStack(nn.Module):
no_heads,
pair_transition_n,
dropout_rate,
tri_mul_first,
blocks_per_ckpt,
inf=1e9,
**kwargs,
......@@ -287,6 +307,7 @@ class TemplatePairStack(nn.Module):
no_heads=no_heads,
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
tri_mul_first=tri_mul_first,
inf=inf,
)
self.blocks.append(block)
......
......@@ -18,7 +18,7 @@ import math
import numpy as np
import torch
import torch.nn as nn
from typing import Dict
from typing import Dict, Union
from openfold.np import protein
import openfold.np.residue_constants as rc
......@@ -179,11 +179,11 @@ def build_extra_msa_feat(batch):
batch["extra_has_deletion"].unsqueeze(-1),
batch["extra_deletion_value"].unsqueeze(-1),
]
return msa_feat
return torch.cat(msa_feat, dim=-1)
def torsion_angles_to_frames(
r: Rigid,
r: Union[Rigid, rigid_matrix_vector.Rigid3Array],
alpha: torch.Tensor,
aatype: torch.Tensor,
rrgdf: torch.Tensor,
......@@ -220,8 +220,14 @@ def torsion_angles_to_frames(
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = rotation_matrix.Rot3Array.from_array(all_rots)
all_frames = default_r.compose_rotation(all_rots)
if isinstance(r, Rigid):
rigid_type = Rigid
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots)
else:
rigid_type = rigid_matrix_vector.Rigid3Array
all_rots = rotation_matrix.Rot3Array.from_array(all_rots)
all_frames = default_r.compose_rotation(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
......@@ -232,7 +238,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = rigid_matrix_vector.Rigid3Array.cat(
all_frames_to_bb = rigid_type.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
......@@ -248,7 +254,7 @@ def torsion_angles_to_frames(
def frames_and_literature_positions_to_atom14_pos(
r: rigid_matrix_vector.Rigid3Array,
r: Union[Rigid, rigid_matrix_vector.Rigid3Array],
aatype: torch.Tensor,
default_frames,
group_idx,
......@@ -277,6 +283,8 @@ def frames_and_literature_positions_to_atom14_pos(
# [*, N, 14]
atom_mask = atom_mask[aatype, ...]
if isinstance(r, Rigid):
atom_mask = atom_mask.unsqueeze(-1)
# [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...]
......
......@@ -142,7 +142,7 @@ class Rigid3Array:
def reshape(self, new_shape) -> Rigid3Array:
rots = self.rotation.reshape(new_shape)
trans = self.translation.reshape(new_shape)
return Rigid3Aray(rots, trans)
return Rigid3Array(rots, trans)
def stop_rot_gradient(self) -> Rigid3Array:
return Rigid3Array(
......@@ -174,3 +174,6 @@ class Rigid3Array:
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]
)
return cls(rotation, translation)
def cuda(self) -> Rigid3Array:
return Rigid3Array.from_tensor_4x4(self.to_tensor_4x4().cuda())
......@@ -15,6 +15,7 @@
from __future__ import annotations
import dataclasses
from typing import List
import torch
import numpy as np
......@@ -172,10 +173,10 @@ class Rot3Array:
"""Construct Rot3Array from components of quaternion."""
if normalize:
inv_norm = torch.rsqrt(eps + w**2 + x**2 + y**2 + z**2)
w *= inv_norm
x *= inv_norm
y *= inv_norm
z *= inv_norm
w = w * inv_norm
x = x * inv_norm
y = y * inv_norm
z = z * inv_norm
xx = 1 - 2 * (y ** 2 + z ** 2)
xy = 2 * (x * y - w * z)
xz = 2 * (x * z + w * y)
......
......@@ -110,7 +110,7 @@ class Vec3Array:
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2 = self.dot(self)
if epsilon:
norm2 = torch.clamp(norm2, max=epsilon**2)
norm2 = torch.clamp(norm2, min=epsilon**2)
return torch.sqrt(norm2)
def norm2(self):
......
......@@ -129,7 +129,7 @@ def assign(translation_dict, orig_weights):
raise
def get_translation_dict(model, is_multimer=False):
def get_translation_dict(model, version, is_multimer=False):
#######################
# Some templates
#######################
......@@ -247,7 +247,7 @@ def get_translation_dict(model, is_multimer=False):
)
IPAParams = lambda ipa: {
"q_scalar_projection": LinearParams(ipa.linear_q),
"q_scalar": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points.linear),
"kv_point_local": LinearParams(ipa.linear_kv_points.linear),
......@@ -331,19 +331,19 @@ def get_translation_dict(model, is_multimer=False):
b.msa_att_row
),
col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.core.msa_transition),
"msa_transition": MSATransitionParams(b.msa_transition),
"outer_product_mean":
OuterProductMeanParams(b.core.outer_product_mean),
OuterProductMeanParams(b.outer_product_mean),
"triangle_multiplication_outgoing":
TriMulOutParams(b.core.tri_mul_out),
TriMulOutParams(b.pair_stack.tri_mul_out),
"triangle_multiplication_incoming":
TriMulInParams(b.core.tri_mul_in),
TriMulInParams(b.pair_stack.tri_mul_in),
"triangle_attention_starting_node":
TriAttParams(b.core.tri_att_start),
TriAttParams(b.pair_stack.tri_att_start),
"triangle_attention_ending_node":
TriAttParams(b.core.tri_att_end),
TriAttParams(b.pair_stack.tri_att_end),
"pair_transition":
PairTransitionParams(b.core.pair_transition),
PairTransitionParams(b.pair_stack.pair_transition),
}
return d
......@@ -584,17 +584,6 @@ def get_translation_dict(model, is_multimer=False):
},
}
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(
model,
is_multimer=("multimer" in version)
)
no_templ = [
"model_3",
"model_4",
......@@ -615,6 +604,18 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"logits": LinearParams(model.aux_heads.tm.linear)
}
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(
model,
version,
is_multimer=("multimer" in version)
)
# Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations)
......
......@@ -636,9 +636,7 @@ def compute_tm(
)
bin_centers = _calculate_bin_centers(boundaries)
soft_n = torch.sum(residue_weights, dim=-1).to(torch.int32)
other = n.new_zeros() + 19
clipped_n = torch.max(soft_n, other, dim=-1)
clipped_n = max(torch.sum(residue_weights), 19)
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
......
......@@ -986,6 +986,16 @@ class Rigid:
"""
return self._trans.device
@property
def dtype(self) -> torch.dtype:
"""
Returns the dtype of the Rigid tensors.
Returns:
The dtype of the Rigid tensors
"""
return self._rots.dtype
def get_rots(self) -> Rotation:
"""
Getter for the rotation.
......
......@@ -46,26 +46,26 @@ def import_alphafold():
def get_alphafold_config():
config = alphafold.model.config.model_config("model_1_ptm") # noqa
config = alphafold.model.config.model_config(consts.model) # noqa
config.model.global_config.deterministic = True
return config
_param_path = "openfold/resources/params/params_model_1_ptm.npz"
_param_path = f"openfold/resources/params/params_{consts.model}.npz"
_model = None
def get_global_pretrained_openfold():
global _model
if _model is None:
_model = AlphaFold(model_config("model_1_ptm"))
_model = AlphaFold(model_config(consts.model))
_model = _model.eval()
if not os.path.exists(_param_path):
raise FileNotFoundError(
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
)
import_jax_weights_(_model, _param_path, version="model_1_ptm")
import_jax_weights_(_model, _param_path, version=consts.model)
_model = _model.cuda()
return _model
......
......@@ -2,6 +2,9 @@ import ml_collections as mlc
consts = mlc.ConfigDict(
{
"model": "model_1_multimer_v2", # monomer:model_1_ptm, multimer: model_1_multimer_v2
"is_multimer": True, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2,
"n_res": 11,
"n_seq": 13,
......@@ -15,6 +18,7 @@ consts = mlc.ConfigDict(
"c_s": 384,
"c_t": 64,
"c_e": 64,
"msa_logits": 22 # monomer: 23, multimer: 22
}
)
......
......@@ -12,9 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from random import randint
import numpy as np
from scipy.spatial.transform import Rotation
from tests.config import consts
def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
n_chain = randint(1, n_res // min_chain_len) if consts.is_multimer else 1
if not split_chains:
return [0] * n_res
assert n_res >= n_chain
pieces = []
asym_ids = []
for idx in range(n_chain - 1):
piece = randint(min_chain_len, (n_res - sum(pieces) - n_chain + idx - min_chain_len))
pieces.append(piece)
asym_ids.extend(piece * [idx])
asym_ids.extend((n_res - sum(pieces)) * [n_chain - 1])
return np.array(asym_ids).astype(np.int64)
def random_template_feats(n_templ, n, batch_size=None):
b = []
......@@ -39,6 +61,11 @@ def random_template_feats(n_templ, n, batch_size=None):
}
batch = {k: v.astype(np.float32) for k, v in batch.items()}
batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
if consts.is_multimer:
asym_ids = np.array(random_asym_ids(n))
batch["asym_id"] = np.tile(asym_ids[np.newaxis, :], (*b, n_templ, 1))
return batch
......
......@@ -15,19 +15,13 @@
import pickle
import shutil
import torch
import numpy as np
import unittest
from openfold.data.data_pipeline import DataPipeline
from openfold.data.templates import TemplateHitFeaturizer
from openfold.model.embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
)
from openfold.data.templates import HhsearchHitFeaturizer, HmmsearchHitFeaturizer
import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
......@@ -45,13 +39,29 @@ class TestDataPipeline(unittest.TestCase):
with open("tests/test_data/alphafold_feature_dict.pickle", "rb") as fp:
alphafold_feature_dict = pickle.load(fp)
template_featurizer = TemplateHitFeaturizer(
mmcif_dir="tests/test_data/mmcifs",
max_template_date="2021-12-20",
max_hits=20,
kalign_binary_path=shutil.which("kalign"),
_zero_center_positions=False,
)
if consts.is_multimer:
# template_featurizer = HmmsearchHitFeaturizer(
# mmcif_dir="tests/test_data/mmcifs",
# max_template_date="2021-12-20",
# max_hits=20,
# kalign_binary_path=shutil.which("kalign"),
# _zero_center_positions=False,
# )
template_featurizer = HhsearchHitFeaturizer(
mmcif_dir="tests/test_data/mmcifs",
max_template_date="2021-12-20",
max_hits=20,
kalign_binary_path=shutil.which("kalign"),
_zero_center_positions=False,
)
else:
template_featurizer = HhsearchHitFeaturizer(
mmcif_dir="tests/test_data/mmcifs",
max_template_date="2021-12-20",
max_hits=20,
kalign_binary_path=shutil.which("kalign"),
_zero_center_positions=False,
)
data_pipeline = DataPipeline(
template_featurizer=template_featurizer,
......
import copy
import gzip
import os
import pickle
import numpy as np
......@@ -178,7 +174,7 @@ class TestDataTransforms(unittest.TestCase):
protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)}
protein = make_hhblits_profile(protein)
masked_msa_config = config.data.common.masked_msa
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15)
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15, seed=42)
assert 'bert_mask' in protein
assert 'true_msa' in protein
assert 'msa' in protein
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment