Commit 080e37bb authored by Sam DeLuca's avatar Sam DeLuca
Browse files

Merge remote-tracking branch 'origin/main' into improved_model_outputs

parents 2ada4f8d efcf80f5
![header ](imgs/of_banner.png)
_Figure: Comparison of OpenFold, AlphaFold2, and experimental structure of Streptomyces tokunonesis TokK protein (pdb code 7KDX), related to novel antibiotics used for rare infections including during COVID-19 infection._
# OpenFold
......@@ -22,7 +24,8 @@ are available via scripts in this repository while the MSAs are hosted by the
[Registry of Open Data on AWS (RODA)](https://registry.opendata.aws/openfold).
Try out running inference for yourself with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
OpenFold also supports inference using AlphaFold's official parameters.
OpenFold also supports inference using AlphaFold's official parameters, and
vice versa (see `scripts/convert_of_weights_to_jax.py`).
OpenFold has the following advantages over the reference implementation:
......
......@@ -254,6 +254,7 @@ class MSAAttention(nn.Module):
use_lma=use_lma,
)
else:
m = self.layer_norm_m(m)
m = self.mha(
q_x=m,
kv_x=m,
......
......@@ -29,6 +29,7 @@ import random
import sys
import time
import torch
import re
from openfold.config import model_config
from openfold.data import templates, feature_pipeline, data_pipeline
......@@ -163,6 +164,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
def parse_fasta(data):
data = re.sub('>$', '', data, flags=re.M)
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
......@@ -272,6 +274,8 @@ def load_models_from_command_line(args, config):
"be specified."
)
def list_files_with_extensions(dir, extensions):
return [f for f in os.listdir(dir) if f.endswith(extensions)]
def main(args):
# Create the output directory
......@@ -307,7 +311,8 @@ def main(args):
prediction_dir = os.path.join(args.output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
for fasta_file in os.listdir(args.fasta_dir):
for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
# Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
data = fp.read()
......
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code.
import argparse
import numpy as np
import torch
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import (
Param,
ParamType,
generate_translation_dict,
process_translation_dict,
)
from openfold.utils.tensor_utils import tree_map
def reshape_fn(of_param, af_weight):
transformations = {
ParamType.LinearWeight: lambda w: w.transpose(-1, -2),
ParamType.LinearWeightMHA: lambda w: w.transpose(-1, -2).reshape(af_weight.shape),
ParamType.LinearMHAOutputWeight: lambda w: w.transpose(-1, -2).reshape(af_weight.shape),
ParamType.LinearBiasMHA: lambda w: w.reshape(af_weight.shape),
ParamType.LinearWeightOPM: lambda w: w.transpose(-1, -2).reshape(af_weight.shape),
ParamType.Other: lambda w: w,
}
if(of_param.stacked):
of_weight = torch.stack([torch.Tensor(p) for p in of_param.param])
else:
of_weight = torch.Tensor(of_param.param)
return transformations[of_param.param_type](of_weight)
def transfer(of_dict, af_weight_template):
for k in of_dict:
if(type(of_dict[k]) == dict):
transfer(of_dict[k], af_weight_template[k])
else:
reshaped = reshape_fn(of_dict[k], af_weight_template[k])
reshaped = reshaped.detach().numpy()
np.copyto(af_weight_template[k], reshaped)
def main(args):
d = torch.load(args.of_pt_path)
config = model_config(args.config_preset)
model = AlphaFold(config)
model.load_state_dict(d)
translation = generate_translation_dict(model, args.config_preset)
translation = process_translation_dict(translation)
af_weight_template = np.load(args.template_npz_path)
af_weight_template = {k:v for k,v in af_weight_template.items() if k in translation}
zero = lambda n: n * 0
af_weight_template = tree_map(zero, af_weight_template, np.ndarray)
transfer(translation, af_weight_template)
np.savez(args.out_path, **af_weight_template)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"of_pt_path", type=str, help="Path to OpenFold .pt checkpoint file"
)
parser.add_argument(
"config_preset", type=str, help="The corresponding config preset"
)
parser.add_argument(
"out_path", type=str, help="Path for output .npz file"
)
parser.add_argument(
"--template_npz_path",
type=str,
default="openfold/resources/params/params_model_1_ptm.npz",
help="""Path to an AlphaFold checkpoint w/ a superset of the OF
checkpoint's parameters. params_model_1_ptm.npz always works.
"""
)
args = parser.parse_args()
main(args)
#!/bin/bash
git clone --branch v3.3.0 https://github.com/soedinglab/hh-suite.git /tmp/hh-suite
&& mkdir /tmp/hh-suite/build
&& pushd /tmp/hh-suite/build
&& cmake -DCMAKE_INSTALL_PREFIX=/opt/hhsuite ..
&& make -j 4 && make install
&& ln -s /opt/hhsuite/bin/* /usr/bin
&& popd
&& rm -rf /tmp/hh-suite
git clone --branch v3.3.0 https://github.com/soedinglab/hh-suite.git /tmp/hh-suite \
&& mkdir /tmp/hh-suite/build \
&& pushd /tmp/hh-suite/build \
&& cmake -DCMAKE_INSTALL_PREFIX=/opt/hhsuite .. \
&& make -j 4 && make install \
&& ln -s /opt/hhsuite/bin/* /usr/bin \
&& popd \
&& rm -rf /tmp/hh-suite
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment