import torch import torch.nn as nn from .common import ( residual, ) from .featurization import ( pseudo_beta_fn, build_extra_msa_feat, build_template_angle_feat, build_template_pair_feat, build_template_pair_feat_v2, atom14_to_atom37, ) from .embedders import ( InputEmbedder, RecyclingEmbedder, TemplateAngleEmbedder, TemplatePairEmbedder, ExtraMSAEmbedder, ) from .evoformer import EvoformerStack, ExtraMSAStack from .auxillary_heads import AuxiliaryHeads from unifold.data import residue_constants from .structure_module import StructureModule from .template import ( TemplatePairStack, TemplatePointwiseAttention, TemplateProjection, ) from unicore.utils import ( tensor_tree_map, ) from .attentions import ( gen_msa_attn_mask, gen_tri_attn_mask, ) class AlphaFold(nn.Module): def __init__(self, config): super(AlphaFold, self).__init__() self.globals = config.globals config = config.model template_config = config.template extra_msa_config = config.extra_msa self.input_embedder = InputEmbedder( **config["input_embedder"], use_chain_relative=config.is_multimer, ) self.recycling_embedder = RecyclingEmbedder( **config["recycling_embedder"], ) if config.template.enabled: self.template_angle_embedder = TemplateAngleEmbedder( **template_config["template_angle_embedder"], ) self.template_pair_embedder = TemplatePairEmbedder( **template_config["template_pair_embedder"], ) self.template_pair_stack = TemplatePairStack( **template_config["template_pair_stack"], ) else: self.template_pair_stack = None self.enable_template_pointwise_attention = template_config[ "template_pointwise_attention" ].enabled if self.enable_template_pointwise_attention: self.template_pointwise_att = TemplatePointwiseAttention( **template_config["template_pointwise_attention"], ) else: self.template_proj = TemplateProjection( **template_config["template_pointwise_attention"], ) self.extra_msa_embedder = ExtraMSAEmbedder( **extra_msa_config["extra_msa_embedder"], ) self.extra_msa_stack = ExtraMSAStack( **extra_msa_config["extra_msa_stack"], ) self.evoformer = EvoformerStack( **config["evoformer_stack"], ) self.structure_module = StructureModule( **config["structure_module"], ) self.aux_heads = AuxiliaryHeads( config["heads"], ) self.config = config self.dtype = torch.float self.inf = self.globals.inf if self.globals.alphafold_original_mode: self.alphafold_original_mode() def __make_input_float__(self): self.input_embedder = self.input_embedder.float() self.recycling_embedder = self.recycling_embedder.float() def half(self): super().half() if (not getattr(self, "inference", False)): self.__make_input_float__() self.dtype = torch.half return self def bfloat16(self): super().bfloat16() if (not getattr(self, "inference", False)): self.__make_input_float__() self.dtype = torch.bfloat16 return self def alphafold_original_mode(self): def set_alphafold_original_mode(module): if hasattr(module, "apply_alphafold_original_mode"): module.apply_alphafold_original_mode() if hasattr(module, "act"): module.act = nn.ReLU() self.apply(set_alphafold_original_mode) def inference_mode(self): def set_inference_mode(module): setattr(module, "inference", True) self.apply(set_inference_mode) def __convert_input_dtype__(self, batch): for key in batch: # only convert features with mask if batch[key].dtype != self.dtype and "mask" in key: batch[key] = batch[key].type(self.dtype) return batch def embed_templates_pair_core(self, batch, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim, multichain_mask_2d): if self.config.template.template_pair_embedder.v2_feature: t = build_template_pair_feat_v2( batch, inf=self.config.template.inf, eps=self.config.template.eps, multichain_mask_2d=multichain_mask_2d, **self.config.template.distogram, ) num_template = t[0].shape[-4] single_templates = [ self.template_pair_embedder([x[..., ti, :, :, :] for x in t], z) for ti in range(num_template) ] else: t = build_template_pair_feat( batch, inf=self.config.template.inf, eps=self.config.template.eps, **self.config.template.distogram, ) single_templates = [ self.template_pair_embedder(x, z) for x in torch.unbind(t, dim=templ_dim) ] t = self.template_pair_stack( single_templates, pair_mask, tri_start_attn_mask=tri_start_attn_mask, tri_end_attn_mask=tri_end_attn_mask, templ_dim=templ_dim, chunk_size=self.globals.chunk_size, block_size=self.globals.block_size, return_mean=not self.enable_template_pointwise_attention, ) return t def embed_templates_pair( self, batch, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim ): if self.config.template.template_pair_embedder.v2_feature and "asym_id" in batch: multichain_mask_2d = ( batch["asym_id"][..., :, None] == batch["asym_id"][..., None, :] ) multichain_mask_2d = multichain_mask_2d.unsqueeze(0) else: multichain_mask_2d = None if self.training or self.enable_template_pointwise_attention: t = self.embed_templates_pair_core(batch, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim, multichain_mask_2d) if self.enable_template_pointwise_attention: t = self.template_pointwise_att( t, z, template_mask=batch["template_mask"], chunk_size=self.globals.chunk_size, ) t_mask = torch.sum(batch["template_mask"], dim=-1, keepdims=True) > 0 t_mask = t_mask[..., None, None].type(t.dtype) t *= t_mask else: t = self.template_proj(t, z) else: template_aatype_shape = batch["template_aatype"].shape # template_aatype is either [n_template, n_res] or [1, n_template_, n_res] batch_templ_dim = 1 if len(template_aatype_shape) == 3 else 0 n_templ = batch["template_aatype"].shape[batch_templ_dim] if n_templ <= 0: t = None else: template_batch = { k: v for k, v in batch.items() if k.startswith("template_") } def embed_one_template(i): def slice_template_tensor(t): s = [slice(None) for _ in t.shape] s[batch_templ_dim] = slice(i, i + 1) return t[s] template_feats = tensor_tree_map( slice_template_tensor, template_batch, ) t = self.embed_templates_pair_core(template_feats, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim, multichain_mask_2d) return t t = embed_one_template(0) # iterate templates one by one for i in range(1, n_templ): t += embed_one_template(i) t /= n_templ t = self.template_proj(t, z) return t def embed_templates_angle(self, batch): template_angle_feat, template_angle_mask = build_template_angle_feat( batch, v2_feature=self.config.template.template_pair_embedder.v2_feature ) t = self.template_angle_embedder(template_angle_feat) return t, template_angle_mask def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev): batch_dims = feats["target_feat"].shape[:-2] n = feats["target_feat"].shape[-2] seq_mask = feats["seq_mask"] pair_mask = seq_mask[..., None] * seq_mask[..., None, :] msa_mask = feats["msa_mask"] m, z = self.input_embedder( feats["target_feat"], feats["msa_feat"], ) if m_1_prev is None: m_1_prev = m.new_zeros( (*batch_dims, n, self.config.input_embedder.d_msa), requires_grad=False, ) if z_prev is None: z_prev = z.new_zeros( (*batch_dims, n, n, self.config.input_embedder.d_pair), requires_grad=False, ) if x_prev is None: x_prev = z.new_zeros( (*batch_dims, n, residue_constants.atom_type_num, 3), requires_grad=False, ) x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None) z += self.recycling_embedder.recyle_pos(x_prev) m_1_prev_emb, z_prev_emb = self.recycling_embedder( m_1_prev, z_prev, ) m[..., 0, :, :] += m_1_prev_emb z += z_prev_emb z += self.input_embedder.relpos_emb( feats["residue_index"].long(), feats.get("sym_id", None), feats.get("asym_id", None), feats.get("entity_id", None), feats.get("num_sym", None), ) m = m.type(self.dtype) z = z.type(self.dtype) tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask(pair_mask, self.inf) if self.config.template.enabled: template_mask = feats["template_mask"] if torch.any(template_mask): z = residual( z, self.embed_templates_pair( feats, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim=-4, ), self.training, ) if self.config.extra_msa.enabled: a = self.extra_msa_embedder(build_extra_msa_feat(feats)) extra_msa_row_mask = gen_msa_attn_mask( feats["extra_msa_mask"], inf=self.inf, gen_col_mask=False, ) z = self.extra_msa_stack( a, z, msa_mask=feats["extra_msa_mask"], chunk_size=self.globals.chunk_size, block_size=self.globals.block_size, pair_mask=pair_mask, msa_row_attn_mask=extra_msa_row_mask, msa_col_attn_mask=None, tri_start_attn_mask=tri_start_attn_mask, tri_end_attn_mask=tri_end_attn_mask, ) if self.config.template.embed_angles: template_1d_feat, template_1d_mask = self.embed_templates_angle(feats) m = torch.cat([m, template_1d_feat], dim=-3) msa_mask = torch.cat([feats["msa_mask"], template_1d_mask], dim=-2) msa_row_mask, msa_col_mask = gen_msa_attn_mask( msa_mask, inf=self.inf, ) m, z, s = self.evoformer( m, z, msa_mask=msa_mask, pair_mask=pair_mask, msa_row_attn_mask=msa_row_mask, msa_col_attn_mask=msa_col_mask, tri_start_attn_mask=tri_start_attn_mask, tri_end_attn_mask=tri_end_attn_mask, chunk_size=self.globals.chunk_size, block_size=self.globals.block_size, ) return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb def iteration_evoformer_structure_module( self, batch, m_1_prev, z_prev, x_prev, cycle_no, num_recycling, num_ensembles=1 ): z, s = 0, 0 n_seq = batch["msa_feat"].shape[-3] assert num_ensembles >= 1 for ensemble_no in range(num_ensembles): idx = cycle_no * num_ensembles + ensemble_no fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...] feats = tensor_tree_map(fetch_cur_batch, batch) m, z0, s0, msa_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer( feats, m_1_prev, z_prev, x_prev ) z += z0 s += s0 del z0, s0 if num_ensembles > 1: z /= float(num_ensembles) s /= float(num_ensembles) outputs = {} outputs["msa"] = m[..., :n_seq, :, :] outputs["pair"] = z outputs["single"] = s # norm loss if (not getattr(self, "inference", False)) and num_recycling == (cycle_no + 1): delta_msa = m delta_msa[..., 0, :, :] = delta_msa[..., 0, :, :] - m_1_prev_emb.detach() delta_pair = z - z_prev_emb.detach() outputs["delta_msa"] = delta_msa outputs["delta_pair"] = delta_pair outputs["msa_norm_mask"] = msa_mask outputs["sm"] = self.structure_module( s, z, feats["aatype"], mask=feats["seq_mask"], ) outputs["final_atom_positions"] = atom14_to_atom37( outputs["sm"]["positions"], feats ) outputs["final_atom_mask"] = feats["atom37_atom_exists"] outputs["pred_frame_tensor"] = outputs["sm"]["frames"][-1] # use float32 for numerical stability if (not getattr(self, "inference", False)): m_1_prev = m[..., 0, :, :].float() z_prev = z.float() x_prev = outputs["final_atom_positions"].float() else: m_1_prev = m[..., 0, :, :] z_prev = z x_prev = outputs["final_atom_positions"] return outputs, m_1_prev, z_prev, x_prev def forward(self, batch): m_1_prev = batch.get("m_1_prev", None) z_prev = batch.get("z_prev", None) x_prev = batch.get("x_prev", None) is_grad_enabled = torch.is_grad_enabled() num_iters = int(batch["num_recycling_iters"]) + 1 num_ensembles = int(batch["msa_mask"].shape[0]) // num_iters if self.training: # don't use ensemble during training assert num_ensembles == 1 # convert dtypes in batch batch = self.__convert_input_dtype__(batch) for cycle_no in range(num_iters): is_final_iter = cycle_no == (num_iters - 1) with torch.set_grad_enabled(is_grad_enabled and is_final_iter): ( outputs, m_1_prev, z_prev, x_prev, ) = self.iteration_evoformer_structure_module( batch, m_1_prev, z_prev, x_prev, cycle_no=cycle_no, num_recycling=num_iters, num_ensembles=num_ensembles, ) if not is_final_iter: del outputs if "asym_id" in batch: outputs["asym_id"] = batch["asym_id"][0, ...] outputs.update(self.aux_heads(outputs)) return outputs