Commit ab09ded4 authored by Christina Floristean's avatar Christina Floristean
Browse files

Fixes for memory usage in multimer dataloader, spatial cropping, and tensor...

Fixes for memory usage in multimer dataloader, spatial cropping, and tensor type conversions within model.
parent da5d0e7d
...@@ -431,15 +431,15 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -431,15 +431,15 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered, _shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
) )
self.data_pipeline = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
self.multimer_data_pipeline = data_pipeline.DataPipelineMultimer( self.data_pipeline = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=self.data_pipeline monomer_data_pipeline=data_processor
) )
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index): def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
...@@ -457,7 +457,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -457,7 +457,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_mmcif( data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
chain_id=chain_id,
alignment_index=alignment_index alignment_index=alignment_index
) )
...@@ -473,82 +472,49 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -473,82 +472,49 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
chains = self.mmcif_data_cache[mmcif_id]['chain_ids'] chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
print(f"mmcif_id is :{mmcif_id} idx:{idx} and has {len(chains)}chains") print(f"mmcif_id is :{mmcif_id} idx:{idx} and has {len(chains)}chains")
seqs = self.mmcif_data_cache[mmcif_id]['seqs']
fasta_str = ""
for c,s in zip(chains,seqs):
fasta_str+=f">{mmcif_id}_{c}\n{s}\n"
with temp_fasta_file(fasta_str) as fasta_file:
all_chain_features = self.multimer_data_pipeline.process_fasta(fasta_file,self.alignment_dir)
# process all_chain_features
all_chain_features = self.feature_pipeline.process_features(all_chain_features,
mode=self.mode,
is_multimer=True)
alignment_index = None alignment_index = None
ground_truth=[]
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
for chain in chains: path = os.path.join(self.data_dir, f"{mmcif_id}")
path = os.path.join(self.data_dir, f"{mmcif_id}") ext = None
ext = None for e in self.supported_exts:
for e in self.supported_exts: if(os.path.exists(path + e)):
if(os.path.exists(path + e)): ext = e
ext = e break
break
if(ext is None): if(ext is None):
raise ValueError("Invalid file type") raise ValueError("Invalid file type")
path += ext #TODO: Add pdb and core exts to data_pipeline for multimer
alignment_dir = os.path.join(self.alignment_dir,f"{mmcif_id}_{chain.upper()}") path += ext
if(ext == ".cif"): if(ext == ".cif"):
data = self._parse_mmcif( data = self._parse_mmcif(
path, mmcif_id, chain, alignment_dir, alignment_index, path, mmcif_id, self.alignment_dir, alignment_index,
) )
ground_truth_feats = self.feature_pipeline.process_features(data, "train", else:
is_multimer=False) raise ValueError("Extension branch missing")
#remove recycling dimension
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
elif(ext == ".pdb"):
structure_index = None
data = self.data_pipeline.process_pdb(
pdb_path=path,
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain,
alignment_index=alignment_index,
_structure_index=structure_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
else:
raise ValueError("Extension branch missing")
all_chain_features["batch_idx"] = torch.tensor(
[idx for _ in range(all_chain_features["aatype"].shape[-1])],
dtype=torch.int64,
device=all_chain_features["aatype"].device)
# if it's training now, then return both all_chain_features and ground_truth
return all_chain_features,ground_truth
else: else:
# if it's inference mode, only need all_chain_features path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
all_chain_features["batch_idx"] = torch.tensor( data = self.data_pipeline.process_fasta(
[idx for _ in range(all_chain_features["aatype"].shape[-1])], fasta_path=path,
dtype=torch.int64, alignment_dir=self.alignment_dir
device=all_chain_features["aatype"].device) )
return all_chain_features
if (self._output_raw):
return data
# process all_chain_features
data = self.feature_pipeline.process_features(data,
mode=self.mode,
is_multimer=True)
# if it's inference mode, only need all_chain_features
data["batch_idx"] = torch.tensor(
[idx for _ in range(data["aatype"].shape[-1])],
dtype=torch.int64,
device=data["aatype"].device)
return data
def __len__(self): def __len__(self):
return len(self._chain_ids) return len(self._chain_ids)
......
...@@ -1188,4 +1188,75 @@ class DataPipelineMultimer: ...@@ -1188,4 +1188,75 @@ class DataPipelineMultimer:
# Pad MSA to avoid zero-sized extra_msa. # Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512) np_example = pad_msa(np_example, 512)
return np_example
def get_mmcif_features(
self, mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict:
mmcif_feats = {}
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
)
mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask
mmcif_feats["resolution"] = np.array(
mmcif_object.header["resolution"], dtype=np.float32
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
)
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
return mmcif_feats
def process_mmcif(
self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
alignment_index: Optional[str] = None,
) -> FeatureDict:
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1
for chain_id, seq in mmcif.chain_to_seqres.items():
desc= "_".join([mmcif.file_id, chain_id])
if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy(
sequence_features[seq]
)
continue
chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
description=desc,
chain_alignment_dir=os.path.join(alignment_dir, desc),
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=desc
)
mmcif_feats = self.get_mmcif_features(mmcif, chain_id)
chain_features.update(mmcif_feats)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example return np_example
\ No newline at end of file
...@@ -307,15 +307,25 @@ def make_msa_profile(batch): ...@@ -307,15 +307,25 @@ def make_msa_profile(batch):
return batch return batch
def randint(lower, upper, generator, device):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=device,
generator=generator,
)[0])
def get_interface_residues(positions, atom_mask, asym_id, interface_threshold): def get_interface_residues(positions, atom_mask, asym_id, interface_threshold):
coord_diff = positions[..., None, :, :] - positions[..., None, :, :, :] coord_diff = positions[..., None, :, :] - positions[..., None, :, :, :]
pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1)) pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float() diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float()
pair_mask = atom_mask[..., None, :] * atom_mask[..., None, :, :] pair_mask = atom_mask[..., None, :] * atom_mask[..., None, :, :]
mask = diff_chain_mask[..., None] * pair_mask mask = (diff_chain_mask[..., None] * pair_mask).bool()
min_dist_per_res = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1) min_dist_per_res, _ = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1)
valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1) valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1)
interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0] interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0]
...@@ -336,8 +346,12 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator): ...@@ -336,8 +346,12 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
if not torch.any(interface_residues): if not torch.any(interface_residues):
return get_contiguous_crop_idx(protein, crop_size, generator) return get_contiguous_crop_idx(protein, crop_size, generator)
target_res = interface_residues[int(torch.randint(0, interface_residues.shape[-1], (1,), target_res_idx = randint(lower=0,
device=positions.device, generator=generator)[0])] upper=interface_residues.shape[-1],
generator=generator,
device=positions.device)
target_res = interface_residues[target_res_idx]
ca_idx = rc.atom_order["CA"] ca_idx = rc.atom_order["CA"]
ca_positions = positions[..., ca_idx, :] ca_positions = positions[..., ca_idx, :]
...@@ -353,33 +367,24 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator): ...@@ -353,33 +367,24 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
).float() ).float()
* 1e-3 * 1e-3
) )
to_target_distances = torch.where(ca_mask[..., None], to_target_distances, torch.inf) + break_tie to_target_distances = torch.where(ca_mask, to_target_distances, torch.inf) + break_tie
ret = torch.argsort(to_target_distances)[:crop_size] ret = torch.argsort(to_target_distances)[:crop_size]
return ret.sort().values return ret.sort().values
def randint(lower, upper, generator, device):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=device,
generator=generator,
)[0])
def get_contiguous_crop_idx(protein, crop_size, generator): def get_contiguous_crop_idx(protein, crop_size, generator):
num_res = protein["aatype"].shape[0] unique_asym_ids, chain_lens = protein["asym_id"].unique(return_counts=True)
if num_res <= crop_size:
return torch.arange(num_res)
_, chain_lens = protein["asym_id"].unique(return_counts=True)
shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator) shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator)
num_remaining = int(chain_lens.sum()) num_remaining = int(chain_lens.sum())
num_budget = crop_size num_budget = crop_size
crop_idxs = [] crop_idxs = []
asym_offset = torch.tensor(0, dtype=torch.int64)
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (protein["asym_id"]== cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(protein["asym_id"], asym_mask)[0]
for j, idx in enumerate(shuffle_idx): for j, idx in enumerate(shuffle_idx):
this_len = int(chain_lens[idx]) this_len = int(chain_lens[idx])
num_remaining -= this_len num_remaining -= this_len
...@@ -396,6 +401,8 @@ def get_contiguous_crop_idx(protein, crop_size, generator): ...@@ -396,6 +401,8 @@ def get_contiguous_crop_idx(protein, crop_size, generator):
upper=this_len - chain_crop_size + 1, upper=this_len - chain_crop_size + 1,
generator=generator, generator=generator,
device=chain_lens.device) device=chain_lens.device)
asym_offset = per_asym_residue_index[int(idx)]
crop_idxs.append( crop_idxs.append(
torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size) torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size)
) )
...@@ -427,7 +434,11 @@ def random_crop_to_size( ...@@ -427,7 +434,11 @@ def random_crop_to_size(
use_spatial_crop = torch.rand((1,), use_spatial_crop = torch.rand((1,),
device=protein["seq_length"].device, device=protein["seq_length"].device,
generator=g) < spatial_crop_prob generator=g) < spatial_crop_prob
if use_spatial_crop:
num_res = protein["aatype"].shape[0]
if num_res <= crop_size:
crop_idxs = torch.arange(num_res)
elif use_spatial_crop:
crop_idxs = get_spatial_crop_idx(protein, crop_size, interface_threshold, g) crop_idxs = get_spatial_crop_idx(protein, crop_size, interface_threshold, g)
else: else:
crop_idxs = get_contiguous_crop_idx(protein, crop_size, g) crop_idxs = get_contiguous_crop_idx(protein, crop_size, g)
...@@ -469,9 +480,8 @@ def random_crop_to_size( ...@@ -469,9 +480,8 @@ def random_crop_to_size(
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)): for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"): if i == 0 and k.startswith("template"):
crop_size = num_templates_crop_size
crop_start = templates_crop_start crop_start = templates_crop_start
v = v[slice(crop_start, crop_start + crop_size)] v = v[slice(crop_start, crop_start + num_templates_crop_size)]
elif is_num_res: elif is_num_res:
v = torch.index_select(v, i, crop_idxs) v = torch.index_select(v, i, crop_idxs)
......
...@@ -228,12 +228,13 @@ def process_unmerged_features( ...@@ -228,12 +228,13 @@ def process_unmerged_features(
chain_features['deletion_matrix'], axis=0 chain_features['deletion_matrix'], axis=0
) )
# Add all_atom_mask and dummy all_atom_positions based on aatype. if 'all_atom_positions' not in chain_features:
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[ # Add all_atom_mask and dummy all_atom_positions based on aatype.
chain_features['aatype']] all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['all_atom_mask'] = all_atom_mask chain_features['aatype']]
chain_features['all_atom_positions'] = np.zeros( chain_features['all_atom_mask'] = all_atom_mask.astype(dtype=np.float32)
list(all_atom_mask.shape) + [3]) chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])
# Add assembly_num_chains. # Add assembly_num_chains.
chain_features['assembly_num_chains'] = np.asarray(num_chains) chain_features['assembly_num_chains'] = np.asarray(num_chains)
......
...@@ -31,6 +31,18 @@ def nonensembled_transform_fns(common_cfg, mode_cfg): ...@@ -31,6 +31,18 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.make_atom14_masks, data_transforms.make_atom14_masks,
] ]
if mode_cfg.supervised:
transforms.extend(
[
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
]
)
return transforms return transforms
......
...@@ -706,12 +706,12 @@ class TemplatePairEmbedderMultimer(nn.Module): ...@@ -706,12 +706,12 @@ class TemplatePairEmbedderMultimer(nn.Module):
backbone_mask[..., None] * backbone_mask[..., None, :] backbone_mask[..., None] * backbone_mask[..., None, :]
) )
backbone_mask_2d *= multichain_mask_2d backbone_mask_2d *= multichain_mask_2d
x, y, z = [coord * backbone_mask_2d for coord in unit_vector] x, y, z = [(coord * backbone_mask_2d).to(dtype=query_embedding.dtype) for coord in unit_vector]
act += self.x_linear(x[..., None]) act += self.x_linear(x[..., None])
act += self.y_linear(y[..., None]) act += self.y_linear(y[..., None])
act += self.z_linear(z[..., None]) act += self.z_linear(z[..., None])
act += self.backbone_mask_linear(backbone_mask_2d[..., None]) act += self.backbone_mask_linear(backbone_mask_2d[..., None].to(dtype=query_embedding.dtype))
query_embedding = self.query_embedding_layer_norm(query_embedding) query_embedding = self.query_embedding_layer_norm(query_embedding)
act += self.query_embedding_linear(query_embedding) act += self.query_embedding_linear(query_embedding)
...@@ -735,6 +735,8 @@ class TemplateSingleEmbedderMultimer(nn.Module): ...@@ -735,6 +735,8 @@ class TemplateSingleEmbedderMultimer(nn.Module):
): ):
out = {} out = {}
dtype = batch["template_all_atom_positions"].dtype
template_chi_angles, template_chi_mask = ( template_chi_angles, template_chi_mask = (
all_atom_multimer.compute_chi_angles( all_atom_multimer.compute_chi_angles(
atom_pos, atom_pos,
...@@ -751,9 +753,9 @@ class TemplateSingleEmbedderMultimer(nn.Module): ...@@ -751,9 +753,9 @@ class TemplateSingleEmbedderMultimer(nn.Module):
template_chi_mask, template_chi_mask,
], ],
dim=-1, dim=-1,
) ).to(dtype=dtype)
template_mask = template_chi_mask[..., 0] template_mask = template_chi_mask[..., 0].to(dtype=dtype)
template_activations = self.template_single_embedder( template_activations = self.template_single_embedder(
template_features template_features
...@@ -829,8 +831,10 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -829,8 +831,10 @@ class TemplateEmbedderMultimer(nn.Module):
) )
raw_atom_pos = single_template_feats["template_all_atom_positions"] raw_atom_pos = single_template_feats["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos) # Vec3Arrays are required to be float32
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos.to(dtype=torch.float32))
rigid, backbone_mask = all_atom_multimer.make_backbone_affine( rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos, atom_pos,
single_template_feats["template_all_atom_mask"], single_template_feats["template_all_atom_mask"],
......
...@@ -348,7 +348,7 @@ class AlphaFold(nn.Module): ...@@ -348,7 +348,7 @@ class AlphaFold(nn.Module):
extra_msa_fn = build_extra_msa_feat extra_msa_fn = build_extra_msa_feat
# [*, S_e, N, C_e] # [*, S_e, N, C_e]
extra_msa_feat = extra_msa_fn(feats) extra_msa_feat = extra_msa_fn(feats).to(dtype=z.dtype)
a = self.extra_msa_embedder(extra_msa_feat) a = self.extra_msa_embedder(extra_msa_feat)
if(self.globals.offload_inference): if(self.globals.offload_inference):
...@@ -527,6 +527,7 @@ class AlphaFold(nn.Module): ...@@ -527,6 +527,7 @@ class AlphaFold(nn.Module):
# Main recycling loop # Main recycling loop
num_iters = batch["aatype"].shape[-1] num_iters = batch["aatype"].shape[-1]
early_stop = False early_stop = False
num_recycles = 0
for cycle_no in range(num_iters): for cycle_no in range(num_iters):
# Select the features for the current recycling cycle # Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no] fetch_cur_batch = lambda t: t[..., cycle_no]
...@@ -547,6 +548,8 @@ class AlphaFold(nn.Module): ...@@ -547,6 +548,8 @@ class AlphaFold(nn.Module):
_recycle=(num_iters > 1) _recycle=(num_iters > 1)
) )
num_recycles += 1
if not is_final_iter: if not is_final_iter:
del outputs del outputs
prevs = [m_1_prev, z_prev, x_prev] prevs = [m_1_prev, z_prev, x_prev]
...@@ -554,6 +557,8 @@ class AlphaFold(nn.Module): ...@@ -554,6 +557,8 @@ class AlphaFold(nn.Module):
else: else:
break break
outputs["num_recycles"] = torch.tensor(num_recycles, device=feats["aatype"].device)
if "asym_id" in batch: if "asym_id" in batch:
outputs["asym_id"] = feats["asym_id"] outputs["asym_id"] = feats["asym_id"]
......
...@@ -130,6 +130,7 @@ class Linear(nn.Linear): ...@@ -130,6 +130,7 @@ class Linear(nn.Linear):
bias: bool = True, bias: bool = True,
init: str = "default", init: str = "default",
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
precision=None
): ):
""" """
Args: Args:
...@@ -181,6 +182,26 @@ class Linear(nn.Linear): ...@@ -181,6 +182,26 @@ class Linear(nn.Linear):
else: else:
raise ValueError("Invalid init string.") raise ValueError("Invalid init string.")
self.precision = precision
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
if self.precision is not None:
with torch.cuda.amp.autocast(enabled=False):
return nn.functional.linear(input.to(dtype=self.precision),
self.weight.to(dtype=self.precision),
self.bias.to(dtype=self.precision)).to(dtype=d)
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
return nn.functional.linear(input, self.weight.to(dtype=d), self.bias.to(dtype=d))
return nn.functional.linear(input, self.weight, self.bias)
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, c_in, eps=1e-5): def __init__(self, c_in, eps=1e-5):
......
...@@ -175,7 +175,7 @@ class PointProjection(nn.Module): ...@@ -175,7 +175,7 @@ class PointProjection(nn.Module):
self.num_points = num_points self.num_points = num_points
self.is_multimer = is_multimer self.is_multimer = is_multimer
self.linear = Linear(c_hidden, no_heads * 3 * num_points) self.linear = Linear(c_hidden, no_heads * 3 * num_points, precision=torch.float32)
def forward(self, def forward(self,
activations: torch.Tensor, activations: torch.Tensor,
...@@ -642,6 +642,7 @@ class InvariantPointAttentionMultimer(nn.Module): ...@@ -642,6 +642,7 @@ class InvariantPointAttentionMultimer(nn.Module):
pt_att = square_euclidean_distance(q_pts.unsqueeze(-3), k_pts.unsqueeze(-4), epsilon=0.) pt_att = square_euclidean_distance(q_pts.unsqueeze(-3), k_pts.unsqueeze(-4), epsilon=0.)
pt_att = torch.sum(pt_att * point_weights[..., None], dim=-1) * (-0.5) pt_att = torch.sum(pt_att * point_weights[..., None], dim=-1) * (-0.5)
pt_att = pt_att.to(dtype=s.dtype)
a = a + pt_att a = a + pt_att
scalar_variance = max(self.c_hidden, 1) * 1. scalar_variance = max(self.c_hidden, 1) * 1.
...@@ -707,6 +708,7 @@ class InvariantPointAttentionMultimer(nn.Module): ...@@ -707,6 +708,7 @@ class InvariantPointAttentionMultimer(nn.Module):
# [*, N_res, H, P_v] # [*, N_res, H, P_v]
o_pt = r[..., None].apply_inverse_to_point(o_pt) o_pt = r[..., None].apply_inverse_to_point(o_pt)
o_pt_flat = [o_pt.x, o_pt.y, o_pt.z] o_pt_flat = [o_pt.x, o_pt.y, o_pt.z]
o_pt_flat = [x.to(dtype=a.dtype) for x in o_pt_flat]
# [*, N_res, H * P_v] # [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(epsilon=1e-8) o_pt_norm = o_pt.norm(epsilon=1e-8)
...@@ -1136,6 +1138,8 @@ class StructureModule(nn.Module): ...@@ -1136,6 +1138,8 @@ class StructureModule(nn.Module):
"positions": pred_xyz, "positions": pred_xyz,
} }
preds = {k: v.to(dtype=s.dtype) for k, v in preds.items()}
outputs.append(preds) outputs.append(preds)
rigids = rigids.stop_rot_gradient() rigids = rigids.stop_rot_gradient()
......
...@@ -67,7 +67,7 @@ def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask): ...@@ -67,7 +67,7 @@ def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask):
residx_atom14_to_atom37, residx_atom14_to_atom37,
dim=no_batch_dims + 1, dim=no_batch_dims + 1,
no_batch_dims=no_batch_dims + 1, no_batch_dims=no_batch_dims + 1,
).to(torch.float32) ).to(all_atom_pos.dtype)
# create a mask for known groundtruth positions # create a mask for known groundtruth positions
atom14_mask *= get_rc_tensor(rc.RESTYPE_ATOM14_MASK, aatype) atom14_mask *= get_rc_tensor(rc.RESTYPE_ATOM14_MASK, aatype)
# gather the groundtruth positions # gather the groundtruth positions
...@@ -143,14 +143,14 @@ def atom37_to_frames( ...@@ -143,14 +143,14 @@ def atom37_to_frames(
# Compute a mask whether ground truth exists for the group # Compute a mask whether ground truth exists for the group
gt_atoms_exist = tensor_utils.batched_gather( # shape (N, 8, 3) gt_atoms_exist = tensor_utils.batched_gather( # shape (N, 8, 3)
all_atom_mask.to(dtype=torch.float32), all_atom_mask.to(dtype=all_atom_positions.dtype),
residx_rigidgroup_base_atom37_idx, residx_rigidgroup_base_atom37_idx,
batch_dims=no_batch_dims + 1, batch_dims=no_batch_dims + 1,
) )
gt_exists = torch.min(gt_atoms_exist, dim=-1) * group_exists # (N, 8) gt_exists = torch.min(gt_atoms_exist, dim=-1) * group_exists # (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis). # Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) rots = np.tile(np.eye(3, dtype=all_atom_positions.dtype), [8, 1, 1])
rots[0, 0, 0] = -1 rots[0, 0, 0] = -1
rots[0, 2, 2] = -1 rots[0, 2, 2] = -1
gt_frames = gt_frames.compose_rotation( gt_frames = gt_frames.compose_rotation(
...@@ -161,9 +161,9 @@ def atom37_to_frames( ...@@ -161,9 +161,9 @@ def atom37_to_frames(
# The frames for ambiguous rigid groups are just rotated by 180 degree around # The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group. # the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32) restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=all_atom_positions.dtype)
restype_rigidgroup_rots = np.tile( restype_rigidgroup_rots = np.tile(
np.eye(3, dtype=np.float32), [21, 8, 1, 1] np.eye(3, dtype=all_atom_positions.dtype), [21, 8, 1, 1]
) )
for resname, _ in rc.residue_atom_renaming_swaps.items(): for resname, _ in rc.residue_atom_renaming_swaps.items():
...@@ -334,7 +334,7 @@ def extreme_ca_ca_distance_violations( ...@@ -334,7 +334,7 @@ def extreme_ca_ca_distance_violations(
next_ca_mask = mask[..., 1:, 1] # (N - 1) next_ca_mask = mask[..., 1:, 1] # (N - 1)
has_no_gap_mask = ( has_no_gap_mask = (
(residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
).astype(torch.float32) ).astype(positions.x.dtype)
ca_ca_distance = geometry.euclidean_distance(this_ca_pos, next_ca_pos, eps) ca_ca_distance = geometry.euclidean_distance(this_ca_pos, next_ca_pos, eps)
violations = (ca_ca_distance - rc.ca_ca) > max_angstrom_tolerance violations = (ca_ca_distance - rc.ca_ca) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask mask = this_ca_mask * next_ca_mask * has_no_gap_mask
...@@ -441,7 +441,7 @@ def compute_chi_angles( ...@@ -441,7 +441,7 @@ def compute_chi_angles(
) )
# Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4]. # Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4].
chi_angle_atoms_mask = torch.prod(chi_angle_atoms_mask, dim=-1) chi_angle_atoms_mask = torch.prod(chi_angle_atoms_mask, dim=-1)
chi_mask = chi_mask * chi_angle_atoms_mask.to(torch.float32) chi_mask = chi_mask * chi_angle_atoms_mask.to(chi_angles.dtype)
return chi_angles, chi_mask return chi_angles, chi_mask
......
...@@ -16,7 +16,7 @@ class QuatRigid(nn.Module): ...@@ -16,7 +16,7 @@ class QuatRigid(nn.Module):
else: else:
rigid_dim = 6 rigid_dim = 6
self.linear = Linear(c_hidden, rigid_dim, init="final") self.linear = Linear(c_hidden, rigid_dim, init="final", precision=torch.float32)
def forward(self, activations: torch.Tensor) -> Rigid3Array: def forward(self, activations: torch.Tensor) -> Rigid3Array:
# NOTE: During training, this needs to be run in higher precision # NOTE: During training, this needs to be run in higher precision
......
...@@ -1672,7 +1672,7 @@ def chain_center_of_mass_loss( ...@@ -1672,7 +1672,7 @@ def chain_center_of_mass_loss(
one_hot = torch.nn.functional.one_hot(asym_id.long()).to(dtype=all_atom_mask.dtype) one_hot = torch.nn.functional.one_hot(asym_id.long()).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1) chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float() chain_exists = torch.any(chain_pos_mask, dim=-1).to(dtype=all_atom_positions.dtype)
def get_chain_center_of_mass(pos): def get_chain_center_of_mass(pos):
center_sum = (chain_pos_mask[..., None] * pos[..., None, :, :]).sum(dim=-2) center_sum = (chain_pos_mask[..., None] * pos[..., None, :, :]).sum(dim=-2)
...@@ -1694,6 +1694,26 @@ def chain_center_of_mass_loss( ...@@ -1694,6 +1694,26 @@ def chain_center_of_mass_loss(
# # # #
# below are the functions required for permutations # below are the functions required for permutations
# # # #
def compute_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor = None,
eps: float = 1e-6,
) -> torch.Tensor:
# shape check
true_atom_pos = true_atom_pos
pred_atom_pos = pred_atom_pos
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
del true_atom_pos
del pred_atom_pos
gc.collect()
if atom_mask is not None:
sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device))
msd = torch.mean(sq_diff)
msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps) # prevent sqrt 0
def kabsch_rotation(P, Q): def kabsch_rotation(P, Q):
""" """
Use procrustes package to calculate best rotation that minimises Use procrustes package to calculate best rotation that minimises
...@@ -1712,11 +1732,12 @@ def kabsch_rotation(P, Q): ...@@ -1712,11 +1732,12 @@ def kabsch_rotation(P, Q):
""" """
assert P.shape == torch.Size([Q.shape[0],Q.shape[1]]) assert P.shape == torch.Size([Q.shape[0],Q.shape[1]])
rotation = procrustes.rotational(P.detach().cpu().numpy(), rotation = procrustes.rotational(P.detach().cpu().float().numpy(),
Q.detach().cpu().numpy(),translate=False,scale=False) Q.detach().cpu().float().numpy(),translate=False,scale=False)
rotation = torch.tensor(rotation.t,dtype=torch.float) # rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object # Rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
rotation = torch.tensor(rotation.t,dtype=torch.float)
assert rotation.shape == torch.Size([3,3]) assert rotation.shape == torch.Size([3,3])
return rotation.to('cuda') return rotation.to(device=P.device, dtype=P.dtype)
def get_optimal_transform( def get_optimal_transform(
...@@ -1731,7 +1752,8 @@ def get_optimal_transform( ...@@ -1731,7 +1752,8 @@ def get_optimal_transform(
""" """
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape) assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3 assert src_atoms.shape[-1] == 3
assert len(mask.shape) ==1,"mask should have the shape of [num_res]" if mask is not None:
assert len(mask.shape) == 1, "mask should have the shape of [num_res]"
if torch.isnan(src_atoms).any() or torch.isinf(src_atoms).any(): if torch.isnan(src_atoms).any() or torch.isinf(src_atoms).any():
# #
# sometimes using fake test inputs generates NaN in the predicted atom positions # sometimes using fake test inputs generates NaN in the predicted atom positions
...@@ -1743,7 +1765,7 @@ def get_optimal_transform( ...@@ -1743,7 +1765,7 @@ def get_optimal_transform(
assert mask.dtype == torch.bool assert mask.dtype == torch.bool
assert mask.shape[-1] == src_atoms.shape[-2] assert mask.shape[-1] == src_atoms.shape[-2]
if mask.sum() == 0: if mask.sum() == 0:
src_atoms = torch.zeros((1, 3), device=src_atoms.device).float() src_atoms = torch.zeros((1, 3), device=src_atoms.device, dtype=src_atoms.dtype)
tgt_atoms = src_atoms tgt_atoms = src_atoms
else: else:
src_atoms = src_atoms[mask, :] src_atoms = src_atoms[mask, :]
...@@ -1754,33 +1776,12 @@ def get_optimal_transform( ...@@ -1754,33 +1776,12 @@ def get_optimal_transform(
del src_atoms,tgt_atoms, del src_atoms,tgt_atoms,
gc.collect() gc.collect()
tgt_center,src_center = tgt_center.to('cuda'),src_center.to('cuda') x = tgt_center - src_center @ r
x = tgt_center.to('cpu') - src_center.to('cpu') @ r.to('cpu')
del tgt_center,src_center,mask del tgt_center,src_center,mask
gc.collect() gc.collect()
return r, x.to('cuda') return r, x
def compute_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor = None,
eps: float = 1e-6,
) -> torch.Tensor:
# shape check
true_atom_pos = true_atom_pos.to('cuda:0')
pred_atom_pos = pred_atom_pos.to('cuda:0')
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
del true_atom_pos
del pred_atom_pos
gc.collect()
if atom_mask is not None:
sq_diff = sq_diff.to('cpu')[atom_mask.to('cpu')] # somehow it causes overflow on cuda so moved to cpu
msd = torch.mean(sq_diff)
msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps)
def get_least_asym_entity_or_longest_length(batch): def get_least_asym_entity_or_longest_length(batch):
...@@ -1834,43 +1835,35 @@ def greedy_align( ...@@ -1834,43 +1835,35 @@ def greedy_align(
true_ca_poses, true_ca_poses,
true_ca_masks, true_ca_masks,
): ):
"""
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
"""
used = [False for _ in range(len(true_ca_poses))] used = [False for _ in range(len(true_ca_poses))]
align = [] align = []
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
# skip padding
if cur_asym_id == 0:
continue
i = int(cur_asym_id - 1) i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id asym_mask = batch["asym_id"] == cur_asym_id
entity_id = batch["entity_id"][asym_mask][0]
# don't need to align
if (entity_id) == 1:
align.append((i, i))
assert used[i] == False
used[i] = True
continue
cur_entity_ids = batch["entity_id"][asym_mask][0] cur_entity_ids = batch["entity_id"][asym_mask][0]
best_rmsd = torch.inf best_rmsd = torch.inf
best_idx = None best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)] cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)] cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask] cur_pred_pos = pred_ca_pos[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask] cur_pred_mask = pred_ca_mask[asym_mask]
for next_asym_id in cur_asym_list: for next_asym_id in cur_asym_list:
if next_asym_id == 0:
continue
j = int(next_asym_id - 1) j = int(next_asym_id - 1)
if not used[j]: # possible candidate if not used[j]: # possible candidate
while best_idx is None: cropped_pos = torch.index_select(true_ca_poses[j],1,cur_residue_index)
cropped_pos = true_ca_poses[j] mask = torch.index_select(true_ca_masks[j],1,cur_residue_index)
mask = true_ca_masks[j][cur_residue_index] rmsd = compute_rmsd(
rmsd = compute_rmsd( torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0),
cropped_pos, cur_pred_pos, (cur_pred_mask.to('cuda:0') * mask.to('cuda:0')).bool() (cur_pred_mask * mask).bool()
) )
if (rmsd is not None) and (rmsd < best_rmsd):
if (rmsd is not None) and (rmsd < best_rmsd): best_rmsd = rmsd
best_rmsd = rmsd best_idx = j
best_idx = j
assert best_idx is not None assert best_idx is not None
used[best_idx] = True used[best_idx] = True
align.append((i, best_idx)) align.append((i, best_idx))
...@@ -1882,6 +1875,9 @@ def merge_labels(per_asym_residue_index, labels, align): ...@@ -1882,6 +1875,9 @@ def merge_labels(per_asym_residue_index, labels, align):
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex. per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
labels: list of original ground truth feats labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym. align: list of tuples, each entry specify the corresponding label of the asym.
modified based on UniFold:
https://github.com/dptech-corp/Uni-Fold/blob/b1c89a2cebd4e4ee4c47b4e443f92beeb9138fbb/unifold/losses/chain_align.py#L176C1-L176C1
""" """
outs = {} outs = {}
for k, v in labels[0].items(): for k, v in labels[0].items():
...@@ -1891,10 +1887,12 @@ def merge_labels(per_asym_residue_index, labels, align): ...@@ -1891,10 +1887,12 @@ def merge_labels(per_asym_residue_index, labels, align):
cur_num_res = labels[j]['aatype'].shape[-1] cur_num_res = labels[j]['aatype'].shape[-1]
# to 1-based # to 1-based
cur_residue_index = per_asym_residue_index[i + 1] cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)==0 or "template" in k: if len(v.shape)<=1 or "template" in k or "row_mask" in k :
continue continue
else: else:
dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0 dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0
if k =='all_atom_positions':
dimension_to_merge=1
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index) cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())] cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0: if len(cur_out)>0:
...@@ -2012,6 +2010,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -2012,6 +2010,7 @@ class AlphaFoldLoss(nn.Module):
cum_loss,losses = self.loss(out,batch,_return_breakdown) cum_loss,losses = self.loss(out,batch,_return_breakdown)
return cum_loss, losses return cum_loss, losses
class AlphaFoldMultimerLoss(AlphaFoldLoss): class AlphaFoldMultimerLoss(AlphaFoldLoss):
""" """
Add multi-chain permutation on top of Add multi-chain permutation on top of
...@@ -2021,7 +2020,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2021,7 +2020,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
super(AlphaFoldMultimerLoss, self).__init__(config) super(AlphaFoldMultimerLoss, self).__init__(config)
self.config = config self.config = config
def multi_chain_perm_align(self,out, batch, labels, shuffle_times=2): @staticmethod
def multi_chain_perm_align(out, batch, labels, permutate_chains=True):
""" """
A class method that first permutate chains in ground truth first A class method that first permutate chains in ground truth first
before calculating the loss. before calculating the loss.
...@@ -2031,99 +2031,95 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2031,99 +2031,95 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
""" """
assert isinstance(labels, list) assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"] ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :].float() # [bsz, nres, 3] pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].float() # [bsz, nres] pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
true_ca_poses = [ true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :].float() for l in labels l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3]) ] # list([nres, 3])
true_ca_masks = [ true_ca_masks = [
l["all_atom_mask"][..., ca_idx].float() for l in labels l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,]) ] # list([nres,])
unique_asym_ids = torch.unique(batch["asym_id"]) unique_asym_ids = torch.unique(batch["asym_id"])
per_asym_residue_index = {} per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool() asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"],asym_mask) per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"]) unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {} entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids: for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask]) cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool() asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)] anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_residue_idx) anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_residue_idx)
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]] anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
# anchor_pred_pos = anchor_pred_pos.to('cuda') anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx], 1, anchor_residue_idx)
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_residue_idx) anchor_pred_mask = pred_ca_mask[0][asym_mask[0]]
anchor_pred_mask =pred_ca_mask[0][asym_mask[0]]
# anchor_pred_mask = anchor_pred_mask.to('cuda') input_mask = (anchor_true_mask * anchor_pred_mask).bool()
input_mask = (anchor_true_mask * anchor_pred_mask).bool() r, x = get_optimal_transform(
r, x = get_optimal_transform( anchor_pred_pos, anchor_true_pos[0],
anchor_pred_pos,anchor_true_pos[0], mask=input_mask[0]
mask=input_mask[0] )
) del input_mask # just to save memory
del input_mask # just to save memory del anchor_pred_mask
del anchor_pred_mask del anchor_true_mask
del anchor_true_mask gc.collect()
gc.collect() aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses
aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms gc.collect()
align = greedy_align( align = greedy_align(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
unique_asym_ids , unique_asym_ids,
entity_2_asym_list, entity_2_asym_list,
pred_ca_pos, pred_ca_pos,
pred_ca_mask, pred_ca_mask,
aligned_true_ca_poses, aligned_true_ca_poses,
true_ca_masks, true_ca_masks,
) )
del aligned_true_ca_poses
del r,x
del pred_ca_pos,pred_ca_mask,true_ca_poses,true_ca_masks
del anchor_pred_pos,anchor_true_pos
gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
merged_labels = merge_labels(
per_asym_residue_index,
labels,
align,
)
return merged_labels
def forward(self,out,batch,_return_breakdown=False): del aligned_true_ca_poses, true_ca_masks
del r, x
del pred_ca_pos, pred_ca_mask
del anchor_pred_pos, anchor_true_pos
gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
else:
align = list(enumerate(range(len(labels))))
return align, per_asym_residue_index
def forward(self, out, features, _return_breakdown=False, permutate_chains=True):
""" """
Overwrite AlphaFoldLoss forward function so that Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation it first compute multi-chain permutation
args: args:
out: the output of model.forward() out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure batch: a pair of input features and its corresponding ground truth structure
""" """
features,labels = batch # permutate ground truth chains before calculating the loss
# first remove the recycling dimention of input features # align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features, labels,
features = tensor_tree_map(lambda t: t[..., -1], features) # permutate_chains=permutate_chains)
features['resolution'] = labels[0]['resolution'] # permutated_labels = merge_labels(per_asym_residue_index, labels, align)
# then permutate ground truth chains before calculating the loss # permutated_labels.pop('aatype')
permutated_labels = self.multi_chain_perm_align(out,features,labels) # features.update(permutated_labels)
permutated_labels.pop('aatype')
features.update(permutated_labels)
move_to_cpu = lambda t: (t.to('cpu'))
# features = tensor_tree_map(move_to_cpu,features)
if (not _return_breakdown): if (not _return_breakdown):
cum_loss = self.loss(out,features,_return_breakdown) cum_loss = self.loss(out, features, _return_breakdown)
print(f"cum_loss: {cum_loss}") print(f"cum_loss: {cum_loss}")
return cum_loss return cum_loss
else: else:
cum_loss,losses = self.loss(out,features,_return_breakdown) cum_loss, losses = self.loss(out, features, _return_breakdown)
print(f"cum_loss: {cum_loss} losses: {losses}") print(f"cum_loss: {cum_loss} losses: {losses}")
return cum_loss, losses return cum_loss, losses
\ No newline at end of file
...@@ -260,9 +260,6 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper): ...@@ -260,9 +260,6 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
def __init__(self, config): def __init__(self, config):
super(OpenFoldMultimerWrapper, self).__init__(config) super(OpenFoldMultimerWrapper, self).__init__(config)
self.config = config self.config = config
self.config.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
self.config.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
self.config.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
self.model = AlphaFold(config) self.model = AlphaFold(config)
self.loss = AlphaFoldMultimerLoss(config.loss) self.loss = AlphaFoldMultimerLoss(config.loss)
self.ema = ExponentialMovingAverage( self.ema = ExponentialMovingAverage(
...@@ -276,24 +273,27 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper): ...@@ -276,24 +273,27 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
return self.model(batch) return self.model(batch)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
all_chain_features,ground_truth = batch # Log it
if(self.ema.device != all_chain_features["aatype"].device): if(self.ema.device != batch["aatype"].device):
self.ema.to(all_chain_features["aatype"].device) self.ema.to(batch["aatype"].device)
# Run the model # Run the model
outputs = self(all_chain_features) outputs = self(batch)
# Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch)
# Compute loss # Compute loss
loss = self.loss( loss, loss_breakdown = self.loss(
outputs, (all_chain_features,ground_truth), _return_breakdown=False outputs, batch, _return_breakdown=True
) )
# Log it # Log it
self._log(loss, all_chain_features, outputs) self._log(loss_breakdown, batch, outputs)
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
all_chain_features,ground_truth = batch
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather # model.state_dict() contains references to model weights rather
...@@ -304,21 +304,22 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper): ...@@ -304,21 +304,22 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model # Run the model
outputs = self(all_chain_features) outputs = self(batch)
# Compute loss and other metrics # Compute loss and other metrics
all_chain_features["use_clamped_fape"] = 0. batch["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss( _, loss_breakdown = self.loss(
outputs, all_chain_features, _return_breakdown=True outputs, batch, _return_breakdown=True
) )
self._log(loss_breakdown, all_chain_features, outputs, train=False) self._log(loss_breakdown, batch, outputs, train=False)
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
# Restore the model weights to normal # Restore the model weights to normal
self.model.load_state_dict(self.cached_weights) self.model.load_state_dict(self.cached_weights)
self.cached_weights = None self.cached_weights = None
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
seed_everything(args.seed) seed_everything(args.seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment