Commit ab09ded4 authored by Christina Floristean's avatar Christina Floristean
Browse files

Fixes for memory usage in multimer dataloader, spatial cropping, and tensor...

Fixes for memory usage in multimer dataloader, spatial cropping, and tensor type conversions within model.
parent da5d0e7d
......@@ -431,15 +431,15 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
)
self.data_pipeline = data_pipeline.DataPipeline(
data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
self.multimer_data_pipeline = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=self.data_pipeline
self.data_pipeline = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor
)
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
......@@ -457,7 +457,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
alignment_index=alignment_index
)
......@@ -473,82 +472,49 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
mmcif_id = self.idx_to_mmcif_id(idx)
chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
print(f"mmcif_id is :{mmcif_id} idx:{idx} and has {len(chains)}chains")
seqs = self.mmcif_data_cache[mmcif_id]['seqs']
fasta_str = ""
for c,s in zip(chains,seqs):
fasta_str+=f">{mmcif_id}_{c}\n{s}\n"
with temp_fasta_file(fasta_str) as fasta_file:
all_chain_features = self.multimer_data_pipeline.process_fasta(fasta_file,self.alignment_dir)
# process all_chain_features
all_chain_features = self.feature_pipeline.process_features(all_chain_features,
mode=self.mode,
is_multimer=True)
alignment_index = None
ground_truth=[]
if(self.mode == 'train' or self.mode == 'eval'):
for chain in chains:
path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
ext = e
break
path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
ext = e
break
if(ext is None):
raise ValueError("Invalid file type")
if(ext is None):
raise ValueError("Invalid file type")
path += ext
alignment_dir = os.path.join(self.alignment_dir,f"{mmcif_id}_{chain.upper()}")
if(ext == ".cif"):
data = self._parse_mmcif(
path, mmcif_id, chain, alignment_dir, alignment_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False)
#remove recycling dimension
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
elif(ext == ".pdb"):
structure_index = None
data = self.data_pipeline.process_pdb(
pdb_path=path,
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain,
alignment_index=alignment_index,
_structure_index=structure_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
else:
raise ValueError("Extension branch missing")
all_chain_features["batch_idx"] = torch.tensor(
[idx for _ in range(all_chain_features["aatype"].shape[-1])],
dtype=torch.int64,
device=all_chain_features["aatype"].device)
# if it's training now, then return both all_chain_features and ground_truth
return all_chain_features,ground_truth
#TODO: Add pdb and core exts to data_pipeline for multimer
path += ext
if(ext == ".cif"):
data = self._parse_mmcif(
path, mmcif_id, self.alignment_dir, alignment_index,
)
else:
raise ValueError("Extension branch missing")
else:
# if it's inference mode, only need all_chain_features
all_chain_features["batch_idx"] = torch.tensor(
[idx for _ in range(all_chain_features["aatype"].shape[-1])],
dtype=torch.int64,
device=all_chain_features["aatype"].device)
return all_chain_features
path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=self.alignment_dir
)
if (self._output_raw):
return data
# process all_chain_features
data = self.feature_pipeline.process_features(data,
mode=self.mode,
is_multimer=True)
# if it's inference mode, only need all_chain_features
data["batch_idx"] = torch.tensor(
[idx for _ in range(data["aatype"].shape[-1])],
dtype=torch.int64,
device=data["aatype"].device)
return data
def __len__(self):
return len(self._chain_ids)
......
......@@ -1188,4 +1188,75 @@ class DataPipelineMultimer:
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
def get_mmcif_features(
self, mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict:
mmcif_feats = {}
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
)
mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask
mmcif_feats["resolution"] = np.array(
mmcif_object.header["resolution"], dtype=np.float32
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
)
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
return mmcif_feats
def process_mmcif(
self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
alignment_index: Optional[str] = None,
) -> FeatureDict:
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1
for chain_id, seq in mmcif.chain_to_seqres.items():
desc= "_".join([mmcif.file_id, chain_id])
if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy(
sequence_features[seq]
)
continue
chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
description=desc,
chain_alignment_dir=os.path.join(alignment_dir, desc),
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=desc
)
mmcif_feats = self.get_mmcif_features(mmcif, chain_id)
chain_features.update(mmcif_feats)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
\ No newline at end of file
......@@ -307,15 +307,25 @@ def make_msa_profile(batch):
return batch
def randint(lower, upper, generator, device):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=device,
generator=generator,
)[0])
def get_interface_residues(positions, atom_mask, asym_id, interface_threshold):
coord_diff = positions[..., None, :, :] - positions[..., None, :, :, :]
pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float()
pair_mask = atom_mask[..., None, :] * atom_mask[..., None, :, :]
mask = diff_chain_mask[..., None] * pair_mask
mask = (diff_chain_mask[..., None] * pair_mask).bool()
min_dist_per_res = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1)
min_dist_per_res, _ = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1)
valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1)
interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0]
......@@ -336,8 +346,12 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
if not torch.any(interface_residues):
return get_contiguous_crop_idx(protein, crop_size, generator)
target_res = interface_residues[int(torch.randint(0, interface_residues.shape[-1], (1,),
device=positions.device, generator=generator)[0])]
target_res_idx = randint(lower=0,
upper=interface_residues.shape[-1],
generator=generator,
device=positions.device)
target_res = interface_residues[target_res_idx]
ca_idx = rc.atom_order["CA"]
ca_positions = positions[..., ca_idx, :]
......@@ -353,33 +367,24 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
).float()
* 1e-3
)
to_target_distances = torch.where(ca_mask[..., None], to_target_distances, torch.inf) + break_tie
to_target_distances = torch.where(ca_mask, to_target_distances, torch.inf) + break_tie
ret = torch.argsort(to_target_distances)[:crop_size]
return ret.sort().values
def randint(lower, upper, generator, device):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=device,
generator=generator,
)[0])
def get_contiguous_crop_idx(protein, crop_size, generator):
num_res = protein["aatype"].shape[0]
if num_res <= crop_size:
return torch.arange(num_res)
_, chain_lens = protein["asym_id"].unique(return_counts=True)
unique_asym_ids, chain_lens = protein["asym_id"].unique(return_counts=True)
shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator)
num_remaining = int(chain_lens.sum())
num_budget = crop_size
crop_idxs = []
asym_offset = torch.tensor(0, dtype=torch.int64)
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (protein["asym_id"]== cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(protein["asym_id"], asym_mask)[0]
for j, idx in enumerate(shuffle_idx):
this_len = int(chain_lens[idx])
num_remaining -= this_len
......@@ -396,6 +401,8 @@ def get_contiguous_crop_idx(protein, crop_size, generator):
upper=this_len - chain_crop_size + 1,
generator=generator,
device=chain_lens.device)
asym_offset = per_asym_residue_index[int(idx)]
crop_idxs.append(
torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size)
)
......@@ -427,7 +434,11 @@ def random_crop_to_size(
use_spatial_crop = torch.rand((1,),
device=protein["seq_length"].device,
generator=g) < spatial_crop_prob
if use_spatial_crop:
num_res = protein["aatype"].shape[0]
if num_res <= crop_size:
crop_idxs = torch.arange(num_res)
elif use_spatial_crop:
crop_idxs = get_spatial_crop_idx(protein, crop_size, interface_threshold, g)
else:
crop_idxs = get_contiguous_crop_idx(protein, crop_size, g)
......@@ -469,9 +480,8 @@ def random_crop_to_size(
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"):
crop_size = num_templates_crop_size
crop_start = templates_crop_start
v = v[slice(crop_start, crop_start + crop_size)]
v = v[slice(crop_start, crop_start + num_templates_crop_size)]
elif is_num_res:
v = torch.index_select(v, i, crop_idxs)
......
......@@ -228,12 +228,13 @@ def process_unmerged_features(
chain_features['deletion_matrix'], axis=0
)
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['aatype']]
chain_features['all_atom_mask'] = all_atom_mask
chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])
if 'all_atom_positions' not in chain_features:
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['aatype']]
chain_features['all_atom_mask'] = all_atom_mask.astype(dtype=np.float32)
chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])
# Add assembly_num_chains.
chain_features['assembly_num_chains'] = np.asarray(num_chains)
......
......@@ -31,6 +31,18 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.make_atom14_masks,
]
if mode_cfg.supervised:
transforms.extend(
[
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
]
)
return transforms
......
......@@ -706,12 +706,12 @@ class TemplatePairEmbedderMultimer(nn.Module):
backbone_mask[..., None] * backbone_mask[..., None, :]
)
backbone_mask_2d *= multichain_mask_2d
x, y, z = [coord * backbone_mask_2d for coord in unit_vector]
x, y, z = [(coord * backbone_mask_2d).to(dtype=query_embedding.dtype) for coord in unit_vector]
act += self.x_linear(x[..., None])
act += self.y_linear(y[..., None])
act += self.z_linear(z[..., None])
act += self.backbone_mask_linear(backbone_mask_2d[..., None])
act += self.backbone_mask_linear(backbone_mask_2d[..., None].to(dtype=query_embedding.dtype))
query_embedding = self.query_embedding_layer_norm(query_embedding)
act += self.query_embedding_linear(query_embedding)
......@@ -735,6 +735,8 @@ class TemplateSingleEmbedderMultimer(nn.Module):
):
out = {}
dtype = batch["template_all_atom_positions"].dtype
template_chi_angles, template_chi_mask = (
all_atom_multimer.compute_chi_angles(
atom_pos,
......@@ -751,9 +753,9 @@ class TemplateSingleEmbedderMultimer(nn.Module):
template_chi_mask,
],
dim=-1,
)
).to(dtype=dtype)
template_mask = template_chi_mask[..., 0]
template_mask = template_chi_mask[..., 0].to(dtype=dtype)
template_activations = self.template_single_embedder(
template_features
......@@ -829,8 +831,10 @@ class TemplateEmbedderMultimer(nn.Module):
)
raw_atom_pos = single_template_feats["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
# Vec3Arrays are required to be float32
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos.to(dtype=torch.float32))
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
......
......@@ -348,7 +348,7 @@ class AlphaFold(nn.Module):
extra_msa_fn = build_extra_msa_feat
# [*, S_e, N, C_e]
extra_msa_feat = extra_msa_fn(feats)
extra_msa_feat = extra_msa_fn(feats).to(dtype=z.dtype)
a = self.extra_msa_embedder(extra_msa_feat)
if(self.globals.offload_inference):
......@@ -527,6 +527,7 @@ class AlphaFold(nn.Module):
# Main recycling loop
num_iters = batch["aatype"].shape[-1]
early_stop = False
num_recycles = 0
for cycle_no in range(num_iters):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
......@@ -547,6 +548,8 @@ class AlphaFold(nn.Module):
_recycle=(num_iters > 1)
)
num_recycles += 1
if not is_final_iter:
del outputs
prevs = [m_1_prev, z_prev, x_prev]
......@@ -554,6 +557,8 @@ class AlphaFold(nn.Module):
else:
break
outputs["num_recycles"] = torch.tensor(num_recycles, device=feats["aatype"].device)
if "asym_id" in batch:
outputs["asym_id"] = feats["asym_id"]
......
......@@ -130,6 +130,7 @@ class Linear(nn.Linear):
bias: bool = True,
init: str = "default",
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
precision=None
):
"""
Args:
......@@ -181,6 +182,26 @@ class Linear(nn.Linear):
else:
raise ValueError("Invalid init string.")
self.precision = precision
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
if self.precision is not None:
with torch.cuda.amp.autocast(enabled=False):
return nn.functional.linear(input.to(dtype=self.precision),
self.weight.to(dtype=self.precision),
self.bias.to(dtype=self.precision)).to(dtype=d)
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
return nn.functional.linear(input, self.weight.to(dtype=d), self.bias.to(dtype=d))
return nn.functional.linear(input, self.weight, self.bias)
class LayerNorm(nn.Module):
def __init__(self, c_in, eps=1e-5):
......
......@@ -175,7 +175,7 @@ class PointProjection(nn.Module):
self.num_points = num_points
self.is_multimer = is_multimer
self.linear = Linear(c_hidden, no_heads * 3 * num_points)
self.linear = Linear(c_hidden, no_heads * 3 * num_points, precision=torch.float32)
def forward(self,
activations: torch.Tensor,
......@@ -642,6 +642,7 @@ class InvariantPointAttentionMultimer(nn.Module):
pt_att = square_euclidean_distance(q_pts.unsqueeze(-3), k_pts.unsqueeze(-4), epsilon=0.)
pt_att = torch.sum(pt_att * point_weights[..., None], dim=-1) * (-0.5)
pt_att = pt_att.to(dtype=s.dtype)
a = a + pt_att
scalar_variance = max(self.c_hidden, 1) * 1.
......@@ -707,6 +708,7 @@ class InvariantPointAttentionMultimer(nn.Module):
# [*, N_res, H, P_v]
o_pt = r[..., None].apply_inverse_to_point(o_pt)
o_pt_flat = [o_pt.x, o_pt.y, o_pt.z]
o_pt_flat = [x.to(dtype=a.dtype) for x in o_pt_flat]
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(epsilon=1e-8)
......@@ -1136,6 +1138,8 @@ class StructureModule(nn.Module):
"positions": pred_xyz,
}
preds = {k: v.to(dtype=s.dtype) for k, v in preds.items()}
outputs.append(preds)
rigids = rigids.stop_rot_gradient()
......
......@@ -67,7 +67,7 @@ def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask):
residx_atom14_to_atom37,
dim=no_batch_dims + 1,
no_batch_dims=no_batch_dims + 1,
).to(torch.float32)
).to(all_atom_pos.dtype)
# create a mask for known groundtruth positions
atom14_mask *= get_rc_tensor(rc.RESTYPE_ATOM14_MASK, aatype)
# gather the groundtruth positions
......@@ -143,14 +143,14 @@ def atom37_to_frames(
# Compute a mask whether ground truth exists for the group
gt_atoms_exist = tensor_utils.batched_gather( # shape (N, 8, 3)
all_atom_mask.to(dtype=torch.float32),
all_atom_mask.to(dtype=all_atom_positions.dtype),
residx_rigidgroup_base_atom37_idx,
batch_dims=no_batch_dims + 1,
)
gt_exists = torch.min(gt_atoms_exist, dim=-1) * group_exists # (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1])
rots = np.tile(np.eye(3, dtype=all_atom_positions.dtype), [8, 1, 1])
rots[0, 0, 0] = -1
rots[0, 2, 2] = -1
gt_frames = gt_frames.compose_rotation(
......@@ -161,9 +161,9 @@ def atom37_to_frames(
# The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32)
restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=all_atom_positions.dtype)
restype_rigidgroup_rots = np.tile(
np.eye(3, dtype=np.float32), [21, 8, 1, 1]
np.eye(3, dtype=all_atom_positions.dtype), [21, 8, 1, 1]
)
for resname, _ in rc.residue_atom_renaming_swaps.items():
......@@ -334,7 +334,7 @@ def extreme_ca_ca_distance_violations(
next_ca_mask = mask[..., 1:, 1] # (N - 1)
has_no_gap_mask = (
(residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
).astype(torch.float32)
).astype(positions.x.dtype)
ca_ca_distance = geometry.euclidean_distance(this_ca_pos, next_ca_pos, eps)
violations = (ca_ca_distance - rc.ca_ca) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
......@@ -441,7 +441,7 @@ def compute_chi_angles(
)
# Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4].
chi_angle_atoms_mask = torch.prod(chi_angle_atoms_mask, dim=-1)
chi_mask = chi_mask * chi_angle_atoms_mask.to(torch.float32)
chi_mask = chi_mask * chi_angle_atoms_mask.to(chi_angles.dtype)
return chi_angles, chi_mask
......
......@@ -16,7 +16,7 @@ class QuatRigid(nn.Module):
else:
rigid_dim = 6
self.linear = Linear(c_hidden, rigid_dim, init="final")
self.linear = Linear(c_hidden, rigid_dim, init="final", precision=torch.float32)
def forward(self, activations: torch.Tensor) -> Rigid3Array:
# NOTE: During training, this needs to be run in higher precision
......
......@@ -1672,7 +1672,7 @@ def chain_center_of_mass_loss(
one_hot = torch.nn.functional.one_hot(asym_id.long()).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float()
chain_exists = torch.any(chain_pos_mask, dim=-1).to(dtype=all_atom_positions.dtype)
def get_chain_center_of_mass(pos):
center_sum = (chain_pos_mask[..., None] * pos[..., None, :, :]).sum(dim=-2)
......@@ -1694,6 +1694,26 @@ def chain_center_of_mass_loss(
# #
# below are the functions required for permutations
# #
def compute_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor = None,
eps: float = 1e-6,
) -> torch.Tensor:
# shape check
true_atom_pos = true_atom_pos
pred_atom_pos = pred_atom_pos
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
del true_atom_pos
del pred_atom_pos
gc.collect()
if atom_mask is not None:
sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device))
msd = torch.mean(sq_diff)
msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps) # prevent sqrt 0
def kabsch_rotation(P, Q):
"""
Use procrustes package to calculate best rotation that minimises
......@@ -1712,11 +1732,12 @@ def kabsch_rotation(P, Q):
"""
assert P.shape == torch.Size([Q.shape[0],Q.shape[1]])
rotation = procrustes.rotational(P.detach().cpu().numpy(),
Q.detach().cpu().numpy(),translate=False,scale=False)
rotation = torch.tensor(rotation.t,dtype=torch.float) # rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
rotation = procrustes.rotational(P.detach().cpu().float().numpy(),
Q.detach().cpu().float().numpy(),translate=False,scale=False)
# Rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
rotation = torch.tensor(rotation.t,dtype=torch.float)
assert rotation.shape == torch.Size([3,3])
return rotation.to('cuda')
return rotation.to(device=P.device, dtype=P.dtype)
def get_optimal_transform(
......@@ -1731,7 +1752,8 @@ def get_optimal_transform(
"""
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3
assert len(mask.shape) ==1,"mask should have the shape of [num_res]"
if mask is not None:
assert len(mask.shape) == 1, "mask should have the shape of [num_res]"
if torch.isnan(src_atoms).any() or torch.isinf(src_atoms).any():
#
# sometimes using fake test inputs generates NaN in the predicted atom positions
......@@ -1743,7 +1765,7 @@ def get_optimal_transform(
assert mask.dtype == torch.bool
assert mask.shape[-1] == src_atoms.shape[-2]
if mask.sum() == 0:
src_atoms = torch.zeros((1, 3), device=src_atoms.device).float()
src_atoms = torch.zeros((1, 3), device=src_atoms.device, dtype=src_atoms.dtype)
tgt_atoms = src_atoms
else:
src_atoms = src_atoms[mask, :]
......@@ -1754,33 +1776,12 @@ def get_optimal_transform(
del src_atoms,tgt_atoms,
gc.collect()
tgt_center,src_center = tgt_center.to('cuda'),src_center.to('cuda')
x = tgt_center.to('cpu') - src_center.to('cpu') @ r.to('cpu')
x = tgt_center - src_center @ r
del tgt_center,src_center,mask
gc.collect()
return r, x.to('cuda')
def compute_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor = None,
eps: float = 1e-6,
) -> torch.Tensor:
# shape check
true_atom_pos = true_atom_pos.to('cuda:0')
pred_atom_pos = pred_atom_pos.to('cuda:0')
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
del true_atom_pos
del pred_atom_pos
gc.collect()
if atom_mask is not None:
sq_diff = sq_diff.to('cpu')[atom_mask.to('cpu')] # somehow it causes overflow on cuda so moved to cpu
msd = torch.mean(sq_diff)
msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps)
return r, x
def get_least_asym_entity_or_longest_length(batch):
......@@ -1834,43 +1835,35 @@ def greedy_align(
true_ca_poses,
true_ca_masks,
):
"""
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
"""
used = [False for _ in range(len(true_ca_poses))]
align = []
for cur_asym_id in unique_asym_ids:
# skip padding
if cur_asym_id == 0:
continue
i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id
entity_id = batch["entity_id"][asym_mask][0]
# don't need to align
if (entity_id) == 1:
align.append((i, i))
assert used[i] == False
used[i] = True
continue
cur_entity_ids = batch["entity_id"][asym_mask][0]
best_rmsd = torch.inf
best_rmsd = torch.inf
best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask]
for next_asym_id in cur_asym_list:
if next_asym_id == 0:
continue
j = int(next_asym_id - 1)
if not used[j]: # possible candidate
while best_idx is None:
cropped_pos = true_ca_poses[j]
mask = true_ca_masks[j][cur_residue_index]
rmsd = compute_rmsd(
cropped_pos, cur_pred_pos, (cur_pred_mask.to('cuda:0') * mask.to('cuda:0')).bool()
)
if (rmsd is not None) and (rmsd < best_rmsd):
best_rmsd = rmsd
best_idx = j
cropped_pos = torch.index_select(true_ca_poses[j],1,cur_residue_index)
mask = torch.index_select(true_ca_masks[j],1,cur_residue_index)
rmsd = compute_rmsd(
torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0),
(cur_pred_mask * mask).bool()
)
if (rmsd is not None) and (rmsd < best_rmsd):
best_rmsd = rmsd
best_idx = j
assert best_idx is not None
used[best_idx] = True
align.append((i, best_idx))
......@@ -1882,6 +1875,9 @@ def merge_labels(per_asym_residue_index, labels, align):
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym.
modified based on UniFold:
https://github.com/dptech-corp/Uni-Fold/blob/b1c89a2cebd4e4ee4c47b4e443f92beeb9138fbb/unifold/losses/chain_align.py#L176C1-L176C1
"""
outs = {}
for k, v in labels[0].items():
......@@ -1891,10 +1887,12 @@ def merge_labels(per_asym_residue_index, labels, align):
cur_num_res = labels[j]['aatype'].shape[-1]
# to 1-based
cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)==0 or "template" in k:
if len(v.shape)<=1 or "template" in k or "row_mask" in k :
continue
else:
else:
dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0
if k =='all_atom_positions':
dimension_to_merge=1
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0:
......@@ -2012,6 +2010,7 @@ class AlphaFoldLoss(nn.Module):
cum_loss,losses = self.loss(out,batch,_return_breakdown)
return cum_loss, losses
class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
Add multi-chain permutation on top of
......@@ -2021,7 +2020,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
super(AlphaFoldMultimerLoss, self).__init__(config)
self.config = config
def multi_chain_perm_align(self,out, batch, labels, shuffle_times=2):
@staticmethod
def multi_chain_perm_align(out, batch, labels, permutate_chains=True):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
......@@ -2031,99 +2031,95 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :].float() # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].float() # [bsz, nres]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :].float() for l in labels
l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3])
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].float() for l in labels
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
unique_asym_ids = torch.unique(batch["asym_id"])
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"],asym_mask)
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_residue_idx)
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
# anchor_pred_pos = anchor_pred_pos.to('cuda')
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_residue_idx)
anchor_pred_mask =pred_ca_mask[0][asym_mask[0]]
# anchor_pred_mask = anchor_pred_mask.to('cuda')
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
r, x = get_optimal_transform(
anchor_pred_pos,anchor_true_pos[0],
mask=input_mask[0]
)
del input_mask # just to save memory
del anchor_pred_mask
del anchor_true_mask
gc.collect()
aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
align = greedy_align(
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_residue_idx)
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx], 1, anchor_residue_idx)
anchor_pred_mask = pred_ca_mask[0][asym_mask[0]]
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
r, x = get_optimal_transform(
anchor_pred_pos, anchor_true_pos[0],
mask=input_mask[0]
)
del input_mask # just to save memory
del anchor_pred_mask
del anchor_true_mask
gc.collect()
aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses
gc.collect()
align = greedy_align(
batch,
per_asym_residue_index,
unique_asym_ids ,
unique_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)
del aligned_true_ca_poses
del r,x
del pred_ca_pos,pred_ca_mask,true_ca_poses,true_ca_masks
del anchor_pred_pos,anchor_true_pos
gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
merged_labels = merge_labels(
per_asym_residue_index,
labels,
align,
)
return merged_labels
def forward(self,out,batch,_return_breakdown=False):
del aligned_true_ca_poses, true_ca_masks
del r, x
del pred_ca_pos, pred_ca_mask
del anchor_pred_pos, anchor_true_pos
gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
else:
align = list(enumerate(range(len(labels))))
return align, per_asym_residue_index
def forward(self, out, features, _return_breakdown=False, permutate_chains=True):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
it first compute multi-chain permutation
args:
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
"""
features,labels = batch
# first remove the recycling dimention of input features
features = tensor_tree_map(lambda t: t[..., -1], features)
features['resolution'] = labels[0]['resolution']
# then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels)
permutated_labels.pop('aatype')
features.update(permutated_labels)
move_to_cpu = lambda t: (t.to('cpu'))
# features = tensor_tree_map(move_to_cpu,features)
# permutate ground truth chains before calculating the loss
# align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features, labels,
# permutate_chains=permutate_chains)
# permutated_labels = merge_labels(per_asym_residue_index, labels, align)
# permutated_labels.pop('aatype')
# features.update(permutated_labels)
if (not _return_breakdown):
cum_loss = self.loss(out,features,_return_breakdown)
cum_loss = self.loss(out, features, _return_breakdown)
print(f"cum_loss: {cum_loss}")
return cum_loss
else:
cum_loss,losses = self.loss(out,features,_return_breakdown)
cum_loss, losses = self.loss(out, features, _return_breakdown)
print(f"cum_loss: {cum_loss} losses: {losses}")
return cum_loss, losses
\ No newline at end of file
return cum_loss, losses
......@@ -260,9 +260,6 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
def __init__(self, config):
super(OpenFoldMultimerWrapper, self).__init__(config)
self.config = config
self.config.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
self.config.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
self.config.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
self.model = AlphaFold(config)
self.loss = AlphaFoldMultimerLoss(config.loss)
self.ema = ExponentialMovingAverage(
......@@ -276,24 +273,27 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
return self.model(batch)
def training_step(self, batch, batch_idx):
all_chain_features,ground_truth = batch
if(self.ema.device != all_chain_features["aatype"].device):
self.ema.to(all_chain_features["aatype"].device)
# Log it
if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device)
# Run the model
outputs = self(all_chain_features)
outputs = self(batch)
# Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch)
# Compute loss
loss = self.loss(
outputs, (all_chain_features,ground_truth), _return_breakdown=False
loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
# Log it
self._log(loss, all_chain_features, outputs)
self._log(loss_breakdown, batch, outputs)
return loss
def validation_step(self, batch, batch_idx):
all_chain_features,ground_truth = batch
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
......@@ -304,21 +304,22 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
outputs = self(all_chain_features)
outputs = self(batch)
# Compute loss and other metrics
all_chain_features["use_clamped_fape"] = 0.
batch["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss(
outputs, all_chain_features, _return_breakdown=True
outputs, batch, _return_breakdown=True
)
self._log(loss_breakdown, all_chain_features, outputs, train=False)
self._log(loss_breakdown, batch, outputs, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def main(args):
if(args.seed is not None):
seed_everything(args.seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment