Commit 07e64267 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Standardize code style

parent de07730f
......@@ -26,28 +26,25 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
# With Param, a poor man's enum with attributes (Rust-style)
class ParamType(Enum):
LinearWeight = partial( # hack: partial prevents fns from becoming methods
LinearWeight = partial( # hack: partial prevents fns from becoming methods
lambda w: w.transpose(-1, -2)
)
LinearWeightMHA = partial(
lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2)
lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2)
)
LinearMHAOutputWeight = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
LinearBiasMHA = partial(
lambda w: w.reshape(*w.shape[:-2], -1)
)
LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1))
LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
Other = partial(
lambda w: w
)
Other = partial(lambda w: w)
def __init__(self, fn):
self.transformation = fn
@dataclass
class Param:
param: Union[torch.Tensor, List[torch.Tensor]]
......@@ -58,16 +55,17 @@ class Param:
def _process_translations_dict(d, top_layer=True):
flat = {}
for k, v in d.items():
if(type(v) == dict):
prefix = _NPZ_KEY_PREFIX if top_layer else ''
if type(v) == dict:
prefix = _NPZ_KEY_PREFIX if top_layer else ""
sub_flat = {
(prefix + '/'.join([k, k_prime])):v_prime
for k_prime, v_prime in
_process_translations_dict(v, top_layer=False).items()
(prefix + "/".join([k, k_prime])): v_prime
for k_prime, v_prime in _process_translations_dict(
v, top_layer=False
).items()
}
flat.update(sub_flat)
else:
k = '/' + k if not top_layer else k
k = "/" + k if not top_layer else k
flat[k] = v
return flat
......@@ -75,29 +73,29 @@ def _process_translations_dict(d, top_layer=True):
def stacked(param_dict_list, out=None):
"""
Args:
param_dict_list:
A list of (nested) Param dicts to stack. The structure of
each dict must be the identical (down to the ParamTypes of
"parallel" Params). There must be at least one dict
in the list.
Args:
param_dict_list:
A list of (nested) Param dicts to stack. The structure of
each dict must be the identical (down to the ParamTypes of
"parallel" Params). There must be at least one dict
in the list.
"""
if(out is None):
if out is None:
out = {}
template = param_dict_list[0]
for k, _ in template.items():
v = [d[k] for d in param_dict_list]
if(type(v[0]) is dict):
if type(v[0]) is dict:
out[k] = {}
stacked(v, out=out[k])
elif(type(v[0]) is Param):
elif type(v[0]) is Param:
stacked_param = Param(
param=[param.param for param in v],
param_type=v[0].param_type,
stacked=True
stacked=True,
)
out[k] = stacked_param
out[k] = stacked_param
return out
......@@ -107,12 +105,12 @@ def assign(translation_dict, orig_weights):
with torch.no_grad():
weights = torch.as_tensor(orig_weights[k])
ref, param_type = param.param, param.param_type
if(param.stacked):
if param.stacked:
weights = torch.unbind(weights, 0)
else:
weights = [weights]
ref = [ref]
try:
weights = list(map(param_type.transformation, weights))
for p, w in zip(ref, weights):
......@@ -121,36 +119,25 @@ def assign(translation_dict, orig_weights):
print(k)
print(ref[0].shape)
print(weights[0].shape)
raise
raise
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
data = np.load(npz_path)
#######################
# Some templates
#######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearWeight = lambda l: (
Param(l, param_type=ParamType.LinearWeight)
)
LinearBias = lambda l: (
Param(l)
)
LinearBias = lambda l: (Param(l))
LinearWeightMHA = lambda l: (
Param(l, param_type=ParamType.LinearWeightMHA)
)
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (
Param(b, param_type=ParamType.LinearBiasMHA)
)
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (
Param(l, param_type=ParamType.LinearWeightOPM)
)
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearParams = lambda l: {
"weights": LinearWeight(l.weight),
......@@ -167,7 +154,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"key_w": LinearWeightMHA(att.linear_k.weight),
"value_w": LinearWeightMHA(att.linear_v.weight),
"output_w": Param(
att.linear_o.weight, param_type=ParamType.LinearMHAOutputWeight,
att.linear_o.weight,
param_type=ParamType.LinearMHAOutputWeight,
),
"output_b": LinearBias(att.linear_o.bias),
}
......@@ -205,7 +193,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
......@@ -231,7 +219,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
MSAGlobalAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": GlobalAttentionParams(matt.global_attention)
"attention": GlobalAttentionParams(matt.global_attention),
}
MSAAttPairBiasParams = lambda matt: dict(
......@@ -247,8 +235,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points),
"kv_point_local": LinearParams(ipa.linear_kv_points),
"trainable_point_weights":
Param(param=ipa.head_weights, param_type=ParamType.Other),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
......@@ -276,7 +265,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
}
def EvoformerBlockParams(b, is_extra_msa=False):
if(is_extra_msa):
if is_extra_msa:
col_att_name = "msa_column_global_attention"
msa_col_att_params = MSAGlobalAttParams(b.msa_att_col)
else:
......@@ -284,8 +273,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
msa_col_att_params = MSAAttParams(b.msa_att_col)
d = {
"msa_row_attention_with_pair_bias":
MSAAttPairBiasParams(b.msa_att_row),
"msa_row_attention_with_pair_bias": MSAAttPairBiasParams(
b.msa_att_row
),
col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.msa_transition),
"outer_product_mean": OuterProductMeanParams(b.outer_product_mean),
......@@ -316,10 +306,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
}
},
}
############################
# translations dict overflow
############################
......@@ -330,14 +319,10 @@ def import_jax_weights_(model, npz_path, version="model_1"):
)
ems_blocks = model.extra_msa_stack.stack.blocks
ems_blocks_params = stacked(
[ExtraMSABlockParams(b) for b in ems_blocks]
)
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked(
[EvoformerBlockParams(b) for b in evo_blocks]
)
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
translations = {
"evoformer": {
......@@ -346,101 +331,108 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm":
LayerNormParams(model.recycling_embedder.layer_norm_m),
"prev_pair_norm":
LayerNormParams(model.recycling_embedder.layer_norm_z),
"pair_activiations":
LinearParams(model.input_embedder.linear_relpos),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"pair_activiations": LinearParams(
model.input_embedder.linear_relpos
),
"template_embedding": {
"single_template_embedding": {
"embedding2d":
LinearParams(model.template_pair_embedder.linear),
"embedding2d": LinearParams(
model.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm":
LayerNormParams(model.template_pair_stack.layer_norm),
"output_layer_norm": LayerNormParams(
model.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(model.template_pointwise_att.mha),
},
"extra_msa_activations":
LinearParams(model.extra_msa_embedder.linear),
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"template_single_embedding":
LinearParams(model.template_angle_embedder.linear_1),
"template_projection":
LinearParams(model.template_angle_embedder.linear_2),
"template_single_embedding": LinearParams(
model.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_angle_embedder.linear_2
),
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm":
LayerNormParams(model.structure_module.layer_norm_s),
"initial_projection":
LinearParams(model.structure_module.linear_in),
"pair_layer_norm":
LayerNormParams(model.structure_module.layer_norm_z),
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(
model.structure_module.linear_in
),
"pair_layer_norm": LayerNormParams(
model.structure_module.layer_norm_z
),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm":
LayerNormParams(model.aux_heads.plddt.layer_norm),
"act_0":
LinearParams(model.aux_heads.plddt.linear_1),
"act_1":
LinearParams(model.aux_heads.plddt.linear_2),
"logits":
LinearParams(model.aux_heads.plddt.linear_3),
"input_layer_norm": LayerNormParams(
model.aux_heads.plddt.layer_norm
),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits":
LinearParams(model.aux_heads.distogram.linear),
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits":
LinearParams(model.aux_heads.experimentally_resolved.linear),
"logits": LinearParams(
model.aux_heads.experimentally_resolved.linear
),
},
"masked_msa_head": {
"logits":
LinearParams(model.aux_heads.masked_msa.linear),
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
no_templ = [
"model_3",
"model_4",
"model_5",
"model_3_ptm",
"model_4_ptm",
"model_3",
"model_4",
"model_5",
"model_3_ptm",
"model_4_ptm",
"model_5_ptm",
]
if(version in no_templ):
if version in no_templ:
evo_dict = translations["evoformer"]
keys = list(evo_dict.keys())
for k in keys:
if("template_" in k):
if "template_" in k:
evo_dict.pop(k)
if("_ptm" in version):
if "_ptm" in version:
translations["predicted_aligned_error_head"] = {
"logits":
LinearParams(model.aux_heads.tm.linear)
"logits": LinearParams(model.aux_heads.tm.linear)
}
# Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations)
# Sanity check
keys = list(data.keys())
flat_keys = list(flat.keys())
incorrect = [k for k in flat_keys if k not in keys]
missing = [k for k in keys if k not in flat_keys]
#print(f"Incorrect: {incorrect}")
#print(f"Missing: {missing}")
# print(f"Incorrect: {incorrect}")
# print(f"Missing: {missing}")
assert(len(incorrect) == 0)
assert len(incorrect) == 0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
# Set weights
assign(flat, data)
......@@ -25,8 +25,8 @@ from openfold.np import residue_constants
from openfold.utils import feats
from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
tree_map,
tensor_tree_map,
masked_mean,
permute_final_dims,
batched_gather,
......@@ -49,9 +49,9 @@ def sigmoid_cross_entropy(logits, labels):
def torsion_angle_loss(
a, # [*, N, 7, 2]
a_gt, # [*, N, 7, 2]
a_alt_gt, # [*, N, 7, 2]
a, # [*, N, 7, 2]
a_gt, # [*, N, 7, 2]
a_alt_gt, # [*, N, 7, 2]
):
# [*, N, 7]
norm = torch.norm(a, dim=-1)
......@@ -81,7 +81,7 @@ def compute_fape(
positions_mask: torch.Tensor,
length_scale: float,
l1_clamp_distance: Optional[float] = None,
eps=1e-8
eps=1e-8,
) -> torch.Tensor:
# [*, N_frames, N_pts, 3]
local_pred_pos = pred_frames.invert()[..., None].apply(
......@@ -91,10 +91,10 @@ def compute_fape(
target_positions[..., None, :, :],
)
error_dist = torch.sqrt(
torch.sum((local_pred_pos - local_target_pos)**2, dim=-1) + eps
torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
)
if(l1_clamp_distance is not None):
if l1_clamp_distance is not None:
error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
normed_error = error_dist / length_scale
......@@ -111,7 +111,9 @@ def compute_fape(
#
# ("roughly" because eps is necessarily duplicated in the latter
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
normed_error = (
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
......@@ -126,14 +128,14 @@ def backbone_loss(
backbone_affine_mask: torch.Tensor,
traj: torch.Tensor,
use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.,
loss_unit_distance: float = 10.,
clamp_distance: float = 10.0,
loss_unit_distance: float = 10.0,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
pred_aff = T.from_tensor(traj)
gt_aff = T.from_tensor(backbone_affine_tensor)
fape_loss = compute_fape(
pred_aff,
gt_aff[..., None, :],
......@@ -145,7 +147,7 @@ def backbone_loss(
length_scale=loss_unit_distance,
eps=eps,
)
if(use_clamped_fape is not None):
if use_clamped_fape is not None:
unclamped_fape_loss = compute_fape(
pred_aff,
gt_aff[..., None, :],
......@@ -158,9 +160,8 @@ def backbone_loss(
eps=eps,
)
fape_loss = (
fape_loss * use_clamped_fape +
unclamped_fape_loss * (1 - use_clamped_fape)
fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (
1 - use_clamped_fape
)
# Take the mean over the layer dimension
......@@ -177,42 +178,31 @@ def sidechain_loss(
renamed_atom14_gt_positions: torch.Tensor,
renamed_atom14_gt_exists: torch.Tensor,
alt_naming_is_better: torch.Tensor,
clamp_distance: float = 10.,
length_scale: float = 10.,
clamp_distance: float = 10.0,
length_scale: float = 10.0,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
renamed_gt_frames = (
(1. - alt_naming_is_better[..., None, None, None]) *
rigidgroups_gt_frames +
alt_naming_is_better[..., None, None, None] *
rigidgroups_alt_gt_frames
)
1.0 - alt_naming_is_better[..., None, None, None]
) * rigidgroups_gt_frames + alt_naming_is_better[
..., None, None, None
] * rigidgroups_alt_gt_frames
# Steamroll the inputs
sidechain_frames = sidechain_frames[-1]
batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view(
*batch_dims, -1, 4, 4
)
sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
sidechain_frames = T.from_4x4(sidechain_frames)
renamed_gt_frames = renamed_gt_frames.view(
*batch_dims, -1, 4, 4
)
renamed_gt_frames = T.from_4x4(renamed_gt_frames)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(
*batch_dims, -1
)
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
renamed_gt_frames = T.from_4x4(renamed_gt_frames)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
sidechain_atom_pos = sidechain_atom_pos[-1]
sidechain_atom_pos = sidechain_atom_pos.view(
*batch_dims, -1, 3
)
sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(
*batch_dims, -1, 3
)
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(
*batch_dims, -1
)
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)
fape = compute_fape(
sidechain_frames,
......@@ -235,19 +225,17 @@ def fape_loss(
config: ml_collections.ConfigDict,
) -> torch.Tensor:
bb_loss = backbone_loss(
traj=out["sm"]["frames"], **{**batch, **config.backbone},
traj=out["sm"]["frames"],
**{**batch, **config.backbone},
)
sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"],
out["sm"]["positions"],
**{**batch, **config.sidechain}
**{**batch, **config.sidechain},
)
return (
config.backbone.weight * bb_loss +
config.sidechain.weight * sc_loss
)
return config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
def supervised_chi_loss(
......@@ -264,10 +252,11 @@ def supervised_chi_loss(
) -> torch.Tensor:
pred_angles = angles_sin_cos[..., 3:, :]
residue_type_one_hot = torch.nn.functional.one_hot(
aatype, residue_constants.restype_num + 1,
aatype,
residue_constants.restype_num + 1,
)
chi_pi_periodic = torch.einsum(
"...ij,jk->ik",
"...ij,jk->ik",
residue_type_one_hot.type(angles_sin_cos.dtype),
angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
)
......@@ -276,11 +265,9 @@ def supervised_chi_loss(
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
true_chi_shifted = shifted_mask * true_chi
sq_chi_error = torch.sum(
(true_chi - pred_angles)**2, dim=-1
)
sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
sq_chi_error_shifted = torch.sum(
(true_chi_shifted - pred_angles)**2, dim=-1
(true_chi_shifted - pred_angles) ** 2, dim=-1
)
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
# The ol' switcheroo
......@@ -295,14 +282,14 @@ def supervised_chi_loss(
loss = loss + chi_weight * sq_chi_loss
angle_norm = torch.sqrt(
torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps
torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
)
norm_error = torch.abs(angle_norm - 1.)
norm_error = torch.abs(angle_norm - 1.0)
norm_error = norm_error.permute(
*range(len(norm_error.shape))[1:-2], 0, -2, -1
)
angle_norm_loss = masked_mean(
seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
)
loss = loss + angle_norm_weight * angle_norm_loss
......@@ -312,14 +299,13 @@ def supervised_chi_loss(
def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
num_bins = logits.shape[-1]
bin_width = 1. / num_bins
bin_width = 1.0 / num_bins
bounds = torch.arange(
start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device
)
probs = torch.nn.functional.softmax(logits, dim=-1)
pred_lddt_ca = torch.sum(
probs *
bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
dim=-1,
)
return pred_lddt_ca * 100
......@@ -331,7 +317,7 @@ def lddt_loss(
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.,
cutoff: float = 15.0,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
......@@ -339,55 +325,57 @@ def lddt_loss(
**kwargs,
) -> torch.Tensor:
n = all_atom_mask.shape[-2]
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos:(ca_pos + 1)] # keep dim
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
dmat_true = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
all_atom_positions[..., None, :] -
all_atom_positions[..., None, :, :]
)**2,
all_atom_positions[..., None, :]
- all_atom_positions[..., None, :, :]
)
** 2,
dim=-1,
)
)
dmat_pred = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
all_atom_pred_pos[..., None, :] -
all_atom_pred_pos[..., None, :, :]
)**2,
all_atom_pred_pos[..., None, :]
- all_atom_pred_pos[..., None, :, :]
)
** 2,
dim=-1,
)
)
dists_to_score = (
(dmat_true < cutoff) * all_atom_mask *
permute_final_dims(all_atom_mask, (1, 0)) *
(1. - torch.eye(n, device=all_atom_mask.device))
(dmat_true < cutoff)
* all_atom_mask
* permute_final_dims(all_atom_mask, (1, 0))
* (1.0 - torch.eye(n, device=all_atom_mask.device))
)
dist_l1 = torch.abs(dmat_true - dmat_pred)
score = (
(dist_l1 < 0.5).type(dist_l1.dtype) +
(dist_l1 < 1.0).type(dist_l1.dtype) +
(dist_l1 < 2.0).type(dist_l1.dtype) +
(dist_l1 < 4.0).type(dist_l1.dtype)
(dist_l1 < 0.5).type(dist_l1.dtype)
+ (dist_l1 < 1.0).type(dist_l1.dtype)
+ (dist_l1 < 2.0).type(dist_l1.dtype)
+ (dist_l1 < 4.0).type(dist_l1.dtype)
)
score = score * 0.25
norm = 1. / (eps + torch.sum(dists_to_score, dim=-1))
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=-1))
score = norm * (eps + torch.sum(dists_to_score * score, dim=-1))
score = score.detach()
score = score.detach()
bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1))
lddt_ca_one_hot = torch.nn.functional.one_hot(
......@@ -396,40 +384,39 @@ def lddt_loss(
errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
all_atom_mask = all_atom_mask.squeeze(-1)
loss = (
torch.sum(errors * all_atom_mask, dim=-1) /
(eps + torch.sum(all_atom_mask, dim=-1))
loss = torch.sum(errors * all_atom_mask, dim=-1) / (
eps + torch.sum(all_atom_mask, dim=-1)
)
loss = loss * (
(resolution >= min_resolution) &
(resolution <= max_resolution)
(resolution >= min_resolution) & (resolution <= max_resolution)
)
return loss
def distogram_loss(
logits,
pseudo_beta,
pseudo_beta_mask,
min_bin=2.3125,
max_bin=21.6875,
no_bins=64,
logits,
pseudo_beta,
pseudo_beta_mask,
min_bin=2.3125,
max_bin=21.6875,
no_bins=64,
eps=1e-6,
**kwargs,
):
boundaries = torch.linspace(
min_bin, max_bin, no_bins - 1, device=logits.device,
min_bin,
max_bin,
no_bins - 1,
device=logits.device,
)
boundaries = boundaries ** 2
dists = torch.sum(
(
pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]
) ** 2,
dim=-1,
keepdims=True
(pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
dim=-1,
keepdims=True,
)
true_bins = torch.sum(dists > boundaries, dim=-1)
......@@ -442,7 +429,7 @@ def distogram_loss(
square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
# FP16-friendly sum. Equivalent to:
# mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
# mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
# (eps + torch.sum(square_mask, dim=(-1, -2))))
denom = eps + torch.sum(square_mask, dim=(-1, -2))
mean = errors * square_mask
......@@ -450,7 +437,7 @@ def distogram_loss(
mean = mean / denom[..., None]
mean = torch.sum(mean, dim=-1)
return mean
return mean
def _calculate_bin_centers(boundaries: torch.Tensor):
......@@ -469,7 +456,7 @@ def _calculate_expected_aligned_error(
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
return (
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
bin_centers[-1]
bin_centers[-1],
)
......@@ -480,7 +467,7 @@ def compute_predicted_aligned_error(
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [*, num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
......@@ -494,18 +481,16 @@ def compute_predicted_aligned_error(
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
boundaries = torch.linspace(
0,
max_bin,
steps=(no_bins - 1),
device=logits.device
0, max_bin, steps=(no_bins - 1), device=logits.device
)
aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
predicted_aligned_error, max_predicted_aligned_error = (
_calculate_expected_aligned_error(
alignment_confidence_breaks=boundaries,
aligned_distance_error_probs=aligned_confidence_probs
)
(
predicted_aligned_error,
max_predicted_aligned_error,
) = _calculate_expected_aligned_error(
alignment_confidence_breaks=boundaries,
aligned_distance_error_probs=aligned_confidence_probs,
)
return {
......@@ -523,14 +508,11 @@ def compute_tm(
eps: float = 1e-8,
**kwargs,
) -> torch.Tensor:
if(residue_weights is None):
if residue_weights is None:
residue_weights = logits.new_ones(logits.shape[-2])
boundaries = torch.linspace(
0,
max_bin,
steps=(no_bins - 1),
device=logits.device
0, max_bin, steps=(no_bins - 1), device=logits.device
)
bin_centers = _calculate_bin_centers(boundaries)
......@@ -538,11 +520,11 @@ def compute_tm(
n = logits.shape[-2]
clipped_n = max(n, 19)
d0 = 1.24 * (clipped_n - 15) ** (1./3) - 1.8
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
probs = torch.nn.functional.softmax(logits, dim=-1)
tm_per_bin = 1. / (1 + (bin_centers ** 2) / (d0 ** 2))
tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
......@@ -554,12 +536,12 @@ def compute_tm(
def tm_loss(
logits,
final_affine_tensor,
backbone_affine_tensor,
backbone_affine_mask,
final_affine_tensor,
backbone_affine_tensor,
backbone_affine_mask,
resolution,
max_bin=31,
no_bins=64,
max_bin=31,
no_bins=64,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps=1e-8,
......@@ -573,25 +555,18 @@ def tm_loss(
return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum(
(_points(pred_affine) - _points(backbone_affine)) ** 2,
dim=-1
(_points(pred_affine) - _points(backbone_affine)) ** 2, dim=-1
)
sq_diff = sq_diff.detach()
boundaries = torch.linspace(
0,
max_bin,
steps=(no_bins - 1),
device=logits.device
0, max_bin, steps=(no_bins - 1), device=logits.device
)
boundaries = boundaries ** 2
true_bins = torch.sum(
sq_diff[..., None] > boundaries, dim=-1
)
true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1)
errors = softmax_cross_entropy(
logits,
torch.nn.functional.one_hot(true_bins, no_bins)
logits, torch.nn.functional.one_hot(true_bins, no_bins)
)
square_mask = (
......@@ -599,15 +574,14 @@ def tm_loss(
)
loss = torch.sum(errors * square_mask, dim=-1)
scale = 0.5 # hack to help FP16 training along
scale = 0.5 # hack to help FP16 training along
denom = eps + torch.sum(scale * square_mask, dim=(-1, -2))
loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1)
loss = loss * scale
loss = loss * (
(resolution >= min_resolution) &
(resolution <= max_resolution)
(resolution >= min_resolution) & (resolution <= max_resolution)
)
return loss
......@@ -623,11 +597,11 @@ def between_residue_bond_loss(
eps=1e-6,
) -> Dict[str, torch.Tensor]:
"""Flat-bottom loss to penalize structural violations between residues.
This is a loss penalizing any violation of the geometry around the peptide
bond between consecutive amino acids. This loss corresponds to
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
......@@ -638,7 +612,7 @@ def between_residue_bond_loss(
of pdb distributions
tolerance_factor_hard: hard tolerance factor measured in standard deviations
of pdb distributions
Returns:
Dict containing:
* 'c_n_loss_mean': Loss for peptide bond length violations
......@@ -659,126 +633,116 @@ def between_residue_bond_loss(
next_n_mask = pred_atom_mask[..., 1:, 0]
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
next_ca_mask = pred_atom_mask[..., 1:, 1]
has_no_gap_mask = (
(residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
)
has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
# Compute loss for the C--N bond.
c_n_bond_length = torch.sqrt(
eps +
torch.sum(
(this_c_pos - next_n_pos)**2, dim=-1
)
eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1)
)
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline = (
aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
)
next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
gt_length = (
(~next_is_proline) * residue_constants.between_res_bond_length_c_n[0]
+ next_is_proline * residue_constants.between_res_bond_length_c_n[1]
)
~next_is_proline
) * residue_constants.between_res_bond_length_c_n[
0
] + next_is_proline * residue_constants.between_res_bond_length_c_n[
1
]
gt_stddev = (
(~next_is_proline) *
residue_constants.between_res_bond_length_stddev_c_n[0] +
next_is_proline *
residue_constants.between_res_bond_length_stddev_c_n[1]
)
c_n_bond_length_error = torch.sqrt(
eps + (c_n_bond_length - gt_length)**2
)
~next_is_proline
) * residue_constants.between_res_bond_length_stddev_c_n[
0
] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
1
]
c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
c_n_loss_per_residue = torch.nn.functional.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev
)
mask = this_c_mask * next_n_mask * has_no_gap_mask
c_n_loss = (
torch.sum(mask * c_n_loss_per_residue, dim=-1) /
(torch.sum(mask, dim=-1) + eps)
c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
c_n_violation_mask = mask * (
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)
)
# Compute loss for the angles.
ca_c_bond_length = torch.sqrt(
eps + torch.sum((this_ca_pos - this_c_pos)**2, dim=-1)
eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1)
)
n_ca_bond_length = torch.sqrt(
eps + torch.sum((next_n_pos - next_ca_pos)**2, dim=-1)
eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1)
)
c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None]
c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None]
n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None]
ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1)
gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
ca_c_n_cos_angle_error = torch.sqrt(
eps + (ca_c_n_cos_angle - gt_angle)**2
eps + (ca_c_n_cos_angle - gt_angle) ** 2
)
ca_c_n_loss_per_residue = torch.nn.functional.relu(
ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev
)
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
ca_c_n_loss = (
torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) /
(torch.sum(mask, dim=-1) + eps)
ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error >
(tolerance_factor_hard * gt_stddev))
ca_c_n_violation_mask = mask * (
ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)
)
c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1)
gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
c_n_ca_cos_angle_error = torch.sqrt(
eps + torch.square(c_n_ca_cos_angle - gt_angle))
eps + torch.square(c_n_ca_cos_angle - gt_angle)
)
c_n_ca_loss_per_residue = torch.nn.functional.relu(
c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev
)
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
c_n_ca_loss = (
torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) /
(torch.sum(mask, dim=-1) + eps)
c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
c_n_ca_violation_mask = mask * (
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)
)
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum = (c_n_loss_per_residue +
ca_c_n_loss_per_residue +
c_n_ca_loss_per_residue)
per_residue_loss_sum = (
c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue
)
per_residue_loss_sum = 0.5 * (
torch.nn.functional.pad(per_residue_loss_sum, (0, 1)) +
torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
torch.nn.functional.pad(per_residue_loss_sum, (0, 1))
+ torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
)
# Compute hard violations.
violation_mask = torch.max(
torch.stack(
[
c_n_violation_mask,
ca_c_n_violation_mask,
c_n_ca_violation_mask
],
[c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask],
dim=-2,
),
dim=-2
),
dim=-2,
)[0]
violation_mask = torch.maximum(
torch.nn.functional.pad(violation_mask, (0, 1)),
torch.nn.functional.pad(violation_mask, (1, 0))
torch.nn.functional.pad(violation_mask, (1, 0)),
)
return {
'c_n_loss_mean': c_n_loss,
'ca_c_n_loss_mean': ca_c_n_loss,
'c_n_ca_loss_mean': c_n_ca_loss,
'per_residue_loss_sum': per_residue_loss_sum,
'per_residue_violation_mask': violation_mask
"c_n_loss_mean": c_n_loss,
"ca_c_n_loss_mean": ca_c_n_loss,
"c_n_ca_loss_mean": c_n_ca_loss,
"per_residue_loss_sum": per_residue_loss_sum,
"per_residue_violation_mask": violation_mask,
}
......@@ -792,12 +756,12 @@ def between_residue_clash_loss(
eps=1e-10,
) -> Dict[str, torch.Tensor]:
"""Loss to penalize steric clashes between residues.
This is a loss penalizing any steric clashes due to non bonded atoms in
different peptides coming too close. This loss corresponds to the part with
different residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
......@@ -807,7 +771,7 @@ def between_residue_clash_loss(
residue_index: Residue index for given amino acid.
overlap_tolerance_soft: Soft tolerance factor.
overlap_tolerance_hard: Hard tolerance factor.
Returns:
Dict containing:
* 'mean_loss': average clash loss
......@@ -816,33 +780,36 @@ def between_residue_clash_loss(
shape (N, 14)
"""
fp_type = atom14_pred_positions.dtype
# Create the distance matrix.
# (N, N, 14, 14)
dists = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
atom14_pred_positions[..., :, None, :, None, :] -
atom14_pred_positions[..., None, :, None, :, :]
)**2,
dim=-1)
atom14_pred_positions[..., :, None, :, None, :]
- atom14_pred_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
# Create the mask for valid distances.
# shape (N, N, 14, 14)
dists_mask = (
atom14_atom_exists[..., :, None, :, None] *
atom14_atom_exists[..., None, :, None, :]
atom14_atom_exists[..., :, None, :, None]
* atom14_atom_exists[..., None, :, None, :]
).type(fp_type)
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask = dists_mask * (
residue_index[..., :, None, None, None] < residue_index[..., None, :, None, None]
residue_index[..., :, None, None, None]
< residue_index[..., None, :, None, None]
)
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot = torch.nn.functional.one_hot(
residue_index.new_tensor(2), num_classes=14
......@@ -860,74 +827,69 @@ def between_residue_clash_loss(
n_one_hot = n_one_hot.type(fp_type)
neighbour_mask = (
(residue_index[..., :, None, None, None] + 1) ==
residue_index[..., None, :, None, None]
)
residue_index[..., :, None, None, None] + 1
) == residue_index[..., None, :, None, None]
c_n_bonds = (
neighbour_mask *
c_one_hot[..., None, None, :, None] *
n_one_hot[..., None, None, None, :]
neighbour_mask
* c_one_hot[..., None, None, :, None]
* n_one_hot[..., None, None, None, :]
)
dists_mask = dists_mask * (1. - c_n_bonds)
dists_mask = dists_mask * (1.0 - c_n_bonds)
# Disulfide bridge between two cysteines is no clash.
cys = residue_constants.restype_name_to_atom14_names["CYS"]
cys_sg_idx = cys.index('SG')
cys_sg_idx = cys.index("SG")
cys_sg_idx = residue_index.new_tensor(cys_sg_idx)
cys_sg_idx = cys_sg_idx.reshape(
*((1,) * len(residue_index.shape[:-1])), 1
*((1,) * len(residue_index.shape[:-1])), 1
).squeeze(-1)
cys_sg_one_hot = torch.nn.functional.one_hot(
cys_sg_idx, num_classes=14
)
cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14)
disulfide_bonds = (
cys_sg_one_hot[..., None, None, :, None] *
cys_sg_one_hot[..., None, None, None, :])
dists_mask = dists_mask * (1. - disulfide_bonds)
cys_sg_one_hot[..., None, None, :, None]
* cys_sg_one_hot[..., None, None, None, :]
)
dists_mask = dists_mask * (1.0 - disulfide_bonds)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
dists_lower_bound = dists_mask * (
atom14_atom_radius[..., :, None, :, None] +
atom14_atom_radius[..., None, :, None, :]
atom14_atom_radius[..., :, None, :, None]
+ atom14_atom_radius[..., None, :, None, :]
)
# Compute the error.
# shape (N, N, 14, 14)
dists_to_low_error = dists_mask * torch.nn.functional.relu(
dists_lower_bound - overlap_tolerance_soft - dists
)
# Compute the mean loss.
# shape ()
mean_loss = (
torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask))
)
mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask))
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = (
torch.sum(dists_to_low_error, dim=(-4, -2)) +
torch.sum(dists_to_low_error, axis=(-3, -1))
per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
dists_to_low_error, axis=(-3, -1)
)
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask = dists_mask * (
dists < (dists_lower_bound - overlap_tolerance_hard)
)
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask = torch.maximum(
torch.amax(clash_mask, axis=(-4, -2)),
torch.amax(clash_mask, axis=(-3, -1)),
)
return {
'mean_loss': mean_loss, # shape ()
'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14)
'per_atom_clash_mask': per_atom_clash_mask # shape (N, 14)
"mean_loss": mean_loss, # shape ()
"per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
"per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
}
......@@ -940,54 +902,53 @@ def within_residue_violations(
eps=1e-10,
) -> Dict[str, torch.Tensor]:
"""Loss to penalize steric clashes within residues.
This is a loss penalizing any steric violations or clashes of non-bonded atoms
in a given peptide. This loss corresponds to the part with
the same residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions ([*, N, 14, 3]):
atom14_pred_positions ([*, N, 14, 3]):
Predicted positions of atoms in global prediction frame.
atom14_atom_exists ([*, N, 14]):
atom14_atom_exists ([*, N, 14]):
Mask denoting whether atom at positions exists for given
amino acid type
atom14_dists_lower_bound ([*, N, 14]):
atom14_dists_lower_bound ([*, N, 14]):
Lower bound on allowed distances.
atom14_dists_upper_bound ([*, N, 14]):
atom14_dists_upper_bound ([*, N, 14]):
Upper bound on allowed distances
tighten_bounds_for_loss ([*, N]):
tighten_bounds_for_loss ([*, N]):
Extra factor to tighten loss
Returns:
Dict containing:
* 'per_atom_loss_sum' ([*, N, 14]):
* 'per_atom_loss_sum' ([*, N, 14]):
sum of all clash losses per atom, shape
* 'per_atom_clash_mask' ([*, N, 14]):
mask whether atom clashes with any other atom shape
"""
* 'per_atom_clash_mask' ([*, N, 14]):
mask whether atom clashes with any other atom shape
"""
# Compute the mask for each residue.
dists_masks = (
1. - torch.eye(14, device=atom14_atom_exists.device)[None]
)
dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None]
dists_masks = dists_masks.reshape(
*((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape
)
dists_masks = (
atom14_atom_exists[..., :, :, None] *
atom14_atom_exists[..., :, None, :] *
dists_masks
atom14_atom_exists[..., :, :, None]
* atom14_atom_exists[..., :, None, :]
* dists_masks
)
# Distance matrix
dists = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
atom14_pred_positions[..., :, :, None, :] -
atom14_pred_positions[..., :, None, :, :]
)**2,
dim=-1
atom14_pred_positions[..., :, :, None, :]
- atom14_pred_positions[..., :, None, :, :]
)
** 2,
dim=-1,
)
)
......@@ -999,34 +960,26 @@ def within_residue_violations(
dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)
)
loss = dists_masks * (dists_to_low_error + dists_to_high_error)
# Compute the per atom loss sum.
per_atom_loss_sum = (
torch.sum(loss, dim=-2) +
torch.sum(loss, dim=-1)
)
per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1)
# Compute the violations mask.
violations = (
dists_masks *
(
(dists < atom14_dists_lower_bound) |
(dists > atom14_dists_upper_bound)
)
violations = dists_masks * (
(dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
)
# Compute the per atom violations.
per_atom_violations = torch.maximum(
torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]
)
return {
'per_atom_loss_sum': per_atom_loss_sum,
'per_atom_violations': per_atom_violations
"per_atom_loss_sum": per_atom_loss_sum,
"per_atom_violations": per_atom_violations,
}
def find_structural_violations(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor,
......@@ -1035,7 +988,7 @@ def find_structural_violations(
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Computes several checks for structural violations."""
# Compute between residue backbone violations of bonds and angles.
connection_violations = between_residue_bond_loss(
pred_atom_positions=atom14_pred_positions,
......@@ -1043,9 +996,9 @@ def find_structural_violations(
residue_index=batch["residue_index"],
aatype=batch["aatype"],
tolerance_factor_soft=violation_tolerance_factor,
tolerance_factor_hard=violation_tolerance_factor
tolerance_factor_hard=violation_tolerance_factor,
)
# Compute the Van der Waals radius for every atom
# (the first letter of the atom name is the element type).
# Shape: (N, 14).
......@@ -1053,14 +1006,12 @@ def find_structural_violations(
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
atomtype_radius = atom14_pred_positions.new_tensor(
atomtype_radius
)
atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
atom14_atom_radius = (
batch["atom14_atom_exists"] *
atomtype_radius[batch["residx_atom14_to_atom37"]]
batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]]
)
# Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions,
......@@ -1068,32 +1019,28 @@ def find_structural_violations(
atom14_atom_radius=atom14_atom_radius,
residue_index=batch["residue_index"],
overlap_tolerance_soft=clash_overlap_tolerance,
overlap_tolerance_hard=clash_overlap_tolerance
overlap_tolerance_hard=clash_overlap_tolerance,
)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
overlap_tolerance=clash_overlap_tolerance,
bond_length_tolerance_factor=violation_tolerance_factor
bond_length_tolerance_factor=violation_tolerance_factor,
)
atom14_atom_exists = batch["atom14_atom_exists"]
atom14_dists_lower_bound = (
atom14_pred_positions.new_tensor(restype_atom14_bounds["lower_bound"])[
batch["aatype"]
]
)
atom14_dists_upper_bound = (
atom14_pred_positions.new_tensor(restype_atom14_bounds["upper_bound"])[
batch["aatype"]
]
)
atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
restype_atom14_bounds["lower_bound"]
)[batch["aatype"]]
atom14_dists_upper_bound = atom14_pred_positions.new_tensor(
restype_atom14_bounds["upper_bound"]
)[batch["aatype"]]
residue_violations = within_residue_violations(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch["atom14_atom_exists"],
atom14_dists_lower_bound=atom14_dists_lower_bound,
atom14_dists_upper_bound=atom14_dists_upper_bound,
tighten_bounds_for_loss=0.0
tighten_bounds_for_loss=0.0,
)
# Combine them to a single per-residue violation mask (used later for LDDT).
......@@ -1104,49 +1051,52 @@ def find_structural_violations(
torch.max(
between_residue_clashes["per_atom_clash_mask"], dim=-1
)[0],
torch.max(
residue_violations["per_atom_violations"], dim=-1
)[0],
],
torch.max(residue_violations["per_atom_violations"], dim=-1)[0],
],
dim=-1,
),
),
dim=-1,
)[0]
return {
'between_residues': {
'bonds_c_n_loss_mean':
connection_violations["c_n_loss_mean"], # ()
'angles_ca_c_n_loss_mean':
connection_violations["ca_c_n_loss_mean"], # ()
'angles_c_n_ca_loss_mean':
connection_violations["c_n_ca_loss_mean"], # ()
'connections_per_residue_loss_sum':
connection_violations["per_residue_loss_sum"], # (N)
'connections_per_residue_violation_mask':
connection_violations["per_residue_violation_mask"], # (N)
'clashes_mean_loss':
between_residue_clashes["mean_loss"], # ()
'clashes_per_atom_loss_sum':
between_residue_clashes["per_atom_loss_sum"], # (N, 14)
'clashes_per_atom_clash_mask':
between_residue_clashes["per_atom_clash_mask"], # (N, 14)
"between_residues": {
"bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # ()
"angles_ca_c_n_loss_mean": connection_violations[
"ca_c_n_loss_mean"
], # ()
"angles_c_n_ca_loss_mean": connection_violations[
"c_n_ca_loss_mean"
], # ()
"connections_per_residue_loss_sum": connection_violations[
"per_residue_loss_sum"
], # (N)
"connections_per_residue_violation_mask": connection_violations[
"per_residue_violation_mask"
], # (N)
"clashes_mean_loss": between_residue_clashes["mean_loss"], # ()
"clashes_per_atom_loss_sum": between_residue_clashes[
"per_atom_loss_sum"
], # (N, 14)
"clashes_per_atom_clash_mask": between_residue_clashes[
"per_atom_clash_mask"
], # (N, 14)
},
'within_residues': {
'per_atom_loss_sum':
residue_violations["per_atom_loss_sum"], # (N, 14)
'per_atom_violations':
residue_violations["per_atom_violations"], # (N, 14),
"within_residues": {
"per_atom_loss_sum": residue_violations[
"per_atom_loss_sum"
], # (N, 14)
"per_atom_violations": residue_violations[
"per_atom_violations"
], # (N, 14),
},
'total_per_residue_violations_mask':
per_residue_violations_mask, # (N)
"total_per_residue_violations_mask": per_residue_violations_mask, # (N)
}
def find_structural_violations_np(
batch: Dict[str, np.ndarray],
atom14_pred_positions: np.ndarray,
config: ml_collections.ConfigDict
config: ml_collections.ConfigDict,
) -> Dict[str, np.ndarray]:
to_tensor = lambda x: torch.tensor(x)
batch = tree_map(to_tensor, batch, np.ndarray)
......@@ -1161,17 +1111,17 @@ def find_structural_violations_np(
def extreme_ca_ca_distance_violations(
pred_atom_positions: torch.Tensor, # (N, 37(14), 3)
pred_atom_mask: torch.Tensor, # (N, 37(14))
residue_index: torch.Tensor, # (N)
max_angstrom_tolerance=1.5,
eps=1e-6,
pred_atom_positions: torch.Tensor, # (N, 37(14), 3)
pred_atom_mask: torch.Tensor, # (N, 37(14))
residue_index: torch.Tensor, # (N)
max_angstrom_tolerance=1.5,
eps=1e-6,
) -> torch.Tensor:
"""Counts residues whose Ca is a large distance from its neighbour.
Measures the fraction of CA-CA pairs between consecutive amino acids that are
more than 'max_angstrom_tolerance' apart.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
......@@ -1185,13 +1135,13 @@ def extreme_ca_ca_distance_violations(
this_ca_mask = pred_atom_mask[..., :-1, 1]
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
next_ca_mask = pred_atom_mask[..., 1:, 1]
has_no_gap_mask = ((residue_index[..., 1:] - residue_index[..., :-1]) == 1.0)
has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
ca_ca_distance = torch.sqrt(
eps + torch.sum((this_ca_pos - next_ca_pos)**2, dim=-1)
eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1)
)
violations = (
(ca_ca_distance - residue_constants.ca_ca) > max_angstrom_tolerance
)
ca_ca_distance - residue_constants.ca_ca
) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
mean = masked_mean(mask, violations, -1)
return mean
......@@ -1202,18 +1152,18 @@ def compute_violation_metrics(
atom14_pred_positions: torch.Tensor, # (N, 14, 3)
violations: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""Compute several metrics to assess the structural violations."""
"""Compute several metrics to assess the structural violations."""
ret = {}
extreme_ca_ca_violations = extreme_ca_ca_distance_violations(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch["atom14_atom_exists"],
residue_index=batch["residue_index"]
residue_index=batch["residue_index"],
)
ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations
ret["violations_between_residue_bond"] = masked_mean(
batch["seq_mask"],
violations["between_residues"][
'connections_per_residue_violation_mask'
"connections_per_residue_violation_mask"
],
dim=-1,
)
......@@ -1221,7 +1171,7 @@ def compute_violation_metrics(
mask=batch["seq_mask"],
value=torch.max(
violations["between_residues"]["clashes_per_atom_clash_mask"],
dim=-1
dim=-1,
)[0],
dim=-1,
)
......@@ -1250,7 +1200,6 @@ def compute_violation_metrics_np(
atom14_pred_positions = to_tensor(atom14_pred_positions)
violations = tree_map(to_tensor, violations, np.ndarray)
out = compute_violation_metrics(batch, atom14_pred_positions, violations)
to_np = lambda x: np.array(x)
......@@ -1265,15 +1214,15 @@ def violation_loss(
) -> torch.Tensor:
num_atoms = torch.sum(atom14_atom_exists)
l_clash = torch.sum(
violations["between_residues"]["clashes_per_atom_loss_sum"] +
violations["within_residues"]["per_atom_loss_sum"]
)
violations["between_residues"]["clashes_per_atom_loss_sum"]
+ violations["within_residues"]["per_atom_loss_sum"]
)
l_clash = l_clash / (eps + num_atoms)
loss = (
violations["between_residues"]["bonds_c_n_loss_mean"] +
violations["between_residues"]["angles_ca_c_n_loss_mean"] +
violations["between_residues"]["angles_c_n_ca_loss_mean"] +
l_clash
violations["between_residues"]["bonds_c_n_loss_mean"]
+ violations["between_residues"]["angles_ca_c_n_loss_mean"]
+ violations["between_residues"]["angles_c_n_ca_loss_mean"]
+ l_clash
)
return loss
......@@ -1286,12 +1235,12 @@ def compute_renamed_ground_truth(
) -> Dict[str, torch.Tensor]:
"""
Find optimal renaming of ground truth based on the predicted positions.
Alg. 26 "renameSymmetricGroundTruthAtoms"
This renamed ground truth is then used for all losses,
such that each loss moves the atoms in the same direction.
Args:
batch: Dictionary containing:
* atom14_gt_positions: Ground truth positions.
......@@ -1313,50 +1262,53 @@ def compute_renamed_ground_truth(
"""
pred_dists = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
atom14_pred_positions[..., None, :, None, :] -
atom14_pred_positions[..., None, :, None, :, :]
)**2,
atom14_pred_positions[..., None, :, None, :]
- atom14_pred_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
atom14_gt_positions = batch["atom14_gt_positions"]
gt_dists = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
atom14_gt_positions[..., None, :, None, :] -
atom14_gt_positions[..., None, :, None, :, :]
)**2,
atom14_gt_positions[..., None, :, None, :]
- atom14_gt_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
atom14_alt_gt_positions = batch["atom14_alt_gt_positions"]
alt_gt_dists = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
atom14_alt_gt_positions[..., None, :, None, :] -
atom14_alt_gt_positions[..., None, :, None, :, :]
)**2,
atom14_alt_gt_positions[..., None, :, None, :]
- atom14_alt_gt_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
lddt = torch.sqrt(eps + (pred_dists - gt_dists)**2)
alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists)**2)
lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2)
alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2)
atom14_gt_exists = batch["atom14_gt_exists"]
atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"]
mask = (
atom14_gt_exists[..., None, :, None] *
atom14_atom_is_ambiguous[..., None, :, None] *
atom14_gt_exists[..., None, :, None, :] *
(1. - atom14_atom_is_ambiguous[..., None, :, None, :])
atom14_gt_exists[..., None, :, None]
* atom14_atom_is_ambiguous[..., None, :, None]
* atom14_gt_exists[..., None, :, None, :]
* (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])
)
per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3))
......@@ -1366,16 +1318,16 @@ def compute_renamed_ground_truth(
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
renamed_atom14_gt_positions = (
(1. - alt_naming_is_better[..., None, None]) *
atom14_gt_positions +
alt_naming_is_better[..., None, None] *
atom14_alt_gt_positions
)
1.0 - alt_naming_is_better[..., None, None]
) * atom14_gt_positions + alt_naming_is_better[
..., None, None
] * atom14_alt_gt_positions
renamed_atom14_gt_mask = (
(1. - alt_naming_is_better[..., None]) * atom14_gt_exists +
alt_naming_is_better[..., None] * batch["atom14_alt_gt_exists"]
)
1.0 - alt_naming_is_better[..., None]
) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
"atom14_alt_gt_exists"
]
return {
"alt_naming_is_better": alt_naming_is_better,
......@@ -1398,10 +1350,9 @@ def experimentally_resolved_loss(
loss = torch.sum(errors * atom37_atom_exists, dim=-1)
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
loss = torch.sum(loss, dim=-1)
loss = loss * (
(resolution >= min_resolution) &
(resolution <= max_resolution)
(resolution >= min_resolution) & (resolution <= max_resolution)
)
return loss
......@@ -1409,10 +1360,9 @@ def experimentally_resolved_loss(
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
errors = softmax_cross_entropy(
logits,
torch.nn.functional.one_hot(true_msa, num_classes=23)
logits, torch.nn.functional.one_hot(true_msa, num_classes=23)
)
# FP16-friendly averaging. Equivalent to:
# loss = (
# torch.sum(errors * bert_mask, dim=(-1, -2)) /
......@@ -1435,7 +1385,7 @@ def compute_drmsd(structure_1, structure_2):
d1 = d1 ** 2
d2 = d2 ** 2
d1 = torch.sqrt(torch.sum(d1, dim=-1))
d2 = torch.sqrt(torch.sum(d2, dim=-1))
......@@ -1450,81 +1400,74 @@ def compute_drmsd(structure_1, structure_2):
class AlphaFoldLoss(nn.Module):
""" Aggregation of the various losses described in the supplement """
"""Aggregation of the various losses described in the supplement"""
def __init__(self, config):
super(AlphaFoldLoss, self).__init__()
self.config = config
def forward(self, out, batch):
if("violation" not in out.keys() and self.config.violation.weight):
def forward(self, out, batch):
if "violation" not in out.keys() and self.config.violation.weight:
out["violation"] = find_structural_violations(
batch,
out["sm"]["positions"][-1],
**self.config.violation,
)
if("renamed_atom14_gt_positions" not in out.keys()):
batch.update(compute_renamed_ground_truth(
batch,
out["sm"]["positions"][-1],
))
if "renamed_atom14_gt_positions" not in out.keys():
batch.update(
compute_renamed_ground_truth(
batch,
out["sm"]["positions"][-1],
)
)
loss_fns = {
"distogram":
lambda: distogram_loss(
logits=out["distogram_logits"],
**{**batch,
**self.config.distogram},
),
"experimentally_resolved":
lambda: experimentally_resolved_loss(
logits=out["experimentally_resolved_logits"],
**{**batch, **self.config.experimentally_resolved},
),
"fape":
lambda: fape_loss(
out,
batch,
self.config.fape,
),
"lddt":
lambda: lddt_loss(
logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.lddt},
),
"masked_msa":
lambda: masked_msa_loss(
logits=out["masked_msa_logits"],
**{**batch,
**self.config.masked_msa},
),
"supervised_chi":
lambda: supervised_chi_loss(
out["sm"]["angles"],
out["sm"]["unnormalized_angles"],
**{**batch, **self.config.supervised_chi},
),
"violation":
lambda: violation_loss(
out["violation"],
**batch,
),
"tm":
lambda: tm_loss(
logits=out["tm_logits"],
**{**batch, **out, **self.config.tm},
),
"distogram": lambda: distogram_loss(
logits=out["distogram_logits"],
**{**batch, **self.config.distogram},
),
"experimentally_resolved": lambda: experimentally_resolved_loss(
logits=out["experimentally_resolved_logits"],
**{**batch, **self.config.experimentally_resolved},
),
"fape": lambda: fape_loss(
out,
batch,
self.config.fape,
),
"lddt": lambda: lddt_loss(
logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.lddt},
),
"masked_msa": lambda: masked_msa_loss(
logits=out["masked_msa_logits"],
**{**batch, **self.config.masked_msa},
),
"supervised_chi": lambda: supervised_chi_loss(
out["sm"]["angles"],
out["sm"]["unnormalized_angles"],
**{**batch, **self.config.supervised_chi},
),
"violation": lambda: violation_loss(
out["violation"],
**batch,
),
"tm": lambda: tm_loss(
logits=out["tm_logits"],
**{**batch, **out, **self.config.tm},
),
}
cum_loss = 0
for k,loss_fn in loss_fns.items():
for k, loss_fn in loss_fns.items():
weight = self.config[k].weight
if(weight):
#print(k)
if weight:
# print(k)
loss = loss_fn()
#print(weight * loss)
# print(weight * loss)
cum_loss = cum_loss + weight * loss
#print(cum_loss)
# print(cum_loss)
return cum_loss
......@@ -49,11 +49,11 @@ def dict_multimap(fn, dicts):
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if(type(v) is dict):
if type(v) is dict:
new_dict[k] = dict_multimap(fn, all_v)
else:
new_dict[k] = fn(all_v)
return new_dict
......@@ -83,7 +83,7 @@ def batched_gather(data, inds, dim=0, no_batch_dims=0):
def dict_map(fn, dic, leaf_type):
new_dict = {}
for k, v in dic.items():
if(type(v) is dict):
if type(v) is dict:
new_dict[k] = dict_map(fn, v, leaf_type)
else:
new_dict[k] = tree_map(fn, v, leaf_type)
......@@ -92,76 +92,77 @@ def dict_map(fn, dic, leaf_type):
def tree_map(fn, tree, leaf_type):
if(isinstance(tree, dict)):
if isinstance(tree, dict):
return dict_map(fn, tree, leaf_type)
elif(isinstance(tree, list)):
elif isinstance(tree, list):
return [tree_map(fn, x, leaf_type) for x in tree]
elif(isinstance(tree, tuple)):
elif isinstance(tree, tuple):
return tuple([tree_map(fn, x, leaf_type) for x in tree])
elif(isinstance(tree, leaf_type)):
elif isinstance(tree, leaf_type):
return fn(tree)
else:
print(type(tree))
raise ValueError("Not supported")
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
Returns:
The reassembled output of the layer on the inputs.
"""
if(not (len(inputs) > 0)):
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
Returns:
The reassembled output of the layer on the inputs.
"""
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
def fetch_dims(tree):
shapes = []
tree_type = type(tree)
if(tree_type is dict):
if tree_type is dict:
for v in tree.values():
shapes.extend(fetch_dims(v))
elif(tree_type is list or tree_type is tuple):
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(fetch_dims(t))
elif(tree_type is torch.Tensor):
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
initial_dims = [shape[:no_batch_dims] for shape in fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def prep_inputs(t):
# TODO: make this more memory efficient. This sucks
if(not sum(t.shape[:no_batch_dims]) == no_batch_dims):
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(*orig_batch_dims, *t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
return t
......@@ -172,40 +173,42 @@ def chunk_layer(
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = (
flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
no_chunks = flat_batch_dim // chunk_size + (
flat_batch_dim % chunk_size != 0
)
i = 0
out = None
for _ in range(no_chunks):
# Chunk the input
select_chunk = lambda t: t[i:i+chunk_size] if t.shape[0] != 1 else t
select_chunk = lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
chunks = tensor_tree_map(select_chunk, flattened_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
# Allocate space for the output
if(out is None):
if out is None:
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if(out_type is dict):
if out_type is dict:
def assign(d1, d2):
for k,v in d1.items():
if(type(v) is dict):
for k, v in d1.items():
if type(v) is dict:
assign(v, d2[k])
else:
v[i:i+chunk_size] = d2[k]
v[i : i + chunk_size] = d2[k]
assign(out, output_chunk)
elif(out_type is tuple):
elif out_type is tuple:
for x1, x2 in zip(out, output_chunk):
x1[i:i+chunk_size] = x2
elif(out_type is torch.Tensor):
out[i:i+chunk_size] = output_chunk
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
out[i : i + chunk_size] = output_chunk
else:
raise ValueError("Not supported")
......@@ -214,4 +217,4 @@ def chunk_layer(
reshape = lambda t: t.reshape(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
return out
return out
......@@ -15,7 +15,7 @@ from openfold.utils.import_weights import import_jax_weights_
from tests.config import consts
# Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also
# (by default it hogs 90% of GPU memory. This disables that behavior and also
# forces it to proactively free memory that it allocates)
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_PLATFORM_NAME"] = "gpu"
......@@ -30,17 +30,15 @@ def skip_unless_alphafold_installed():
def import_alphafold():
"""
If AlphaFold is installed using the provided setuptools script, this
is necessary to expose all of AlphaFold's precious insides
"""
if("alphafold" in sys.modules):
return sys.modules["alphafold"]
If AlphaFold is installed using the provided setuptools script, this
is necessary to expose all of AlphaFold's precious insides
"""
if "alphafold" in sys.modules:
return sys.modules["alphafold"]
module = importlib.import_module("alphafold")
# Forcefully import alphafold's submodules
submodules = pkgutil.walk_packages(
module.__path__, prefix=("alphafold.")
)
submodules = pkgutil.walk_packages(module.__path__, prefix=("alphafold."))
for submodule_info in submodules:
importlib.import_module(submodule_info.name)
sys.modules["alphafold"] = module
......@@ -57,16 +55,18 @@ def get_alphafold_config():
_param_path = "openfold/resources/params/params_model_1_ptm.npz"
_model = None
def get_global_pretrained_openfold():
global _model
if(_model is None):
if _model is None:
_model = AlphaFold(model_config("model_1_ptm").model)
_model = _model.eval()
if(not os.path.exists(_param_path)):
if not os.path.exists(_param_path):
raise FileNotFoundError(
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
)
)
import_jax_weights_(_model, _param_path, version="model_1_ptm")
_model = _model.cuda()
......@@ -74,9 +74,11 @@ def get_global_pretrained_openfold():
_orig_weights = None
def _get_orig_weights():
global _orig_weights
if(_orig_weights is None):
if _orig_weights is None:
_orig_weights = np.load(_param_path)
return _orig_weights
......@@ -84,22 +86,19 @@ def _get_orig_weights():
def _remove_key_prefix(d, prefix):
for k, v in list(d.items()):
if(k.startswith(prefix)):
if k.startswith(prefix):
d.pop(k)
d[k[len(prefix):]] = v
d[k[len(prefix) :]] = v
def fetch_alphafold_module_weights(weight_path):
orig_weights = _get_orig_weights()
params = {
k:v for k,v in orig_weights.items()
if weight_path in k
}
if('/' in weight_path):
spl = weight_path.split('/')
params = {k: v for k, v in orig_weights.items() if weight_path in k}
if "/" in weight_path:
spl = weight_path.split("/")
spl = spl if len(spl[-1]) != 0 else spl[:-1]
module_name = spl[-1]
prefix = '/'.join(spl[:-1]) + '/'
prefix = "/".join(spl[:-1]) + "/"
_remove_key_prefix(params, prefix)
params = alphafold.model.utils.flat_params_to_haiku(params)
return params
import ml_collections as mlc
consts = mlc.ConfigDict({
"batch_size": 2,
"n_res": 11,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
"eps": 5e-4,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
"c_m": 256,
"c_z": 128,
"c_s": 384,
"c_t": 64,
"c_e": 64,
})
consts = mlc.ConfigDict(
{
"batch_size": 2,
"n_res": 11,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
"eps": 5e-4,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
"c_m": 256,
"c_z": 128,
"c_s": 384,
"c_t": 64,
"c_e": 64,
}
)
......@@ -18,7 +18,7 @@ from scipy.spatial.transform import Rotation
def random_template_feats(n_templ, n, batch_size=None):
b = []
if(batch_size is not None):
if batch_size is not None:
b.append(batch_size)
batch = {
"template_mask": np.random.randint(0, 2, (*b, n_templ)),
......@@ -28,28 +28,31 @@ def random_template_feats(n_templ, n, batch_size=None):
"template_all_atom_masks": np.random.randint(
0, 2, (*b, n_templ, n, 37)
),
"template_all_atom_positions": np.random.rand(
*b, n_templ, n, 37, 3
) * 10,
"template_all_atom_positions": np.random.rand(*b, n_templ, n, 37, 3)
* 10,
}
batch = {k:v.astype(np.float32) for k,v in batch.items()}
batch = {k: v.astype(np.float32) for k, v in batch.items()}
batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
return batch
def random_extra_msa_feats(n_extra, n, batch_size=None):
b = []
if(batch_size is not None):
if batch_size is not None:
b.append(batch_size)
batch = {
"extra_msa":
np.random.randint(0, 22, (*b, n_extra, n)).astype(np.int64),
"extra_has_deletion":
np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32),
"extra_deletion_value":
np.random.rand(*b, n_extra, n).astype(np.float32),
"extra_msa_mask":
np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32),
"extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype(
np.int64
),
"extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype(
np.float32
),
"extra_deletion_value": np.random.rand(*b, n_extra, n).astype(
np.float32
),
"extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype(
np.float32
),
}
return batch
......@@ -63,7 +66,9 @@ def random_affines_vector(dim):
for i in range(prod_dim):
affines[i, :4] = Rotation.random(random_state=42).as_quat()
affines[i, 4:] = np.random.rand(3,).astype(np.float32)
affines[i, 4:] = np.random.rand(
3,
).astype(np.float32)
return affines.reshape(*dim, 7)
......@@ -77,9 +82,10 @@ def random_affines_4x4(dim):
for i in range(prod_dim):
affines[i, :3, :3] = Rotation.random(random_state=42).as_matrix()
affines[i, :3, 3] = np.random.rand(3,).astype(np.float32)
affines[i, :3, 3] = np.random.rand(
3,
).astype(np.float32)
affines[:, 3, 3] = 1
return affines.reshape(*dim, 4, 4)
......@@ -24,30 +24,30 @@ from openfold.model.embedders import (
class TestInputEmbedder(unittest.TestCase):
def test_shape(self):
def test_shape(self):
tf_dim = 2
msa_dim = 3
c_z = 5
c_m = 7
relpos_k = 11
b = 13
n_res = 17
n_clust = 19
tf = torch.rand((b, n_res, tf_dim))
ri = torch.rand((b, n_res))
msa = torch.rand((b, n_clust, n_res, msa_dim))
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(tf, ri, msa)
self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m))
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
class TestRecyclingEmbedder(unittest.TestCase):
def test_shape(self):
def test_shape(self):
batch_size = 2
n = 3
c_z = 5
......@@ -66,7 +66,7 @@ class TestRecyclingEmbedder(unittest.TestCase):
self.assertTrue(z.shape == (batch_size, n, n, c_z))
self.assertTrue(m_1.shape == (batch_size, n, c_m))
class TestTemplateAngleEmbedder(unittest.TestCase):
def test_shape(self):
......@@ -80,13 +80,11 @@ class TestTemplateAngleEmbedder(unittest.TestCase):
template_angle_dim,
c_m,
)
x = torch.rand((batch_size, n_templ, n_res, template_angle_dim))
x = tae(x)
self.assertTrue(
x.shape == (batch_size, n_templ, n_res, c_m)
)
self.assertTrue(x.shape == (batch_size, n_templ, n_res, c_m))
class TestTemplatePairEmbedder(unittest.TestCase):
......@@ -96,20 +94,17 @@ class TestTemplatePairEmbedder(unittest.TestCase):
n_res = 5
template_pair_dim = 7
c_t = 11
tpe = TemplatePairEmbedder(
template_pair_dim,
c_t,
)
x = torch.rand((batch_size, n_templ, n_res, n_res, template_pair_dim))
x = tpe(x)
self.assertTrue(
x.shape == (batch_size, n_templ, n_res, n_res, c_t)
)
self.assertTrue(x.shape == (batch_size, n_templ, n_res, n_res, c_t))
if __name__ == "__main__":
unittest.main()
......@@ -24,14 +24,14 @@ from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestEvoformerStack(unittest.TestCase):
def test_shape(self):
def test_shape(self):
batch_size = consts.batch_size
n_seq = consts.n_seq
n_res = consts.n_res
......@@ -91,56 +91,54 @@ class TestEvoformerStack(unittest.TestCase):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
ei = alphafold.model.modules.EvoformerIteration(
c_e, config.model.global_config, is_extra_msa=False)
c_e, config.model.global_config, is_extra_msa=False
)
return ei(activations, masks, is_training=False)
f = hk.transform(run_ei)
n_res = consts.n_res
n_seq = consts.n_seq
activations = {
'msa': np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32),
'pair': np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
"msa": np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32),
"pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
}
masks = {
'msa': np.random.randint(0, 2, (n_seq, n_res)).astype(np.float32),
'pair': np.random.randint(0, 2, (n_res, n_res)).astype(np.float32),
"msa": np.random.randint(0, 2, (n_seq, n_res)).astype(np.float32),
"pair": np.random.randint(0, 2, (n_res, n_res)).astype(np.float32),
}
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
key = jax.random.PRNGKey(42)
out_gt = f.apply(
params, key, activations, masks
)
out_gt = f.apply(params, key, activations, masks)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt_msa = torch.as_tensor(np.array(out_gt["msa"]))
out_gt_pair = torch.as_tensor(np.array(out_gt["pair"]))
model = compare_utils.get_global_pretrained_openfold()
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
torch.as_tensor(activations["msa"]).cuda(),
torch.as_tensor(activations["pair"]).cuda(),
torch.as_tensor(masks["msa"]).cuda(),
torch.as_tensor(activations["msa"]).cuda(),
torch.as_tensor(activations["pair"]).cuda(),
torch.as_tensor(masks["msa"]).cuda(),
torch.as_tensor(masks["pair"]).cuda(),
_mask_trans=False,
)
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
assert(torch.max(torch.abs(out_repro_msa - out_gt_msa) < consts.eps))
assert(torch.max(torch.abs(out_repro_pair - out_gt_pair) < consts.eps))
assert torch.max(torch.abs(out_repro_msa - out_gt_msa) < consts.eps)
assert torch.max(torch.abs(out_repro_pair - out_gt_pair) < consts.eps)
class TestExtraMSAStack(unittest.TestCase):
def test_shape(self):
def test_shape(self):
batch_size = 2
s_t = 23
n_res = 5
......@@ -180,8 +178,24 @@ class TestExtraMSAStack(unittest.TestCase):
m = torch.rand((batch_size, s_t, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
msa_mask = torch.randint(0, 2, size=(batch_size, s_t, n_res,))
pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res,))
msa_mask = torch.randint(
0,
2,
size=(
batch_size,
s_t,
n_res,
),
)
pair_mask = torch.randint(
0,
2,
size=(
batch_size,
n_res,
n_res,
),
)
shape_z_before = z.shape
......@@ -191,7 +205,7 @@ class TestExtraMSAStack(unittest.TestCase):
class TestMSATransition(unittest.TestCase):
def test_shape(self):
def test_shape(self):
batch_size = 2
s_t = 3
n_r = 5
......@@ -214,39 +228,43 @@ class TestMSATransition(unittest.TestCase):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
msa_trans = alphafold.model.modules.Transition(
c_e.msa_transition,
c_e.msa_transition,
config.model.global_config,
name="msa_transition"
name="msa_transition",
)
act = msa_trans(act=msa_act, mask=msa_mask)
return act
f = hk.transform(run_msa_transition)
n_res = consts.n_res
n_seq = consts.n_seq
msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
msa_mask = np.ones((n_seq, n_res)).astype(np.float32) # no mask here either
msa_mask = np.ones((n_seq, n_res)).astype(
np.float32
) # no mask here either
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
"msa_transition"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_transition"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask
).block_until_ready()
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].msa_transition(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
).cpu()
out_repro = (
model.evoformer.blocks[0]
.msa_transition(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
)
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
......
......@@ -26,14 +26,14 @@ from openfold.np.residue_constants import (
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
tree_map,
tensor_tree_map,
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_4x4
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -48,21 +48,21 @@ class TestFeats(unittest.TestCase):
all_atom_pos,
all_atom_mask,
)
f = hk.transform(test_pbf)
n_res = consts.n_res
n_res = consts.n_res
aatype = np.random.randint(0, 22, (n_res,))
all_atom_pos = np.random.rand(n_res, 37, 3).astype(np.float32)
all_atom_mask = np.random.randint(0, 2, (n_res, 37))
out_gt_pos, out_gt_mask = f.apply(
{}, None, aatype, all_atom_pos, all_atom_mask
)
out_gt_pos = torch.tensor(np.array(out_gt_pos.block_until_ready()))
out_gt_mask = torch.tensor(np.array(out_gt_mask.block_until_ready()))
out_repro_pos, out_repro_mask = feats.pseudo_beta_fn(
torch.tensor(aatype).cuda(),
torch.tensor(all_atom_pos).cuda(),
......@@ -70,7 +70,7 @@ class TestFeats(unittest.TestCase):
)
out_repro_pos = out_repro_pos.cpu()
out_repro_mask = out_repro_mask.cpu()
self.assertTrue(
torch.max(torch.abs(out_gt_pos - out_repro_pos)) < consts.eps
)
......@@ -82,26 +82,26 @@ class TestFeats(unittest.TestCase):
def test_atom37_to_torsion_angles_compare(self):
def run_test(aatype, all_atom_pos, all_atom_mask):
return alphafold.model.all_atom.atom37_to_torsion_angles(
aatype,
all_atom_pos,
aatype,
all_atom_pos,
all_atom_mask,
placeholder_for_undefined=False,
)
f = hk.transform(run_test)
n_templ = 7
n_templ = 7
n_res = 13
aatype = np.random.randint(0, 22, (n_templ, n_res)).astype(np.int64)
all_atom_pos = np.random.rand(n_templ, n_res, 37, 3).astype(np.float32)
all_atom_mask = np.random.randint(
0, 2, (n_templ, n_res, 37)
).astype(np.float32)
all_atom_mask = np.random.randint(0, 2, (n_templ, n_res, 37)).astype(
np.float32
)
out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask)
out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt)
out_repro = feats.atom37_to_torsion_angles(
torch.as_tensor(aatype).cuda(),
torch.as_tensor(all_atom_pos).cuda(),
......@@ -110,20 +110,21 @@ class TestFeats(unittest.TestCase):
tasc = out_repro["torsion_angles_sin_cos"].cpu()
atasc = out_repro["alt_torsion_angles_sin_cos"].cpu()
tam = out_repro["torsion_angles_mask"].cpu()
# This function is extremely sensitive to floating point imprecisions,
# so it is given much greater latitude in comparison tests.
self.assertTrue(
torch.mean(
torch.abs(out_gt["torsion_angles_sin_cos"] - tasc)
) < 0.01
torch.mean(torch.abs(out_gt["torsion_angles_sin_cos"] - tasc))
< 0.01
)
self.assertTrue(
torch.mean(torch.abs(out_gt["alt_torsion_angles_sin_cos"] - atasc))
< 0.01
)
self.assertTrue(
torch.mean(
torch.abs(out_gt["alt_torsion_angles_sin_cos"] - atasc)
) < 0.01
torch.max(torch.abs(out_gt["torsion_angles_mask"] - tam))
< consts.eps
)
self.assertTrue(torch.max(torch.abs(out_gt["torsion_angles_mask"] - tam)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_atom37_to_frames_compare(self):
......@@ -131,48 +132,50 @@ class TestFeats(unittest.TestCase):
return alphafold.model.all_atom.atom37_to_frames(
aatype, all_atom_positions, all_atom_mask
)
f = hk.transform(run_atom37_to_frames)
n_res = consts.n_res
n_res = consts.n_res
batch = {
"aatype": np.random.randint(0, 21, (n_res,)),
"all_atom_positions":
np.random.rand(n_res, 37, 3).astype(np.float32),
"all_atom_mask":
np.random.randint(0, 2, (n_res, 37)).astype(np.float32),
"all_atom_positions": np.random.rand(n_res, 37, 3).astype(
np.float32
),
"all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
np.float32
),
}
out_gt = f.apply({}, None, **batch)
to_tensor = lambda t: torch.tensor(np.array(t))
out_gt = {k:to_tensor(v) for k,v in out_gt.items()}
out_gt = {k: to_tensor(v) for k, v in out_gt.items()}
def flat12_to_4x4(flat12):
rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
trans = flat12[..., 9:]
four_by_four = torch.zeros(*flat12.shape[:-1], 4, 4)
four_by_four[..., :3, :3] = rot
four_by_four[..., :3, 3] = trans
four_by_four[..., 3, 3] = 1
return four_by_four
out_gt["rigidgroups_gt_frames"] = flat12_to_4x4(
out_gt["rigidgroups_gt_frames"]
)
out_gt["rigidgroups_alt_gt_frames"] = flat12_to_4x4(
out_gt["rigidgroups_alt_gt_frames"]
)
to_tensor = lambda t: torch.tensor(np.array(t)).cuda()
batch = tree_map(to_tensor, batch, np.ndarray)
out_repro = data_transforms.atom37_to_frames(batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
for k,v in out_gt.items():
for k, v in out_gt.items():
self.assertTrue(
torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
)
......@@ -190,56 +193,50 @@ class TestFeats(unittest.TestCase):
aas = torch.stack([aas for _ in range(batch_size)])
frames = feats.torsion_angles_to_frames(
ts,
angles,
aas,
ts,
angles,
aas,
torch.tensor(restype_rigid_group_default_frame),
)
self.assertTrue(frames.shape == (batch_size, n, 8))
@compare_utils.skip_unless_alphafold_installed()
def test_torsion_angles_to_frames_compare(self):
def run_torsion_angles_to_frames(
aatype,
backb_to_global,
torsion_angles_sin_cos
aatype, backb_to_global, torsion_angles_sin_cos
):
return alphafold.model.all_atom.torsion_angles_to_frames(
aatype,
backb_to_global,
torsion_angles_sin_cos,
)
f = hk.transform(run_torsion_angles_to_frames)
n_res = consts.n_res
aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float())
torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
out_gt = f.apply(
{}, None, aatype, rigids, torsion_angles_sin_cos
)
out_gt = f.apply({}, None, aatype, rigids, torsion_angles_sin_cos)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out = feats.torsion_angles_to_frames(
transformations.cuda(),
torch.as_tensor(torsion_angles_sin_cos).cuda(),
torch.as_tensor(aatype).cuda(),
torch.tensor(restype_rigid_group_default_frame).cuda(),
)
# Convert the Rigids to 4x4 transformation tensors
rots_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot)
)
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot))
trans_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt.trans)
)
......@@ -250,9 +247,9 @@ class TestFeats(unittest.TestCase):
bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
bottom_row[..., 3] = 1
transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
transforms_repro = out.to_4x4().cpu()
self.assertTrue(
torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
)
......@@ -275,7 +272,7 @@ class TestFeats(unittest.TestCase):
torch.tensor(restype_atom14_mask),
torch.tensor(restype_atom14_rigid_group_positions),
)
self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3))
@compare_utils.skip_unless_alphafold_installed()
......@@ -285,34 +282,32 @@ class TestFeats(unittest.TestCase):
return am.all_atom.frames_and_literature_positions_to_atom14_pos(
aatype, affines
)
f = hk.transform(run_f)
n_res = consts.n_res
aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res, 8))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float())
out_gt = f.apply(
{}, None, aatype, rigids
)
out_gt = f.apply({}, None, aatype, rigids)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = torch.stack(
[torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
)
out_repro = feats.frames_and_literature_positions_to_atom14_pos(
transformations.cuda(),
transformations.cuda(),
torch.as_tensor(aatype).cuda(),
torch.tensor(restype_rigid_group_default_frame).cuda(),
torch.tensor(restype_atom14_to_rigid_group).cuda(),
torch.tensor(restype_atom14_mask).cuda(),
torch.tensor(restype_atom14_rigid_group_positions).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
......
......@@ -24,13 +24,14 @@ from openfold.utils.import_weights import import_jax_weights_
class TestImportWeights(unittest.TestCase):
def test_import_jax_weights_(self):
npz_path = "openfold/resources/params/params_model_1_ptm.npz"
c = model_config("model_1_ptm")
c.globals.blocks_per_ckpt = None
model = AlphaFold(c.model)
import_jax_weights_(
model, npz_path,
model,
npz_path,
)
data = np.load(npz_path)
......@@ -38,23 +39,34 @@ class TestImportWeights(unittest.TestCase):
test_pairs = [
# Normal linear weight
(torch.as_tensor(
data[prefix + "structure_module/initial_projection//weights"]
).transpose(-1, -2),
model.structure_module.linear_in.weight),
(
torch.as_tensor(
data[
prefix + "structure_module/initial_projection//weights"
]
).transpose(-1, -2),
model.structure_module.linear_in.weight,
),
# Normal layer norm param
(torch.as_tensor(
data[prefix + "evoformer/prev_pair_norm//offset"],
),
model.recycling_embedder.layer_norm_z.bias),
(
torch.as_tensor(
data[prefix + "evoformer/prev_pair_norm//offset"],
),
model.recycling_embedder.layer_norm_z.bias,
),
# From a stack
(torch.as_tensor(data[
prefix + (
"evoformer/evoformer_iteration/outer_product_mean/"
"left_projection//weights"
)
][1].transpose(-1, -2)),
model.evoformer.blocks[1].outer_product_mean.linear_1.weight,),
(
torch.as_tensor(
data[
prefix
+ (
"evoformer/evoformer_iteration/outer_product_mean/"
"left_projection//weights"
)
][1].transpose(-1, -2)
),
model.evoformer.blocks[1].outer_product_mean.linear_1.weight,
),
]
for w_alpha, w_repro in test_pairs:
......
......@@ -41,15 +41,15 @@ from openfold.utils.loss import (
tm_loss,
)
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
tree_map,
tensor_tree_map,
dict_multimap,
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_vector, random_affines_4x4
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -99,12 +99,19 @@ class TestLoss(unittest.TestCase):
pred_pos = torch.rand(bs, n, 14, 3)
pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
residue_index = torch.arange(n).unsqueeze(0)
aatype = torch.randint(0, 22, (bs, n,))
aatype = torch.randint(
0,
22,
(
bs,
n,
),
)
between_residue_bond_loss(
pred_pos,
pred_atom_mask,
residue_index,
residue_index,
aatype,
)
......@@ -117,27 +124,26 @@ class TestLoss(unittest.TestCase):
residue_index,
aatype,
)
f = hk.transform(run_brbl)
n_res = consts.n_res
n_res = consts.n_res
pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
pred_atom_mask = np.random.randint(
0, 2, (n_res, 14)
).astype(np.float32)
pred_atom_mask = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
residue_index = np.arange(n_res)
aatype = np.random.randint(0, 22, (n_res,))
out_gt = f.apply(
{}, None,
pred_pos,
pred_atom_mask,
{},
None,
pred_pos,
pred_atom_mask,
residue_index,
aatype,
)
out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
out_repro = between_residue_bond_loss(
torch.tensor(pred_pos).cuda(),
torch.tensor(pred_atom_mask).cuda(),
......@@ -145,13 +151,12 @@ class TestLoss(unittest.TestCase):
torch.tensor(aatype).cuda(),
)
out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
for k in out_gt.keys():
self.assertTrue(
torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
)
def test_run_between_residue_clash_loss(self):
bs = consts.batch_size
n = consts.n_res
......@@ -164,7 +169,7 @@ class TestLoss(unittest.TestCase):
loss = between_residue_clash_loss(
pred_pos,
pred_atom_mask,
atom14_atom_radius,
atom14_atom_radius,
residue_index,
)
......@@ -185,10 +190,13 @@ class TestLoss(unittest.TestCase):
pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
atom_radius = np.random.rand(n_res, 14).astype(np.float32)
res_ind = np.arange(n_res,)
res_ind = np.arange(
n_res,
)
out_gt = f.apply(
{}, None,
{},
None,
pred_pos,
atom_exists,
atom_radius,
......@@ -196,7 +204,7 @@ class TestLoss(unittest.TestCase):
)
out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
out_repro = between_residue_clash_loss(
torch.tensor(pred_pos).cuda(),
torch.tensor(atom_exists).cuda(),
......@@ -204,7 +212,7 @@ class TestLoss(unittest.TestCase):
torch.tensor(res_ind).cuda(),
)
out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
for k in out_gt.keys():
self.assertTrue(
torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
......@@ -221,7 +229,7 @@ class TestLoss(unittest.TestCase):
}
pred_pos = torch.rand(n, 14, 3)
config = {
"clash_overlap_tolerance": 1.5,
"violation_tolerance_factor": 12.0,
......@@ -242,50 +250,44 @@ class TestLoss(unittest.TestCase):
os.chdir(cwd)
return loss
f = hk.transform(run_fsv)
n_res = consts.n_res
batch = {
"atom14_atom_exists": np.random.randint(0, 2, (n_res, 14)),
"residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 20, (n_res,)),
"residx_atom14_to_atom37":
np.random.randint(0, 37, (n_res, 14)).astype(np.int64),
"residx_atom14_to_atom37": np.random.randint(
0, 37, (n_res, 14)
).astype(np.int64),
}
pred_pos = np.random.rand(n_res, 14, 3)
config = mlc.ConfigDict({
"clash_overlap_tolerance": 1.5,
"violation_tolerance_factor": 12.0,
})
out_gt = f.apply(
{}, None,
batch,
pred_pos,
config
config = mlc.ConfigDict(
{
"clash_overlap_tolerance": 1.5,
"violation_tolerance_factor": 12.0,
}
)
out_gt = f.apply({}, None, batch, pred_pos, config)
out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
batch = tree_map(
lambda x: torch.tensor(x).cuda(), batch, np.ndarray
)
batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
out_repro = find_structural_violations(
batch,
torch.tensor(pred_pos).cuda(),
**config,
)
out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
def compare(out):
gt, repro = out
assert(torch.max(torch.abs(gt - repro)) < consts.eps)
assert torch.max(torch.abs(gt - repro)) < consts.eps
dict_multimap(compare, [out_gt, out_repro])
@compare_utils.skip_unless_alphafold_installed()
......@@ -295,44 +297,45 @@ class TestLoss(unittest.TestCase):
batch,
atom14_pred_pos,
)
f = hk.transform(run_crgt)
n_res = consts.n_res
batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"aatype": np.random.randint(0, 20, (n_res,)),
"atom14_gt_positions": np.random.rand(n_res, 14, 3),
"atom14_gt_exists":
np.random.randint(0, 2, (n_res, 14)).astype(np.float32),
"all_atom_mask":
np.random.randint(0, 2, (n_res, 37)).astype(np.float32),
"all_atom_positions":
np.random.rand(n_res, 37, 3).astype(np.float32),
"atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
np.float32
),
"all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
np.float32
),
"all_atom_positions": np.random.rand(n_res, 37, 3).astype(
np.float32
),
}
def _build_extra_feats_np():
b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
b = data_transforms.make_atom14_masks(b)
b = data_transforms.make_atom14_positions(b)
return tensor_tree_map(lambda t: np.array(t), b)
batch = _build_extra_feats_np()
atom14_pred_pos = np.random.rand(n_res, 14, 3)
out_gt = f.apply({}, None, batch, atom14_pred_pos)
out_gt = jax.tree_map(lambda x: torch.tensor(np.array(x)), out_gt)
batch = tree_map(
lambda x: torch.tensor(x).cuda(), batch, np.ndarray
)
batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
out_repro = compute_renamed_ground_truth(batch, atom14_pred_pos)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
for k in out_repro:
self.assertTrue(
torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
......@@ -346,84 +349,76 @@ class TestLoss(unittest.TestCase):
config.model.heads.masked_msa, config.model.global_config
)
return msa_head.loss(value, batch)
f = hk.transform(run_msa_loss)
n_res = consts.n_res
n_seq = consts.n_seq
value = {
"logits": np.random.rand(n_res, n_seq, 23).astype(np.float32),
}
batch = {
"true_msa": np.random.randint(0, 21, (n_res, n_seq)),
"bert_mask":
np.random.randint(0, 2, (n_res, n_seq)).astype(np.float32),
"bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype(
np.float32
),
}
out_gt = f.apply({}, None, value, batch)["loss"]
out_gt = torch.tensor(np.array(out_gt))
value = tree_map(
lambda x: torch.tensor(x).cuda(), value, np.ndarray
)
batch = tree_map(
lambda x: torch.tensor(x).cuda(), batch, np.ndarray
)
value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
with torch.no_grad():
out_repro = masked_msa_loss(
value["logits"],
**batch,
)
)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_distogram_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_distogram = config.model.heads.distogram
def run_distogram_loss(value, batch):
dist_head = alphafold.model.modules.DistogramHead(
c_distogram, config.model.global_config
)
return dist_head.loss(value, batch)
f = hk.transform(run_distogram_loss)
n_res = consts.n_res
value = {
"logits": np.random.rand(
n_res,
n_res,
c_distogram.num_bins
).astype(np.float32),
"logits": np.random.rand(n_res, n_res, c_distogram.num_bins).astype(
np.float32
),
"bin_edges": np.linspace(
c_distogram.first_break,
c_distogram.last_break,
c_distogram.num_bins,
)
),
}
batch = {
"pseudo_beta": np.random.rand(n_res, 3).astype(np.float32),
"pseudo_beta_mask": np.random.randint(0, 2, (n_res,))
"pseudo_beta_mask": np.random.randint(0, 2, (n_res,)),
}
out_gt = f.apply({}, None, value, batch)["loss"]
out_gt = torch.tensor(np.array(out_gt))
value = tree_map(
lambda x: torch.tensor(x).cuda(), value, np.ndarray
)
batch = tree_map(
lambda x: torch.tensor(x).cuda(), batch, np.ndarray
)
value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
with torch.no_grad():
out_repro = distogram_loss(
logits=value["logits"],
......@@ -431,66 +426,64 @@ class TestLoss(unittest.TestCase):
max_bin=c_distogram.last_break,
no_bins=c_distogram.num_bins,
**batch,
)
)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_experimentally_resolved_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_experimentally_resolved = config.model.heads.experimentally_resolved
def run_experimentally_resolved_loss(value, batch):
er_head = alphafold.model.modules.ExperimentallyResolvedHead(
c_experimentally_resolved, config.model.global_config
)
return er_head.loss(value, batch)
f = hk.transform(run_experimentally_resolved_loss)
n_res = consts.n_res
value = {
"logits": np.random.rand(n_res, 37).astype(np.float32),
}
batch = {
"all_atom_mask": np.random.randint(0, 2, (n_res, 37)),
"atom37_atom_exists": np.random.randint(0, 2, (n_res, 37)),
"resolution": np.array(1.0)
"resolution": np.array(1.0),
}
out_gt = f.apply({}, None, value, batch)["loss"]
out_gt = torch.tensor(np.array(out_gt))
value = tree_map(
lambda x: torch.tensor(x).cuda(), value, np.ndarray
)
batch = tree_map(
lambda x: torch.tensor(x).cuda(), batch, np.ndarray
)
value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
with torch.no_grad():
out_repro = experimentally_resolved_loss(
logits=value["logits"],
min_resolution=c_experimentally_resolved.min_resolution,
max_resolution=c_experimentally_resolved.max_resolution,
**batch,
)
)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_supervised_chi_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_chi_loss = config.model.heads.structure_module
def run_supervised_chi_loss(value, batch):
ret = {
"loss": jax.numpy.array(0.),
"loss": jax.numpy.array(0.0),
}
alphafold.model.folding.supervised_chi_loss(
ret, batch, value, c_chi_loss
......@@ -503,10 +496,12 @@ class TestLoss(unittest.TestCase):
value = {
"sidechains": {
"angles_sin_cos":
np.random.rand(8, n_res, 7, 2).astype(np.float32),
"unnormalized_angles_sin_cos":
np.random.rand(8, n_res, 7, 2).astype(np.float32),
"angles_sin_cos": np.random.rand(8, n_res, 7, 2).astype(
np.float32
),
"unnormalized_angles_sin_cos": np.random.rand(
8, n_res, 7, 2
).astype(np.float32),
}
}
......@@ -519,13 +514,9 @@ class TestLoss(unittest.TestCase):
out_gt = f.apply({}, None, value, batch)
out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
value = tree_map(
lambda x: torch.tensor(x).cuda(), value, np.ndarray
)
value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
batch = tree_map(
lambda x: torch.tensor(x).cuda(), batch, np.ndarray
)
batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
batch["chi_angles_sin_cos"] = torch.stack(
[
......@@ -539,9 +530,9 @@ class TestLoss(unittest.TestCase):
out_repro = supervised_chi_loss(
chi_weight=c_chi_loss.chi_weight,
angle_norm_weight=c_chi_loss.angle_norm_weight,
**{**batch, **value["sidechains"]}
)
**{**batch, **value["sidechains"]},
)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
......@@ -550,111 +541,119 @@ class TestLoss(unittest.TestCase):
def test_violation_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_viol = config.model.heads.structure_module
def run_viol_loss(batch, atom14_pred_pos):
ret = {
"loss": np.array(0.).astype(np.float32),
"loss": np.array(0.0).astype(np.float32),
}
value = {}
value["violations"] = (
alphafold.model.folding.find_structural_violations(
batch,
atom14_pred_pos,
c_viol,
)
value[
"violations"
] = alphafold.model.folding.find_structural_violations(
batch,
atom14_pred_pos,
c_viol,
)
alphafold.model.folding.structural_violation_loss(
ret, batch, value, c_viol,
ret,
batch,
value,
c_viol,
)
return ret["loss"]
f = hk.transform(run_viol_loss)
n_res = consts.n_res
batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)),
}
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k:np.array(v) for k,v in batch.items()}
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k: np.array(v) for k, v in batch.items()}
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
out_gt = f.apply({}, None, batch, atom14_pred_pos)
out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
batch = tree_map(
lambda n: torch.tensor(n).cuda(), batch, np.ndarray
)
batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)
atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
batch = data_transforms.make_atom14_masks(batch)
out_repro = violation_loss(
find_structural_violations(batch, atom14_pred_pos, **c_viol),
**batch,
)
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_lddt_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_plddt = config.model.heads.predicted_lddt
def run_plddt_loss(value, batch):
head = alphafold.model.modules.PredictedLDDTHead(
c_plddt, config.model.global_config
)
return head.loss(value, batch)
f = hk.transform(run_plddt_loss)
n_res = consts.n_res
value = {
"predicted_lddt": {
"logits":
np.random.rand(n_res, c_plddt.num_bins).astype(np.float32),
"logits": np.random.rand(n_res, c_plddt.num_bins).astype(
np.float32
),
},
"structure_module": {
"final_atom_positions":
np.random.rand(n_res, 37, 3).astype(np.float32),
}
"final_atom_positions": np.random.rand(n_res, 37, 3).astype(
np.float32
),
},
}
batch = {
"all_atom_positions":
np.random.rand(n_res, 37, 3).astype(np.float32),
"all_atom_mask":
np.random.randint(0, 2, (n_res, 37)).astype(np.float32),
"resolution": np.array(1.).astype(np.float32),
"all_atom_positions": np.random.rand(n_res, 37, 3).astype(
np.float32
),
"all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
np.float32
),
"resolution": np.array(1.0).astype(np.float32),
}
out_gt = f.apply({}, None, value, batch)
out_gt = torch.tensor(np.array(out_gt["loss"]))
to_tensor = lambda t: torch.tensor(t).cuda()
value = tree_map(to_tensor, value, np.ndarray)
batch = tree_map(to_tensor, batch, np.ndarray)
out_repro = lddt_loss(
logits=value["predicted_lddt"]["logits"],
all_atom_pred_pos=value["structure_module"]["final_atom_positions"],
**{**batch, **c_plddt},
)
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_backbone_loss(self):
config = compare_utils.get_alphafold_config()
c_sm = config.model.heads.structure_module
def run_bb_loss(batch, value):
ret = {
"loss": np.array(0.),
"loss": np.array(0.0),
}
alphafold.model.folding.backbone_loss(ret, batch, value, c_sm)
return ret["loss"]
......@@ -665,13 +664,19 @@ class TestLoss(unittest.TestCase):
batch = {
"backbone_affine_tensor": random_affines_vector((n_res,)),
"backbone_affine_mask":
np.random.randint(0, 2, (n_res,)).astype(np.float32),
"use_clamped_fape": np.array(0.),
"backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
np.float32
),
"use_clamped_fape": np.array(0.0),
}
value = {
"traj": random_affines_vector((c_sm.num_layer, n_res,)),
"traj": random_affines_vector(
(
c_sm.num_layer,
n_res,
)
),
}
out_gt = f.apply({}, None, batch, value)
......@@ -695,6 +700,7 @@ class TestLoss(unittest.TestCase):
def test_sidechain_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_sm = config.model.heads.structure_module
def run_sidechain_loss(batch, value, atom14_pred_positions):
batch = {
**batch,
......@@ -702,88 +708,94 @@ class TestLoss(unittest.TestCase):
batch["aatype"],
batch["all_atom_positions"],
batch["all_atom_mask"],
)
),
}
v = {}
v["sidechains"] = {}
v["sidechains"]["frames"] = (
alphafold.model.r3.rigids_from_tensor4x4(
value["sidechains"]["frames"]
)
v["sidechains"][
"frames"
] = alphafold.model.r3.rigids_from_tensor4x4(
value["sidechains"]["frames"]
)
v["sidechains"]["atom_pos"] = alphafold.model.r3.vecs_from_tensor(
value["sidechains"]["atom_pos"]
)
v.update(alphafold.model.folding.compute_renamed_ground_truth(
batch,
atom14_pred_positions,
))
v.update(
alphafold.model.folding.compute_renamed_ground_truth(
batch,
atom14_pred_positions,
)
)
value = v
ret = alphafold.model.folding.sidechain_loss(batch, value, c_sm)
return ret["loss"]
f = hk.transform(run_sidechain_loss)
n_res = consts.n_res
batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"aatype": np.random.randint(0, 20, (n_res,)),
"atom14_gt_positions":
np.random.rand(n_res, 14, 3).astype(np.float32),
"atom14_gt_exists":
np.random.randint(0, 2, (n_res, 14)).astype(np.float32),
"all_atom_positions":
np.random.rand(n_res, 37, 3).astype(np.float32),
"all_atom_mask":
np.random.randint(0, 2, (n_res, 37)).astype(np.float32),
"atom14_gt_positions": np.random.rand(n_res, 14, 3).astype(
np.float32
),
"atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
np.float32
),
"all_atom_positions": np.random.rand(n_res, 37, 3).astype(
np.float32
),
"all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
np.float32
),
}
def _build_extra_feats_np():
b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
b = data_transforms.make_atom14_masks(b)
b = data_transforms.make_atom14_positions(b)
return tensor_tree_map(lambda t: np.array(t), b)
batch = _build_extra_feats_np()
batch = _build_extra_feats_np()
value = {
"sidechains": {
"frames": random_affines_4x4((c_sm.num_layer, n_res, 8)),
"atom_pos":
np.random.rand(
c_sm.num_layer, n_res, 14, 3
).astype(np.float32),
"atom_pos": np.random.rand(c_sm.num_layer, n_res, 14, 3).astype(
np.float32
),
}
}
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
out_gt = f.apply({}, None, batch, value, atom14_pred_pos)
out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
to_tensor = lambda t: torch.tensor(t).cuda()
batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray)
atom14_pred_pos = to_tensor(atom14_pred_pos)
batch = data_transforms.atom37_to_frames(batch)
batch.update(compute_renamed_ground_truth(batch, atom14_pred_pos))
out_repro = sidechain_loss(
sidechain_frames=value["sidechains"]["frames"],
sidechain_atom_pos=value["sidechains"]["atom_pos"],
**{**batch, **c_sm},
)
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tm_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_tm = config.model.heads.predicted_aligned_error
def run_tm_loss(representations, batch, value):
head = alphafold.model.modules.PredictedAlignedErrorHead(
c_tm, config.model.global_config
......@@ -792,58 +804,58 @@ class TestLoss(unittest.TestCase):
v.update(value)
v["predicted_aligned_error"] = head(representations, batch, False)
return head.loss(v, batch)["loss"]
f = hk.transform(run_tm_loss)
n_res = consts.n_res
representations = {
"pair":
np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
"pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
}
batch = {
"backbone_affine_tensor": random_affines_vector((n_res,)),
"backbone_affine_mask":
np.random.randint(0, 2, (n_res,)).astype(np.float32),
"resolution": np.array(1.).astype(np.float32),
"backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
np.float32
),
"resolution": np.array(1.0).astype(np.float32),
}
value = {
"structure_module": {
"final_affines": random_affines_vector((n_res,)),
}
}
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/predicted_aligned_error_head"
)
out_gt = f.apply(params, None, representations, batch, value)
out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
to_tensor = lambda n: torch.tensor(n).cuda()
representations = tree_map(to_tensor, representations, np.ndarray)
batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray)
batch["backbone_affine_tensor"] = (
affine_vector_to_4x4(batch["backbone_affine_tensor"])
batch["backbone_affine_tensor"] = affine_vector_to_4x4(
batch["backbone_affine_tensor"]
)
value["structure_module"]["final_affines"] = (
affine_vector_to_4x4(value["structure_module"]["final_affines"])
value["structure_module"]["final_affines"] = affine_vector_to_4x4(
value["structure_module"]["final_affines"]
)
model = compare_utils.get_global_pretrained_openfold()
logits = model.aux_heads.tm(representations["pair"])
out_repro = tm_loss(
logits=logits,
final_affine_tensor=value["structure_module"]["final_affines"],
**{**batch, **c_tm},
)
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
......
......@@ -29,7 +29,7 @@ from tests.data_utils import (
random_extra_msa_feats,
)
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -43,36 +43,29 @@ class TestModel(unittest.TestCase):
n_extra_seq = consts.n_extra
c = model_config("model_1").model
c.no_cycles = 2
c.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
c.no_cycles = 2
c.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
model = AlphaFold(c)
batch = {}
tf = torch.randint(
c.input_embedder.tf_dim - 1, size=(n_res,)
)
tf = torch.randint(c.input_embedder.tf_dim - 1, size=(n_res,))
batch["target_feat"] = nn.functional.one_hot(
tf, c.input_embedder.tf_dim).float()
tf, c.input_embedder.tf_dim
).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand(
(n_seq, n_res, c.input_embedder.msa_dim)
)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res)
batch.update({k:torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(
n_extra_seq, n_res
)
batch.update({k:torch.tensor(v) for k, v in extra_feats.items()})
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint(
low=0, high=2, size=(n_seq, n_res)
).float()
batch["seq_mask"] = torch.randint(
low=0, high=2, size=(n_res,)
).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(make_atom14_masks(batch))
add_recycling_dims = lambda t: (
......@@ -80,7 +73,7 @@ class TestModel(unittest.TestCase):
)
batch = tensor_tree_map(add_recycling_dims, batch)
with torch.no_grad():
with torch.no_grad():
out = model(batch)
@compare_utils.skip_unless_alphafold_installed()
......@@ -89,12 +82,14 @@ class TestModel(unittest.TestCase):
config = compare_utils.get_alphafold_config()
model = alphafold.model.modules.AlphaFold(config.model)
return model(
batch=batch, is_training=False, return_representations=True,
batch=batch,
is_training=False,
return_representations=True,
)
f = hk.transform(run_alphafold)
params = compare_utils.fetch_alphafold_module_weights('')
params = compare_utils.fetch_alphafold_module_weights("")
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp)
......@@ -107,14 +102,14 @@ class TestModel(unittest.TestCase):
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch)
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
batch = {
k:torch.as_tensor(v).cuda() for k,v in batch.items()
}
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch["aatype"] = batch["aatype"].long()
batch["template_aatype"] = batch["template_aatype"].long()
batch["extra_msa"] = batch["extra_msa"].long()
batch["residx_atom37_to_atom14"] = batch["residx_atom37_to_atom14"].long()
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
......@@ -130,4 +125,3 @@ class TestModel(unittest.TestCase):
out_repro = out_repro.squeeze(0)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 1e-3))
......@@ -24,14 +24,14 @@ from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestMSARowAttentionWithPairBias(unittest.TestCase):
def test_shape(self):
def test_shape(self):
batch_size = consts.batch_size
n_seq = consts.n_seq
n_res = consts.n_res
......@@ -39,7 +39,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
c_z = consts.c_z
c = 52
no_heads = 4
chunk_size=None
chunk_size = None
mrapb = MSARowAttentionWithPairBias(c_m, c_z, c, no_heads, chunk_size)
......@@ -58,29 +58,26 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
msa_row = alphafold.model.modules.MSARowAttentionWithPairBias(
c_e.msa_row_attention_with_pair_bias,
config.model.global_config
)
act = msa_row(
msa_act=msa_act, msa_mask=msa_mask, pair_act=pair_act
c_e.msa_row_attention_with_pair_bias, config.model.global_config
)
act = msa_row(msa_act=msa_act, msa_mask=msa_mask, pair_act=pair_act)
return act
f = hk.transform(run_msa_row_att)
n_res = consts.n_res
n_seq = consts.n_seq
msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
msa_mask = np.random.randint(
low=0, high=2, size=(n_seq, n_res)
).astype(np.float32)
msa_mask = np.random.randint(low=0, high=2, size=(n_seq, n_res)).astype(
np.float32
)
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
"msa_row_attention"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_row_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
......@@ -90,17 +87,21 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].msa_att_row(
torch.as_tensor(msa_act).cuda(),
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(msa_mask).cuda(),
).cpu()
out_repro = (
model.evoformer.blocks[0]
.msa_att_row(
torch.as_tensor(msa_act).cuda(),
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(msa_mask).cuda(),
)
.cpu()
)
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
class TestMSAColumnAttention(unittest.TestCase):
def test_shape(self):
def test_shape(self):
batch_size = consts.batch_size
n_seq = consts.n_seq
n_res = consts.n_res
......@@ -124,47 +125,46 @@ class TestMSAColumnAttention(unittest.TestCase):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
msa_col = alphafold.model.modules.MSAColumnAttention(
c_e.msa_column_attention,
config.model.global_config
)
act = msa_col(
msa_act=msa_act, msa_mask=msa_mask
c_e.msa_column_attention, config.model.global_config
)
act = msa_col(msa_act=msa_act, msa_mask=msa_mask)
return act
f = hk.transform(run_msa_col_att)
n_res = consts.n_res
n_seq = consts.n_seq
msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
msa_mask = np.random.randint(
low=0, high=2, size=(n_seq, n_res)
).astype(np.float32)
msa_mask = np.random.randint(low=0, high=2, size=(n_seq, n_res)).astype(
np.float32
)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
"msa_column_attention"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_column_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask
).block_until_ready()
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].msa_att_col(
torch.as_tensor(msa_act).cuda(),
torch.as_tensor(msa_mask).cuda(),
).cpu()
out_repro = (
model.evoformer.blocks[0]
.msa_att_col(
torch.as_tensor(msa_act).cuda(),
torch.as_tensor(msa_mask).cuda(),
)
.cpu()
)
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
class TestMSAColumnGlobalAttention(unittest.TestCase):
def test_shape(self):
def test_shape(self):
batch_size = consts.batch_size
n_seq = consts.n_seq
n_res = consts.n_res
......@@ -188,40 +188,42 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
msa_col = alphafold.model.modules.MSAColumnGlobalAttention(
c_e.msa_column_attention,
config.model.global_config,
name="msa_column_global_attention"
c_e.msa_column_attention,
config.model.global_config,
name="msa_column_global_attention",
)
act = msa_col(msa_act=msa_act, msa_mask=msa_mask)
return act
f = hk.transform(run_msa_col_global_att)
n_res = consts.n_res
n_seq = consts.n_seq
c_e = consts.c_e
msa_act = np.random.rand(n_seq, n_res, c_e)
msa_mask = np.random.randint(low=0, high=2, size=(n_seq, n_res))
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/" +
"msa_column_global_attention"
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+ "msa_column_global_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask
).block_until_ready()
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.extra_msa_stack.stack.blocks[0].msa_att_col(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
).cpu()
out_repro = (
model.extra_msa_stack.stack.blocks[0]
.msa_att_col(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
)
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
......
......@@ -19,7 +19,8 @@ from openfold.model.outer_product_mean import OuterProductMean
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -40,51 +41,54 @@ class TestOuterProductMean(unittest.TestCase):
m = opm(m, mask)
self.assertTrue(
m.shape == (consts.batch_size, consts.n_res, consts.n_res, consts.c_z)
m.shape
== (consts.batch_size, consts.n_res, consts.n_res, consts.c_z)
)
@compare_utils.skip_unless_alphafold_installed()
def test_opm_compare(self):
def test_opm_compare(self):
def run_opm(msa_act, msa_mask):
config = compare_utils.get_alphafold_config()
c_evo = config.model.embeddings_and_evoformer.evoformer
opm = alphafold.model.modules.OuterProductMean(
c_evo.outer_product_mean,
c_evo.outer_product_mean,
config.model.global_config,
consts.c_z,
)
act = opm(act=msa_act, mask=msa_mask)
return act
f = hk.transform(run_opm)
n_res = consts.n_res
n_seq = consts.n_seq
c_m = consts.c_m
msa_act = np.random.rand(n_seq, n_res, c_m).astype(np.float32) * 100
msa_mask = np.random.randint(
low=0, high=2, size=(n_seq, n_res)
).astype(np.float32)
msa_mask = np.random.randint(low=0, high=2, size=(n_seq, n_res)).astype(
np.float32
)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/" +
"evoformer_iteration/outer_product_mean"
"alphafold/alphafold_iteration/evoformer/"
+ "evoformer_iteration/outer_product_mean"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask
).block_until_ready()
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].outer_product_mean(
torch.as_tensor(msa_act).cuda(),
mask=torch.as_tensor(msa_mask).cuda(),
).cpu()
out_repro = (
model.evoformer.blocks[0]
.outer_product_mean(
torch.as_tensor(msa_act).cuda(),
mask=torch.as_tensor(msa_mask).cuda(),
)
.cpu()
)
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 5e-4))
......
......@@ -20,14 +20,14 @@ from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestPairTransition(unittest.TestCase):
def test_shape(self):
def test_shape(self):
c_z = consts.c_z
n = 4
......@@ -50,42 +50,42 @@ class TestPairTransition(unittest.TestCase):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
pt = alphafold.model.modules.Transition(
c_e.pair_transition,
c_e.pair_transition,
config.model.global_config,
name="pair_transition"
name="pair_transition",
)
act = pt(act=pair_act, mask=pair_mask)
return act
f = hk.transform(run_pair_transition)
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.ones((n_res, n_res)).astype(np.float32) # no mask
pair_mask = np.ones((n_res, n_res)).astype(np.float32) # no mask
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
"pair_transition"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "pair_transition"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, pair_act, pair_mask
).block_until_ready()
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
out_repro = (
model.evoformer.blocks[0]
.pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
)
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
if __name__ == "__main__":
unittest.main()
......@@ -23,7 +23,7 @@ from openfold.np.residue_constants import (
restype_atom14_mask,
restype_atom14_rigid_group_positions,
restype_atom37_mask,
)
)
from openfold.model.structure_module import (
StructureModule,
StructureModuleTransition,
......@@ -39,7 +39,7 @@ from tests.data_utils import (
random_affines_4x4,
)
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -89,9 +89,7 @@ class TestStructureModule(unittest.TestCase):
out = sm(s, z, f)
self.assertTrue(
out["frames"].shape == (no_layers, batch_size, n, 4, 4)
)
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
)
......@@ -121,78 +119,70 @@ class TestStructureModule(unittest.TestCase):
config = compare_utils.get_alphafold_config()
c_sm = config.model.heads.structure_module
c_global = config.model.global_config
def run_sm(representations, batch):
sm = alphafold.model.folding.StructureModule(c_sm, c_global)
representations = {
k:jax.lax.stop_gradient(v) for k,v in representations.items()
}
batch = {
k:jax.lax.stop_gradient(v) for k,v in batch.items()
k: jax.lax.stop_gradient(v) for k, v in representations.items()
}
batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()}
return sm(representations, batch, is_training=False)
f = hk.transform(run_sm)
n_res = 200
representations = {
'single': np.random.rand(n_res, consts.c_s).astype(np.float32),
'pair':
np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
"single": np.random.rand(n_res, consts.c_s).astype(np.float32),
"pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
}
batch = {
'seq_mask': np.random.randint(0, 2, (n_res,)).astype(np.float32),
'aatype': np.random.randint(0, 21, (n_res,)),
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"aatype": np.random.randint(0, 21, (n_res,)),
}
batch['atom14_atom_exists'] = np.take(
restype_atom14_mask,
batch['aatype'],
axis=0
batch["atom14_atom_exists"] = np.take(
restype_atom14_mask, batch["aatype"], axis=0
)
batch['atom37_atom_exists'] = np.take(
restype_atom37_mask,
batch['aatype'],
axis=0
batch["atom37_atom_exists"] = np.take(
restype_atom37_mask, batch["aatype"], axis=0
)
batch.update(make_atom14_masks_np(batch))
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/structure_module"
)
key = jax.random.PRNGKey(42)
out_gt = f.apply(
params, key, representations, batch
)
out_gt = f.apply(params, key, representations, batch)
out_gt = torch.as_tensor(
np.array(out_gt["final_atom14_positions"].block_until_ready())
)
)
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.structure_module(
torch.as_tensor(representations["single"]).cuda(),
torch.as_tensor(representations["pair"]).cuda(),
torch.as_tensor(batch["aatype"]).cuda(),
torch.as_tensor(representations["single"]).cuda(),
torch.as_tensor(representations["pair"]).cuda(),
torch.as_tensor(batch["aatype"]).cuda(),
mask=torch.as_tensor(batch["seq_mask"]).cuda(),
)
out_repro = out_repro["positions"][-1].cpu()
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.01)
class TestBackboneUpdate(unittest.TestCase):
def test_shape(self):
def test_shape(self):
batch_size = 2
n_res = 3
c_in = 5
bu = BackboneUpdate(c_in)
s = torch.rand((batch_size, n_res, c_in))
......@@ -237,25 +227,25 @@ class TestInvariantPointAttention(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_ipa_compare(self):
def run_ipa(act, static_feat_2d, mask, affine):
config = compare_utils.get_alphafold_config()
config = compare_utils.get_alphafold_config()
ipa = alphafold.model.folding.InvariantPointAttention(
config.model.heads.structure_module,
config.model.global_config,
config.model.heads.structure_module,
config.model.global_config,
)
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
affine=affine
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
affine=affine,
)
return attn
f = hk.transform(run_ipa)
n_res = consts.n_res
c_s = consts.c_s
c_z = consts.c_z
sample_act = np.random.rand(n_res, c_s)
sample_2d = np.random.rand(n_res, n_res, c_z)
sample_mask = np.ones((n_res, 1))
......@@ -263,15 +253,13 @@ class TestInvariantPointAttention(unittest.TestCase):
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids)
transformations = T.from_4x4(
torch.as_tensor(affines).float().cuda()
)
transformations = T.from_4x4(torch.as_tensor(affines).float().cuda())
sample_affine = quats
ipa_params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/structure_module/" +
"fold_iteration/invariant_point_attention"
"alphafold/alphafold_iteration/structure_module/"
+ "fold_iteration/invariant_point_attention"
)
out_gt = f.apply(
......@@ -282,17 +270,17 @@ class TestInvariantPointAttention(unittest.TestCase):
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.structure_module.ipa(
torch.as_tensor(sample_act).float().cuda(),
torch.as_tensor(sample_2d).float().cuda(),
transformations,
torch.as_tensor(sample_act).float().cuda(),
torch.as_tensor(sample_2d).float().cuda(),
transformations,
torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
class TestAngleResnet(unittest.TestCase):
def test_shape(self):
def test_shape(self):
batch_size = 2
n = 3
c_s = 13
......@@ -300,7 +288,7 @@ class TestAngleResnet(unittest.TestCase):
no_layers = 5
no_angles = 7
epsilon = 1e-12
ar = AngleResnet(c_s, c_hidden, no_layers, no_angles, epsilon)
a = torch.rand((batch_size, n, c_s))
a_initial = torch.rand((batch_size, n, c_s))
......
......@@ -24,14 +24,14 @@ import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_template_feats
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestTemplatePointwiseAttention(unittest.TestCase):
def test_shape(self):
def test_shape(self):
batch_size = consts.batch_size
n_seq = consts.n_seq
c_t = consts.c_t
......@@ -40,7 +40,7 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
no_heads = 13
n_res = consts.n_res
inf = 1e7
tpa = TemplatePointwiseAttention(
c_t, c_z, c, no_heads, chunk_size=4, inf=inf
)
......@@ -67,8 +67,8 @@ class TestTemplatePairStack(unittest.TestCase):
n_res = consts.n_res
blocks_per_ckpt = None
chunk_size = 4
inf=1e7
eps=1e-7
inf = 1e7
eps = 1e-7
tpe = TemplatePairStack(
c_t,
......@@ -98,45 +98,47 @@ class TestTemplatePairStack(unittest.TestCase):
config = compare_utils.get_alphafold_config()
c_ee = config.model.embeddings_and_evoformer
tps = alphafold.model.modules.TemplatePairStack(
c_ee.template.template_pair_stack,
c_ee.template.template_pair_stack,
config.model.global_config,
name="template_pair_stack"
name="template_pair_stack",
)
act = tps(pair_act, pair_mask, is_training=False)
ln = hk.LayerNorm([-1], True, True, name="output_layer_norm")
act = ln(act)
return act
f = hk.transform(run_template_pair_stack)
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_t).astype(np.float32)
pair_mask = np.random.randint(
low=0, high=2, size=(n_res, n_res)
).astype(np.float32)
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/" +
"single_template_embedding/template_pair_stack"
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_pair_stack"
)
params.update(
compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/output_layer_norm"
)
)
params.update(compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/" +
"single_template_embedding/output_layer_norm"
))
out_gt = f.apply(
params, jax.random.PRNGKey(42), pair_act, pair_mask
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.template_pair_stack(
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
_mask_trans=False,
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
......@@ -146,46 +148,46 @@ class Template(unittest.TestCase):
def test_template_embedding(pair, batch, mask_2d):
config = compare_utils.get_alphafold_config()
te = alphafold.model.modules.TemplateEmbedding(
config.model.embeddings_and_evoformer.template,
config.model.global_config
config.model.embeddings_and_evoformer.template,
config.model.global_config,
)
act = te(pair, batch, mask_2d, is_training=False)
return act
f = hk.transform(test_template_embedding)
n_res = consts.n_res
n_templ = consts.n_templ
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding"
)
out_gt = f.apply(
params, jax.random.PRNGKey(42), pair_act, batch, pair_mask
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.embed_templates(
{k:torch.as_tensor(v).cuda() for k,v in batch.items()},
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
)
out_repro = out_repro["template_pair_embedding"]
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
if __name__ == "__main__":
unittest.main()
unittest.main()
......@@ -21,7 +21,7 @@ from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -34,12 +34,7 @@ class TestTriangularAttention(unittest.TestCase):
no_heads = 4
starting = True
tan = TriangleAttention(
c_z,
c,
no_heads,
starting
)
tan = TriangleAttention(c_z, c, no_heads, starting)
batch_size = consts.batch_size
n_res = consts.n_res
......@@ -53,22 +48,24 @@ class TestTriangularAttention(unittest.TestCase):
def _tri_att_compare(self, starting=False):
name = (
"triangle_attention_" +
("starting" if starting else "ending") +
"_node"
"triangle_attention_"
+ ("starting" if starting else "ending")
+ "_node"
)
def run_tri_att(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
tri_att = alphafold.model.modules.TriangleAttention(
c_e.triangle_attention_starting_node if starting else
c_e.triangle_attention_ending_node,
c_e.triangle_attention_starting_node
if starting
else c_e.triangle_attention_ending_node,
config.model.global_config,
name=name,
)
act = tri_att(pair_act=pair_act, pair_mask=pair_mask)
return act
f = hk.transform(run_tri_att)
n_res = consts.n_res
......@@ -78,24 +75,23 @@ class TestTriangularAttention(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
name
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, pair_act, pair_mask
).block_until_ready()
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].tri_att_start if starting else
model.evoformer.blocks[0].tri_att_end
model.evoformer.blocks[0].tri_att_start
if starting
else model.evoformer.blocks[0].tri_att_end
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
......@@ -110,4 +106,4 @@ class TestTriangularAttention(unittest.TestCase):
if __name__ == "__main__":
unittest.main()
unittest.main()
......@@ -20,14 +20,14 @@ from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_shape(self):
def test_shape(self):
c_z = consts.c_z
c = 11
outgoing = True
......@@ -50,22 +50,23 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
self.assertTrue(shape_before == shape_after)
def _tri_mul_compare(self, incoming=False):
name = (
"triangle_multiplication_" +
("incoming" if incoming else "outgoing")
name = "triangle_multiplication_" + (
"incoming" if incoming else "outgoing"
)
def run_tri_mul(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
tri_mul = alphafold.model.modules.TriangleMultiplication(
c_e.triangle_multiplication_incoming if incoming else
c_e.triangle_multiplication_outgoing,
c_e.triangle_multiplication_incoming
if incoming
else c_e.triangle_multiplication_outgoing,
config.model.global_config,
name=name,
)
act = tri_mul(act=pair_act, mask=pair_mask)
return act
f = hk.transform(run_tri_mul)
n_res = consts.n_res
......@@ -76,24 +77,23 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
name
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, pair_act, pair_mask
).block_until_ready()
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].tri_mul_in if incoming else
model.evoformer.blocks[0].tri_mul_out
model.evoformer.blocks[0].tri_mul_in
if incoming
else model.evoformer.blocks[0].tri_mul_out
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
......@@ -109,4 +109,3 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
if __name__ == "__main__":
unittest.main()
......@@ -20,17 +20,21 @@ from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.tensor_utils import chunk_layer
X_90_ROT = torch.tensor([
[1, 0, 0],
[0, 0,-1],
[0, 1, 0],
])
X_NEG_90_ROT = torch.tensor([
[1, 0, 0],
[0, 0, 1],
[0,-1, 0],
])
X_90_ROT = torch.tensor(
[
[1, 0, 0],
[0, 0, -1],
[0, 1, 0],
]
)
X_NEG_90_ROT = torch.tensor(
[
[1, 0, 0],
[0, 0, 1],
[0, -1, 0],
]
)
class TestAffineT(unittest.TestCase):
......@@ -53,7 +57,7 @@ class TestAffineT(unittest.TestCase):
batch_size = 2
transf = [
[1, 0, 0, 1],
[0, 0,-1, 2],
[0, 0, -1, 2],
[0, 1, 0, 3],
[0, 0, 0, 1],
]
......@@ -62,10 +66,7 @@ class TestAffineT(unittest.TestCase):
true_rot = transf[:3, :3]
true_trans = transf[:3, 3]
transf = torch.stack(
[transf for _ in range(batch_size)],
dim=0
)
transf = torch.stack([transf for _ in range(batch_size)], dim=0)
t = T.from_4x4(transf)
......@@ -78,8 +79,7 @@ class TestAffineT(unittest.TestCase):
batch_size = 2
n = 5
transf = T(
torch.rand((batch_size, n, 3, 3)),
torch.rand((batch_size, n, 3))
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3))
)
self.assertTrue(transf.shape == (batch_size, n))
......@@ -88,12 +88,11 @@ class TestAffineT(unittest.TestCase):
batch_size = 2
n = 5
transf = T(
torch.rand((batch_size, n, 3, 3)),
torch.rand((batch_size, n, 3))
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3))
)
transf_concat = T.concat([transf, transf], dim=0)
self.assertTrue(transf_concat.rots.shape == (batch_size * 2, n, 3, 3))
transf_concat = T.concat([transf, transf], dim=1)
......@@ -124,7 +123,7 @@ class TestAffineT(unittest.TestCase):
x = torch.arange(30)
x = torch.stack([x, x], dim=0)
x = x.view(2, -1, 3) # [2, 10, 3]
x = x.view(2, -1, 3) # [2, 10, 3]
pts = t[..., None].apply(x)
......@@ -165,4 +164,4 @@ class TestAffineT(unittest.TestCase):
self.assertTrue(torch.all(chunked["out"] == unchunked["out"]))
self.assertTrue(
torch.all(chunked["inner"]["out"] == unchunked["inner"]["out"])
)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment