Unverified Commit da5fe1a6 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

add support on habana platform (#131)



* add habana

* add mask

* fix mask in outer_product_mean

* add dap

* add hmp

* merge training code

* add chunk for inference

* fix extra-msa stack for training

* support ddp in training

* fix inference bugs

* code refactoring for habana

* support hmp training

* enable all inference and train on Gaudi/Gaudi2 with optimized perf with latest base (#139)

* enable all inference and train on Gaudi/Gaudi2 with optimized perf

* refine code to adapt new base

* refine code to fix issues in code review
Co-authored-by: default avatarhabanachina <habanachina@habana.ai>
Co-authored-by: default avatarLeo Zhao <48052473+LeoZhao-Habana@users.noreply.github.com>
Co-authored-by: default avatarhabanachina <habanachina@habana.ai>
parent e9db72d6
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from fastfold.habana.fastnn import EvoformerStack, ExtraMSAStack
#from fastfold.model.fastnn.embedders import TemplateEmbedder
#from fastfold.model.fastnn.embedders_multimer import TemplateEmbedderMultimer
#from fastfold.model.fastnn.ops import RecyclingEmbedder, InputEmbedder
def copy_layernorm(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
model_fast.bias.copy_(model_ori.bias)
def copy_linear(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
if model_fast.use_bias:
model_fast.bias.copy_(model_ori.bias)
def copy_native_linear(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
try:
model_fast.bias.copy_(model_ori.bias)
except:
pass
def copy_kv_linear(model_fast, ori_k, ori_v):
model_fast.weight.copy_(torch.cat((ori_k.weight, ori_v.weight), dim=0))
def copy_qkv_linear(model_fast, ori_q, ori_k, ori_v):
model_fast.weight.copy_(torch.cat((ori_q.weight, ori_k.weight, ori_v.weight), dim=0))
def copy_attention(model_fast, model_ori):
copy_qkv_linear(model_fast.to_qkv, model_ori.linear_q, model_ori.linear_k, model_ori.linear_v)
copy_linear(model_fast.gating_linear, model_ori.linear_g)
copy_linear(model_fast.o_linear, model_ori.linear_o)
try:
model_fast.gating_bias.copy_(model_ori.linear_g.bias)
except:
print("no gating_bias need copy")
def copy_left_right(model_fast, ori_left, ori_right):
model_fast.weight.copy_(torch.cat((ori_left.weight, ori_right.weight), dim=0))
model_fast.bias.copy_(torch.cat((ori_left.bias, ori_right.bias), dim=0))
def copy_transition(model_fast, model_ori):
copy_layernorm(model_fast.norm, model_ori.layer_norm)
copy_linear(model_fast.linear1, model_ori.linear_1)
copy_linear(model_fast.linear2, model_ori.linear_2)
def copy_triangle(model_fast, model_ori):
copy_layernorm(model_fast.layernorm1, model_ori.layer_norm_in)
copy_layernorm(model_fast.layernorm2, model_ori.layer_norm_out)
copy_linear(model_fast.output_gate, model_ori.linear_g)
copy_linear(model_fast.output_projection, model_ori.linear_z)
model_fast.output_bias.copy_(model_ori.linear_z.bias)
copy_linear(model_fast.left_projection, model_ori.linear_a_p)
copy_linear(model_fast.right_projection, model_ori.linear_b_p)
copy_linear(model_fast.left_gate, model_ori.linear_a_g)
copy_linear(model_fast.right_gate, model_ori.linear_b_g)
def copy_triangle_att(model_fast, model_ori):
copy_layernorm(model_fast.layernorm1, model_ori.layer_norm)
model_fast.linear_b_weights = model_ori.linear.weight
copy_attention(model_fast.attention, model_ori.mha)
model_fast.out_bias.copy_(model_ori.mha.linear_o.bias)
def copy_native_att(model_fast, model_ori):
copy_native_linear(model_fast.linear_q, model_ori.linear_q)
copy_native_linear(model_fast.linear_k, model_ori.linear_k)
copy_native_linear(model_fast.linear_v, model_ori.linear_v)
copy_native_linear(model_fast.linear_o, model_ori.linear_o)
if model_ori.gating:
copy_native_linear(model_fast.linear_g, model_ori.linear_g)
def copy_evoformer_para(block_fast, block_ori):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm(block_fast.msa.MSARowAttentionWithPairBias.layernormM,
block_ori.msa_att_row.layer_norm_m)
copy_layernorm(block_fast.msa.MSARowAttentionWithPairBias.layernormZ,
block_ori.msa_att_row.layer_norm_z)
copy_attention(block_fast.msa.MSARowAttentionWithPairBias.attention, block_ori.msa_att_row.mha)
block_fast.msa.MSARowAttentionWithPairBias.linear_b_weights.copy_(
block_ori.msa_att_row.linear_z.weight)
block_fast.msa.MSARowAttentionWithPairBias.out_bias.copy_(
block_ori.msa_att_row.mha.linear_o.bias)
# MSAColumnAttention
copy_layernorm(block_fast.msa.MSAColumnAttention.layernormM,
block_ori.msa_att_col._msa_att.layer_norm_m)
copy_attention(block_fast.msa.MSAColumnAttention.attention, block_ori.msa_att_col._msa_att.mha)
# MSATransition
copy_transition(block_fast.msa.MSATransition, block_ori.core.msa_transition)
# communication
copy_layernorm(block_fast.communication.layernormM,
block_ori.core.outer_product_mean.layer_norm)
copy_linear(block_fast.communication.linear_a, block_ori.core.outer_product_mean.linear_1)
copy_linear(block_fast.communication.linear_b, block_ori.core.outer_product_mean.linear_2)
copy_linear(block_fast.communication.o_linear, block_ori.core.outer_product_mean.linear_out)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle(block_fast.pair.TriangleMultiplicationOutgoing, block_ori.core.tri_mul_out)
# TriangleMultiplicationIncoming
copy_triangle(block_fast.pair.TriangleMultiplicationIncoming, block_ori.core.tri_mul_in)
# TriangleAttentionStartingNode
copy_triangle_att(block_fast.pair.TriangleAttentionStartingNode, block_ori.core.tri_att_start)
copy_triangle_att(block_fast.pair.TriangleAttentionEndingNode, block_ori.core.tri_att_end)
copy_transition(block_fast.pair.PairTransition, block_ori.core.pair_transition)
def copy_global_attention(model_fast, model_ori):
copy_linear(model_fast.to_q, model_ori.linear_q)
copy_kv_linear(model_fast.to_kv, model_ori.linear_k, model_ori.linear_v)
copy_linear(model_fast.gating_linear, model_ori.linear_g)
copy_linear(model_fast.o_linear, model_ori.linear_o)
try:
model_fast.gating_bias.copy_(model_ori.linear_g.bias)
except:
print("no gating_bias need copy")
def copy_extra_msa_para(block_fast, block_ori):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm(
block_fast.msa_stack.MSARowAttentionWithPairBias.layernormM,
block_ori.msa_att_row.layer_norm_m,
)
copy_layernorm(
block_fast.msa_stack.MSARowAttentionWithPairBias.layernormZ,
block_ori.msa_att_row.layer_norm_z,
)
copy_attention(
block_fast.msa_stack.MSARowAttentionWithPairBias.attention,
block_ori.msa_att_row.mha,
)
block_fast.msa_stack.MSARowAttentionWithPairBias.linear_b_weights.copy_(
block_ori.msa_att_row.linear_z.weight)
block_fast.msa_stack.MSARowAttentionWithPairBias.out_bias.copy_(
block_ori.msa_att_row.mha.linear_o.bias)
# MSAColumnAttention
copy_layernorm(
block_fast.msa_stack.MSAColumnAttention.layernormM,
block_ori.msa_att_col.layer_norm_m,
)
copy_global_attention(
block_fast.msa_stack.MSAColumnAttention.global_attention,
block_ori.msa_att_col.global_attention,
)
# MSATransition
copy_transition(block_fast.msa_stack.MSATransition, block_ori.core.msa_transition)
# communication
comm_model = (
block_ori.core.
outer_product_mean # if not block_ori.is_multimer else block_ori.outer_product_mean
)
copy_layernorm(block_fast.communication.layernormM, comm_model.layer_norm)
copy_linear(block_fast.communication.linear_a, comm_model.linear_1)
copy_linear(block_fast.communication.linear_b, comm_model.linear_2)
copy_linear(block_fast.communication.o_linear, comm_model.linear_out)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle(block_fast.pair_stack.TriangleMultiplicationOutgoing, block_ori.core.tri_mul_out)
# TriangleMultiplicationIncoming
copy_triangle(block_fast.pair_stack.TriangleMultiplicationIncoming, block_ori.core.tri_mul_in)
# TriangleAttentionStartingNode
copy_triangle_att(
block_fast.pair_stack.TriangleAttentionStartingNode,
block_ori.core.tri_att_start,
)
copy_triangle_att(block_fast.pair_stack.TriangleAttentionEndingNode, block_ori.core.tri_att_end)
copy_transition(block_fast.pair_stack.PairTransition, block_ori.core.pair_transition)
def copy_template_pair_stack_para(block_fast, block_ori):
# TriangleMultiplicationOutgoing
copy_triangle(block_fast.TriangleMultiplicationOutgoing, block_ori.tri_mul_out)
# TriangleMultiplicationIncoming
copy_triangle(block_fast.TriangleMultiplicationIncoming, block_ori.tri_mul_in)
# TriangleAttentionStartingNode
copy_triangle_att(block_fast.TriangleAttentionStartingNode, block_ori.tri_att_start)
copy_triangle_att(block_fast.TriangleAttentionEndingNode, block_ori.tri_att_end)
copy_transition(block_fast.PairTransition, block_ori.pair_transition)
def copy_template_pair_block_para(fast_module, target_module):
with torch.no_grad():
for ori_block, fast_block in zip(target_module.blocks, fast_module.blocks):
copy_template_pair_stack_para(fast_block, ori_block)
if ori_block.training == False:
fast_block.eval()
def copy_template_para(block_fast, block_ori):
# TemplateAngleEmbedder
copy_linear(block_fast.template_angle_embedder.linear_1,
block_ori.template_angle_embedder.linear_1)
copy_linear(block_fast.template_angle_embedder.linear_2,
block_ori.template_angle_embedder.linear_2)
# TemplatePairEmbedder
copy_linear(block_fast.template_pair_embedder.linear, block_ori.template_pair_embedder.linear)
# TemplatePairStack
copy_template_pair_block_para(block_fast.template_pair_stack, block_ori.template_pair_stack)
copy_layernorm(block_fast.template_pair_stack.layer_norm,
block_ori.template_pair_stack.layer_norm)
# TemplatePointwiseAttention
copy_native_att(block_fast.template_pointwise_att.mha, block_ori.template_pointwise_att.mha)
def copy_template_multimer_para(block_fast, block_ori):
# TemplatePairEmbedderMultimer
copy_linear(block_fast.template_pair_embedder.dgram_linear,
block_ori.template_pair_embedder.dgram_linear)
copy_linear(block_fast.template_pair_embedder.aatype_linear_1,
block_ori.template_pair_embedder.aatype_linear_1)
copy_linear(block_fast.template_pair_embedder.aatype_linear_2,
block_ori.template_pair_embedder.aatype_linear_2)
copy_layernorm(block_fast.template_pair_embedder.query_embedding_layer_norm,
block_ori.template_pair_embedder.query_embedding_layer_norm)
copy_linear(block_fast.template_pair_embedder.query_embedding_linear,
block_ori.template_pair_embedder.query_embedding_linear)
copy_linear(block_fast.template_pair_embedder.pseudo_beta_mask_linear,
block_ori.template_pair_embedder.pseudo_beta_mask_linear)
copy_linear(block_fast.template_pair_embedder.x_linear,
block_ori.template_pair_embedder.x_linear)
copy_linear(block_fast.template_pair_embedder.y_linear,
block_ori.template_pair_embedder.y_linear)
copy_linear(block_fast.template_pair_embedder.z_linear,
block_ori.template_pair_embedder.z_linear)
copy_linear(block_fast.template_pair_embedder.backbone_mask_linear,
block_ori.template_pair_embedder.backbone_mask_linear)
# TemplateSingleEmbedderMultimer
copy_linear(block_fast.template_single_embedder.template_single_embedder,
block_ori.template_single_embedder.template_single_embedder)
copy_linear(block_fast.template_single_embedder.template_projector,
block_ori.template_single_embedder.template_projector)
# TemplatePairStack
copy_template_pair_block_para(block_fast.template_pair_stack, block_ori.template_pair_stack)
copy_layernorm(block_fast.template_pair_stack.layer_norm,
block_ori.template_pair_stack.layer_norm)
# linear_t
copy_linear(block_fast.linear_t, block_ori.linear_t)
def inject_evoformer(model):
with torch.no_grad():
target_module = model.evoformer
fast_module = EvoformerStack(
c_m=target_module.blocks[0].msa_att_row.c_in,
c_z=target_module.blocks[0].msa_att_row.c_z,
c_s=target_module.linear.out_features,
no_blocks=len(target_module.blocks),
blocks_per_ckpt=target_module.blocks_per_ckpt,
clear_cache_between_blocks=target_module.clear_cache_between_blocks,
is_multimer=target_module.blocks[0].is_multimer,
)
for target_block, fast_block in zip(target_module.blocks, fast_module.blocks):
copy_evoformer_para(fast_block, target_block)
if target_module.training == False:
fast_module.eval()
copy_linear(fast_module.linear, target_module.linear)
model.evoformer = fast_module
def inject_extramsa(model):
with torch.no_grad():
target_module = model.extra_msa_stack
fast_module = ExtraMSAStack(
c_m=target_module.blocks[0].msa_att_row.c_in,
c_z=target_module.blocks[0].msa_att_row.c_z,
no_blocks=len(target_module.blocks),
blocks_per_ckpt=1,
clear_cache_between_blocks=target_module.clear_cache_between_blocks,
is_multimer=target_module.blocks[0].is_multimer,
)
for target_block, fast_block in zip(target_module.blocks, fast_module.blocks):
copy_extra_msa_para(fast_block, target_block)
if target_module.training == False:
fast_module.eval()
model.extra_msa_stack = fast_module
def inject_template(model):
with torch.no_grad():
if model.evoformer.blocks[0].is_multimer:
target_module = model.template_embedder
fast_module = TemplateEmbedderMultimer(config=model.template_embedder.config)
copy_template_multimer_para(fast_module, target_module)
if target_module.training == False:
fast_module.eval()
model.template_embedder = fast_module
else:
target_module = model.template_embedder
fast_module = TemplateEmbedder(config=model.template_embedder.config)
copy_template_para(fast_module, target_module)
if target_module.training == False:
fast_module.eval()
model.template_embedder = fast_module
def inject_embedder(model):
if model.evoformer.blocks[0].is_multimer:
return
# recycle embedder
with torch.no_grad():
target_module = model.recycling_embedder
fast_module = RecyclingEmbedder(c_m=target_module.c_m,
c_z=target_module.c_z,
min_bin=target_module.min_bin,
max_bin=target_module.max_bin,
no_bins=target_module.no_bins,
inf=target_module.inf)
copy_native_linear(fast_module.linear, target_module.linear)
copy_layernorm(fast_module.layer_norm_m, target_module.layer_norm_m)
copy_layernorm(fast_module.layer_norm_z, target_module.layer_norm_z)
if target_module.training == False:
fast_module.eval()
model.recycling_embedder = fast_module
# input embedder
with torch.no_grad():
target_module = model.input_embedder
fast_module = InputEmbedder(
tf_dim=target_module.tf_dim,
msa_dim=target_module.msa_dim,
c_z=target_module.c_z,
c_m=target_module.c_m,
relpos_k=target_module.relpos_k,
)
copy_linear(fast_module.linear_tf_z_i, target_module.linear_tf_z_i)
copy_linear(fast_module.linear_tf_z_j, target_module.linear_tf_z_j)
copy_linear(fast_module.linear_tf_m, target_module.linear_tf_m)
copy_linear(fast_module.linear_msa_m, target_module.linear_msa_m)
copy_linear(fast_module.linear_relpos, target_module.linear_relpos)
if target_module.training == False:
fast_module.eval()
model.input_embedder = fast_module
def inject_habana(model):
inject_evoformer(model)
inject_extramsa(model)
#inject_template(model)
#inject_embedder(model)
return model
......@@ -41,6 +41,7 @@ from fastfold.utils.tensor_utils import (
tensor_tree_map,
)
import fastfold.habana as habana
class AlphaFold(nn.Module):
"""
......@@ -173,6 +174,9 @@ class AlphaFold(nn.Module):
# Primary output dictionary
outputs = {}
if habana.is_habana():
from habana.hpuhelper import hpu_perf
perf = hpu_perf("iteration", sync=False)
dtype = next(self.parameters()).dtype
for k in feats:
if(feats[k].dtype == torch.float32):
......@@ -190,7 +194,8 @@ class AlphaFold(nn.Module):
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
# Initialize the MSA and pair representations
if habana.is_habana():
perf.checkahead("1: Initialize the MSA and pair representations")
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
......@@ -252,7 +257,8 @@ class AlphaFold(nn.Module):
# Possibly prevents memory fragmentation
del m_1_prev, z_prev, x_prev
# Embed the templates + merge with MSA/pair embeddings
if habana.is_habana():
perf.checkahead("2: Embed the templates + merge with MSA/pair embeddings")
if self.config.template.enabled:
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
......@@ -320,7 +326,8 @@ class AlphaFold(nn.Module):
)
del template_feats, template_embeds
# Embed extra MSA features + merge with pairwise embeddings
if habana.is_habana():
perf.checkahead("3: Embed extra MSA features + merge with pairwise embeddings")
if self.config.extra_msa.enabled:
if(self.globals.is_multimer):
extra_msa_fn = data_transforms_multimer.build_extra_msa_feat
......@@ -354,7 +361,8 @@ class AlphaFold(nn.Module):
)[0]
del extra_msa_feat, extra_msa_fn
# Run MSA + pair embeddings through the trunk of the network
if habana.is_habana():
perf.checkahead("4: Run MSA + pair embeddings through the trunk of the network")
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
......@@ -385,7 +393,8 @@ class AlphaFold(nn.Module):
outputs["pair"] = z
outputs["single"] = s
# Predict 3D structure
if habana.is_habana():
perf.checkahead("5: Predict 3D structure")
outputs["sm"] = self.structure_module(
s,
z,
......@@ -409,6 +418,9 @@ class AlphaFold(nn.Module):
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
if habana.is_habana():
perf.checkahead("6: stop iteration")
return outputs, m_1_prev, z_prev, x_prev
def _disable_activation_checkpointing(self):
......@@ -490,6 +502,9 @@ class AlphaFold(nn.Module):
# Main recycling loop
num_iters = batch["aatype"].shape[-1]
for cycle_no in range(num_iters):
if habana.is_habana():
from habana.hpuhelper import hpu_perf
perf = hpu_perf(f"cycle {cycle_no+1}/{num_iters}")
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
......@@ -511,7 +526,8 @@ class AlphaFold(nn.Module):
x_prev,
_recycle=(num_iters > 1)
)
if habana.is_habana():
perf.checknow("cycle finish")
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
......
......@@ -22,6 +22,7 @@ import torch.nn as nn
from torch.distributions.bernoulli import Bernoulli
from typing import Dict, Optional, Tuple
import fastfold.habana as habana
from fastfold.common import residue_constants
from fastfold.utils import feats
from fastfold.utils.rigid_utils import Rotation, Rigid
......@@ -486,8 +487,12 @@ def lddt_loss(
score = score.detach()
bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1))
if habana.is_habana():
bin_index = torch.floor(score * no_bins)
bin_index = torch.clamp(bin_index, max=(no_bins - 1)).float().long()
else:
bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1))
lddt_ca_one_hot = torch.nn.functional.one_hot(
bin_index, num_classes=no_bins
)
......@@ -931,16 +936,23 @@ def between_residue_clash_loss(
)
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot = torch.nn.functional.one_hot(
residue_index.new_tensor(2), num_classes=14
)
if habana.is_habana():
c_one_hot = torch.tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device=residue_index.device)
else:
c_one_hot = torch.nn.functional.one_hot(
residue_index.new_tensor(2), num_classes=14
)
c_one_hot = c_one_hot.reshape(
*((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape
)
c_one_hot = c_one_hot.type(fp_type)
n_one_hot = torch.nn.functional.one_hot(
residue_index.new_tensor(0), num_classes=14
)
if habana.is_habana():
n_one_hot = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device=residue_index.device)
else:
n_one_hot = torch.nn.functional.one_hot(
residue_index.new_tensor(0), num_classes=14
)
n_one_hot = n_one_hot.reshape(
*((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape
)
......@@ -963,7 +975,11 @@ def between_residue_clash_loss(
cys_sg_idx = cys_sg_idx.reshape(
*((1,) * len(residue_index.shape[:-1])), 1
).squeeze(-1)
cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14)
if habana.is_habana():
cys_sg_one_hot = torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], device=n_one_hot.device)
else:
cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14)
disulfide_bonds = (
cys_sg_one_hot[..., None, None, :, None]
* cys_sg_one_hot[..., None, None, None, :]
......@@ -1596,12 +1612,16 @@ class AlphaFoldLoss(nn.Module):
out["sm"]["unnormalized_angles"],
**{**batch, **self.config.supervised_chi},
),
# Habana: TODO comment out below part to WA error in HMP
"violation": lambda: violation_loss(
out["violation"],
**batch,
),
}
if habana.is_habana():
del loss_fns["violation"]
if(self.config.tm.enabled):
loss_fns["tm"] = lambda: tm_loss(
logits=out["tm_logits"],
......@@ -1632,4 +1652,4 @@ class AlphaFoldLoss(nn.Module):
if(not _return_breakdown):
return cum_loss
return cum_loss, losses
\ No newline at end of file
return cum_loss, losses
......@@ -360,7 +360,7 @@ class TemplateAngleEmbedder(nn.Module):
Returns:
x: [*, N_templ, N_res, C_out] embedding
"""
x = self.linear_1(x)
x = self.linear_1(x.to(dtype=self.linear_1.weight.dtype))
x = self.relu(x)
x = self.linear_2(x)
......@@ -446,6 +446,6 @@ class ExtraMSAEmbedder(nn.Module):
Returns:
[*, N_extra_seq, N_res, C_out] embedding
"""
x = self.linear(x)
x = self.linear(x.to(dtype=self.linear.weight.dtype))
return x
\ No newline at end of file
return x
......@@ -29,6 +29,7 @@ from fastfold.utils.tensor_utils import (
_chunk_slice,
)
import fastfold.habana as habana
def _prod(nums):
out = 1
......@@ -214,13 +215,20 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
# [*, H, Q, K]
a = torch.matmul(query, key)
if habana.is_habana():
from fastfold.habana.fastnn.custom_op import fused_softmax, fused_softmax_bias
if len(biases) == 1:
a = fused_softmax(a, biases[0], -1)
else:
a = fused_softmax_bias(a, biases[0], biases[1], -1)
else:
for b in biases:
a += b
for b in biases:
a += b
a = softmax(a, -1)
a = softmax(a, -1)
# [*, H, Q, C_hidden]
a = a.to(dtype=value.dtype)
a = torch.matmul(a, value)
# [*, Q, H, C_hidden]
......@@ -463,10 +471,15 @@ class GlobalAttention(nn.Module):
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = softmax(a)
if habana.is_habana():
from fastfold.habana.fastnn.custom_op import fused_softmax, fused_softmax_bias
a = fused_softmax(a, bias, -1)
else:
a += bias
a = softmax(a)
# [*, N_res, H, C_hidden]
a = a.to(dtype=v.dtype)
o = torch.matmul(
a,
v,
......
......@@ -39,6 +39,7 @@ from fastfold.utils.tensor_utils import (
flatten_final_dims,
)
import fastfold.habana as habana
class AngleResnetBlock(nn.Module):
def __init__(self, c_hidden):
......@@ -397,10 +398,20 @@ class InvariantPointAttention(nn.Module):
pt_att = sum([c**2 for c in pt_att])
else:
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att**2
# [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1))
######################################
q_pts_t0 = q_pts.unsqueeze(-4)
q_shape = q_pts_t0.shape
q_pts_t0 = q_pts_t0.reshape([q_shape[0], q_shape[1], -1])
k_pts_t0 = k_pts.unsqueeze(-5)
k_shape = k_pts_t0.shape
k_pts_t0 = k_pts_t0.reshape([k_shape[0], k_shape[1], -1])
q_k = q_pts_t0 - k_pts_t0
q_k = q_k ** 2
q_k_shape = q_k.shape
pt_att = q_k.reshape(q_k_shape[:2] + q_shape[-3:])
#####################################
pt_att = pt_att.permute(0, 4, 1, 2, 3)
pt_att = torch.sum(pt_att, 1)
head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
......@@ -408,7 +419,12 @@ class InvariantPointAttention(nn.Module):
head_weights = head_weights * math.sqrt(
1.0 / (3 * (self.no_qk_points * 9.0 / 2))
)
pt_att = pt_att * head_weights
##############################
pt_att_t0 = pt_att.permute(0, 3, 1, 2)
head_weights_t0 = head_weights.permute(0, 3, 1, 2)
pt_att_o = pt_att_t0 * head_weights_t0
pt_att = pt_att_o.permute(0, 2,3, 1)
##############################
# [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
......@@ -448,13 +464,14 @@ class InvariantPointAttention(nn.Module):
o_pt_norm = o_pt.norm(self.eps)
else:
# [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
###################################
a1 = a[..., None, :, :, None]
a1 = a1.permute(0, 1, 2, 4, 3)
b = permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
b = b.permute(0, 1, 2, 4, 3)
c = a1 * b
o_pt = torch.sum(c, -1)
###################################
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
......@@ -788,6 +805,10 @@ class StructureModule(nn.Module):
if i < (self.no_blocks - 1):
rigids = rigids.stop_rot_gradient()
if habana.is_habana():
import habana_frameworks.torch.core as htcore
htcore.mark_step()
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
......
......@@ -40,6 +40,7 @@ from fastfold.utils.tensor_utils import (
flatten_final_dims,
)
import fastfold.habana as habana
class TemplatePointwiseAttention(nn.Module):
"""
......@@ -121,10 +122,13 @@ class TemplatePointwiseAttention(nn.Module):
# [*, N_res, N_res, 1, C_z]
biases = [bias]
if chunk_size is not None:
z = self._chunk(z, t, biases, chunk_size)
else:
if habana.is_habana():
z = self.mha(q_x=z, kv_x=t, biases=biases)
else:
if chunk_size is not None:
z = self._chunk(z, t, biases, chunk_size)
else:
z = self.mha(q_x=z, kv_x=t, biases=biases)
# [*, N_res, N_res, C_z]
z = z.squeeze(-2)
......
......@@ -44,7 +44,7 @@ def dgram_from_positions(
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
dgram = ((dgram > lower).type(dgram.dtype) * (dgram < upper)).type(dgram.dtype)
return dgram
......@@ -91,7 +91,7 @@ def build_template_angle_feat(template_feats: Dict[str, Any]) -> torch.Tensor:
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
template_angle_feat = torch.cat(
[
nn.functional.one_hot(template_aatype, 22),
nn.functional.one_hot(template_aatype, 22).to(torch.float32),
torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14),
alt_torsion_angles_sin_cos.reshape(
*alt_torsion_angles_sin_cos.shape[:-2], 14
......@@ -136,10 +136,10 @@ def build_template_pair_feat(
to_concat.append(
aatype_one_hot[..., None, :, :].expand(
*aatype_one_hot.shape[:-2], n_res, -1, -1
)
).to(dgram.dtype)
)
to_concat.append(
aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1)
aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1).to(dgram.dtype)
)
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
......@@ -179,7 +179,7 @@ def build_template_pair_feat(
def build_extra_msa_feat(batch: Dict[str, Any]) -> torch.Tensor:
msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
msa_feat = [
msa_1hot,
msa_1hot.to(torch.float32),
batch["extra_has_deletion"].unsqueeze(-1),
batch["extra_deletion_value"].unsqueeze(-1),
]
......
......@@ -19,6 +19,7 @@ import torch
import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
import fastfold.habana as habana
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds)
......@@ -407,4 +408,8 @@ def chunk_layer(
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
if habana.is_habana():
import habana_frameworks.torch.core as htcore
htcore.mark_step()
return out
import time
import habana_frameworks.torch as ht
class hpu_perf:
def __init__(self, module, log=True, mark_step=True, memoryinfo=False, sync=False):
if log:
print(f" {module}: start")
self.module = module
self.stime = time.perf_counter()
self.mark = mark_step
self.mem = memoryinfo
self.sync = sync
self.log = log
if self.mem:
ht.hpu.reset_peak_memory_stats()
self.prelog = None
def checknow(self, log):
if self.mark:
ht.core.mark_step()
if self.sync:
ht.core.hpu.default_stream().synchronize()
if self.mem:
print(ht.hpu.memory_summary())
tmp = time.perf_counter()
if self.log:
print(" {}: {} takes {:.2f} ms".format(self.module, log, (tmp - self.stime)*1000))
self.stime = tmp
def checkahead(self, log):
if self.mark:
ht.core.mark_step()
if self.sync:
ht.core.hpu.default_stream().synchronize()
if self.mem:
print(ht.hpu.memory_summary())
tmp = time.perf_counter()
if self.prelog is not None and self.log:
print(" {}: {} takes {:.2f} ms".format(self.module, self.prelog, (tmp - self.stime)*1000))
self.stime = tmp
self.prelog = log
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import os
import pickle
import random
import shutil
import sys
import tempfile
import time
from datetime import date
import numpy as np
import torch
import torch.multiprocessing as mp
import habana_frameworks.torch.core as htcore
import fastfold.habana as habana
import fastfold.relax.relax as relax
from fastfold.common import protein, residue_constants
from fastfold.config import model_config
from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.data.parsers import parse_fasta
from fastfold.habana.distributed import init_dist
from fastfold.habana.fastnn.ops import set_chunk_size
from fastfold.habana.inject_habana import inject_habana
from fastfold.model.hub import AlphaFold
from fastfold.model.nn.triangular_multiplicative_update import \
set_fused_triangle_multiplication
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map
from fastfold.workflow.template import (FastFoldDataWorkFlow, FastFoldMultimerDataWorkFlow)
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
fasta_file.write(fasta_str)
fasta_file.seek(0)
yield fasta_file.name
def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument(
'--uniref90_database_path',
type=str,
default=None,
)
parser.add_argument(
'--mgnify_database_path',
type=str,
default=None,
)
parser.add_argument(
'--pdb70_database_path',
type=str,
default=None,
)
parser.add_argument(
'--uniclust30_database_path',
type=str,
default=None,
)
parser.add_argument(
'--bfd_database_path',
type=str,
default=None,
)
parser.add_argument(
"--pdb_seqres_database_path",
type=str,
default=None,
)
parser.add_argument(
"--uniprot_database_path",
type=str,
default=None,
)
parser.add_argument('--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer')
parser.add_argument('--hhblits_binary_path', type=str, default='/usr/bin/hhblits')
parser.add_argument('--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch')
parser.add_argument('--kalign_binary_path', type=str, default='/usr/bin/kalign')
parser.add_argument("--hmmsearch_binary_path", type=str, default="hmmsearch")
parser.add_argument("--hmmbuild_binary_path", type=str, default="hmmbuild")
parser.add_argument(
'--max_template_date',
type=str,
default=date.today().strftime("%Y-%m-%d"),
)
parser.add_argument('--obsolete_pdbs_path', type=str, default=None)
parser.add_argument('--release_dates_path', type=str, default=None)
parser.add_argument('--chunk_size', type=int, default=None)
parser.add_argument('--enable_workflow',
default=False,
action='store_true',
help='run inference with ray workflow or not')
parser.add_argument('--inplace', default=False, action='store_true')
def inference_model(rank, world_size, result_q, batch, args):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
habana.enable_habana()
init_dist()
device = torch.device("hpu")
config = model_config(args.model_name)
if args.chunk_size:
config.globals.chunk_size = args.chunk_size
if "v3" in args.param_path:
set_fused_triangle_multiplication()
config.globals.inplace = False
config.globals.is_multimer = args.model_preset == 'multimer'
model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name)
model = inject_habana(model)
model = model.eval()
model = model.to(device=device)
set_chunk_size(model.globals.chunk_size)
with torch.no_grad():
batch = {k: torch.as_tensor(v).to(device=device) for k, v in batch.items()}
t = time.perf_counter()
out = model(batch)
htcore.mark_step()
print(f"Inference time: {time.perf_counter() - t}")
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
result_q.put(out)
torch.distributed.barrier()
def main(args):
if args.model_preset == "multimer":
inference_multimer_model(args)
else:
inference_monomer_model(args)
def inference_multimer_model(args):
print("running in multimer mode...")
config = model_config(args.model_name)
predict_max_templates = 4
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=predict_max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path,
)
if (not args.use_precomputed_alignments):
if args.enable_workflow:
print("Running alignment with ray workflow...")
alignment_runner = FastFoldMultimerDataWorkFlow(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hmmsearch_binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
pdb_seqres_database_path=args.pdb_seqres_database_path,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus)
else:
alignment_runner = data_pipeline.AlignmentRunnerMultimer(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hmmsearch_binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
pdb_seqres_database_path=args.pdb_seqres_database_path,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus)
else:
alignment_runner = None
monomer_data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,)
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=monomer_data_processor,)
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
if (not args.use_precomputed_alignments):
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
# Gather input sequences
fasta_path = args.fasta_path
with open(fasta_path, "r") as fp:
data = fp.read()
lines = [l.replace('\n', '') for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
for tag, seq in zip(tags, seqs):
local_alignment_dir = os.path.join(alignment_dir, tag)
if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
else:
shutil.rmtree(local_alignment_dir)
os.makedirs(local_alignment_dir)
chain_fasta_str = f'>chain_{tag}\n{seq}\n'
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
if args.enable_workflow:
print("Running alignment with ray workflow...")
t = time.perf_counter()
alignment_runner.run(chain_fasta_path, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner.run(chain_fasta_path, local_alignment_dir)
print(f"Finished running alignment for {tag}")
local_alignment_dir = alignment_dir
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
processed_feature_dict = feature_processor.process_features(
feature_dict,
mode='predict',
is_multimer=True,
)
batch = processed_feature_dict
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model,
nprocs=args.hpus,
args=(args.hpus, result_q, batch, args))
out = result_q.get()
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir, f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation(
use_gpu=False,
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
def inference_monomer_model(args):
print("running in monomer mode...")
config = model_config(args.model_name)
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path)
use_small_bfd = args.preset == 'reduced_dbs' # (args.bfd_database_path is None)
if use_small_bfd:
assert args.bfd_database_path is not None
else:
assert args.bfd_database_path is not None
assert args.uniclust30_database_path is not None
data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,)
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
if (args.use_precomputed_alignments is None):
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
# Gather input sequences
with open(args.fasta_path, "r") as fp:
fasta = fp.read()
seqs, tags = parse_fasta(fasta)
seq, tag = seqs[0], tags[0]
print(f"tag:{tag}\nseq[{len(seq)}]:{seq}")
batch = [None]
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
print("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
if args.enable_workflow:
print("Running alignment with ray workflow...")
alignment_data_workflow_runner = FastFoldDataWorkFlow(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
t = time.perf_counter()
alignment_data_workflow_runner.run(fasta_path, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(fasta_path, local_alignment_dir)
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
# Remove temporary FASTA file
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict,
mode='predict',
)
batch = processed_feature_dict
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model,
nprocs=args.hpus,
args=(args.hpus, result_q, batch, args))
out = result_q.get()
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir, f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation(
use_gpu=False,
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"fasta_path",
type=str,
)
parser.add_argument(
"template_mmcif_dir",
type=str,
)
parser.add_argument("--use_precomputed_alignments",
type=str,
default=None,
help="""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored.""")
parser.add_argument(
"--output_dir",
type=str,
default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
)
parser.add_argument("--model_name",
type=str,
default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm or model_{1-5}_multimer, as defined on the AlphaFold GitHub.""")
parser.add_argument("--param_path",
type=str,
default=None,
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
./data/params""")
parser.add_argument("--cpus",
type=int,
default=12,
help="""Number of CPUs with which to run alignment tools""")
parser.add_argument("--hpus",
type=int,
default=1,
help="""Number of GPUs with which to run inference""")
parser.add_argument('--preset',
type=str,
default='full_dbs',
choices=('reduced_dbs', 'full_dbs'))
parser.add_argument('--data_random_seed', type=str, default=None)
parser.add_argument(
"--model_preset",
type=str,
default="monomer",
choices=["monomer", "multimer"],
help="Choose preset model configuration - the monomer model, the monomer model with "
"extra ensembling, monomer model with pTM head, or multimer model",
)
add_data_args(parser)
args = parser.parse_args()
if (args.param_path is None):
args.param_path = os.path.join("data", "params", "params_" + args.model_name + ".npz")
main(args)
# add '--gpus [N]' to use N gpus for inference
# add '--enable_workflow' to use parallel workflow for data processing
# add '--use_precomputed_alignments [path_to_alignments]' to use precomputed msa
# add '--chunk_size [N]' to use chunk to reduce peak memory
# add '--inplace' to use inplace to save memory
python habana/inference.py target.fasta data/pdb_mmcif/mmcif_files \
--output_dir ./ \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path `which jackhmmer` \
--hhblits_binary_path `which hhblits` \
--hhsearch_binary_path `which hhsearch` \
--kalign_binary_path `which kalign`
import pickle
import time
import habana_frameworks.torch.core as htcore
import torch
import fastfold.habana as habana
from fastfold.config import model_config
from fastfold.habana.distributed import init_dist
from fastfold.habana.fastnn.ops import set_chunk_size
from fastfold.habana.inject_habana import inject_habana
from fastfold.model.hub import AlphaFold
def main():
habana.enable_habana()
init_dist()
batch = pickle.load(open('./test_batch.pkl', 'rb'))
model_name = "model_1"
device = torch.device("hpu")
config = model_config(model_name)
config.globals.inplace = False
config.globals.chunk_size = 512
# habana.enable_hmp()
model = AlphaFold(config)
model = inject_habana(model)
model = model.eval()
model = model.to(device=device)
if config.globals.chunk_size is not None:
set_chunk_size(model.globals.chunk_size + 1)
if habana.is_hmp():
from habana_frameworks.torch.hpex import hmp
hmp.convert(opt_level='O1',
bf16_file_path='./habana/ops_bf16.txt',
fp32_file_path='./habana/ops_fp32.txt',
isVerbose=False)
print("========= AMP ENABLED!!")
with torch.no_grad():
batch = {k: torch.as_tensor(v).to(device=device) for k, v in batch.items()}
for _ in range(5):
t = time.perf_counter()
out = model(batch)
htcore.mark_step()
htcore.hpu.default_stream().synchronize()
print(f"Inference time: {time.perf_counter() - t}")
if __name__ == '__main__':
main()
addmm
conv2d
max_pool2d
sum
relu
mm
bmm
mv
linear
t
mul
sub
add
truediv
layer_norm
cross_entropy
log_softmax
nll_loss
softmax
import argparse
import logging
import random
import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
import fastfold.habana as habana
from fastfold.config import model_config
from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader
from fastfold.habana.distributed import init_dist
from fastfold.habana.inject_habana import inject_habana
from fastfold.model.hub import AlphaFold, AlphaFoldLoss, AlphaFoldLRScheduler
from fastfold.utils.tensor_utils import tensor_tree_map
import habana_frameworks.torch.core as htcore
from habana_frameworks.torch.hpex import hmp
logging.disable(logging.WARNING)
torch.multiprocessing.set_sharing_strategy('file_system')
from habana.hpuhelper import *
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--from_torch', default=False, action='store_true')
parser.add_argument("--template_mmcif_dir",
type=str,
help="Directory containing mmCIF files to search for templates")
parser.add_argument("--max_template_date",
type=str,
help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target''')
parser.add_argument("--train_data_dir",
type=str,
help="Directory containing training mmCIF files")
parser.add_argument("--train_alignment_dir",
type=str,
help="Directory containing precomputed training alignments")
parser.add_argument(
"--train_chain_data_cache_path",
type=str,
default=None,
)
parser.add_argument("--distillation_data_dir",
type=str,
default=None,
help="Directory containing training PDB files")
parser.add_argument("--distillation_alignment_dir",
type=str,
default=None,
help="Directory containing precomputed distillation alignments")
parser.add_argument(
"--distillation_chain_data_cache_path",
type=str,
default=None,
)
parser.add_argument("--val_data_dir",
type=str,
default=None,
help="Directory containing validation mmCIF files")
parser.add_argument("--val_alignment_dir",
type=str,
default=None,
help="Directory containing precomputed validation alignments")
parser.add_argument("--kalign_binary_path",
type=str,
default='/usr/bin/kalign',
help="Path to the kalign binary")
parser.add_argument("--train_filter_path",
type=str,
default=None,
help='''Optional path to a text file containing names of training
examples to include, one per line. Used to filter the training
set''')
parser.add_argument("--distillation_filter_path",
type=str,
default=None,
help="""See --train_filter_path""")
parser.add_argument("--obsolete_pdbs_file_path",
type=str,
default=None,
help="""Path to obsolete.dat file containing list of obsolete PDBs and
their replacements.""")
parser.add_argument("--template_release_dates_cache_path",
type=str,
default=None,
help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
files.""")
parser.add_argument("--train_epoch_len",
type=int,
default=10000,
help=("The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."))
parser.add_argument("--_alignment_index_path",
type=str,
default=None,
help="Training alignment index. See the README for instructions.")
parser.add_argument("--config_preset",
type=str,
default="initial_training",
help=('Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'))
parser.add_argument(
"--_distillation_structure_index_path",
type=str,
default=None,
)
parser.add_argument("--distillation_alignment_index_path",
type=str,
default=None,
help="Distillation alignment index. See the README for instructions.")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
# habana arguments
parser.add_argument("--hmp",
action='store_true',
default=False,
help="Whether to use habana mixed precision")
parser.add_argument("--hmp-bf16",
type=str,
default="./habana/ops_bf16.txt",
help="Path to bf16 ops list in hmp O1 mode")
parser.add_argument("--hmp-fp32",
type=str,
default="./habana/ops_fp32.txt",
help="Path to fp32 ops list in hmp O1 mode")
parser.add_argument("--hmp-opt-level",
type=str,
default='O1',
help="Choose optimization level for hmp")
parser.add_argument("--hmp-verbose",
action='store_true',
default=False,
help='Enable verbose mode for hmp')
args = parser.parse_args()
habana.enable_habana()
init_dist()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
config = model_config(args.config_preset, train=True)
config.globals.inplace = False
model = AlphaFold(config)
model = inject_habana(model)
model = model.to(device="hpu")
model = DDP(model)
train_dataset, test_dataset = SetupTrainDataset(
config=config.data,
template_mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
train_data_dir=args.train_data_dir,
train_alignment_dir=args.train_alignment_dir,
train_chain_data_cache_path=args.train_chain_data_cache_path,
distillation_data_dir=args.distillation_data_dir,
distillation_alignment_dir=args.distillation_alignment_dir,
distillation_chain_data_cache_path=args.distillation_chain_data_cache_path,
val_data_dir=args.val_data_dir,
val_alignment_dir=args.val_alignment_dir,
kalign_binary_path=args.kalign_binary_path,
# train_mapping_path=args.train_mapping_path,
# distillation_mapping_path=args.distillation_mapping_path,
obsolete_pdbs_file_path=args.obsolete_pdbs_file_path,
template_release_dates_cache_path=args.template_release_dates_cache_path,
train_epoch_len=args.train_epoch_len,
_alignment_index_path=args._alignment_index_path,
)
train_dataloader, test_dataloader = TrainDataLoader(
config=config.data,
train_dataset=train_dataset,
test_dataset=test_dataset,
batch_seed=args.seed,
)
criterion = AlphaFoldLoss(config.loss)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-8)
from habana_frameworks.torch.hpex.optimizers import FusedAdamW
optimizer = FusedAdamW(model.parameters(), lr=1e-3, eps=1e-8)
lr_scheduler = AlphaFoldLRScheduler(optimizer)
if args.hmp:
hmp.convert(opt_level='O1',
bf16_file_path=args.hmp_bf16,
fp32_file_path=args.hmp_fp32,
isVerbose=args.hmp_verbose)
print("========= HMP ENABLED!!")
for epoch in range(200):
model.train()
train_dataloader = tqdm(train_dataloader)
for batch in train_dataloader:
perf = hpu_perf("train step")
batch = {k: torch.as_tensor(v).to(device="hpu") for k, v in batch.items()}
optimizer.zero_grad()
output = model(batch)
perf.checknow("forward")
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss, loss_breakdown = criterion(output, batch, _return_breakdown=True)
perf.checknow("loss")
loss.backward()
train_dataloader.set_postfix(loss=float(loss))
perf.checknow("backward")
with hmp.disable_casts():
optimizer.step()
perf.checknow("optimizer")
lr_scheduler.step()
if test_dataloader is not None:
model.eval()
train_dataloader = tqdm(train_dataloader)
for batch in test_dataloader:
batch = {k: torch.as_tensor(v).to(device="hpu") for k, v in batch.items()}
with torch.no_grad():
output = model(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
_, loss_breakdown = criterion(output, batch, _return_breakdown=True)
htcore.mark_step()
train_dataloader.set_postfix(loss=float(loss))
if __name__ == "__main__":
main()
DATA_DIR=/mnt/usb/training-demo
hpus_per_node=1
max_template_date=2021-10-10
train_data_dir=${DATA_DIR}/mmcif_dir # specify the dir contains *.cif or *.pdb
train_alignment_dir=${DATA_DIR}/alignment_dir # a dir to save template and features.pkl of training sequence
mkdir -p ${train_alignment_dir}
# val_data_dir=${PROJECT_DIR}/dataset/val_pdb
# val_alignment_dir=${PROJECT_DIR}/dataset/alignment_val_pdb # a dir to save template and features.pkl of vld sequence
template_mmcif_dir=${DATA_DIR}/data/pdb_mmcif/mmcif_files
template_release_dates_cache_path=${DATA_DIR}/mmcif_cache.json # a cache used to pre-filter templates
train_chain_data_cache_path=${DATA_DIR}/chain_data_cache.json # a separate chain-level cache with data used for training-time data filtering
train_epoch_len=10000 # virtual length of each training epoch, which affects frequency of validation & checkpointing
mpirun --allow-run-as-root --bind-to none -np ${hpus_per_node} python habana/train.py \
--from_torch \
--template_mmcif_dir=${template_mmcif_dir} \
--max_template_date=${max_template_date} \
--train_data_dir=${train_data_dir} \
--train_alignment_dir=${train_alignment_dir} \
--train_chain_data_cache_path=${train_chain_data_cache_path} \
--template_release_dates_cache_path=${template_release_dates_cache_path} \
--train_epoch_len=${train_epoch_len} \
......@@ -46,23 +46,24 @@ def append_nvcc_threads(nvcc_extra_args):
if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
print(
'\nWarning: Torch did not find available GPUs on this system.\n',
'If your intention is to cross-compile, this is not an error.\n'
'By default, FastFold will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
'Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
'If you wish to cross-compile for a single specific architecture,\n'
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("======== NOTICE: torch.cuda.is_available == False")
# # https://github.com/NVIDIA/apex/issues/486
# # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
# print(
# '\nWarning: Torch did not find available GPUs on this system.\n',
# 'If your intention is to cross-compile, this is not an error.\n'
# 'By default, FastFold will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
# 'Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
# 'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
# 'If you wish to cross-compile for a single specific architecture,\n'
# 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
# if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
# _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
# if int(bare_metal_major) == 11:
# os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
# else:
# os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0])
......@@ -82,11 +83,7 @@ ext_modules = []
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
if CUDA_HOME is None:
raise RuntimeError(
"Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc."
)
else:
if CUDA_HOME:
# check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)
def cuda_ext_helper(name, sources, extra_cuda_flags):
......@@ -126,6 +123,8 @@ else:
ext_modules.append(
cuda_ext_helper('fastfold_softmax_cuda', ['softmax_cuda.cpp', 'softmax_cuda_kernel.cu'],
extra_cuda_flags + cc_flag))
else:
print("======== NOTICE: install without cuda kernel")
setup(
name='fastfold',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment