Commit 07e64267 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Standardize code style

parent de07730f
......@@ -35,19 +35,16 @@ class ParamType(Enum):
LinearMHAOutputWeight = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
LinearBiasMHA = partial(
lambda w: w.reshape(*w.shape[:-2], -1)
)
LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1))
LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
Other = partial(
lambda w: w
)
Other = partial(lambda w: w)
def __init__(self, fn):
self.transformation = fn
@dataclass
class Param:
param: Union[torch.Tensor, List[torch.Tensor]]
......@@ -58,16 +55,17 @@ class Param:
def _process_translations_dict(d, top_layer=True):
flat = {}
for k, v in d.items():
if(type(v) == dict):
prefix = _NPZ_KEY_PREFIX if top_layer else ''
if type(v) == dict:
prefix = _NPZ_KEY_PREFIX if top_layer else ""
sub_flat = {
(prefix + '/'.join([k, k_prime])):v_prime
for k_prime, v_prime in
_process_translations_dict(v, top_layer=False).items()
(prefix + "/".join([k, k_prime])): v_prime
for k_prime, v_prime in _process_translations_dict(
v, top_layer=False
).items()
}
flat.update(sub_flat)
else:
k = '/' + k if not top_layer else k
k = "/" + k if not top_layer else k
flat[k] = v
return flat
......@@ -82,19 +80,19 @@ def stacked(param_dict_list, out=None):
"parallel" Params). There must be at least one dict
in the list.
"""
if(out is None):
if out is None:
out = {}
template = param_dict_list[0]
for k, _ in template.items():
v = [d[k] for d in param_dict_list]
if(type(v[0]) is dict):
if type(v[0]) is dict:
out[k] = {}
stacked(v, out=out[k])
elif(type(v[0]) is Param):
elif type(v[0]) is Param:
stacked_param = Param(
param=[param.param for param in v],
param_type=v[0].param_type,
stacked=True
stacked=True,
)
out[k] = stacked_param
......@@ -107,7 +105,7 @@ def assign(translation_dict, orig_weights):
with torch.no_grad():
weights = torch.as_tensor(orig_weights[k])
ref, param_type = param.param, param.param_type
if(param.stacked):
if param.stacked:
weights = torch.unbind(weights, 0)
else:
weights = [weights]
......@@ -131,26 +129,15 @@ def import_jax_weights_(model, npz_path, version="model_1"):
# Some templates
#######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearWeight = lambda l: (
Param(l, param_type=ParamType.LinearWeight)
)
LinearBias = lambda l: (Param(l))
LinearBias = lambda l: (
Param(l)
)
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearWeightMHA = lambda l: (
Param(l, param_type=ParamType.LinearWeightMHA)
)
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearBiasMHA = lambda b: (
Param(b, param_type=ParamType.LinearBiasMHA)
)
LinearWeightOPM = lambda l: (
Param(l, param_type=ParamType.LinearWeightOPM)
)
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearParams = lambda l: {
"weights": LinearWeight(l.weight),
......@@ -167,7 +154,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"key_w": LinearWeightMHA(att.linear_k.weight),
"value_w": LinearWeightMHA(att.linear_v.weight),
"output_w": Param(
att.linear_o.weight, param_type=ParamType.LinearMHAOutputWeight,
att.linear_o.weight,
param_type=ParamType.LinearMHAOutputWeight,
),
"output_b": LinearBias(att.linear_o.bias),
}
......@@ -231,7 +219,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
MSAGlobalAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": GlobalAttentionParams(matt.global_attention)
"attention": GlobalAttentionParams(matt.global_attention),
}
MSAAttPairBiasParams = lambda matt: dict(
......@@ -247,8 +235,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points),
"kv_point_local": LinearParams(ipa.linear_kv_points),
"trainable_point_weights":
Param(param=ipa.head_weights, param_type=ParamType.Other),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
......@@ -276,7 +265,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
}
def EvoformerBlockParams(b, is_extra_msa=False):
if(is_extra_msa):
if is_extra_msa:
col_att_name = "msa_column_global_attention"
msa_col_att_params = MSAGlobalAttParams(b.msa_att_col)
else:
......@@ -284,8 +273,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
msa_col_att_params = MSAAttParams(b.msa_att_col)
d = {
"msa_row_attention_with_pair_bias":
MSAAttPairBiasParams(b.msa_att_row),
"msa_row_attention_with_pair_bias": MSAAttPairBiasParams(
b.msa_att_row
),
col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.msa_transition),
"outer_product_mean": OuterProductMeanParams(b.outer_product_mean),
......@@ -316,9 +306,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
},
}
}
############################
# translations dict overflow
......@@ -330,14 +319,10 @@ def import_jax_weights_(model, npz_path, version="model_1"):
)
ems_blocks = model.extra_msa_stack.stack.blocks
ems_blocks_params = stacked(
[ExtraMSABlockParams(b) for b in ems_blocks]
)
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked(
[EvoformerBlockParams(b) for b in evo_blocks]
)
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
translations = {
"evoformer": {
......@@ -346,64 +331,72 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm":
LayerNormParams(model.recycling_embedder.layer_norm_m),
"prev_pair_norm":
LayerNormParams(model.recycling_embedder.layer_norm_z),
"pair_activiations":
LinearParams(model.input_embedder.linear_relpos),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"pair_activiations": LinearParams(
model.input_embedder.linear_relpos
),
"template_embedding": {
"single_template_embedding": {
"embedding2d":
LinearParams(model.template_pair_embedder.linear),
"embedding2d": LinearParams(
model.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm":
LayerNormParams(model.template_pair_stack.layer_norm),
"output_layer_norm": LayerNormParams(
model.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(model.template_pointwise_att.mha),
},
"extra_msa_activations":
LinearParams(model.extra_msa_embedder.linear),
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"template_single_embedding":
LinearParams(model.template_angle_embedder.linear_1),
"template_projection":
LinearParams(model.template_angle_embedder.linear_2),
"template_single_embedding": LinearParams(
model.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_angle_embedder.linear_2
),
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm":
LayerNormParams(model.structure_module.layer_norm_s),
"initial_projection":
LinearParams(model.structure_module.linear_in),
"pair_layer_norm":
LayerNormParams(model.structure_module.layer_norm_z),
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(
model.structure_module.linear_in
),
"pair_layer_norm": LayerNormParams(
model.structure_module.layer_norm_z
),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm":
LayerNormParams(model.aux_heads.plddt.layer_norm),
"act_0":
LinearParams(model.aux_heads.plddt.linear_1),
"act_1":
LinearParams(model.aux_heads.plddt.linear_2),
"logits":
LinearParams(model.aux_heads.plddt.linear_3),
"input_layer_norm": LayerNormParams(
model.aux_heads.plddt.layer_norm
),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits":
LinearParams(model.aux_heads.distogram.linear),
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits":
LinearParams(model.aux_heads.experimentally_resolved.linear),
"logits": LinearParams(
model.aux_heads.experimentally_resolved.linear
),
},
"masked_msa_head": {
"logits":
LinearParams(model.aux_heads.masked_msa.linear),
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
......@@ -415,17 +408,16 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"model_4_ptm",
"model_5_ptm",
]
if(version in no_templ):
if version in no_templ:
evo_dict = translations["evoformer"]
keys = list(evo_dict.keys())
for k in keys:
if("template_" in k):
if "template_" in k:
evo_dict.pop(k)
if("_ptm" in version):
if "_ptm" in version:
translations["predicted_aligned_error_head"] = {
"logits":
LinearParams(model.aux_heads.tm.linear)
"logits": LinearParams(model.aux_heads.tm.linear)
}
# Flatten keys and insert missing key prefixes
......@@ -436,10 +428,10 @@ def import_jax_weights_(model, npz_path, version="model_1"):
flat_keys = list(flat.keys())
incorrect = [k for k in flat_keys if k not in keys]
missing = [k for k in keys if k not in flat_keys]
#print(f"Incorrect: {incorrect}")
#print(f"Missing: {missing}")
# print(f"Incorrect: {incorrect}")
# print(f"Missing: {missing}")
assert(len(incorrect) == 0)
assert len(incorrect) == 0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
# Set weights
......
......@@ -81,7 +81,7 @@ def compute_fape(
positions_mask: torch.Tensor,
length_scale: float,
l1_clamp_distance: Optional[float] = None,
eps=1e-8
eps=1e-8,
) -> torch.Tensor:
# [*, N_frames, N_pts, 3]
local_pred_pos = pred_frames.invert()[..., None].apply(
......@@ -91,10 +91,10 @@ def compute_fape(
target_positions[..., None, :, :],
)
error_dist = torch.sqrt(
torch.sum((local_pred_pos - local_target_pos)**2, dim=-1) + eps
torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
)
if(l1_clamp_distance is not None):
if l1_clamp_distance is not None:
error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
normed_error = error_dist / length_scale
......@@ -111,7 +111,9 @@ def compute_fape(
#
# ("roughly" because eps is necessarily duplicated in the latter
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
normed_error = (
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
......@@ -126,8 +128,8 @@ def backbone_loss(
backbone_affine_mask: torch.Tensor,
traj: torch.Tensor,
use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.,
loss_unit_distance: float = 10.,
clamp_distance: float = 10.0,
loss_unit_distance: float = 10.0,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
......@@ -145,7 +147,7 @@ def backbone_loss(
length_scale=loss_unit_distance,
eps=eps,
)
if(use_clamped_fape is not None):
if use_clamped_fape is not None:
unclamped_fape_loss = compute_fape(
pred_aff,
gt_aff[..., None, :],
......@@ -158,9 +160,8 @@ def backbone_loss(
eps=eps,
)
fape_loss = (
fape_loss * use_clamped_fape +
unclamped_fape_loss * (1 - use_clamped_fape)
fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (
1 - use_clamped_fape
)
# Take the mean over the layer dimension
......@@ -177,42 +178,31 @@ def sidechain_loss(
renamed_atom14_gt_positions: torch.Tensor,
renamed_atom14_gt_exists: torch.Tensor,
alt_naming_is_better: torch.Tensor,
clamp_distance: float = 10.,
length_scale: float = 10.,
clamp_distance: float = 10.0,
length_scale: float = 10.0,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
renamed_gt_frames = (
(1. - alt_naming_is_better[..., None, None, None]) *
rigidgroups_gt_frames +
alt_naming_is_better[..., None, None, None] *
rigidgroups_alt_gt_frames
)
1.0 - alt_naming_is_better[..., None, None, None]
) * rigidgroups_gt_frames + alt_naming_is_better[
..., None, None, None
] * rigidgroups_alt_gt_frames
# Steamroll the inputs
sidechain_frames = sidechain_frames[-1]
batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view(
*batch_dims, -1, 4, 4
)
sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
sidechain_frames = T.from_4x4(sidechain_frames)
renamed_gt_frames = renamed_gt_frames.view(
*batch_dims, -1, 4, 4
)
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
renamed_gt_frames = T.from_4x4(renamed_gt_frames)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(
*batch_dims, -1
)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
sidechain_atom_pos = sidechain_atom_pos[-1]
sidechain_atom_pos = sidechain_atom_pos.view(
*batch_dims, -1, 3
)
sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(
*batch_dims, -1, 3
)
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(
*batch_dims, -1
)
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)
fape = compute_fape(
sidechain_frames,
......@@ -235,19 +225,17 @@ def fape_loss(
config: ml_collections.ConfigDict,
) -> torch.Tensor:
bb_loss = backbone_loss(
traj=out["sm"]["frames"], **{**batch, **config.backbone},
traj=out["sm"]["frames"],
**{**batch, **config.backbone},
)
sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"],
out["sm"]["positions"],
**{**batch, **config.sidechain}
**{**batch, **config.sidechain},
)
return (
config.backbone.weight * bb_loss +
config.sidechain.weight * sc_loss
)
return config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
def supervised_chi_loss(
......@@ -264,7 +252,8 @@ def supervised_chi_loss(
) -> torch.Tensor:
pred_angles = angles_sin_cos[..., 3:, :]
residue_type_one_hot = torch.nn.functional.one_hot(
aatype, residue_constants.restype_num + 1,
aatype,
residue_constants.restype_num + 1,
)
chi_pi_periodic = torch.einsum(
"...ij,jk->ik",
......@@ -276,11 +265,9 @@ def supervised_chi_loss(
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
true_chi_shifted = shifted_mask * true_chi
sq_chi_error = torch.sum(
(true_chi - pred_angles)**2, dim=-1
)
sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
sq_chi_error_shifted = torch.sum(
(true_chi_shifted - pred_angles)**2, dim=-1
(true_chi_shifted - pred_angles) ** 2, dim=-1
)
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
# The ol' switcheroo
......@@ -295,9 +282,9 @@ def supervised_chi_loss(
loss = loss + chi_weight * sq_chi_loss
angle_norm = torch.sqrt(
torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps
torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
)
norm_error = torch.abs(angle_norm - 1.)
norm_error = torch.abs(angle_norm - 1.0)
norm_error = norm_error.permute(
*range(len(norm_error.shape))[1:-2], 0, -2, -1
)
......@@ -312,14 +299,13 @@ def supervised_chi_loss(
def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
num_bins = logits.shape[-1]
bin_width = 1. / num_bins
bin_width = 1.0 / num_bins
bounds = torch.arange(
start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device
)
probs = torch.nn.functional.softmax(logits, dim=-1)
pred_lddt_ca = torch.sum(
probs *
bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
dim=-1,
)
return pred_lddt_ca * 100
......@@ -331,7 +317,7 @@ def lddt_loss(
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.,
cutoff: float = 15.0,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
......@@ -343,47 +329,49 @@ def lddt_loss(
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos:(ca_pos + 1)] # keep dim
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
dmat_true = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
all_atom_positions[..., None, :] -
all_atom_positions[..., None, :, :]
)**2,
all_atom_positions[..., None, :]
- all_atom_positions[..., None, :, :]
)
** 2,
dim=-1,
)
)
dmat_pred = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
all_atom_pred_pos[..., None, :] -
all_atom_pred_pos[..., None, :, :]
)**2,
all_atom_pred_pos[..., None, :]
- all_atom_pred_pos[..., None, :, :]
)
** 2,
dim=-1,
)
)
dists_to_score = (
(dmat_true < cutoff) * all_atom_mask *
permute_final_dims(all_atom_mask, (1, 0)) *
(1. - torch.eye(n, device=all_atom_mask.device))
(dmat_true < cutoff)
* all_atom_mask
* permute_final_dims(all_atom_mask, (1, 0))
* (1.0 - torch.eye(n, device=all_atom_mask.device))
)
dist_l1 = torch.abs(dmat_true - dmat_pred)
score = (
(dist_l1 < 0.5).type(dist_l1.dtype) +
(dist_l1 < 1.0).type(dist_l1.dtype) +
(dist_l1 < 2.0).type(dist_l1.dtype) +
(dist_l1 < 4.0).type(dist_l1.dtype)
(dist_l1 < 0.5).type(dist_l1.dtype)
+ (dist_l1 < 1.0).type(dist_l1.dtype)
+ (dist_l1 < 2.0).type(dist_l1.dtype)
+ (dist_l1 < 4.0).type(dist_l1.dtype)
)
score = score * 0.25
norm = 1. / (eps + torch.sum(dists_to_score, dim=-1))
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=-1))
score = norm * (eps + torch.sum(dists_to_score * score, dim=-1))
score = score.detach()
......@@ -396,14 +384,12 @@ def lddt_loss(
errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
all_atom_mask = all_atom_mask.squeeze(-1)
loss = (
torch.sum(errors * all_atom_mask, dim=-1) /
(eps + torch.sum(all_atom_mask, dim=-1))
loss = torch.sum(errors * all_atom_mask, dim=-1) / (
eps + torch.sum(all_atom_mask, dim=-1)
)
loss = loss * (
(resolution >= min_resolution) &
(resolution <= max_resolution)
(resolution >= min_resolution) & (resolution <= max_resolution)
)
return loss
......@@ -420,16 +406,17 @@ def distogram_loss(
**kwargs,
):
boundaries = torch.linspace(
min_bin, max_bin, no_bins - 1, device=logits.device,
min_bin,
max_bin,
no_bins - 1,
device=logits.device,
)
boundaries = boundaries ** 2
dists = torch.sum(
(
pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]
) ** 2,
(pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
dim=-1,
keepdims=True
keepdims=True,
)
true_bins = torch.sum(dists > boundaries, dim=-1)
......@@ -469,7 +456,7 @@ def _calculate_expected_aligned_error(
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
return (
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
bin_centers[-1]
bin_centers[-1],
)
......@@ -494,18 +481,16 @@ def compute_predicted_aligned_error(
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
boundaries = torch.linspace(
0,
max_bin,
steps=(no_bins - 1),
device=logits.device
0, max_bin, steps=(no_bins - 1), device=logits.device
)
aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
predicted_aligned_error, max_predicted_aligned_error = (
_calculate_expected_aligned_error(
(
predicted_aligned_error,
max_predicted_aligned_error,
) = _calculate_expected_aligned_error(
alignment_confidence_breaks=boundaries,
aligned_distance_error_probs=aligned_confidence_probs
)
aligned_distance_error_probs=aligned_confidence_probs,
)
return {
......@@ -523,14 +508,11 @@ def compute_tm(
eps: float = 1e-8,
**kwargs,
) -> torch.Tensor:
if(residue_weights is None):
if residue_weights is None:
residue_weights = logits.new_ones(logits.shape[-2])
boundaries = torch.linspace(
0,
max_bin,
steps=(no_bins - 1),
device=logits.device
0, max_bin, steps=(no_bins - 1), device=logits.device
)
bin_centers = _calculate_bin_centers(boundaries)
......@@ -538,11 +520,11 @@ def compute_tm(
n = logits.shape[-2]
clipped_n = max(n, 19)
d0 = 1.24 * (clipped_n - 15) ** (1./3) - 1.8
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
probs = torch.nn.functional.softmax(logits, dim=-1)
tm_per_bin = 1. / (1 + (bin_centers ** 2) / (d0 ** 2))
tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
......@@ -573,25 +555,18 @@ def tm_loss(
return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum(
(_points(pred_affine) - _points(backbone_affine)) ** 2,
dim=-1
(_points(pred_affine) - _points(backbone_affine)) ** 2, dim=-1
)
sq_diff = sq_diff.detach()
boundaries = torch.linspace(
0,
max_bin,
steps=(no_bins - 1),
device=logits.device
0, max_bin, steps=(no_bins - 1), device=logits.device
)
boundaries = boundaries ** 2
true_bins = torch.sum(
sq_diff[..., None] > boundaries, dim=-1
)
true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1)
errors = softmax_cross_entropy(
logits,
torch.nn.functional.one_hot(true_bins, no_bins)
logits, torch.nn.functional.one_hot(true_bins, no_bins)
)
square_mask = (
......@@ -606,8 +581,7 @@ def tm_loss(
loss = loss * scale
loss = loss * (
(resolution >= min_resolution) &
(resolution <= max_resolution)
(resolution >= min_resolution) & (resolution <= max_resolution)
)
return loss
......@@ -659,42 +633,36 @@ def between_residue_bond_loss(
next_n_mask = pred_atom_mask[..., 1:, 0]
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
next_ca_mask = pred_atom_mask[..., 1:, 1]
has_no_gap_mask = (
(residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
)
has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
# Compute loss for the C--N bond.
c_n_bond_length = torch.sqrt(
eps +
torch.sum(
(this_c_pos - next_n_pos)**2, dim=-1
)
eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1)
)
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline = (
aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
)
next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
gt_length = (
(~next_is_proline) * residue_constants.between_res_bond_length_c_n[0]
+ next_is_proline * residue_constants.between_res_bond_length_c_n[1]
)
~next_is_proline
) * residue_constants.between_res_bond_length_c_n[
0
] + next_is_proline * residue_constants.between_res_bond_length_c_n[
1
]
gt_stddev = (
(~next_is_proline) *
residue_constants.between_res_bond_length_stddev_c_n[0] +
next_is_proline *
residue_constants.between_res_bond_length_stddev_c_n[1]
)
c_n_bond_length_error = torch.sqrt(
eps + (c_n_bond_length - gt_length)**2
)
~next_is_proline
) * residue_constants.between_res_bond_length_stddev_c_n[
0
] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
1
]
c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
c_n_loss_per_residue = torch.nn.functional.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev
)
mask = this_c_mask * next_n_mask * has_no_gap_mask
c_n_loss = (
torch.sum(mask * c_n_loss_per_residue, dim=-1) /
(torch.sum(mask, dim=-1) + eps)
c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
c_n_violation_mask = mask * (
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)
......@@ -702,10 +670,10 @@ def between_residue_bond_loss(
# Compute loss for the angles.
ca_c_bond_length = torch.sqrt(
eps + torch.sum((this_ca_pos - this_c_pos)**2, dim=-1)
eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1)
)
n_ca_bond_length = torch.sqrt(
eps + torch.sum((next_n_pos - next_ca_pos)**2, dim=-1)
eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1)
)
c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None]
......@@ -716,31 +684,31 @@ def between_residue_bond_loss(
gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
ca_c_n_cos_angle_error = torch.sqrt(
eps + (ca_c_n_cos_angle - gt_angle)**2
eps + (ca_c_n_cos_angle - gt_angle) ** 2
)
ca_c_n_loss_per_residue = torch.nn.functional.relu(
ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev
)
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
ca_c_n_loss = (
torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) /
(torch.sum(mask, dim=-1) + eps)
ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
ca_c_n_violation_mask = mask * (
ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)
)
ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error >
(tolerance_factor_hard * gt_stddev))
c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1)
gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
c_n_ca_cos_angle_error = torch.sqrt(
eps + torch.square(c_n_ca_cos_angle - gt_angle))
eps + torch.square(c_n_ca_cos_angle - gt_angle)
)
c_n_ca_loss_per_residue = torch.nn.functional.relu(
c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev
)
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
c_n_ca_loss = (
torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) /
(torch.sum(mask, dim=-1) + eps)
c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
c_n_ca_violation_mask = mask * (
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)
......@@ -748,37 +716,33 @@ def between_residue_bond_loss(
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum = (c_n_loss_per_residue +
ca_c_n_loss_per_residue +
c_n_ca_loss_per_residue)
per_residue_loss_sum = (
c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue
)
per_residue_loss_sum = 0.5 * (
torch.nn.functional.pad(per_residue_loss_sum, (0, 1)) +
torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
torch.nn.functional.pad(per_residue_loss_sum, (0, 1))
+ torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
)
# Compute hard violations.
violation_mask = torch.max(
torch.stack(
[
c_n_violation_mask,
ca_c_n_violation_mask,
c_n_ca_violation_mask
],
[c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask],
dim=-2,
),
dim=-2
dim=-2,
)[0]
violation_mask = torch.maximum(
torch.nn.functional.pad(violation_mask, (0, 1)),
torch.nn.functional.pad(violation_mask, (1, 0))
torch.nn.functional.pad(violation_mask, (1, 0)),
)
return {
'c_n_loss_mean': c_n_loss,
'ca_c_n_loss_mean': ca_c_n_loss,
'c_n_ca_loss_mean': c_n_ca_loss,
'per_residue_loss_sum': per_residue_loss_sum,
'per_residue_violation_mask': violation_mask
"c_n_loss_mean": c_n_loss,
"ca_c_n_loss_mean": ca_c_n_loss,
"c_n_ca_loss_mean": c_n_ca_loss,
"per_residue_loss_sum": per_residue_loss_sum,
"per_residue_violation_mask": violation_mask,
}
......@@ -820,27 +784,30 @@ def between_residue_clash_loss(
# Create the distance matrix.
# (N, N, 14, 14)
dists = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
atom14_pred_positions[..., :, None, :, None, :] -
atom14_pred_positions[..., None, :, None, :, :]
)**2,
dim=-1)
atom14_pred_positions[..., :, None, :, None, :]
- atom14_pred_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
# Create the mask for valid distances.
# shape (N, N, 14, 14)
dists_mask = (
atom14_atom_exists[..., :, None, :, None] *
atom14_atom_exists[..., None, :, None, :]
atom14_atom_exists[..., :, None, :, None]
* atom14_atom_exists[..., None, :, None, :]
).type(fp_type)
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask = dists_mask * (
residue_index[..., :, None, None, None] < residue_index[..., None, :, None, None]
residue_index[..., :, None, None, None]
< residue_index[..., None, :, None, None]
)
# Backbone C--N bond between subsequent residues is no clash.
......@@ -860,36 +827,34 @@ def between_residue_clash_loss(
n_one_hot = n_one_hot.type(fp_type)
neighbour_mask = (
(residue_index[..., :, None, None, None] + 1) ==
residue_index[..., None, :, None, None]
)
residue_index[..., :, None, None, None] + 1
) == residue_index[..., None, :, None, None]
c_n_bonds = (
neighbour_mask *
c_one_hot[..., None, None, :, None] *
n_one_hot[..., None, None, None, :]
neighbour_mask
* c_one_hot[..., None, None, :, None]
* n_one_hot[..., None, None, None, :]
)
dists_mask = dists_mask * (1. - c_n_bonds)
dists_mask = dists_mask * (1.0 - c_n_bonds)
# Disulfide bridge between two cysteines is no clash.
cys = residue_constants.restype_name_to_atom14_names["CYS"]
cys_sg_idx = cys.index('SG')
cys_sg_idx = cys.index("SG")
cys_sg_idx = residue_index.new_tensor(cys_sg_idx)
cys_sg_idx = cys_sg_idx.reshape(
*((1,) * len(residue_index.shape[:-1])), 1
).squeeze(-1)
cys_sg_one_hot = torch.nn.functional.one_hot(
cys_sg_idx, num_classes=14
)
cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14)
disulfide_bonds = (
cys_sg_one_hot[..., None, None, :, None] *
cys_sg_one_hot[..., None, None, None, :])
dists_mask = dists_mask * (1. - disulfide_bonds)
cys_sg_one_hot[..., None, None, :, None]
* cys_sg_one_hot[..., None, None, None, :]
)
dists_mask = dists_mask * (1.0 - disulfide_bonds)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
dists_lower_bound = dists_mask * (
atom14_atom_radius[..., :, None, :, None] +
atom14_atom_radius[..., None, :, None, :]
atom14_atom_radius[..., :, None, :, None]
+ atom14_atom_radius[..., None, :, None, :]
)
# Compute the error.
......@@ -900,15 +865,12 @@ def between_residue_clash_loss(
# Compute the mean loss.
# shape ()
mean_loss = (
torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask))
)
mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask))
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = (
torch.sum(dists_to_low_error, dim=(-4, -2)) +
torch.sum(dists_to_low_error, axis=(-3, -1))
per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
dists_to_low_error, axis=(-3, -1)
)
# Compute the hard clash mask.
......@@ -925,9 +887,9 @@ def between_residue_clash_loss(
)
return {
'mean_loss': mean_loss, # shape ()
'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14)
'per_atom_clash_mask': per_atom_clash_mask # shape (N, 14)
"mean_loss": mean_loss, # shape ()
"per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
"per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
}
......@@ -967,27 +929,26 @@ def within_residue_violations(
mask whether atom clashes with any other atom shape
"""
# Compute the mask for each residue.
dists_masks = (
1. - torch.eye(14, device=atom14_atom_exists.device)[None]
)
dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None]
dists_masks = dists_masks.reshape(
*((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape
)
dists_masks = (
atom14_atom_exists[..., :, :, None] *
atom14_atom_exists[..., :, None, :] *
dists_masks
atom14_atom_exists[..., :, :, None]
* atom14_atom_exists[..., :, None, :]
* dists_masks
)
# Distance matrix
dists = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
atom14_pred_positions[..., :, :, None, :] -
atom14_pred_positions[..., :, None, :, :]
)**2,
dim=-1
atom14_pred_positions[..., :, :, None, :]
- atom14_pred_positions[..., :, None, :, :]
)
** 2,
dim=-1,
)
)
......@@ -1001,18 +962,11 @@ def within_residue_violations(
loss = dists_masks * (dists_to_low_error + dists_to_high_error)
# Compute the per atom loss sum.
per_atom_loss_sum = (
torch.sum(loss, dim=-2) +
torch.sum(loss, dim=-1)
)
per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1)
# Compute the violations mask.
violations = (
dists_masks *
(
(dists < atom14_dists_lower_bound) |
(dists > atom14_dists_upper_bound)
)
violations = dists_masks * (
(dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
)
# Compute the per atom violations.
......@@ -1021,12 +975,11 @@ def within_residue_violations(
)
return {
'per_atom_loss_sum': per_atom_loss_sum,
'per_atom_violations': per_atom_violations
"per_atom_loss_sum": per_atom_loss_sum,
"per_atom_violations": per_atom_violations,
}
def find_structural_violations(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor,
......@@ -1043,7 +996,7 @@ def find_structural_violations(
residue_index=batch["residue_index"],
aatype=batch["aatype"],
tolerance_factor_soft=violation_tolerance_factor,
tolerance_factor_hard=violation_tolerance_factor
tolerance_factor_hard=violation_tolerance_factor,
)
# Compute the Van der Waals radius for every atom
......@@ -1053,12 +1006,10 @@ def find_structural_violations(
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
atomtype_radius = atom14_pred_positions.new_tensor(
atomtype_radius
)
atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
atom14_atom_radius = (
batch["atom14_atom_exists"] *
atomtype_radius[batch["residx_atom14_to_atom37"]]
batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]]
)
# Compute the between residue clash loss.
......@@ -1068,32 +1019,28 @@ def find_structural_violations(
atom14_atom_radius=atom14_atom_radius,
residue_index=batch["residue_index"],
overlap_tolerance_soft=clash_overlap_tolerance,
overlap_tolerance_hard=clash_overlap_tolerance
overlap_tolerance_hard=clash_overlap_tolerance,
)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
overlap_tolerance=clash_overlap_tolerance,
bond_length_tolerance_factor=violation_tolerance_factor
bond_length_tolerance_factor=violation_tolerance_factor,
)
atom14_atom_exists = batch["atom14_atom_exists"]
atom14_dists_lower_bound = (
atom14_pred_positions.new_tensor(restype_atom14_bounds["lower_bound"])[
batch["aatype"]
]
)
atom14_dists_upper_bound = (
atom14_pred_positions.new_tensor(restype_atom14_bounds["upper_bound"])[
batch["aatype"]
]
)
atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
restype_atom14_bounds["lower_bound"]
)[batch["aatype"]]
atom14_dists_upper_bound = atom14_pred_positions.new_tensor(
restype_atom14_bounds["upper_bound"]
)[batch["aatype"]]
residue_violations = within_residue_violations(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch["atom14_atom_exists"],
atom14_dists_lower_bound=atom14_dists_lower_bound,
atom14_dists_upper_bound=atom14_dists_upper_bound,
tighten_bounds_for_loss=0.0
tighten_bounds_for_loss=0.0,
)
# Combine them to a single per-residue violation mask (used later for LDDT).
......@@ -1104,9 +1051,7 @@ def find_structural_violations(
torch.max(
between_residue_clashes["per_atom_clash_mask"], dim=-1
)[0],
torch.max(
residue_violations["per_atom_violations"], dim=-1
)[0],
torch.max(residue_violations["per_atom_violations"], dim=-1)[0],
],
dim=-1,
),
......@@ -1114,39 +1059,44 @@ def find_structural_violations(
)[0]
return {
'between_residues': {
'bonds_c_n_loss_mean':
connection_violations["c_n_loss_mean"], # ()
'angles_ca_c_n_loss_mean':
connection_violations["ca_c_n_loss_mean"], # ()
'angles_c_n_ca_loss_mean':
connection_violations["c_n_ca_loss_mean"], # ()
'connections_per_residue_loss_sum':
connection_violations["per_residue_loss_sum"], # (N)
'connections_per_residue_violation_mask':
connection_violations["per_residue_violation_mask"], # (N)
'clashes_mean_loss':
between_residue_clashes["mean_loss"], # ()
'clashes_per_atom_loss_sum':
between_residue_clashes["per_atom_loss_sum"], # (N, 14)
'clashes_per_atom_clash_mask':
between_residue_clashes["per_atom_clash_mask"], # (N, 14)
"between_residues": {
"bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # ()
"angles_ca_c_n_loss_mean": connection_violations[
"ca_c_n_loss_mean"
], # ()
"angles_c_n_ca_loss_mean": connection_violations[
"c_n_ca_loss_mean"
], # ()
"connections_per_residue_loss_sum": connection_violations[
"per_residue_loss_sum"
], # (N)
"connections_per_residue_violation_mask": connection_violations[
"per_residue_violation_mask"
], # (N)
"clashes_mean_loss": between_residue_clashes["mean_loss"], # ()
"clashes_per_atom_loss_sum": between_residue_clashes[
"per_atom_loss_sum"
], # (N, 14)
"clashes_per_atom_clash_mask": between_residue_clashes[
"per_atom_clash_mask"
], # (N, 14)
},
'within_residues': {
'per_atom_loss_sum':
residue_violations["per_atom_loss_sum"], # (N, 14)
'per_atom_violations':
residue_violations["per_atom_violations"], # (N, 14),
"within_residues": {
"per_atom_loss_sum": residue_violations[
"per_atom_loss_sum"
], # (N, 14)
"per_atom_violations": residue_violations[
"per_atom_violations"
], # (N, 14),
},
'total_per_residue_violations_mask':
per_residue_violations_mask, # (N)
"total_per_residue_violations_mask": per_residue_violations_mask, # (N)
}
def find_structural_violations_np(
batch: Dict[str, np.ndarray],
atom14_pred_positions: np.ndarray,
config: ml_collections.ConfigDict
config: ml_collections.ConfigDict,
) -> Dict[str, np.ndarray]:
to_tensor = lambda x: torch.tensor(x)
batch = tree_map(to_tensor, batch, np.ndarray)
......@@ -1185,13 +1135,13 @@ def extreme_ca_ca_distance_violations(
this_ca_mask = pred_atom_mask[..., :-1, 1]
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
next_ca_mask = pred_atom_mask[..., 1:, 1]
has_no_gap_mask = ((residue_index[..., 1:] - residue_index[..., :-1]) == 1.0)
has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
ca_ca_distance = torch.sqrt(
eps + torch.sum((this_ca_pos - next_ca_pos)**2, dim=-1)
eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1)
)
violations = (
(ca_ca_distance - residue_constants.ca_ca) > max_angstrom_tolerance
)
ca_ca_distance - residue_constants.ca_ca
) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
mean = masked_mean(mask, violations, -1)
return mean
......@@ -1207,13 +1157,13 @@ def compute_violation_metrics(
extreme_ca_ca_violations = extreme_ca_ca_distance_violations(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch["atom14_atom_exists"],
residue_index=batch["residue_index"]
residue_index=batch["residue_index"],
)
ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations
ret["violations_between_residue_bond"] = masked_mean(
batch["seq_mask"],
violations["between_residues"][
'connections_per_residue_violation_mask'
"connections_per_residue_violation_mask"
],
dim=-1,
)
......@@ -1221,7 +1171,7 @@ def compute_violation_metrics(
mask=batch["seq_mask"],
value=torch.max(
violations["between_residues"]["clashes_per_atom_clash_mask"],
dim=-1
dim=-1,
)[0],
dim=-1,
)
......@@ -1250,7 +1200,6 @@ def compute_violation_metrics_np(
atom14_pred_positions = to_tensor(atom14_pred_positions)
violations = tree_map(to_tensor, violations, np.ndarray)
out = compute_violation_metrics(batch, atom14_pred_positions, violations)
to_np = lambda x: np.array(x)
......@@ -1265,15 +1214,15 @@ def violation_loss(
) -> torch.Tensor:
num_atoms = torch.sum(atom14_atom_exists)
l_clash = torch.sum(
violations["between_residues"]["clashes_per_atom_loss_sum"] +
violations["within_residues"]["per_atom_loss_sum"]
violations["between_residues"]["clashes_per_atom_loss_sum"]
+ violations["within_residues"]["per_atom_loss_sum"]
)
l_clash = l_clash / (eps + num_atoms)
loss = (
violations["between_residues"]["bonds_c_n_loss_mean"] +
violations["between_residues"]["angles_ca_c_n_loss_mean"] +
violations["between_residues"]["angles_c_n_ca_loss_mean"] +
l_clash
violations["between_residues"]["bonds_c_n_loss_mean"]
+ violations["between_residues"]["angles_ca_c_n_loss_mean"]
+ violations["between_residues"]["angles_c_n_ca_loss_mean"]
+ l_clash
)
return loss
......@@ -1313,50 +1262,53 @@ def compute_renamed_ground_truth(
"""
pred_dists = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
atom14_pred_positions[..., None, :, None, :] -
atom14_pred_positions[..., None, :, None, :, :]
)**2,
atom14_pred_positions[..., None, :, None, :]
- atom14_pred_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
atom14_gt_positions = batch["atom14_gt_positions"]
gt_dists = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
atom14_gt_positions[..., None, :, None, :] -
atom14_gt_positions[..., None, :, None, :, :]
)**2,
atom14_gt_positions[..., None, :, None, :]
- atom14_gt_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
atom14_alt_gt_positions = batch["atom14_alt_gt_positions"]
alt_gt_dists = torch.sqrt(
eps +
torch.sum(
eps
+ torch.sum(
(
atom14_alt_gt_positions[..., None, :, None, :] -
atom14_alt_gt_positions[..., None, :, None, :, :]
)**2,
atom14_alt_gt_positions[..., None, :, None, :]
- atom14_alt_gt_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
lddt = torch.sqrt(eps + (pred_dists - gt_dists)**2)
alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists)**2)
lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2)
alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2)
atom14_gt_exists = batch["atom14_gt_exists"]
atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"]
mask = (
atom14_gt_exists[..., None, :, None] *
atom14_atom_is_ambiguous[..., None, :, None] *
atom14_gt_exists[..., None, :, None, :] *
(1. - atom14_atom_is_ambiguous[..., None, :, None, :])
atom14_gt_exists[..., None, :, None]
* atom14_atom_is_ambiguous[..., None, :, None]
* atom14_gt_exists[..., None, :, None, :]
* (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])
)
per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3))
......@@ -1366,16 +1318,16 @@ def compute_renamed_ground_truth(
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
renamed_atom14_gt_positions = (
(1. - alt_naming_is_better[..., None, None]) *
atom14_gt_positions +
alt_naming_is_better[..., None, None] *
atom14_alt_gt_positions
)
1.0 - alt_naming_is_better[..., None, None]
) * atom14_gt_positions + alt_naming_is_better[
..., None, None
] * atom14_alt_gt_positions
renamed_atom14_gt_mask = (
(1. - alt_naming_is_better[..., None]) * atom14_gt_exists +
alt_naming_is_better[..., None] * batch["atom14_alt_gt_exists"]
)
1.0 - alt_naming_is_better[..., None]
) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
"atom14_alt_gt_exists"
]
return {
"alt_naming_is_better": alt_naming_is_better,
......@@ -1400,8 +1352,7 @@ def experimentally_resolved_loss(
loss = torch.sum(loss, dim=-1)
loss = loss * (
(resolution >= min_resolution) &
(resolution <= max_resolution)
(resolution >= min_resolution) & (resolution <= max_resolution)
)
return loss
......@@ -1409,8 +1360,7 @@ def experimentally_resolved_loss(
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
errors = softmax_cross_entropy(
logits,
torch.nn.functional.one_hot(true_msa, num_classes=23)
logits, torch.nn.functional.one_hot(true_msa, num_classes=23)
)
# FP16-friendly averaging. Equivalent to:
......@@ -1450,81 +1400,74 @@ def compute_drmsd(structure_1, structure_2):
class AlphaFoldLoss(nn.Module):
""" Aggregation of the various losses described in the supplement """
"""Aggregation of the various losses described in the supplement"""
def __init__(self, config):
super(AlphaFoldLoss, self).__init__()
self.config = config
def forward(self, out, batch):
if("violation" not in out.keys() and self.config.violation.weight):
if "violation" not in out.keys() and self.config.violation.weight:
out["violation"] = find_structural_violations(
batch,
out["sm"]["positions"][-1],
**self.config.violation,
)
if("renamed_atom14_gt_positions" not in out.keys()):
batch.update(compute_renamed_ground_truth(
if "renamed_atom14_gt_positions" not in out.keys():
batch.update(
compute_renamed_ground_truth(
batch,
out["sm"]["positions"][-1],
))
)
)
loss_fns = {
"distogram":
lambda: distogram_loss(
"distogram": lambda: distogram_loss(
logits=out["distogram_logits"],
**{**batch,
**self.config.distogram},
**{**batch, **self.config.distogram},
),
"experimentally_resolved":
lambda: experimentally_resolved_loss(
"experimentally_resolved": lambda: experimentally_resolved_loss(
logits=out["experimentally_resolved_logits"],
**{**batch, **self.config.experimentally_resolved},
),
"fape":
lambda: fape_loss(
"fape": lambda: fape_loss(
out,
batch,
self.config.fape,
),
"lddt":
lambda: lddt_loss(
"lddt": lambda: lddt_loss(
logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.lddt},
),
"masked_msa":
lambda: masked_msa_loss(
"masked_msa": lambda: masked_msa_loss(
logits=out["masked_msa_logits"],
**{**batch,
**self.config.masked_msa},
**{**batch, **self.config.masked_msa},
),
"supervised_chi":
lambda: supervised_chi_loss(
"supervised_chi": lambda: supervised_chi_loss(
out["sm"]["angles"],
out["sm"]["unnormalized_angles"],
**{**batch, **self.config.supervised_chi},
),
"violation":
lambda: violation_loss(
"violation": lambda: violation_loss(
out["violation"],
**batch,
),
"tm":
lambda: tm_loss(
"tm": lambda: tm_loss(
logits=out["tm_logits"],
**{**batch, **out, **self.config.tm},
),
}
cum_loss = 0
for k,loss_fn in loss_fns.items():
for k, loss_fn in loss_fns.items():
weight = self.config[k].weight
if(weight):
#print(k)
if weight:
# print(k)
loss = loss_fn()
#print(weight * loss)
# print(weight * loss)
cum_loss = cum_loss + weight * loss
#print(cum_loss)
# print(cum_loss)
return cum_loss
......@@ -49,7 +49,7 @@ def dict_multimap(fn, dicts):
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if(type(v) is dict):
if type(v) is dict:
new_dict[k] = dict_multimap(fn, all_v)
else:
new_dict[k] = fn(all_v)
......@@ -83,7 +83,7 @@ def batched_gather(data, inds, dim=0, no_batch_dims=0):
def dict_map(fn, dic, leaf_type):
new_dict = {}
for k, v in dic.items():
if(type(v) is dict):
if type(v) is dict:
new_dict[k] = dict_map(fn, v, leaf_type)
else:
new_dict[k] = tree_map(fn, v, leaf_type)
......@@ -92,18 +92,19 @@ def dict_map(fn, dic, leaf_type):
def tree_map(fn, tree, leaf_type):
if(isinstance(tree, dict)):
if isinstance(tree, dict):
return dict_map(fn, tree, leaf_type)
elif(isinstance(tree, list)):
elif isinstance(tree, list):
return [tree_map(fn, x, leaf_type) for x in tree]
elif(isinstance(tree, tuple)):
elif isinstance(tree, tuple):
return tuple([tree_map(fn, x, leaf_type) for x in tree])
elif(isinstance(tree, leaf_type)):
elif isinstance(tree, leaf_type):
return fn(tree)
else:
print(type(tree))
raise ValueError("Not supported")
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
......@@ -137,19 +138,19 @@ def chunk_layer(
Returns:
The reassembled output of the layer on the inputs.
"""
if(not (len(inputs) > 0)):
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
def fetch_dims(tree):
shapes = []
tree_type = type(tree)
if(tree_type is dict):
if tree_type is dict:
for v in tree.values():
shapes.extend(fetch_dims(v))
elif(tree_type is list or tree_type is tuple):
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(fetch_dims(t))
elif(tree_type is torch.Tensor):
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
......@@ -161,7 +162,7 @@ def chunk_layer(
def prep_inputs(t):
# TODO: make this more memory efficient. This sucks
if(not sum(t.shape[:no_batch_dims]) == no_batch_dims):
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(*orig_batch_dims, *t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
return t
......@@ -172,40 +173,42 @@ def chunk_layer(
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = (
flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
no_chunks = flat_batch_dim // chunk_size + (
flat_batch_dim % chunk_size != 0
)
i = 0
out = None
for _ in range(no_chunks):
# Chunk the input
select_chunk = lambda t: t[i:i+chunk_size] if t.shape[0] != 1 else t
select_chunk = lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
chunks = tensor_tree_map(select_chunk, flattened_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
# Allocate space for the output
if(out is None):
if out is None:
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if(out_type is dict):
if out_type is dict:
def assign(d1, d2):
for k,v in d1.items():
if(type(v) is dict):
for k, v in d1.items():
if type(v) is dict:
assign(v, d2[k])
else:
v[i:i+chunk_size] = d2[k]
v[i : i + chunk_size] = d2[k]
assign(out, output_chunk)
elif(out_type is tuple):
elif out_type is tuple:
for x1, x2 in zip(out, output_chunk):
x1[i:i+chunk_size] = x2
elif(out_type is torch.Tensor):
out[i:i+chunk_size] = output_chunk
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
out[i : i + chunk_size] = output_chunk
else:
raise ValueError("Not supported")
......
......@@ -34,13 +34,11 @@ def import_alphafold():
If AlphaFold is installed using the provided setuptools script, this
is necessary to expose all of AlphaFold's precious insides
"""
if("alphafold" in sys.modules):
if "alphafold" in sys.modules:
return sys.modules["alphafold"]
module = importlib.import_module("alphafold")
# Forcefully import alphafold's submodules
submodules = pkgutil.walk_packages(
module.__path__, prefix=("alphafold.")
)
submodules = pkgutil.walk_packages(module.__path__, prefix=("alphafold."))
for submodule_info in submodules:
importlib.import_module(submodule_info.name)
sys.modules["alphafold"] = module
......@@ -57,12 +55,14 @@ def get_alphafold_config():
_param_path = "openfold/resources/params/params_model_1_ptm.npz"
_model = None
def get_global_pretrained_openfold():
global _model
if(_model is None):
if _model is None:
_model = AlphaFold(model_config("model_1_ptm").model)
_model = _model.eval()
if(not os.path.exists(_param_path)):
if not os.path.exists(_param_path):
raise FileNotFoundError(
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
......@@ -74,9 +74,11 @@ def get_global_pretrained_openfold():
_orig_weights = None
def _get_orig_weights():
global _orig_weights
if(_orig_weights is None):
if _orig_weights is None:
_orig_weights = np.load(_param_path)
return _orig_weights
......@@ -84,22 +86,19 @@ def _get_orig_weights():
def _remove_key_prefix(d, prefix):
for k, v in list(d.items()):
if(k.startswith(prefix)):
if k.startswith(prefix):
d.pop(k)
d[k[len(prefix):]] = v
d[k[len(prefix) :]] = v
def fetch_alphafold_module_weights(weight_path):
orig_weights = _get_orig_weights()
params = {
k:v for k,v in orig_weights.items()
if weight_path in k
}
if('/' in weight_path):
spl = weight_path.split('/')
params = {k: v for k, v in orig_weights.items() if weight_path in k}
if "/" in weight_path:
spl = weight_path.split("/")
spl = spl if len(spl[-1]) != 0 else spl[:-1]
module_name = spl[-1]
prefix = '/'.join(spl[:-1]) + '/'
prefix = "/".join(spl[:-1]) + "/"
_remove_key_prefix(params, prefix)
params = alphafold.model.utils.flat_params_to_haiku(params)
return params
import ml_collections as mlc
consts = mlc.ConfigDict({
consts = mlc.ConfigDict(
{
"batch_size": 2,
"n_res": 11,
"n_seq": 13,
......@@ -14,4 +15,5 @@ consts = mlc.ConfigDict({
"c_s": 384,
"c_t": 64,
"c_e": 64,
})
}
)
......@@ -18,7 +18,7 @@ from scipy.spatial.transform import Rotation
def random_template_feats(n_templ, n, batch_size=None):
b = []
if(batch_size is not None):
if batch_size is not None:
b.append(batch_size)
batch = {
"template_mask": np.random.randint(0, 2, (*b, n_templ)),
......@@ -28,28 +28,31 @@ def random_template_feats(n_templ, n, batch_size=None):
"template_all_atom_masks": np.random.randint(
0, 2, (*b, n_templ, n, 37)
),
"template_all_atom_positions": np.random.rand(
*b, n_templ, n, 37, 3
) * 10,
"template_all_atom_positions": np.random.rand(*b, n_templ, n, 37, 3)
* 10,
}
batch = {k:v.astype(np.float32) for k,v in batch.items()}
batch = {k: v.astype(np.float32) for k, v in batch.items()}
batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
return batch
def random_extra_msa_feats(n_extra, n, batch_size=None):
b = []
if(batch_size is not None):
if batch_size is not None:
b.append(batch_size)
batch = {
"extra_msa":
np.random.randint(0, 22, (*b, n_extra, n)).astype(np.int64),
"extra_has_deletion":
np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32),
"extra_deletion_value":
np.random.rand(*b, n_extra, n).astype(np.float32),
"extra_msa_mask":
np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32),
"extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype(
np.int64
),
"extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype(
np.float32
),
"extra_deletion_value": np.random.rand(*b, n_extra, n).astype(
np.float32
),
"extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype(
np.float32
),
}
return batch
......@@ -63,7 +66,9 @@ def random_affines_vector(dim):
for i in range(prod_dim):
affines[i, :4] = Rotation.random(random_state=42).as_quat()
affines[i, 4:] = np.random.rand(3,).astype(np.float32)
affines[i, 4:] = np.random.rand(
3,
).astype(np.float32)
return affines.reshape(*dim, 7)
......@@ -77,9 +82,10 @@ def random_affines_4x4(dim):
for i in range(prod_dim):
affines[i, :3, :3] = Rotation.random(random_state=42).as_matrix()
affines[i, :3, 3] = np.random.rand(3,).astype(np.float32)
affines[i, :3, 3] = np.random.rand(
3,
).astype(np.float32)
affines[:, 3, 3] = 1
return affines.reshape(*dim, 4, 4)
......@@ -84,9 +84,7 @@ class TestTemplateAngleEmbedder(unittest.TestCase):
x = torch.rand((batch_size, n_templ, n_res, template_angle_dim))
x = tae(x)
self.assertTrue(
x.shape == (batch_size, n_templ, n_res, c_m)
)
self.assertTrue(x.shape == (batch_size, n_templ, n_res, c_m))
class TestTemplatePairEmbedder(unittest.TestCase):
......@@ -105,11 +103,8 @@ class TestTemplatePairEmbedder(unittest.TestCase):
x = torch.rand((batch_size, n_templ, n_res, n_res, template_pair_dim))
x = tpe(x)
self.assertTrue(
x.shape == (batch_size, n_templ, n_res, n_res, c_t)
)
self.assertTrue(x.shape == (batch_size, n_templ, n_res, n_res, c_t))
if __name__ == "__main__":
unittest.main()
......@@ -24,7 +24,7 @@ from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -91,7 +91,8 @@ class TestEvoformerStack(unittest.TestCase):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
ei = alphafold.model.modules.EvoformerIteration(
c_e, config.model.global_config, is_extra_msa=False)
c_e, config.model.global_config, is_extra_msa=False
)
return ei(activations, masks, is_training=False)
f = hk.transform(run_ei)
......@@ -100,13 +101,13 @@ class TestEvoformerStack(unittest.TestCase):
n_seq = consts.n_seq
activations = {
'msa': np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32),
'pair': np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
"msa": np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32),
"pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
}
masks = {
'msa': np.random.randint(0, 2, (n_seq, n_res)).astype(np.float32),
'pair': np.random.randint(0, 2, (n_res, n_res)).astype(np.float32),
"msa": np.random.randint(0, 2, (n_seq, n_res)).astype(np.float32),
"pair": np.random.randint(0, 2, (n_res, n_res)).astype(np.float32),
}
params = compare_utils.fetch_alphafold_module_weights(
......@@ -115,9 +116,7 @@ class TestEvoformerStack(unittest.TestCase):
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
key = jax.random.PRNGKey(42)
out_gt = f.apply(
params, key, activations, masks
)
out_gt = f.apply(params, key, activations, masks)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt_msa = torch.as_tensor(np.array(out_gt["msa"]))
out_gt_pair = torch.as_tensor(np.array(out_gt["pair"]))
......@@ -134,9 +133,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
assert(torch.max(torch.abs(out_repro_msa - out_gt_msa) < consts.eps))
assert(torch.max(torch.abs(out_repro_pair - out_gt_pair) < consts.eps))
assert torch.max(torch.abs(out_repro_msa - out_gt_msa) < consts.eps)
assert torch.max(torch.abs(out_repro_pair - out_gt_pair) < consts.eps)
class TestExtraMSAStack(unittest.TestCase):
......@@ -180,8 +178,24 @@ class TestExtraMSAStack(unittest.TestCase):
m = torch.rand((batch_size, s_t, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
msa_mask = torch.randint(0, 2, size=(batch_size, s_t, n_res,))
pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res,))
msa_mask = torch.randint(
0,
2,
size=(
batch_size,
s_t,
n_res,
),
)
pair_mask = torch.randint(
0,
2,
size=(
batch_size,
n_res,
n_res,
),
)
shape_z_before = z.shape
......@@ -216,7 +230,7 @@ class TestMSATransition(unittest.TestCase):
msa_trans = alphafold.model.modules.Transition(
c_e.msa_transition,
config.model.global_config,
name="msa_transition"
name="msa_transition",
)
act = msa_trans(act=msa_act, mask=msa_mask)
return act
......@@ -227,25 +241,29 @@ class TestMSATransition(unittest.TestCase):
n_seq = consts.n_seq
msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
msa_mask = np.ones((n_seq, n_res)).astype(np.float32) # no mask here either
msa_mask = np.ones((n_seq, n_res)).astype(
np.float32
) # no mask here either
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
"msa_transition"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_transition"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask
).block_until_ready()
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].msa_transition(
out_repro = (
model.evoformer.blocks[0]
.msa_transition(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
).cpu()
)
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
......
......@@ -33,7 +33,7 @@ import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_4x4
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -95,9 +95,9 @@ class TestFeats(unittest.TestCase):
aatype = np.random.randint(0, 22, (n_templ, n_res)).astype(np.int64)
all_atom_pos = np.random.rand(n_templ, n_res, 37, 3).astype(np.float32)
all_atom_mask = np.random.randint(
0, 2, (n_templ, n_res, 37)
).astype(np.float32)
all_atom_mask = np.random.randint(0, 2, (n_templ, n_res, 37)).astype(
np.float32
)
out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask)
out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt)
......@@ -114,16 +114,17 @@ class TestFeats(unittest.TestCase):
# This function is extremely sensitive to floating point imprecisions,
# so it is given much greater latitude in comparison tests.
self.assertTrue(
torch.mean(
torch.abs(out_gt["torsion_angles_sin_cos"] - tasc)
) < 0.01
torch.mean(torch.abs(out_gt["torsion_angles_sin_cos"] - tasc))
< 0.01
)
self.assertTrue(
torch.mean(torch.abs(out_gt["alt_torsion_angles_sin_cos"] - atasc))
< 0.01
)
self.assertTrue(
torch.mean(
torch.abs(out_gt["alt_torsion_angles_sin_cos"] - atasc)
) < 0.01
torch.max(torch.abs(out_gt["torsion_angles_mask"] - tam))
< consts.eps
)
self.assertTrue(torch.max(torch.abs(out_gt["torsion_angles_mask"] - tam)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_atom37_to_frames_compare(self):
......@@ -138,15 +139,17 @@ class TestFeats(unittest.TestCase):
batch = {
"aatype": np.random.randint(0, 21, (n_res,)),
"all_atom_positions":
np.random.rand(n_res, 37, 3).astype(np.float32),
"all_atom_mask":
np.random.randint(0, 2, (n_res, 37)).astype(np.float32),
"all_atom_positions": np.random.rand(n_res, 37, 3).astype(
np.float32
),
"all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
np.float32
),
}
out_gt = f.apply({}, None, **batch)
to_tensor = lambda t: torch.tensor(np.array(t))
out_gt = {k:to_tensor(v) for k,v in out_gt.items()}
out_gt = {k: to_tensor(v) for k, v in out_gt.items()}
def flat12_to_4x4(flat12):
rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
......@@ -172,7 +175,7 @@ class TestFeats(unittest.TestCase):
out_repro = data_transforms.atom37_to_frames(batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
for k,v in out_gt.items():
for k, v in out_gt.items():
self.assertTrue(
torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
)
......@@ -201,9 +204,7 @@ class TestFeats(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_torsion_angles_to_frames_compare(self):
def run_torsion_angles_to_frames(
aatype,
backb_to_global,
torsion_angles_sin_cos
aatype, backb_to_global, torsion_angles_sin_cos
):
return alphafold.model.all_atom.torsion_angles_to_frames(
aatype,
......@@ -223,9 +224,7 @@ class TestFeats(unittest.TestCase):
torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
out_gt = f.apply(
{}, None, aatype, rigids, torsion_angles_sin_cos
)
out_gt = f.apply({}, None, aatype, rigids, torsion_angles_sin_cos)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
......@@ -237,9 +236,7 @@ class TestFeats(unittest.TestCase):
)
# Convert the Rigids to 4x4 transformation tensors
rots_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot)
)
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot))
trans_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt.trans)
)
......@@ -296,9 +293,7 @@ class TestFeats(unittest.TestCase):
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float())
out_gt = f.apply(
{}, None, aatype, rigids
)
out_gt = f.apply({}, None, aatype, rigids)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = torch.stack(
[torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
......
......@@ -30,7 +30,8 @@ class TestImportWeights(unittest.TestCase):
model = AlphaFold(c.model)
import_jax_weights_(
model, npz_path,
model,
npz_path,
)
data = np.load(npz_path)
......@@ -38,23 +39,34 @@ class TestImportWeights(unittest.TestCase):
test_pairs = [
# Normal linear weight
(torch.as_tensor(
data[prefix + "structure_module/initial_projection//weights"]
(
torch.as_tensor(
data[
prefix + "structure_module/initial_projection//weights"
]
).transpose(-1, -2),
model.structure_module.linear_in.weight),
model.structure_module.linear_in.weight,
),
# Normal layer norm param
(torch.as_tensor(
(
torch.as_tensor(
data[prefix + "evoformer/prev_pair_norm//offset"],
),
model.recycling_embedder.layer_norm_z.bias),
model.recycling_embedder.layer_norm_z.bias,
),
# From a stack
(torch.as_tensor(data[
prefix + (
(
torch.as_tensor(
data[
prefix
+ (
"evoformer/evoformer_iteration/outer_product_mean/"
"left_projection//weights"
)
][1].transpose(-1, -2)),
model.evoformer.blocks[1].outer_product_mean.linear_1.weight,),
][1].transpose(-1, -2)
),
model.evoformer.blocks[1].outer_product_mean.linear_1.weight,
),
]
for w_alpha, w_repro in test_pairs:
......
......@@ -49,7 +49,7 @@ import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_vector, random_affines_4x4
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -99,7 +99,14 @@ class TestLoss(unittest.TestCase):
pred_pos = torch.rand(bs, n, 14, 3)
pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
residue_index = torch.arange(n).unsqueeze(0)
aatype = torch.randint(0, 22, (bs, n,))
aatype = torch.randint(
0,
22,
(
bs,
n,
),
)
between_residue_bond_loss(
pred_pos,
......@@ -122,14 +129,13 @@ class TestLoss(unittest.TestCase):
n_res = consts.n_res
pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
pred_atom_mask = np.random.randint(
0, 2, (n_res, 14)
).astype(np.float32)
pred_atom_mask = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
residue_index = np.arange(n_res)
aatype = np.random.randint(0, 22, (n_res,))
out_gt = f.apply(
{}, None,
{},
None,
pred_pos,
pred_atom_mask,
residue_index,
......@@ -151,7 +157,6 @@ class TestLoss(unittest.TestCase):
torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
)
def test_run_between_residue_clash_loss(self):
bs = consts.batch_size
n = consts.n_res
......@@ -185,10 +190,13 @@ class TestLoss(unittest.TestCase):
pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
atom_radius = np.random.rand(n_res, 14).astype(np.float32)
res_ind = np.arange(n_res,)
res_ind = np.arange(
n_res,
)
out_gt = f.apply(
{}, None,
{},
None,
pred_pos,
atom_exists,
atom_radius,
......@@ -242,7 +250,6 @@ class TestLoss(unittest.TestCase):
os.chdir(cwd)
return loss
f = hk.transform(run_fsv)
n_res = consts.n_res
......@@ -251,30 +258,25 @@ class TestLoss(unittest.TestCase):
"atom14_atom_exists": np.random.randint(0, 2, (n_res, 14)),
"residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 20, (n_res,)),
"residx_atom14_to_atom37":
np.random.randint(0, 37, (n_res, 14)).astype(np.int64),
"residx_atom14_to_atom37": np.random.randint(
0, 37, (n_res, 14)
).astype(np.int64),
}
pred_pos = np.random.rand(n_res, 14, 3)
config = mlc.ConfigDict({
config = mlc.ConfigDict(
{
"clash_overlap_tolerance": 1.5,
"violation_tolerance_factor": 12.0,
})
out_gt = f.apply(
{}, None,
batch,
pred_pos,
config
}
)
out_gt = f.apply({}, None, batch, pred_pos, config)
out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
batch = tree_map(
lambda x: torch.tensor(x).cuda(), batch, np.ndarray
)
batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
out_repro = find_structural_violations(
batch,
torch.tensor(pred_pos).cuda(),
......@@ -284,7 +286,7 @@ class TestLoss(unittest.TestCase):
def compare(out):
gt, repro = out
assert(torch.max(torch.abs(gt - repro)) < consts.eps)
assert torch.max(torch.abs(gt - repro)) < consts.eps
dict_multimap(compare, [out_gt, out_repro])
......@@ -304,12 +306,15 @@ class TestLoss(unittest.TestCase):
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"aatype": np.random.randint(0, 20, (n_res,)),
"atom14_gt_positions": np.random.rand(n_res, 14, 3),
"atom14_gt_exists":
np.random.randint(0, 2, (n_res, 14)).astype(np.float32),
"all_atom_mask":
np.random.randint(0, 2, (n_res, 37)).astype(np.float32),
"all_atom_positions":
np.random.rand(n_res, 37, 3).astype(np.float32),
"atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
np.float32
),
"all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
np.float32
),
"all_atom_positions": np.random.rand(n_res, 37, 3).astype(
np.float32
),
}
def _build_extra_feats_np():
......@@ -325,9 +330,7 @@ class TestLoss(unittest.TestCase):
out_gt = f.apply({}, None, batch, atom14_pred_pos)
out_gt = jax.tree_map(lambda x: torch.tensor(np.array(x)), out_gt)
batch = tree_map(
lambda x: torch.tensor(x).cuda(), batch, np.ndarray
)
batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
out_repro = compute_renamed_ground_truth(batch, atom14_pred_pos)
......@@ -358,19 +361,16 @@ class TestLoss(unittest.TestCase):
batch = {
"true_msa": np.random.randint(0, 21, (n_res, n_seq)),
"bert_mask":
np.random.randint(0, 2, (n_res, n_seq)).astype(np.float32),
"bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype(
np.float32
),
}
out_gt = f.apply({}, None, value, batch)["loss"]
out_gt = torch.tensor(np.array(out_gt))
value = tree_map(
lambda x: torch.tensor(x).cuda(), value, np.ndarray
)
batch = tree_map(
lambda x: torch.tensor(x).cuda(), batch, np.ndarray
)
value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
with torch.no_grad():
out_repro = masked_msa_loss(
......@@ -385,6 +385,7 @@ class TestLoss(unittest.TestCase):
def test_distogram_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_distogram = config.model.heads.distogram
def run_distogram_loss(value, batch):
dist_head = alphafold.model.modules.DistogramHead(
c_distogram, config.model.global_config
......@@ -396,33 +397,27 @@ class TestLoss(unittest.TestCase):
n_res = consts.n_res
value = {
"logits": np.random.rand(
n_res,
n_res,
c_distogram.num_bins
).astype(np.float32),
"logits": np.random.rand(n_res, n_res, c_distogram.num_bins).astype(
np.float32
),
"bin_edges": np.linspace(
c_distogram.first_break,
c_distogram.last_break,
c_distogram.num_bins,
)
),
}
batch = {
"pseudo_beta": np.random.rand(n_res, 3).astype(np.float32),
"pseudo_beta_mask": np.random.randint(0, 2, (n_res,))
"pseudo_beta_mask": np.random.randint(0, 2, (n_res,)),
}
out_gt = f.apply({}, None, value, batch)["loss"]
out_gt = torch.tensor(np.array(out_gt))
value = tree_map(
lambda x: torch.tensor(x).cuda(), value, np.ndarray
)
value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
batch = tree_map(
lambda x: torch.tensor(x).cuda(), batch, np.ndarray
)
batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
with torch.no_grad():
out_repro = distogram_loss(
......@@ -441,6 +436,7 @@ class TestLoss(unittest.TestCase):
def test_experimentally_resolved_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_experimentally_resolved = config.model.heads.experimentally_resolved
def run_experimentally_resolved_loss(value, batch):
er_head = alphafold.model.modules.ExperimentallyResolvedHead(
c_experimentally_resolved, config.model.global_config
......@@ -458,19 +454,15 @@ class TestLoss(unittest.TestCase):
batch = {
"all_atom_mask": np.random.randint(0, 2, (n_res, 37)),
"atom37_atom_exists": np.random.randint(0, 2, (n_res, 37)),
"resolution": np.array(1.0)
"resolution": np.array(1.0),
}
out_gt = f.apply({}, None, value, batch)["loss"]
out_gt = torch.tensor(np.array(out_gt))
value = tree_map(
lambda x: torch.tensor(x).cuda(), value, np.ndarray
)
value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
batch = tree_map(
lambda x: torch.tensor(x).cuda(), batch, np.ndarray
)
batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
with torch.no_grad():
out_repro = experimentally_resolved_loss(
......@@ -488,9 +480,10 @@ class TestLoss(unittest.TestCase):
def test_supervised_chi_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_chi_loss = config.model.heads.structure_module
def run_supervised_chi_loss(value, batch):
ret = {
"loss": jax.numpy.array(0.),
"loss": jax.numpy.array(0.0),
}
alphafold.model.folding.supervised_chi_loss(
ret, batch, value, c_chi_loss
......@@ -503,10 +496,12 @@ class TestLoss(unittest.TestCase):
value = {
"sidechains": {
"angles_sin_cos":
np.random.rand(8, n_res, 7, 2).astype(np.float32),
"unnormalized_angles_sin_cos":
np.random.rand(8, n_res, 7, 2).astype(np.float32),
"angles_sin_cos": np.random.rand(8, n_res, 7, 2).astype(
np.float32
),
"unnormalized_angles_sin_cos": np.random.rand(
8, n_res, 7, 2
).astype(np.float32),
}
}
......@@ -519,13 +514,9 @@ class TestLoss(unittest.TestCase):
out_gt = f.apply({}, None, value, batch)
out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
value = tree_map(
lambda x: torch.tensor(x).cuda(), value, np.ndarray
)
value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
batch = tree_map(
lambda x: torch.tensor(x).cuda(), batch, np.ndarray
)
batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
batch["chi_angles_sin_cos"] = torch.stack(
[
......@@ -539,7 +530,7 @@ class TestLoss(unittest.TestCase):
out_repro = supervised_chi_loss(
chi_weight=c_chi_loss.chi_weight,
angle_norm_weight=c_chi_loss.angle_norm_weight,
**{**batch, **value["sidechains"]}
**{**batch, **value["sidechains"]},
)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
......@@ -550,20 +541,24 @@ class TestLoss(unittest.TestCase):
def test_violation_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_viol = config.model.heads.structure_module
def run_viol_loss(batch, atom14_pred_pos):
ret = {
"loss": np.array(0.).astype(np.float32),
"loss": np.array(0.0).astype(np.float32),
}
value = {}
value["violations"] = (
alphafold.model.folding.find_structural_violations(
value[
"violations"
] = alphafold.model.folding.find_structural_violations(
batch,
atom14_pred_pos,
c_viol,
)
)
alphafold.model.folding.structural_violation_loss(
ret, batch, value, c_viol,
ret,
batch,
value,
c_viol,
)
return ret["loss"]
......@@ -577,16 +572,14 @@ class TestLoss(unittest.TestCase):
"aatype": np.random.randint(0, 21, (n_res,)),
}
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k:np.array(v) for k,v in batch.items()}
batch = {k: np.array(v) for k, v in batch.items()}
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
out_gt = f.apply({}, None, batch, atom14_pred_pos)
out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
batch = tree_map(
lambda n: torch.tensor(n).cuda(), batch, np.ndarray
)
batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)
atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
batch = data_transforms.make_atom14_masks(batch)
......@@ -603,6 +596,7 @@ class TestLoss(unittest.TestCase):
def test_lddt_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_plddt = config.model.heads.predicted_lddt
def run_plddt_loss(value, batch):
head = alphafold.model.modules.PredictedLDDTHead(
c_plddt, config.model.global_config
......@@ -615,21 +609,25 @@ class TestLoss(unittest.TestCase):
value = {
"predicted_lddt": {
"logits":
np.random.rand(n_res, c_plddt.num_bins).astype(np.float32),
"logits": np.random.rand(n_res, c_plddt.num_bins).astype(
np.float32
),
},
"structure_module": {
"final_atom_positions":
np.random.rand(n_res, 37, 3).astype(np.float32),
}
"final_atom_positions": np.random.rand(n_res, 37, 3).astype(
np.float32
),
},
}
batch = {
"all_atom_positions":
np.random.rand(n_res, 37, 3).astype(np.float32),
"all_atom_mask":
np.random.randint(0, 2, (n_res, 37)).astype(np.float32),
"resolution": np.array(1.).astype(np.float32),
"all_atom_positions": np.random.rand(n_res, 37, 3).astype(
np.float32
),
"all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
np.float32
),
"resolution": np.array(1.0).astype(np.float32),
}
out_gt = f.apply({}, None, value, batch)
......@@ -652,9 +650,10 @@ class TestLoss(unittest.TestCase):
def test_backbone_loss(self):
config = compare_utils.get_alphafold_config()
c_sm = config.model.heads.structure_module
def run_bb_loss(batch, value):
ret = {
"loss": np.array(0.),
"loss": np.array(0.0),
}
alphafold.model.folding.backbone_loss(ret, batch, value, c_sm)
return ret["loss"]
......@@ -665,13 +664,19 @@ class TestLoss(unittest.TestCase):
batch = {
"backbone_affine_tensor": random_affines_vector((n_res,)),
"backbone_affine_mask":
np.random.randint(0, 2, (n_res,)).astype(np.float32),
"use_clamped_fape": np.array(0.),
"backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
np.float32
),
"use_clamped_fape": np.array(0.0),
}
value = {
"traj": random_affines_vector((c_sm.num_layer, n_res,)),
"traj": random_affines_vector(
(
c_sm.num_layer,
n_res,
)
),
}
out_gt = f.apply({}, None, batch, value)
......@@ -695,6 +700,7 @@ class TestLoss(unittest.TestCase):
def test_sidechain_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_sm = config.model.heads.structure_module
def run_sidechain_loss(batch, value, atom14_pred_positions):
batch = {
**batch,
......@@ -702,22 +708,24 @@ class TestLoss(unittest.TestCase):
batch["aatype"],
batch["all_atom_positions"],
batch["all_atom_mask"],
)
),
}
v = {}
v["sidechains"] = {}
v["sidechains"]["frames"] = (
alphafold.model.r3.rigids_from_tensor4x4(
v["sidechains"][
"frames"
] = alphafold.model.r3.rigids_from_tensor4x4(
value["sidechains"]["frames"]
)
)
v["sidechains"]["atom_pos"] = alphafold.model.r3.vecs_from_tensor(
value["sidechains"]["atom_pos"]
)
v.update(alphafold.model.folding.compute_renamed_ground_truth(
v.update(
alphafold.model.folding.compute_renamed_ground_truth(
batch,
atom14_pred_positions,
))
)
)
value = v
ret = alphafold.model.folding.sidechain_loss(batch, value, c_sm)
......@@ -730,14 +738,18 @@ class TestLoss(unittest.TestCase):
batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"aatype": np.random.randint(0, 20, (n_res,)),
"atom14_gt_positions":
np.random.rand(n_res, 14, 3).astype(np.float32),
"atom14_gt_exists":
np.random.randint(0, 2, (n_res, 14)).astype(np.float32),
"all_atom_positions":
np.random.rand(n_res, 37, 3).astype(np.float32),
"all_atom_mask":
np.random.randint(0, 2, (n_res, 37)).astype(np.float32),
"atom14_gt_positions": np.random.rand(n_res, 14, 3).astype(
np.float32
),
"atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
np.float32
),
"all_atom_positions": np.random.rand(n_res, 37, 3).astype(
np.float32
),
"all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
np.float32
),
}
def _build_extra_feats_np():
......@@ -751,10 +763,9 @@ class TestLoss(unittest.TestCase):
value = {
"sidechains": {
"frames": random_affines_4x4((c_sm.num_layer, n_res, 8)),
"atom_pos":
np.random.rand(
c_sm.num_layer, n_res, 14, 3
).astype(np.float32),
"atom_pos": np.random.rand(c_sm.num_layer, n_res, 14, 3).astype(
np.float32
),
}
}
......@@ -784,6 +795,7 @@ class TestLoss(unittest.TestCase):
def test_tm_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_tm = config.model.heads.predicted_aligned_error
def run_tm_loss(representations, batch, value):
head = alphafold.model.modules.PredictedAlignedErrorHead(
c_tm, config.model.global_config
......@@ -798,15 +810,15 @@ class TestLoss(unittest.TestCase):
n_res = consts.n_res
representations = {
"pair":
np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
"pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
}
batch = {
"backbone_affine_tensor": random_affines_vector((n_res,)),
"backbone_affine_mask":
np.random.randint(0, 2, (n_res,)).astype(np.float32),
"resolution": np.array(1.).astype(np.float32),
"backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
np.float32
),
"resolution": np.array(1.0).astype(np.float32),
}
value = {
......@@ -827,11 +839,11 @@ class TestLoss(unittest.TestCase):
batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray)
batch["backbone_affine_tensor"] = (
affine_vector_to_4x4(batch["backbone_affine_tensor"])
batch["backbone_affine_tensor"] = affine_vector_to_4x4(
batch["backbone_affine_tensor"]
)
value["structure_module"]["final_affines"] = (
affine_vector_to_4x4(value["structure_module"]["final_affines"])
value["structure_module"]["final_affines"] = affine_vector_to_4x4(
value["structure_module"]["final_affines"]
)
model = compare_utils.get_global_pretrained_openfold()
......
......@@ -29,7 +29,7 @@ from tests.data_utils import (
random_extra_msa_feats,
)
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -51,28 +51,21 @@ class TestModel(unittest.TestCase):
model = AlphaFold(c)
batch = {}
tf = torch.randint(
c.input_embedder.tf_dim - 1, size=(n_res,)
)
tf = torch.randint(c.input_embedder.tf_dim - 1, size=(n_res,))
batch["target_feat"] = nn.functional.one_hot(
tf, c.input_embedder.tf_dim).float()
tf, c.input_embedder.tf_dim
).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand(
(n_seq, n_res, c.input_embedder.msa_dim)
)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res)
batch.update({k:torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(
n_extra_seq, n_res
)
batch.update({k:torch.tensor(v) for k, v in extra_feats.items()})
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint(
low=0, high=2, size=(n_seq, n_res)
).float()
batch["seq_mask"] = torch.randint(
low=0, high=2, size=(n_res,)
).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(make_atom14_masks(batch))
add_recycling_dims = lambda t: (
......@@ -89,12 +82,14 @@ class TestModel(unittest.TestCase):
config = compare_utils.get_alphafold_config()
model = alphafold.model.modules.AlphaFold(config.model)
return model(
batch=batch, is_training=False, return_representations=True,
batch=batch,
is_training=False,
return_representations=True,
)
f = hk.transform(run_alphafold)
params = compare_utils.fetch_alphafold_module_weights('')
params = compare_utils.fetch_alphafold_module_weights("")
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp)
......@@ -108,13 +103,13 @@ class TestModel(unittest.TestCase):
out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch)
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
batch = {
k:torch.as_tensor(v).cuda() for k,v in batch.items()
}
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch["aatype"] = batch["aatype"].long()
batch["template_aatype"] = batch["template_aatype"].long()
batch["extra_msa"] = batch["extra_msa"].long()
batch["residx_atom37_to_atom14"] = batch["residx_atom37_to_atom14"].long()
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
......@@ -130,4 +125,3 @@ class TestModel(unittest.TestCase):
out_repro = out_repro.squeeze(0)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 1e-3))
......@@ -24,7 +24,7 @@ from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -39,7 +39,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
c_z = consts.c_z
c = 52
no_heads = 4
chunk_size=None
chunk_size = None
mrapb = MSARowAttentionWithPairBias(c_m, c_z, c, no_heads, chunk_size)
......@@ -58,12 +58,9 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
msa_row = alphafold.model.modules.MSARowAttentionWithPairBias(
c_e.msa_row_attention_with_pair_bias,
config.model.global_config
)
act = msa_row(
msa_act=msa_act, msa_mask=msa_mask, pair_act=pair_act
c_e.msa_row_attention_with_pair_bias, config.model.global_config
)
act = msa_row(msa_act=msa_act, msa_mask=msa_mask, pair_act=pair_act)
return act
f = hk.transform(run_msa_row_att)
......@@ -72,15 +69,15 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
n_seq = consts.n_seq
msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
msa_mask = np.random.randint(
low=0, high=2, size=(n_seq, n_res)
).astype(np.float32)
msa_mask = np.random.randint(low=0, high=2, size=(n_seq, n_res)).astype(
np.float32
)
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
"msa_row_attention"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_row_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
......@@ -90,11 +87,15 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].msa_att_row(
out_repro = (
model.evoformer.blocks[0]
.msa_att_row(
torch.as_tensor(msa_act).cuda(),
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(msa_mask).cuda(),
).cpu()
)
.cpu()
)
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
......@@ -124,12 +125,9 @@ class TestMSAColumnAttention(unittest.TestCase):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
msa_col = alphafold.model.modules.MSAColumnAttention(
c_e.msa_column_attention,
config.model.global_config
)
act = msa_col(
msa_act=msa_act, msa_mask=msa_mask
c_e.msa_column_attention, config.model.global_config
)
act = msa_col(msa_act=msa_act, msa_mask=msa_mask)
return act
f = hk.transform(run_msa_col_att)
......@@ -138,27 +136,29 @@ class TestMSAColumnAttention(unittest.TestCase):
n_seq = consts.n_seq
msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
msa_mask = np.random.randint(
low=0, high=2, size=(n_seq, n_res)
).astype(np.float32)
msa_mask = np.random.randint(low=0, high=2, size=(n_seq, n_res)).astype(
np.float32
)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
"msa_column_attention"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_column_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask
).block_until_ready()
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].msa_att_col(
out_repro = (
model.evoformer.blocks[0]
.msa_att_col(
torch.as_tensor(msa_act).cuda(),
torch.as_tensor(msa_mask).cuda(),
).cpu()
)
.cpu()
)
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
......@@ -190,7 +190,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
msa_col = alphafold.model.modules.MSAColumnGlobalAttention(
c_e.msa_column_attention,
config.model.global_config,
name="msa_column_global_attention"
name="msa_column_global_attention",
)
act = msa_col(msa_act=msa_act, msa_mask=msa_mask)
return act
......@@ -206,21 +206,23 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/" +
"msa_column_global_attention"
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+ "msa_column_global_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask
).block_until_ready()
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.extra_msa_stack.stack.blocks[0].msa_att_col(
out_repro = (
model.extra_msa_stack.stack.blocks[0]
.msa_att_col(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
).cpu()
)
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
......
......@@ -19,7 +19,8 @@ from openfold.model.outer_product_mean import OuterProductMean
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -40,7 +41,8 @@ class TestOuterProductMean(unittest.TestCase):
m = opm(m, mask)
self.assertTrue(
m.shape == (consts.batch_size, consts.n_res, consts.n_res, consts.c_z)
m.shape
== (consts.batch_size, consts.n_res, consts.n_res, consts.c_z)
)
@compare_utils.skip_unless_alphafold_installed()
......@@ -63,27 +65,29 @@ class TestOuterProductMean(unittest.TestCase):
c_m = consts.c_m
msa_act = np.random.rand(n_seq, n_res, c_m).astype(np.float32) * 100
msa_mask = np.random.randint(
low=0, high=2, size=(n_seq, n_res)
).astype(np.float32)
msa_mask = np.random.randint(low=0, high=2, size=(n_seq, n_res)).astype(
np.float32
)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/" +
"evoformer_iteration/outer_product_mean"
"alphafold/alphafold_iteration/evoformer/"
+ "evoformer_iteration/outer_product_mean"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask
).block_until_ready()
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].outer_product_mean(
out_repro = (
model.evoformer.blocks[0]
.outer_product_mean(
torch.as_tensor(msa_act).cuda(),
mask=torch.as_tensor(msa_mask).cuda(),
).cpu()
)
.cpu()
)
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
......
......@@ -20,7 +20,7 @@ from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -52,7 +52,7 @@ class TestPairTransition(unittest.TestCase):
pt = alphafold.model.modules.Transition(
c_e.pair_transition,
config.model.global_config,
name="pair_transition"
name="pair_transition",
)
act = pt(act=pair_act, mask=pair_mask)
return act
......@@ -66,26 +66,26 @@ class TestPairTransition(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
"pair_transition"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "pair_transition"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, pair_act, pair_mask
).block_until_ready()
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].pair_transition(
out_repro = (
model.evoformer.blocks[0]
.pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
).cpu()
)
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
if __name__ == "__main__":
unittest.main()
......@@ -39,7 +39,7 @@ from tests.data_utils import (
random_affines_4x4,
)
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -89,9 +89,7 @@ class TestStructureModule(unittest.TestCase):
out = sm(s, z, f)
self.assertTrue(
out["frames"].shape == (no_layers, batch_size, n, 4, 4)
)
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
)
......@@ -121,14 +119,13 @@ class TestStructureModule(unittest.TestCase):
config = compare_utils.get_alphafold_config()
c_sm = config.model.heads.structure_module
c_global = config.model.global_config
def run_sm(representations, batch):
sm = alphafold.model.folding.StructureModule(c_sm, c_global)
representations = {
k:jax.lax.stop_gradient(v) for k,v in representations.items()
}
batch = {
k:jax.lax.stop_gradient(v) for k,v in batch.items()
k: jax.lax.stop_gradient(v) for k, v in representations.items()
}
batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()}
return sm(representations, batch, is_training=False)
f = hk.transform(run_sm)
......@@ -136,26 +133,21 @@ class TestStructureModule(unittest.TestCase):
n_res = 200
representations = {
'single': np.random.rand(n_res, consts.c_s).astype(np.float32),
'pair':
np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
"single": np.random.rand(n_res, consts.c_s).astype(np.float32),
"pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
}
batch = {
'seq_mask': np.random.randint(0, 2, (n_res,)).astype(np.float32),
'aatype': np.random.randint(0, 21, (n_res,)),
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"aatype": np.random.randint(0, 21, (n_res,)),
}
batch['atom14_atom_exists'] = np.take(
restype_atom14_mask,
batch['aatype'],
axis=0
batch["atom14_atom_exists"] = np.take(
restype_atom14_mask, batch["aatype"], axis=0
)
batch['atom37_atom_exists'] = np.take(
restype_atom37_mask,
batch['aatype'],
axis=0
batch["atom37_atom_exists"] = np.take(
restype_atom37_mask, batch["aatype"], axis=0
)
batch.update(make_atom14_masks_np(batch))
......@@ -165,9 +157,7 @@ class TestStructureModule(unittest.TestCase):
)
key = jax.random.PRNGKey(42)
out_gt = f.apply(
params, key, representations, batch
)
out_gt = f.apply(params, key, representations, batch)
out_gt = torch.as_tensor(
np.array(out_gt["final_atom14_positions"].block_until_ready())
)
......@@ -246,7 +236,7 @@ class TestInvariantPointAttention(unittest.TestCase):
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
affine=affine
affine=affine,
)
return attn
......@@ -263,15 +253,13 @@ class TestInvariantPointAttention(unittest.TestCase):
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids)
transformations = T.from_4x4(
torch.as_tensor(affines).float().cuda()
)
transformations = T.from_4x4(torch.as_tensor(affines).float().cuda())
sample_affine = quats
ipa_params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/structure_module/" +
"fold_iteration/invariant_point_attention"
"alphafold/alphafold_iteration/structure_module/"
+ "fold_iteration/invariant_point_attention"
)
out_gt = f.apply(
......
......@@ -24,7 +24,7 @@ import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_template_feats
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -67,8 +67,8 @@ class TestTemplatePairStack(unittest.TestCase):
n_res = consts.n_res
blocks_per_ckpt = None
chunk_size = 4
inf=1e7
eps=1e-7
inf = 1e7
eps = 1e-7
tpe = TemplatePairStack(
c_t,
......@@ -100,7 +100,7 @@ class TestTemplatePairStack(unittest.TestCase):
tps = alphafold.model.modules.TemplatePairStack(
c_ee.template.template_pair_stack,
config.model.global_config,
name="template_pair_stack"
name="template_pair_stack",
)
act = tps(pair_act, pair_mask, is_training=False)
ln = hk.LayerNorm([-1], True, True, name="output_layer_norm")
......@@ -117,13 +117,15 @@ class TestTemplatePairStack(unittest.TestCase):
).astype(np.float32)
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/" +
"single_template_embedding/template_pair_stack"
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_pair_stack"
)
params.update(
compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/output_layer_norm"
)
)
params.update(compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/" +
"single_template_embedding/output_layer_norm"
))
out_gt = f.apply(
params, jax.random.PRNGKey(42), pair_act, pair_mask
......@@ -147,7 +149,7 @@ class Template(unittest.TestCase):
config = compare_utils.get_alphafold_config()
te = alphafold.model.modules.TemplateEmbedding(
config.model.embeddings_and_evoformer.template,
config.model.global_config
config.model.global_config,
)
act = te(pair, batch, mask_2d, is_training=False)
return act
......@@ -176,7 +178,7 @@ class Template(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.embed_templates(
{k:torch.as_tensor(v).cuda() for k,v in batch.items()},
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
......
......@@ -21,7 +21,7 @@ from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -34,12 +34,7 @@ class TestTriangularAttention(unittest.TestCase):
no_heads = 4
starting = True
tan = TriangleAttention(
c_z,
c,
no_heads,
starting
)
tan = TriangleAttention(c_z, c, no_heads, starting)
batch_size = consts.batch_size
n_res = consts.n_res
......@@ -53,16 +48,18 @@ class TestTriangularAttention(unittest.TestCase):
def _tri_att_compare(self, starting=False):
name = (
"triangle_attention_" +
("starting" if starting else "ending") +
"_node"
"triangle_attention_"
+ ("starting" if starting else "ending")
+ "_node"
)
def run_tri_att(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
tri_att = alphafold.model.modules.TriangleAttention(
c_e.triangle_attention_starting_node if starting else
c_e.triangle_attention_ending_node,
c_e.triangle_attention_starting_node
if starting
else c_e.triangle_attention_ending_node,
config.model.global_config,
name=name,
)
......@@ -78,20 +75,19 @@ class TestTriangularAttention(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
name
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, pair_act, pair_mask
).block_until_ready()
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].tri_att_start if starting else
model.evoformer.blocks[0].tri_att_end
model.evoformer.blocks[0].tri_att_start
if starting
else model.evoformer.blocks[0].tri_att_end
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......
......@@ -20,7 +20,7 @@ from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
......@@ -50,16 +50,17 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
self.assertTrue(shape_before == shape_after)
def _tri_mul_compare(self, incoming=False):
name = (
"triangle_multiplication_" +
("incoming" if incoming else "outgoing")
name = "triangle_multiplication_" + (
"incoming" if incoming else "outgoing"
)
def run_tri_mul(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
tri_mul = alphafold.model.modules.TriangleMultiplication(
c_e.triangle_multiplication_incoming if incoming else
c_e.triangle_multiplication_outgoing,
c_e.triangle_multiplication_incoming
if incoming
else c_e.triangle_multiplication_outgoing,
config.model.global_config,
name=name,
)
......@@ -76,20 +77,19 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
name
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, pair_act, pair_mask
).block_until_ready()
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].tri_mul_in if incoming else
model.evoformer.blocks[0].tri_mul_out
model.evoformer.blocks[0].tri_mul_in
if incoming
else model.evoformer.blocks[0].tri_mul_out
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......@@ -109,4 +109,3 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
if __name__ == "__main__":
unittest.main()
......@@ -20,17 +20,21 @@ from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.tensor_utils import chunk_layer
X_90_ROT = torch.tensor([
X_90_ROT = torch.tensor(
[
[1, 0, 0],
[0, 0,-1],
[0, 0, -1],
[0, 1, 0],
])
]
)
X_NEG_90_ROT = torch.tensor([
X_NEG_90_ROT = torch.tensor(
[
[1, 0, 0],
[0, 0, 1],
[0,-1, 0],
])
[0, -1, 0],
]
)
class TestAffineT(unittest.TestCase):
......@@ -53,7 +57,7 @@ class TestAffineT(unittest.TestCase):
batch_size = 2
transf = [
[1, 0, 0, 1],
[0, 0,-1, 2],
[0, 0, -1, 2],
[0, 1, 0, 3],
[0, 0, 0, 1],
]
......@@ -62,10 +66,7 @@ class TestAffineT(unittest.TestCase):
true_rot = transf[:3, :3]
true_trans = transf[:3, 3]
transf = torch.stack(
[transf for _ in range(batch_size)],
dim=0
)
transf = torch.stack([transf for _ in range(batch_size)], dim=0)
t = T.from_4x4(transf)
......@@ -78,8 +79,7 @@ class TestAffineT(unittest.TestCase):
batch_size = 2
n = 5
transf = T(
torch.rand((batch_size, n, 3, 3)),
torch.rand((batch_size, n, 3))
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3))
)
self.assertTrue(transf.shape == (batch_size, n))
......@@ -88,8 +88,7 @@ class TestAffineT(unittest.TestCase):
batch_size = 2
n = 5
transf = T(
torch.rand((batch_size, n, 3, 3)),
torch.rand((batch_size, n, 3))
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3))
)
transf_concat = T.concat([transf, transf], dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment