Commit 15cd3506 authored by mashun1's avatar mashun1
Browse files

Merge branch 'dtk24.04.1'

parents 24e633dc 19085464
*.egg*
tryme.ipynb
build/
dist/
test/
temp_output/
temp_fasta/
*pycache*
FROM image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04.1-py3.10
# RUN apt update
# WORKDIR /app
# WORKDIR /app/softwares
# RUN git clone https://github.com/soedinglab/hh-suite.git
# RUN mkdir -p hh-suite/build && cd hh-suite/build && cmake -DCMAKE_INSTALL_PREFIX=. .. && make -j 4 && make install
# ENV PATH=/app/softwares/hh-suite/build/bin:/app/softwares/hh-suite/build/scripts:$PATH
# WORKDIR /app/softwares
# RUN wget https://github.com/TimoLassmann/kalign/archive/refs/tags/v3.4.0.zip && unzip v3.4.0.zip && cd kalign-3.4.0 && mkdir build && cd build && cmake .. && make && make install
# WORKDIR /app/softwares
# RUN sudo apt install doxygen -y
# RUN wget https://github.com/openmm/openmm/archive/refs/tags/8.0.0.zip && unzip 8.0.0.zip && cd openmm-8.0.0 && mkdir build && cd build && cmake .. && make && sudo make install && sudo make PythonInstall
# WORKDIR /app/softwares
# RUN wget https://github.com/openmm/pdbfixer/archive/refs/tags/1.9.zip && unzip 1.9.zip && cd pdbfixer-1.9 && python setup.py install
# RUN sudo apt install hmmer -y
# WORKDIR /app
# COPY . /app/alphafold2
# RUN ls
# RUN pip install --no-cache-dir -r /app/alphafold2/requirements_dcu.txt -i https://mirrors.ustc.edu.cn/pypi/web/simple
# RUN pip install dm-haiku==0.0.11 flax==0.7.1 jmp==0.0.2 tabulate==0.8.9 --no-deps jax -i https://mirrors.ustc.edu.cn/pypi/web/simple
# RUN pip install orbax==0.1.6 orbax-checkpoint==0.1.6 optax==0.2.2 -i https://mirrors.ustc.edu.cn/pypi/web/simple
# WORKDIR /app/alphafold2
# RUN python setup.py install
<!--
* @Author: zhuww
* @email: zhuww@sugon.com
* @Date: 2023-04-06 18:04:07
* @LastEditTime: 2023-12-26 15:54:01
-->
# AF2 # AF2
## 论文 ## 论文
- [https://www.nature.com/articles/s41586-021-03819-2](https://www.nature.com/articles/s41586-021-03819-2) - [https://www.nature.com/articles/s41586-021-03819-2](https://www.nature.com/articles/s41586-021-03819-2)
...@@ -19,9 +14,17 @@ AlphaFold2通过从蛋白质序列和结构数据中提取信息,使用神经 ...@@ -19,9 +14,17 @@ AlphaFold2通过从蛋白质序列和结构数据中提取信息,使用神经
![img](./docs/alphafold2_1.png) ![img](./docs/alphafold2_1.png)
<!-- ## 环境配置 ## 环境配置
### Docker(方法一)
# 使用该方法不需要下载本仓库,镜像中已包含可运行代码,但需要挂载相应的数据文件
### Docker docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:alphafold2-dtk24.04.1-py310
docker run --shm-size 100g --network=host --name=alphafold2 --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 本地数据地址:镜像数据地址 -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
### Docker(方法二)
docker pull image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04.1-py3.10 docker pull image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04.1-py3.10
...@@ -45,7 +48,7 @@ AlphaFold2通过从蛋白质序列和结构数据中提取信息,使用神经 ...@@ -45,7 +48,7 @@ AlphaFold2通过从蛋白质序列和结构数据中提取信息,使用神经
export PATH="$(pwd)/bin:$(pwd)/scripts:$PATH" export PATH="$(pwd)/bin:$(pwd)/scripts:$PATH"
wget https://github.com/TimoLassmann/kalign/archive/refs/tags/v3.4.0.zip wget https://github.com/TimoLassmann/kalign/archive/refs/tags/v3.4.0.zip
unzip 3.4.0.zip && cd kalign-3.4.0 unzip v3.4.0.zip && cd kalign-3.4.0
mkdir build mkdir build
cd build cd build
cmake .. cmake ..
...@@ -65,23 +68,8 @@ AlphaFold2通过从蛋白质序列和结构数据中提取信息,使用神经 ...@@ -65,23 +68,8 @@ AlphaFold2通过从蛋白质序列和结构数据中提取信息,使用神经
wget https://github.com/openmm/pdbfixer/archive/refs/tags/1.9.zip wget https://github.com/openmm/pdbfixer/archive/refs/tags/1.9.zip
unzip 1.9.zip && cd pdbfixer-1.9 && python setup.py install --> unzip 1.9.zip && cd pdbfixer-1.9 && python setup.py install
## 环境配置
提供[光源](https://www.sourcefind.cn/#/image/dcu/custom)拉取推理的docker镜像:
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:alphafold2-2.3.2-dtk23.10-py38
# <Image ID>用上面拉取docker镜像的ID替换
# <Host Path>主机端路径
# <Container Path>容器映射路径
docker run -it --name alphafold --privileged --shm-size=32G --device=/dev/kfd --device=/dev/dri/ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v <Host Path>:<Container Path> <Image ID> /bin/bash
```
镜像版本依赖:
* DTK驱动:dtk23.10
* Jax: 0.3.25
* TensorFlow2: 2.11.0
* python: python3.8
## 数据集 ## 数据集
推荐使用AlphaFold2中的开源数据集,包括BFD、MGnify、PDB70、Uniclust、Uniref90等,数据集大小约2.62TB。数据集格式如下: 推荐使用AlphaFold2中的开源数据集,包括BFD、MGnify、PDB70、Uniclust、Uniref90等,数据集大小约2.62TB。数据集格式如下:
...@@ -171,12 +159,12 @@ $DOWNLOAD_DIR/ ...@@ -171,12 +159,12 @@ $DOWNLOAD_DIR/
``` ```
[查看蛋白质3D结构](https://www.pdbus.org/3d-view) [查看蛋白质3D结构](https://www.pdbus.org/3d-view)
<div style="display: flex; justify-content: center; align-items: center;">
<img src="./docs/result_pdb.png" alt="Image"> ID: 8U23
<div style="position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); background: rgba(0, 0, 0, 0.5); color: #fff; padding: 10px;">
红色为真实结构,色为预测结构 蓝色的为预测结构,色为真实结构
</div>
</div> ![alt text](image.png)
### 精度 ### 精度
测试数据:[casp15](https://www.predictioncenter.org/casp15/targetlist.cgi)[uniprot](https://www.uniprot.org/) 测试数据:[casp15](https://www.predictioncenter.org/casp15/targetlist.cgi)[uniprot](https://www.uniprot.org/)
...@@ -196,6 +184,8 @@ $DOWNLOAD_DIR/ ...@@ -196,6 +184,8 @@ $DOWNLOAD_DIR/
| fp32 | 单体 | T1024 | 408 | 0.664 | 0.470 | 87.076 | 0.829 | 0.518 | 3.516 | | fp32 | 单体 | T1024 | 408 | 0.664 | 0.470 | 87.076 | 0.829 | 0.518 | 3.516 |
| fp32 | 多体 | H1106 | 236 | 0.203 | 0.144 | 0.860 | 0.181 | 0.151 | 20.457 | | fp32 | 多体 | H1106 | 236 | 0.203 | 0.144 | 0.860 | 0.181 | 0.151 | 20.457 |
## 应用场景 ## 应用场景
### 算法类别 ### 算法类别
......
This diff is collapsed.
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
"""Functions for processing confidence metrics.""" """Functions for processing confidence metrics."""
import json
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import numpy as np import numpy as np
import scipy.special import scipy.special
...@@ -36,6 +38,43 @@ def compute_plddt(logits: np.ndarray) -> np.ndarray: ...@@ -36,6 +38,43 @@ def compute_plddt(logits: np.ndarray) -> np.ndarray:
return predicted_lddt_ca * 100 return predicted_lddt_ca * 100
def _confidence_category(score: float) -> str:
"""Categorizes pLDDT into: disordered (D), low (L), medium (M), high (H)."""
if 0 <= score < 50:
return 'D'
if 50 <= score < 70:
return 'L'
elif 70 <= score < 90:
return 'M'
elif 90 <= score <= 100:
return 'H'
else:
raise ValueError(f'Invalid pLDDT score {score}')
def confidence_json(plddt: np.ndarray) -> str:
"""Returns JSON with confidence score and category for every residue.
Args:
plddt: Per-residue confidence metric data.
Returns:
String with a formatted JSON.
Raises:
ValueError: If `plddt` has a rank different than 1.
"""
if plddt.ndim != 1:
raise ValueError(f'The plddt array must be rank 1, got: {plddt.shape}.')
confidence = {
'residueNumber': list(range(1, len(plddt) + 1)),
'confidenceScore': [round(float(s), 2) for s in plddt],
'confidenceCategory': [_confidence_category(s) for s in plddt],
}
return json.dumps(confidence, indent=None, separators=(',', ':'))
def _calculate_bin_centers(breaks: np.ndarray): def _calculate_bin_centers(breaks: np.ndarray):
"""Gets the bin centers from the bin edges. """Gets the bin centers from the bin edges.
...@@ -108,6 +147,32 @@ def compute_predicted_aligned_error( ...@@ -108,6 +147,32 @@ def compute_predicted_aligned_error(
} }
def pae_json(pae: np.ndarray, max_pae: float) -> str:
"""Returns the PAE in the same format as is used in the AFDB.
Note that the values are presented as floats to 1 decimal place, whereas AFDB
returns integer values.
Args:
pae: The n_res x n_res PAE array.
max_pae: The maximum possible PAE value.
Returns:
PAE output format as a JSON string.
"""
# Check the PAE array is the correct shape.
if pae.ndim != 2 or pae.shape[0] != pae.shape[1]:
raise ValueError(f'PAE must be a square matrix, got {pae.shape}')
# Round the predicted aligned errors to 1 decimal place.
rounded_errors = np.round(pae.astype(np.float64), decimals=1)
formatted_output = [{
'predicted_aligned_error': rounded_errors.tolist(),
'max_predicted_aligned_error': max_pae,
}]
return json.dumps(formatted_output, indent=None, separators=(',', ':'))
def predicted_tm_score( def predicted_tm_score(
logits: np.ndarray, logits: np.ndarray,
breaks: np.ndarray, breaks: np.ndarray,
......
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test confidence metrics."""
from absl.testing import absltest
from alphafold.common import confidence
import numpy as np
class ConfidenceTest(absltest.TestCase):
def test_pae_json(self):
pae = np.array([[0.01, 13.12345], [20.0987, 0.0]])
pae_json = confidence.pae_json(pae=pae, max_pae=31.75)
self.assertEqual(
pae_json, '[{"predicted_aligned_error":[[0.0,13.1],[20.1,0.0]],'
'"max_predicted_aligned_error":31.75}]')
def test_confidence_json(self):
plddt = np.array([42, 42.42])
confidence_json = confidence.confidence_json(plddt=plddt)
print(confidence_json)
self.assertEqual(
confidence_json,
('{"residueNumber":[1,2],'
'"confidenceScore":[42.0,42.42],'
'"confidenceCategory":["D","D"]}'),
)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""mmCIF metadata."""
from typing import Mapping, Sequence
from alphafold import version
import numpy as np
_DISCLAIMER = """ALPHAFOLD DATA, COPYRIGHT (2021) DEEPMIND TECHNOLOGIES LIMITED.
THE INFORMATION PROVIDED IS THEORETICAL MODELLING ONLY AND CAUTION SHOULD BE
EXERCISED IN ITS USE. IT IS PROVIDED "AS-IS" WITHOUT ANY WARRANTY OF ANY KIND,
WHETHER EXPRESSED OR IMPLIED. NO WARRANTY IS GIVEN THAT USE OF THE INFORMATION
SHALL NOT INFRINGE THE RIGHTS OF ANY THIRD PARTY. DISCLAIMER: THE INFORMATION IS
NOT INTENDED TO BE A SUBSTITUTE FOR PROFESSIONAL MEDICAL ADVICE, DIAGNOSIS, OR
TREATMENT, AND DOES NOT CONSTITUTE MEDICAL OR OTHER PROFESSIONAL ADVICE. IT IS
AVAILABLE FOR ACADEMIC AND COMMERCIAL PURPOSES, UNDER CC-BY 4.0 LICENCE."""
# Authors of the Nature methods paper we reference in the mmCIF.
_MMCIF_PAPER_AUTHORS = (
'Jumper, John',
'Evans, Richard',
'Pritzel, Alexander',
'Green, Tim',
'Figurnov, Michael',
'Ronneberger, Olaf',
'Tunyasuvunakool, Kathryn',
'Bates, Russ',
'Zidek, Augustin',
'Potapenko, Anna',
'Bridgland, Alex',
'Meyer, Clemens',
'Kohl, Simon A. A.',
'Ballard, Andrew J.',
'Cowie, Andrew',
'Romera-Paredes, Bernardino',
'Nikolov, Stanislav',
'Jain, Rishub',
'Adler, Jonas',
'Back, Trevor',
'Petersen, Stig',
'Reiman, David',
'Clancy, Ellen',
'Zielinski, Michal',
'Steinegger, Martin',
'Pacholska, Michalina',
'Berghammer, Tamas',
'Silver, David',
'Vinyals, Oriol',
'Senior, Andrew W.',
'Kavukcuoglu, Koray',
'Kohli, Pushmeet',
'Hassabis, Demis',
)
# Authors of the mmCIF - we set them to be equal to the authors of the paper.
_MMCIF_AUTHORS = _MMCIF_PAPER_AUTHORS
def add_metadata_to_mmcif(
old_cif: Mapping[str, Sequence[str]], model_type: str
) -> Mapping[str, Sequence[str]]:
"""Adds AlphaFold metadata in the given mmCIF."""
cif = {}
# ModelCIF conformation dictionary.
cif['_audit_conform.dict_name'] = ['mmcif_ma.dic']
cif['_audit_conform.dict_version'] = ['1.3.9']
cif['_audit_conform.dict_location'] = [
'https://raw.githubusercontent.com/ihmwg/ModelCIF/master/dist/'
'mmcif_ma.dic'
]
# License and disclaimer.
cif['_pdbx_data_usage.id'] = ['1', '2']
cif['_pdbx_data_usage.type'] = ['license', 'disclaimer']
cif['_pdbx_data_usage.details'] = [
'Data in this file is available under a CC-BY-4.0 license.',
_DISCLAIMER,
]
cif['_pdbx_data_usage.url'] = [
'https://creativecommons.org/licenses/by/4.0/',
'?',
]
cif['_pdbx_data_usage.name'] = ['CC-BY-4.0', '?']
# Structure author details.
cif['_audit_author.name'] = []
cif['_audit_author.pdbx_ordinal'] = []
for author_index, author_name in enumerate(_MMCIF_AUTHORS, start=1):
cif['_audit_author.name'].append(author_name)
cif['_audit_author.pdbx_ordinal'].append(str(author_index))
# Paper author details.
cif['_citation_author.citation_id'] = []
cif['_citation_author.name'] = []
cif['_citation_author.ordinal'] = []
for author_index, author_name in enumerate(_MMCIF_PAPER_AUTHORS, start=1):
cif['_citation_author.citation_id'].append('primary')
cif['_citation_author.name'].append(author_name)
cif['_citation_author.ordinal'].append(str(author_index))
# Paper citation details.
cif['_citation.id'] = ['primary']
cif['_citation.title'] = [
'Highly accurate protein structure prediction with AlphaFold'
]
cif['_citation.journal_full'] = ['Nature']
cif['_citation.journal_volume'] = ['596']
cif['_citation.page_first'] = ['583']
cif['_citation.page_last'] = ['589']
cif['_citation.year'] = ['2021']
cif['_citation.journal_id_ASTM'] = ['NATUAS']
cif['_citation.country'] = ['UK']
cif['_citation.journal_id_ISSN'] = ['0028-0836']
cif['_citation.journal_id_CSD'] = ['0006']
cif['_citation.book_publisher'] = ['?']
cif['_citation.pdbx_database_id_PubMed'] = ['34265844']
cif['_citation.pdbx_database_id_DOI'] = ['10.1038/s41586-021-03819-2']
# Type of data in the dataset including data used in the model generation.
cif['_ma_data.id'] = ['1']
cif['_ma_data.name'] = ['Model']
cif['_ma_data.content_type'] = ['model coordinates']
# Description of number of instances for each entity.
cif['_ma_target_entity_instance.asym_id'] = old_cif['_struct_asym.id']
cif['_ma_target_entity_instance.entity_id'] = old_cif[
'_struct_asym.entity_id'
]
cif['_ma_target_entity_instance.details'] = ['.'] * len(
cif['_ma_target_entity_instance.entity_id']
)
# Details about the target entities.
cif['_ma_target_entity.entity_id'] = cif[
'_ma_target_entity_instance.entity_id'
]
cif['_ma_target_entity.data_id'] = ['1'] * len(
cif['_ma_target_entity.entity_id']
)
cif['_ma_target_entity.origin'] = ['.'] * len(
cif['_ma_target_entity.entity_id']
)
# Details of the models being deposited.
cif['_ma_model_list.ordinal_id'] = ['1']
cif['_ma_model_list.model_id'] = ['1']
cif['_ma_model_list.model_group_id'] = ['1']
cif['_ma_model_list.model_name'] = ['Top ranked model']
cif['_ma_model_list.model_group_name'] = [
f'AlphaFold {model_type} v{version.__version__} model'
]
cif['_ma_model_list.data_id'] = ['1']
cif['_ma_model_list.model_type'] = ['Ab initio model']
# Software used.
cif['_software.pdbx_ordinal'] = ['1']
cif['_software.name'] = ['AlphaFold']
cif['_software.version'] = [f'v{version.__version__}']
cif['_software.type'] = ['package']
cif['_software.description'] = ['Structure prediction']
cif['_software.classification'] = ['other']
cif['_software.date'] = ['?']
# Collection of software into groups.
cif['_ma_software_group.ordinal_id'] = ['1']
cif['_ma_software_group.group_id'] = ['1']
cif['_ma_software_group.software_id'] = ['1']
# Method description to conform with ModelCIF.
cif['_ma_protocol_step.ordinal_id'] = ['1', '2', '3']
cif['_ma_protocol_step.protocol_id'] = ['1', '1', '1']
cif['_ma_protocol_step.step_id'] = ['1', '2', '3']
cif['_ma_protocol_step.method_type'] = [
'coevolution MSA',
'template search',
'modeling',
]
# Details of the metrics use to assess model confidence.
cif['_ma_qa_metric.id'] = ['1', '2']
cif['_ma_qa_metric.name'] = ['pLDDT', 'pLDDT']
# Accepted values are distance, energy, normalised score, other, zscore.
cif['_ma_qa_metric.type'] = ['pLDDT', 'pLDDT']
cif['_ma_qa_metric.mode'] = ['global', 'local']
cif['_ma_qa_metric.software_group_id'] = ['1', '1']
# Global model confidence metric value.
cif['_ma_qa_metric_global.ordinal_id'] = ['1']
cif['_ma_qa_metric_global.model_id'] = ['1']
cif['_ma_qa_metric_global.metric_id'] = ['1']
global_plddt = np.mean(
[float(v) for v in old_cif['_atom_site.B_iso_or_equiv']]
)
cif['_ma_qa_metric_global.metric_value'] = [f'{global_plddt:.2f}']
cif['_atom_type.symbol'] = sorted(set(old_cif['_atom_site.type_symbol']))
return cif
...@@ -13,11 +13,18 @@ ...@@ -13,11 +13,18 @@
# limitations under the License. # limitations under the License.
"""Protein data type.""" """Protein data type."""
import collections
import dataclasses import dataclasses
import functools
import io import io
from typing import Any, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional, Tuple
from alphafold.common import mmcif_metadata
from alphafold.common import residue_constants from alphafold.common import residue_constants
from Bio.PDB import MMCIFParser
from Bio.PDB import PDBParser from Bio.PDB import PDBParser
from Bio.PDB.mmcifio import MMCIFIO
from Bio.PDB.Structure import Structure
import numpy as np import numpy as np
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
...@@ -27,6 +34,32 @@ ModelOutput = Mapping[str, Any] # Is a nested dict. ...@@ -27,6 +34,32 @@ ModelOutput = Mapping[str, Any] # Is a nested dict.
PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
# Data to fill the _chem_comp table when writing mmCIFs.
_CHEM_COMP: Mapping[str, Tuple[Tuple[str, str], ...]] = {
'L-peptide linking': (
('ALA', 'ALANINE'),
('ARG', 'ARGININE'),
('ASN', 'ASPARAGINE'),
('ASP', 'ASPARTIC ACID'),
('CYS', 'CYSTEINE'),
('GLN', 'GLUTAMINE'),
('GLU', 'GLUTAMIC ACID'),
('HIS', 'HISTIDINE'),
('ILE', 'ISOLEUCINE'),
('LEU', 'LEUCINE'),
('LYS', 'LYSINE'),
('MET', 'METHIONINE'),
('PHE', 'PHENYLALANINE'),
('PRO', 'PROLINE'),
('SER', 'SERINE'),
('THR', 'THREONINE'),
('TRP', 'TRYPTOPHAN'),
('TYR', 'TYROSINE'),
('VAL', 'VALINE'),
),
'peptide linking': (('GLY', 'GLYCINE'),),
}
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Protein: class Protein:
...@@ -63,27 +96,32 @@ class Protein: ...@@ -63,27 +96,32 @@ class Protein:
'because these cannot be written to PDB format.') 'because these cannot be written to PDB format.')
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: def _from_bio_structure(
"""Takes a PDB string and constructs a Protein object. structure: Structure, chain_id: Optional[str] = None
) -> Protein:
"""Takes a Biopython structure and creates a `Protein` instance.
WARNING: All non-standard residue types will be converted into UNK. All WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored. non-standard atoms will be ignored.
Args: Args:
pdb_str: The contents of the pdb file structure: Structure from the Biopython library.
chain_id: If chain_id is specified (e.g. A), then only that chain chain_id: If chain_id is specified (e.g. A), then only that chain is parsed.
is parsed. Otherwise all chains are parsed. Otherwise all chains are parsed.
Returns: Returns:
A new `Protein` parsed from the pdb contents. A new `Protein` created from the structure contents.
Raises:
ValueError: If the number of models included in the structure is not 1.
ValueError: If insertion code is detected at a residue.
""" """
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure('none', pdb_fh)
models = list(structure.get_models()) models = list(structure.get_models())
if len(models) != 1: if len(models) != 1:
raise ValueError( raise ValueError(
f'Only single model PDBs are supported. Found {len(models)} models.') 'Only single model PDBs/mmCIFs are supported. Found'
f' {len(models)} models.'
)
model = models[0] model = models[0]
atom_positions = [] atom_positions = []
...@@ -99,8 +137,9 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: ...@@ -99,8 +137,9 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
for res in chain: for res in chain:
if res.id[2] != ' ': if res.id[2] != ' ':
raise ValueError( raise ValueError(
f'PDB contains an insertion code at chain {chain.id} and residue ' f'PDB/mmCIF contains an insertion code at chain {chain.id} and'
f'index {res.id[1]}. These are not supported.') f' residue index {res.id[1]}. These are not supported.'
)
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get( restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num) res_shortname, residue_constants.restype_num)
...@@ -137,6 +176,48 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: ...@@ -137,6 +176,48 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
b_factors=np.array(b_factors)) b_factors=np.array(b_factors))
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a `Protein` object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If chain_id is specified (e.g. A), then only that chain is parsed.
Otherwise all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
with io.StringIO(pdb_str) as pdb_fh:
parser = PDBParser(QUIET=True)
structure = parser.get_structure(id='none', file=pdb_fh)
return _from_bio_structure(structure, chain_id)
def from_mmcif_string(
mmcif_str: str, chain_id: Optional[str] = None
) -> Protein:
"""Takes a mmCIF string and constructs a `Protein` object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
mmcif_str: The contents of the mmCIF file
chain_id: If chain_id is specified (e.g. A), then only that chain is parsed.
Otherwise all chains are parsed.
Returns:
A new `Protein` parsed from the mmCIF contents.
"""
with io.StringIO(mmcif_str) as mmcif_fh:
parser = MMCIFParser(QUIET=True)
structure = parser.get_structure(structure_id='none', filename=mmcif_fh)
return _from_bio_structure(structure, chain_id)
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER' chain_end = 'TER'
return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
...@@ -276,3 +357,223 @@ def from_prediction( ...@@ -276,3 +357,223 @@ def from_prediction(
residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1, residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1,
chain_index=chain_index, chain_index=chain_index,
b_factors=b_factors) b_factors=b_factors)
def to_mmcif(
prot: Protein,
file_id: str,
model_type: str,
) -> str:
"""Converts a `Protein` instance to an mmCIF string.
WARNING 1: The _entity_poly_seq is filled with unknown (UNK) residues for any
missing residue indices in the range from min(1, min(residue_index)) to
max(residue_index). E.g. for a protein object with positions for residues
2 (MET), 3 (LYS), 6 (GLY), this method would set the _entity_poly_seq to:
1 UNK
2 MET
3 LYS
4 UNK
5 UNK
6 GLY
This is done to preserve the residue numbering.
WARNING 2: Converting ground truth mmCIF file to Protein and then back to
mmCIF using this method will convert all non-standard residue types to UNK.
If you need this behaviour, you need to store more mmCIF metadata in the
Protein object (e.g. all fields except for the _atom_site loop).
WARNING 3: Converting ground truth mmCIF file to Protein and then back to
mmCIF using this method will not retain the original chain indices.
WARNING 4: In case of multiple identical chains, they are assigned different
`_atom_site.label_entity_id` values.
Args:
prot: A protein to convert to mmCIF string.
file_id: The file ID (usually the PDB ID) to be used in the mmCIF.
model_type: 'Multimer' or 'Monomer'.
Returns:
A valid mmCIF string.
Raises:
ValueError: If aminoacid types array contains entries with too many protein
types.
"""
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
# We count unknown residues as protein residues.
for entity_id in np.unique(chain_index): # np.unique gives sorted output.
chain_ids[entity_id] = _int_id_to_str_id(entity_id + 1)
mmcif_dict = collections.defaultdict(list)
mmcif_dict['data_'] = file_id.upper()
mmcif_dict['_entry.id'] = file_id.upper()
label_asym_id_to_entity_id = {}
# Entity and chain information.
for entity_id, chain_id in chain_ids.items():
# Add all chain information to the _struct_asym table.
label_asym_id_to_entity_id[str(chain_id)] = str(entity_id)
mmcif_dict['_struct_asym.id'].append(chain_id)
mmcif_dict['_struct_asym.entity_id'].append(str(entity_id))
# Add information about the entity to the _entity_poly table.
mmcif_dict['_entity_poly.entity_id'].append(str(entity_id))
mmcif_dict['_entity_poly.type'].append(residue_constants.PROTEIN_CHAIN)
mmcif_dict['_entity_poly.pdbx_strand_id'].append(chain_id)
# Generate the _entity table.
mmcif_dict['_entity.id'].append(str(entity_id))
mmcif_dict['_entity.type'].append(residue_constants.POLYMER_CHAIN)
# Add the residues to the _entity_poly_seq table.
for entity_id, (res_ids, aas) in _get_entity_poly_seq(
aatype, residue_index, chain_index
).items():
for res_id, aa in zip(res_ids, aas):
mmcif_dict['_entity_poly_seq.entity_id'].append(str(entity_id))
mmcif_dict['_entity_poly_seq.num'].append(str(res_id))
mmcif_dict['_entity_poly_seq.mon_id'].append(
residue_constants.resnames[aa]
)
# Populate the chem comp table.
for chem_type, chem_comp in _CHEM_COMP.items():
for chem_id, chem_name in chem_comp:
mmcif_dict['_chem_comp.id'].append(chem_id)
mmcif_dict['_chem_comp.type'].append(chem_type)
mmcif_dict['_chem_comp.name'].append(chem_name)
# Add all atom sites.
atom_index = 1
for i in range(aatype.shape[0]):
res_name_3 = residue_constants.resnames[aatype[i]]
if aatype[i] <= len(residue_constants.restypes):
atom_names = residue_constants.atom_types
else:
raise ValueError(
'Amino acid types array contains entries with too many protein types.'
)
for atom_name, pos, mask, b_factor in zip(
atom_names, atom_positions[i], atom_mask[i], b_factors[i]
):
if mask < 0.5:
continue
type_symbol = residue_constants.atom_id_to_type(atom_name)
mmcif_dict['_atom_site.group_PDB'].append('ATOM')
mmcif_dict['_atom_site.id'].append(str(atom_index))
mmcif_dict['_atom_site.type_symbol'].append(type_symbol)
mmcif_dict['_atom_site.label_atom_id'].append(atom_name)
mmcif_dict['_atom_site.label_alt_id'].append('.')
mmcif_dict['_atom_site.label_comp_id'].append(res_name_3)
mmcif_dict['_atom_site.label_asym_id'].append(chain_ids[chain_index[i]])
mmcif_dict['_atom_site.label_entity_id'].append(
label_asym_id_to_entity_id[chain_ids[chain_index[i]]]
)
mmcif_dict['_atom_site.label_seq_id'].append(str(residue_index[i]))
mmcif_dict['_atom_site.pdbx_PDB_ins_code'].append('.')
mmcif_dict['_atom_site.Cartn_x'].append(f'{pos[0]:.3f}')
mmcif_dict['_atom_site.Cartn_y'].append(f'{pos[1]:.3f}')
mmcif_dict['_atom_site.Cartn_z'].append(f'{pos[2]:.3f}')
mmcif_dict['_atom_site.occupancy'].append('1.00')
mmcif_dict['_atom_site.B_iso_or_equiv'].append(f'{b_factor:.2f}')
mmcif_dict['_atom_site.auth_seq_id'].append(str(residue_index[i]))
mmcif_dict['_atom_site.auth_asym_id'].append(chain_ids[chain_index[i]])
mmcif_dict['_atom_site.pdbx_PDB_model_num'].append('1')
atom_index += 1
metadata_dict = mmcif_metadata.add_metadata_to_mmcif(mmcif_dict, model_type)
mmcif_dict.update(metadata_dict)
return _create_mmcif_string(mmcif_dict)
@functools.lru_cache(maxsize=256)
def _int_id_to_str_id(num: int) -> str:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if num <= 0:
raise ValueError(f'Only positive integers allowed, got {num}.')
num = num - 1 # 1-based indexing.
output = []
while num >= 0:
output.append(chr(num % 26 + ord('A')))
num = num // 26 - 1
return ''.join(output)
def _get_entity_poly_seq(
aatypes: np.ndarray, residue_indices: np.ndarray, chain_indices: np.ndarray
) -> Dict[int, Tuple[List[int], List[int]]]:
"""Constructs gapless residue index and aatype lists for each chain.
Args:
aatypes: A numpy array with aatypes.
residue_indices: A numpy array with residue indices.
chain_indices: A numpy array with chain indices.
Returns:
A dictionary mapping chain indices to a tuple with list of residue indices
and a list of aatypes. Missing residues are filled with UNK residue type.
"""
if (
aatypes.shape[0] != residue_indices.shape[0]
or aatypes.shape[0] != chain_indices.shape[0]
):
raise ValueError(
'aatypes, residue_indices, chain_indices must have the same length.'
)
# Group the present residues by chain index.
present = collections.defaultdict(list)
for chain_id, res_id, aa in zip(chain_indices, residue_indices, aatypes):
present[chain_id].append((res_id, aa))
# Add any missing residues (from 1 to the first residue and for any gaps).
entity_poly_seq = {}
for chain_id, present_residues in present.items():
present_residue_indices = set([x[0] for x in present_residues])
min_res_id = min(present_residue_indices) # Could be negative.
max_res_id = max(present_residue_indices)
new_residue_indices = []
new_aatypes = []
present_index = 0
for i in range(min(1, min_res_id), max_res_id + 1):
new_residue_indices.append(i)
if i in present_residue_indices:
new_aatypes.append(present_residues[present_index][1])
present_index += 1
else:
new_aatypes.append(20) # Unknown amino acid type.
entity_poly_seq[chain_id] = (new_residue_indices, new_aatypes)
return entity_poly_seq
def _create_mmcif_string(mmcif_dict: Dict[str, Any]) -> str:
"""Converts mmCIF dictionary into mmCIF string."""
mmcifio = MMCIFIO()
mmcifio.set_dict(mmcif_dict)
with io.StringIO() as file_handle:
mmcifio.save(file_handle)
return file_handle.getvalue()
...@@ -82,16 +82,55 @@ class ProteinTest(parameterized.TestCase): ...@@ -82,16 +82,55 @@ class ProteinTest(parameterized.TestCase):
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
prot_reconstr.b_factors, prot.b_factors) prot_reconstr.b_factors, prot.b_factors)
@parameterized.named_parameters(
dict(
testcase_name='glucagon',
pdb_file='glucagon.pdb',
model_type='Monomer',
),
dict(testcase_name='7bui', pdb_file='5nmu.pdb', model_type='Multimer'),
)
def test_to_mmcif(self, pdb_file, model_type):
with open(
os.path.join(
absltest.get_default_test_srcdir(), TEST_DATA_DIR, pdb_file
)
) as f:
pdb_string = f.read()
prot = protein.from_pdb_string(pdb_string)
file_id = 'test'
mmcif_string = protein.to_mmcif(prot, file_id, model_type)
prot_reconstr = protein.from_mmcif_string(mmcif_string)
np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype)
np.testing.assert_array_almost_equal(
prot_reconstr.atom_positions, prot.atom_positions
)
np.testing.assert_array_almost_equal(
prot_reconstr.atom_mask, prot.atom_mask
)
np.testing.assert_array_equal(
prot_reconstr.residue_index, prot.residue_index
)
np.testing.assert_array_equal(prot_reconstr.chain_index, prot.chain_index)
np.testing.assert_array_almost_equal(
prot_reconstr.b_factors, prot.b_factors
)
def test_ideal_atom_mask(self): def test_ideal_atom_mask(self):
with open( with open(
os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, os.path.join(
'2rbg.pdb')) as f: absltest.get_default_test_srcdir(), TEST_DATA_DIR, '2rbg.pdb'
)
) as f:
pdb_string = f.read() pdb_string = f.read()
prot = protein.from_pdb_string(pdb_string) prot = protein.from_pdb_string(pdb_string)
ideal_mask = protein.ideal_atom_mask(prot) ideal_mask = protein.ideal_atom_mask(prot)
non_ideal_residues = set([102] + list(range(127, 286))) non_ideal_residues = set([102] + list(range(127, 286)))
for i, (res, atom_mask) in enumerate( for i, (res, atom_mask) in enumerate(
zip(prot.residue_index, prot.atom_mask)): zip(prot.residue_index, prot.atom_mask)
):
if res in non_ideal_residues: if res in non_ideal_residues:
self.assertFalse(np.all(atom_mask == ideal_mask[i]), msg=f'{res}') self.assertFalse(np.all(atom_mask == ideal_mask[i]), msg=f'{res}')
else: else:
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import collections import collections
import functools import functools
import os import os
from typing import List, Mapping, Tuple from typing import Final, List, Mapping, Tuple
import numpy as np import numpy as np
import tree import tree
...@@ -609,6 +609,35 @@ restype_1to3 = { ...@@ -609,6 +609,35 @@ restype_1to3 = {
'V': 'VAL', 'V': 'VAL',
} }
PROTEIN_CHAIN: Final[str] = 'polypeptide(L)'
POLYMER_CHAIN: Final[str] = 'polymer'
def atom_id_to_type(atom_id: str) -> str:
"""Convert atom ID to atom type, works only for standard protein residues.
Args:
atom_id: Atom ID to be converted.
Returns:
String corresponding to atom type.
Raises:
ValueError: If atom ID not recognized.
"""
if atom_id.startswith('C'):
return 'C'
elif atom_id.startswith('N'):
return 'N'
elif atom_id.startswith('O'):
return 'O'
elif atom_id.startswith('H'):
return 'H'
elif atom_id.startswith('S'):
return 'S'
raise ValueError('Atom ID not recognized.')
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple # NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains # 1-to-1 mapping of 3 letter names to one letter names. The latter contains
......
This diff is collapsed.
This diff is collapsed.
...@@ -315,6 +315,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader: ...@@ -315,6 +315,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
try: try:
raw_resolution = parsed_info[res_key][0] raw_resolution = parsed_info[res_key][0]
header['resolution'] = float(raw_resolution) header['resolution'] = float(raw_resolution)
break
except ValueError: except ValueError:
logging.debug('Invalid resolution format: %s', parsed_info[res_key]) logging.debug('Invalid resolution format: %s', parsed_info[res_key])
......
...@@ -15,9 +15,7 @@ ...@@ -15,9 +15,7 @@
"""Pairing logic for multimer data pipeline.""" """Pairing logic for multimer data pipeline."""
import collections import collections
import functools from typing import cast, Dict, Iterable, List, Sequence
import string
from typing import Any, Dict, Iterable, List, Sequence
from alphafold.common import residue_constants from alphafold.common import residue_constants
from alphafold.data import pipeline from alphafold.data import pipeline
...@@ -135,7 +133,7 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]: ...@@ -135,7 +133,7 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
"""Creates mapping from species to msa dataframe of that species.""" """Creates mapping from species to msa dataframe of that species."""
species_lookup = {} species_lookup = {}
for species, species_df in msa_df.groupby('msa_species_identifiers'): for species, species_df in msa_df.groupby('msa_species_identifiers'):
species_lookup[species] = species_df species_lookup[cast(bytes, species)] = species_df
return species_lookup return species_lookup
......
...@@ -449,6 +449,7 @@ def _get_atom_positions( ...@@ -449,6 +449,7 @@ def _get_atom_positions(
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32) mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][res_index] res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][res_index]
if not res_at_position.is_missing: if not res_at_position.is_missing:
assert res_at_position.position is not None
res = chain[(res_at_position.hetflag, res = chain[(res_at_position.hetflag,
res_at_position.position.residue_number, res_at_position.position.residue_number,
res_at_position.position.insertion_code)] res_at_position.position.insertion_code)]
......
...@@ -775,7 +775,7 @@ def compute_atom14_gt( ...@@ -775,7 +775,7 @@ def compute_atom14_gt(
gt_mask = (1. - use_alt) * gt_mask + use_alt * alt_gt_mask gt_mask = (1. - use_alt) * gt_mask + use_alt * alt_gt_mask
gt_positions = (1. - use_alt) * gt_positions + use_alt * alt_gt_positions gt_positions = (1. - use_alt) * gt_positions + use_alt * alt_gt_positions
return gt_positions, alt_gt_mask, alt_naming_is_better return gt_positions, gt_mask, alt_naming_is_better
def backbone_loss(gt_rigid: geometry.Rigid3Array, def backbone_loss(gt_rigid: geometry.Rigid3Array,
......
...@@ -61,9 +61,9 @@ def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array): ...@@ -61,9 +61,9 @@ def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array): def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.) np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-5, rtol=0.)
np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.) np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-5, rtol=0.)
np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.) np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-5, rtol=0.)
def assert_array_close_to_vector(array: jnp.ndarray, vec: vector.Vec3Array): def assert_array_close_to_vector(array: jnp.ndarray, vec: vector.Vec3Array):
......
...@@ -29,8 +29,7 @@ class PrngTest(absltest.TestCase): ...@@ -29,8 +29,7 @@ class PrngTest(absltest.TestCase):
raw_key = safe_key.get() raw_key = safe_key.get()
self.assertNotEqual(raw_key[0], init_key[0]) self.assertFalse((raw_key == init_key).all())
self.assertNotEqual(raw_key[1], init_key[1])
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
safe_key.get() safe_key.get()
......
...@@ -160,8 +160,14 @@ def padding_consistent_rng(f): ...@@ -160,8 +160,14 @@ def padding_consistent_rng(f):
return jax.vmap(functools.partial(grid_keys, shape=shape[1:]))(new_keys) return jax.vmap(functools.partial(grid_keys, shape=shape[1:]))(new_keys)
def inner(key, shape, **kwargs): def inner(key, shape, **kwargs):
keys = grid_keys(key, shape)
signature = (
'()->()'
if jax.dtypes.issubdtype(keys.dtype, jax.dtypes.prng_key)
else '(2)->()'
)
return jnp.vectorize( return jnp.vectorize(
lambda key: f(key, shape=(), **kwargs), functools.partial(f, shape=(), **kwargs), signature=signature
signature='(2)->()')( )(keys)
grid_keys(key, shape))
return inner return inner
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
"""Helper methods for the AlphaFold Colab notebook.""" """Helper methods for the AlphaFold Colab notebook."""
import json
from typing import AbstractSet, Any, Mapping, Optional, Sequence from typing import AbstractSet, Any, Mapping, Optional, Sequence
from alphafold.common import residue_constants from alphafold.common import residue_constants
...@@ -143,31 +142,6 @@ def empty_placeholder_template_features( ...@@ -143,31 +142,6 @@ def empty_placeholder_template_features(
} }
def get_pae_json(pae: np.ndarray, max_pae: float) -> str:
"""Returns the PAE in the same format as is used in the AFDB.
Note that the values are presented as floats to 1 decimal place,
whereas AFDB returns integer values.
Args:
pae: The n_res x n_res PAE array.
max_pae: The maximum possible PAE value.
Returns:
PAE output format as a JSON string.
"""
# Check the PAE array is the correct shape.
if (pae.ndim != 2 or pae.shape[0] != pae.shape[1]):
raise ValueError(f'PAE must be a square matrix, got {pae.shape}')
# Round the predicted aligned errors to 1 decimal place.
rounded_errors = np.round(pae.astype(np.float64), decimals=1)
formatted_output = [{
'predicted_aligned_error': rounded_errors.tolist(),
'max_predicted_aligned_error': max_pae
}]
return json.dumps(formatted_output, indent=None, separators=(',', ':'))
def check_cell_execution_order( def check_cell_execution_order(
cells_ran: AbstractSet[int], cell_number: int) -> None: cells_ran: AbstractSet[int], cell_number: int) -> None:
"""Check that the cell execution order is correct. """Check that the cell execution order is correct.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment