Commit eb93322b authored by mashun1's avatar mashun1
Browse files

dtk24.04.1

parents
HEADER HORMONE 17-OCT-77 1GCN
TITLE X-RAY ANALYSIS OF GLUCAGON AND ITS RELATIONSHIP TO RECEPTOR
TITLE 2 BINDING
COMPND MOL_ID: 1;
COMPND 2 MOLECULE: GLUCAGON;
COMPND 3 CHAIN: A;
COMPND 4 ENGINEERED: YES
SOURCE MOL_ID: 1;
SOURCE 2 ORGANISM_SCIENTIFIC: SUS SCROFA;
SOURCE 3 ORGANISM_COMMON: PIG;
SOURCE 4 ORGANISM_TAXID: 9823
KEYWDS HORMONE
EXPDTA X-RAY DIFFRACTION
AUTHOR T.L.BLUNDELL,K.SASAKI,S.DOCKERILL,I.J.TICKLE
REVDAT 6 24-FEB-09 1GCN 1 VERSN
REVDAT 5 30-SEP-83 1GCN 1 REVDAT
REVDAT 4 31-DEC-80 1GCN 1 REMARK
REVDAT 3 22-OCT-79 1GCN 3 ATOM
REVDAT 2 29-AUG-79 1GCN 3 CRYST1
REVDAT 1 28-NOV-77 1GCN 0
JRNL AUTH K.SASAKI,S.DOCKERILL,D.A.ADAMIAK,I.J.TICKLE,
JRNL AUTH 2 T.BLUNDELL
JRNL TITL X-RAY ANALYSIS OF GLUCAGON AND ITS RELATIONSHIP TO
JRNL TITL 2 RECEPTOR BINDING.
JRNL REF NATURE V. 257 751 1975
JRNL REFN ISSN 0028-0836
JRNL PMID 171582
JRNL DOI 10.1038/257751A0
REMARK 1
REMARK 1 REFERENCE 1
REMARK 1 EDIT M.O.DAYHOFF
REMARK 1 REF ATLAS OF PROTEIN SEQUENCE V. 5 125 1976
REMARK 1 REF 2 AND STRUCTURE,SUPPLEMENT 2
REMARK 1 PUBL NATIONAL BIOMEDICAL RESEARCH FOUNDATION, SILVER
REMARK 1 PUBL 2 SPRING,MD.
REMARK 1 REFN ISSN 0-912466-05-7
REMARK 2
REMARK 2 RESOLUTION. 3.00 ANGSTROMS.
REMARK 3
REMARK 3 REFINEMENT.
REMARK 3 PROGRAM : NULL
REMARK 3 AUTHORS : NULL
REMARK 3
REMARK 3 DATA USED IN REFINEMENT.
REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 3.00
REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : NULL
REMARK 3 DATA CUTOFF (SIGMA(F)) : NULL
REMARK 3 DATA CUTOFF HIGH (ABS(F)) : NULL
REMARK 3 DATA CUTOFF LOW (ABS(F)) : NULL
REMARK 3 COMPLETENESS (WORKING+TEST) (%) : NULL
REMARK 3 NUMBER OF REFLECTIONS : NULL
REMARK 3
REMARK 3 FIT TO DATA USED IN REFINEMENT.
REMARK 3 CROSS-VALIDATION METHOD : NULL
REMARK 3 FREE R VALUE TEST SET SELECTION : NULL
REMARK 3 R VALUE (WORKING SET) : NULL
REMARK 3 FREE R VALUE : NULL
REMARK 3 FREE R VALUE TEST SET SIZE (%) : NULL
REMARK 3 FREE R VALUE TEST SET COUNT : NULL
REMARK 3 ESTIMATED ERROR OF FREE R VALUE : NULL
REMARK 3
REMARK 3 FIT IN THE HIGHEST RESOLUTION BIN.
REMARK 3 TOTAL NUMBER OF BINS USED : NULL
REMARK 3 BIN RESOLUTION RANGE HIGH (A) : NULL
REMARK 3 BIN RESOLUTION RANGE LOW (A) : NULL
REMARK 3 BIN COMPLETENESS (WORKING+TEST) (%) : NULL
REMARK 3 REFLECTIONS IN BIN (WORKING SET) : NULL
REMARK 3 BIN R VALUE (WORKING SET) : NULL
REMARK 3 BIN FREE R VALUE : NULL
REMARK 3 BIN FREE R VALUE TEST SET SIZE (%) : NULL
REMARK 3 BIN FREE R VALUE TEST SET COUNT : NULL
REMARK 3 ESTIMATED ERROR OF BIN FREE R VALUE : NULL
REMARK 3
REMARK 3 NUMBER OF NON-HYDROGEN ATOMS USED IN REFINEMENT.
REMARK 3 PROTEIN ATOMS : 246
REMARK 3 NUCLEIC ACID ATOMS : 0
REMARK 3 HETEROGEN ATOMS : 0
REMARK 3 SOLVENT ATOMS : 0
REMARK 3
REMARK 3 B VALUES.
REMARK 3 FROM WILSON PLOT (A**2) : NULL
REMARK 3 MEAN B VALUE (OVERALL, A**2) : NULL
REMARK 3 OVERALL ANISOTROPIC B VALUE.
REMARK 3 B11 (A**2) : NULL
REMARK 3 B22 (A**2) : NULL
REMARK 3 B33 (A**2) : NULL
REMARK 3 B12 (A**2) : NULL
REMARK 3 B13 (A**2) : NULL
REMARK 3 B23 (A**2) : NULL
REMARK 3
REMARK 3 ESTIMATED COORDINATE ERROR.
REMARK 3 ESD FROM LUZZATI PLOT (A) : NULL
REMARK 3 ESD FROM SIGMAA (A) : NULL
REMARK 3 LOW RESOLUTION CUTOFF (A) : NULL
REMARK 3
REMARK 3 CROSS-VALIDATED ESTIMATED COORDINATE ERROR.
REMARK 3 ESD FROM C-V LUZZATI PLOT (A) : NULL
REMARK 3 ESD FROM C-V SIGMAA (A) : NULL
REMARK 3
REMARK 3 RMS DEVIATIONS FROM IDEAL VALUES.
REMARK 3 BOND LENGTHS (A) : NULL
REMARK 3 BOND ANGLES (DEGREES) : NULL
REMARK 3 DIHEDRAL ANGLES (DEGREES) : NULL
REMARK 3 IMPROPER ANGLES (DEGREES) : NULL
REMARK 3
REMARK 3 ISOTROPIC THERMAL MODEL : NULL
REMARK 3
REMARK 3 ISOTROPIC THERMAL FACTOR RESTRAINTS. RMS SIGMA
REMARK 3 MAIN-CHAIN BOND (A**2) : NULL ; NULL
REMARK 3 MAIN-CHAIN ANGLE (A**2) : NULL ; NULL
REMARK 3 SIDE-CHAIN BOND (A**2) : NULL ; NULL
REMARK 3 SIDE-CHAIN ANGLE (A**2) : NULL ; NULL
REMARK 3
REMARK 3 NCS MODEL : NULL
REMARK 3
REMARK 3 NCS RESTRAINTS. RMS SIGMA/WEIGHT
REMARK 3 GROUP 1 POSITIONAL (A) : NULL ; NULL
REMARK 3 GROUP 1 B-FACTOR (A**2) : NULL ; NULL
REMARK 3
REMARK 3 PARAMETER FILE 1 : NULL
REMARK 3 TOPOLOGY FILE 1 : NULL
REMARK 3
REMARK 3 OTHER REFINEMENT REMARKS: NULL
REMARK 4
REMARK 4 1GCN COMPLIES WITH FORMAT V. 3.15, 01-DEC-08
REMARK 100
REMARK 100 THIS ENTRY HAS BEEN PROCESSED BY BNL.
REMARK 200
REMARK 200 EXPERIMENTAL DETAILS
REMARK 200 EXPERIMENT TYPE : X-RAY DIFFRACTION
REMARK 200 DATE OF DATA COLLECTION : NULL
REMARK 200 TEMPERATURE (KELVIN) : NULL
REMARK 200 PH : NULL
REMARK 200 NUMBER OF CRYSTALS USED : NULL
REMARK 200
REMARK 200 SYNCHROTRON (Y/N) : NULL
REMARK 200 RADIATION SOURCE : NULL
REMARK 200 BEAMLINE : NULL
REMARK 200 X-RAY GENERATOR MODEL : NULL
REMARK 200 MONOCHROMATIC OR LAUE (M/L) : NULL
REMARK 200 WAVELENGTH OR RANGE (A) : NULL
REMARK 200 MONOCHROMATOR : NULL
REMARK 200 OPTICS : NULL
REMARK 200
REMARK 200 DETECTOR TYPE : NULL
REMARK 200 DETECTOR MANUFACTURER : NULL
REMARK 200 INTENSITY-INTEGRATION SOFTWARE : NULL
REMARK 200 DATA SCALING SOFTWARE : NULL
REMARK 200
REMARK 200 NUMBER OF UNIQUE REFLECTIONS : NULL
REMARK 200 RESOLUTION RANGE HIGH (A) : NULL
REMARK 200 RESOLUTION RANGE LOW (A) : NULL
REMARK 200 REJECTION CRITERIA (SIGMA(I)) : NULL
REMARK 200
REMARK 200 OVERALL.
REMARK 200 COMPLETENESS FOR RANGE (%) : NULL
REMARK 200 DATA REDUNDANCY : NULL
REMARK 200 R MERGE (I) : NULL
REMARK 200 R SYM (I) : NULL
REMARK 200 <I/SIGMA(I)> FOR THE DATA SET : NULL
REMARK 200
REMARK 200 IN THE HIGHEST RESOLUTION SHELL.
REMARK 200 HIGHEST RESOLUTION SHELL, RANGE HIGH (A) : NULL
REMARK 200 HIGHEST RESOLUTION SHELL, RANGE LOW (A) : NULL
REMARK 200 COMPLETENESS FOR SHELL (%) : NULL
REMARK 200 DATA REDUNDANCY IN SHELL : NULL
REMARK 200 R MERGE FOR SHELL (I) : NULL
REMARK 200 R SYM FOR SHELL (I) : NULL
REMARK 200 <I/SIGMA(I)> FOR SHELL : NULL
REMARK 200
REMARK 200 DIFFRACTION PROTOCOL: NULL
REMARK 200 METHOD USED TO DETERMINE THE STRUCTURE: NULL
REMARK 200 SOFTWARE USED: NULL
REMARK 200 STARTING MODEL: NULL
REMARK 200
REMARK 200 REMARK: NULL
REMARK 280
REMARK 280 CRYSTAL
REMARK 280 SOLVENT CONTENT, VS (%): 50.74
REMARK 280 MATTHEWS COEFFICIENT, VM (ANGSTROMS**3/DA): 2.50
REMARK 280
REMARK 280 CRYSTALLIZATION CONDITIONS: NULL
REMARK 290
REMARK 290 CRYSTALLOGRAPHIC SYMMETRY
REMARK 290 SYMMETRY OPERATORS FOR SPACE GROUP: P 21 3
REMARK 290
REMARK 290 SYMOP SYMMETRY
REMARK 290 NNNMMM OPERATOR
REMARK 290 1555 X,Y,Z
REMARK 290 2555 -X+1/2,-Y,Z+1/2
REMARK 290 3555 -X,Y+1/2,-Z+1/2
REMARK 290 4555 X+1/2,-Y+1/2,-Z
REMARK 290 5555 Z,X,Y
REMARK 290 6555 Z+1/2,-X+1/2,-Y
REMARK 290 7555 -Z+1/2,-X,Y+1/2
REMARK 290 8555 -Z,X+1/2,-Y+1/2
REMARK 290 9555 Y,Z,X
REMARK 290 10555 -Y,Z+1/2,-X+1/2
REMARK 290 11555 Y+1/2,-Z+1/2,-X
REMARK 290 12555 -Y+1/2,-Z,X+1/2
REMARK 290
REMARK 290 WHERE NNN -> OPERATOR NUMBER
REMARK 290 MMM -> TRANSLATION VECTOR
REMARK 290
REMARK 290 CRYSTALLOGRAPHIC SYMMETRY TRANSFORMATIONS
REMARK 290 THE FOLLOWING TRANSFORMATIONS OPERATE ON THE ATOM/HETATM
REMARK 290 RECORDS IN THIS ENTRY TO PRODUCE CRYSTALLOGRAPHICALLY
REMARK 290 RELATED MOLECULES.
REMARK 290 SMTRY1 1 1.000000 0.000000 0.000000 0.00000
REMARK 290 SMTRY2 1 0.000000 1.000000 0.000000 0.00000
REMARK 290 SMTRY3 1 0.000000 0.000000 1.000000 0.00000
REMARK 290 SMTRY1 2 -1.000000 0.000000 0.000000 23.55000
REMARK 290 SMTRY2 2 0.000000 -1.000000 0.000000 0.00000
REMARK 290 SMTRY3 2 0.000000 0.000000 1.000000 23.55000
REMARK 290 SMTRY1 3 -1.000000 0.000000 0.000000 0.00000
REMARK 290 SMTRY2 3 0.000000 1.000000 0.000000 23.55000
REMARK 290 SMTRY3 3 0.000000 0.000000 -1.000000 23.55000
REMARK 290 SMTRY1 4 1.000000 0.000000 0.000000 23.55000
REMARK 290 SMTRY2 4 0.000000 -1.000000 0.000000 23.55000
REMARK 290 SMTRY3 4 0.000000 0.000000 -1.000000 0.00000
REMARK 290 SMTRY1 5 0.000000 0.000000 1.000000 0.00000
REMARK 290 SMTRY2 5 1.000000 0.000000 0.000000 0.00000
REMARK 290 SMTRY3 5 0.000000 1.000000 0.000000 0.00000
REMARK 290 SMTRY1 6 0.000000 0.000000 1.000000 23.55000
REMARK 290 SMTRY2 6 -1.000000 0.000000 0.000000 23.55000
REMARK 290 SMTRY3 6 0.000000 -1.000000 0.000000 0.00000
REMARK 290 SMTRY1 7 0.000000 0.000000 -1.000000 23.55000
REMARK 290 SMTRY2 7 -1.000000 0.000000 0.000000 0.00000
REMARK 290 SMTRY3 7 0.000000 1.000000 0.000000 23.55000
REMARK 290 SMTRY1 8 0.000000 0.000000 -1.000000 0.00000
REMARK 290 SMTRY2 8 1.000000 0.000000 0.000000 23.55000
REMARK 290 SMTRY3 8 0.000000 -1.000000 0.000000 23.55000
REMARK 290 SMTRY1 9 0.000000 1.000000 0.000000 0.00000
REMARK 290 SMTRY2 9 0.000000 0.000000 1.000000 0.00000
REMARK 290 SMTRY3 9 1.000000 0.000000 0.000000 0.00000
REMARK 290 SMTRY1 10 0.000000 -1.000000 0.000000 0.00000
REMARK 290 SMTRY2 10 0.000000 0.000000 1.000000 23.55000
REMARK 290 SMTRY3 10 -1.000000 0.000000 0.000000 23.55000
REMARK 290 SMTRY1 11 0.000000 1.000000 0.000000 23.55000
REMARK 290 SMTRY2 11 0.000000 0.000000 -1.000000 23.55000
REMARK 290 SMTRY3 11 -1.000000 0.000000 0.000000 0.00000
REMARK 290 SMTRY1 12 0.000000 -1.000000 0.000000 23.55000
REMARK 290 SMTRY2 12 0.000000 0.000000 -1.000000 0.00000
REMARK 290 SMTRY3 12 1.000000 0.000000 0.000000 23.55000
REMARK 290
REMARK 290 REMARK: NULL
REMARK 300
REMARK 300 BIOMOLECULE: 1
REMARK 300 SEE REMARK 350 FOR THE AUTHOR PROVIDED AND/OR PROGRAM
REMARK 300 GENERATED ASSEMBLY INFORMATION FOR THE STRUCTURE IN
REMARK 300 THIS ENTRY. THE REMARK MAY ALSO PROVIDE INFORMATION ON
REMARK 300 BURIED SURFACE AREA.
REMARK 350
REMARK 350 COORDINATES FOR A COMPLETE MULTIMER REPRESENTING THE KNOWN
REMARK 350 BIOLOGICALLY SIGNIFICANT OLIGOMERIZATION STATE OF THE
REMARK 350 MOLECULE CAN BE GENERATED BY APPLYING BIOMT TRANSFORMATIONS
REMARK 350 GIVEN BELOW. BOTH NON-CRYSTALLOGRAPHIC AND
REMARK 350 CRYSTALLOGRAPHIC OPERATIONS ARE GIVEN.
REMARK 350
REMARK 350 BIOMOLECULE: 1
REMARK 350 AUTHOR DETERMINED BIOLOGICAL UNIT: MONOMERIC
REMARK 350 APPLY THE FOLLOWING TO CHAINS: A
REMARK 350 BIOMT1 1 1.000000 0.000000 0.000000 0.00000
REMARK 350 BIOMT2 1 0.000000 1.000000 0.000000 0.00000
REMARK 350 BIOMT3 1 0.000000 0.000000 1.000000 0.00000
REMARK 500
REMARK 500 GEOMETRY AND STEREOCHEMISTRY
REMARK 500 SUBTOPIC: COVALENT BOND LENGTHS
REMARK 500
REMARK 500 THE STEREOCHEMICAL PARAMETERS OF THE FOLLOWING RESIDUES
REMARK 500 HAVE VALUES WHICH DEVIATE FROM EXPECTED VALUES BY MORE
REMARK 500 THAN 6*RMSD (M=MODEL NUMBER; RES=RESIDUE NAME; C=CHAIN
REMARK 500 IDENTIFIER; SSEQ=SEQUENCE NUMBER; I=INSERTION CODE).
REMARK 500
REMARK 500 STANDARD TABLE:
REMARK 500 FORMAT: (10X,I3,1X,2(A3,1X,A1,I4,A1,1X,A4,3X),1X,F6.3)
REMARK 500
REMARK 500 EXPECTED VALUES PROTEIN: ENGH AND HUBER, 1999
REMARK 500 EXPECTED VALUES NUCLEIC ACID: CLOWNEY ET AL 1996
REMARK 500
REMARK 500 M RES CSSEQI ATM1 RES CSSEQI ATM2 DEVIATION
REMARK 500 TYR A 10 CZ TYR A 10 OH -0.387
REMARK 500 TRP A 25 CD1 TRP A 25 NE1 0.287
REMARK 500 TRP A 25 NE1 TRP A 25 CE2 0.109
REMARK 500
REMARK 500 REMARK: NULL
REMARK 500
REMARK 500 GEOMETRY AND STEREOCHEMISTRY
REMARK 500 SUBTOPIC: COVALENT BOND ANGLES
REMARK 500
REMARK 500 THE STEREOCHEMICAL PARAMETERS OF THE FOLLOWING RESIDUES
REMARK 500 HAVE VALUES WHICH DEVIATE FROM EXPECTED VALUES BY MORE
REMARK 500 THAN 6*RMSD (M=MODEL NUMBER; RES=RESIDUE NAME; C=CHAIN
REMARK 500 IDENTIFIER; SSEQ=SEQUENCE NUMBER; I=INSERTION CODE).
REMARK 500
REMARK 500 STANDARD TABLE:
REMARK 500 FORMAT: (10X,I3,1X,A3,1X,A1,I4,A1,3(1X,A4,2X),12X,F5.1)
REMARK 500
REMARK 500 EXPECTED VALUES PROTEIN: ENGH AND HUBER, 1999
REMARK 500 EXPECTED VALUES NUCLEIC ACID: CLOWNEY ET AL 1996
REMARK 500
REMARK 500 M RES CSSEQI ATM1 ATM2 ATM3
REMARK 500 TRP A 25 CG - CD1 - NE1 ANGL. DEV. = 6.7 DEGREES
REMARK 500 TRP A 25 CD1 - NE1 - CE2 ANGL. DEV. = -21.5 DEGREES
REMARK 500 TRP A 25 NE1 - CE2 - CZ2 ANGL. DEV. = -11.0 DEGREES
REMARK 500 TRP A 25 NE1 - CE2 - CD2 ANGL. DEV. = 9.6 DEGREES
REMARK 500
REMARK 500 REMARK: NULL
REMARK 500
REMARK 500 GEOMETRY AND STEREOCHEMISTRY
REMARK 500 SUBTOPIC: TORSION ANGLES
REMARK 500
REMARK 500 TORSION ANGLES OUTSIDE THE EXPECTED RAMACHANDRAN REGIONS:
REMARK 500 (M=MODEL NUMBER; RES=RESIDUE NAME; C=CHAIN IDENTIFIER;
REMARK 500 SSEQ=SEQUENCE NUMBER; I=INSERTION CODE).
REMARK 500
REMARK 500 STANDARD TABLE:
REMARK 500 FORMAT:(10X,I3,1X,A3,1X,A1,I4,A1,4X,F7.2,3X,F7.2)
REMARK 500
REMARK 500 EXPECTED VALUES: GJ KLEYWEGT AND TA JONES (1996). PHI/PSI-
REMARK 500 CHOLOGY: RAMACHANDRAN REVISITED. STRUCTURE 4, 1395 - 1400
REMARK 500
REMARK 500 M RES CSSEQI PSI PHI
REMARK 500 SER A 2 -57.57 -21.14
REMARK 500 THR A 5 54.62 -63.85
REMARK 500 SER A 11 9.62 -51.97
REMARK 500 MET A 27 -93.98 -145.30
REMARK 500 ASN A 28 64.02 15.67
REMARK 500
REMARK 500 REMARK: NULL
REMARK 500
REMARK 500 GEOMETRY AND STEREOCHEMISTRY
REMARK 500 SUBTOPIC: PLANAR GROUPS
REMARK 500
REMARK 500 PLANAR GROUPS IN THE FOLLOWING RESIDUES HAVE A TOTAL
REMARK 500 RMS DISTANCE OF ALL ATOMS FROM THE BEST-FIT PLANE
REMARK 500 BY MORE THAN AN EXPECTED VALUE OF 6*RMSD, WITH AN
REMARK 500 RMSD 0.02 ANGSTROMS, OR AT LEAST ONE ATOM HAS
REMARK 500 AN RMSD GREATER THAN THIS VALUE
REMARK 500 (M=MODEL NUMBER; RES=RESIDUE NAME; C=CHAIN IDENTIFIER;
REMARK 500 SSEQ=SEQUENCE NUMBER; I=INSERTION CODE).
REMARK 500
REMARK 500 M RES CSSEQI RMS TYPE
REMARK 500 ASN A 28 0.08 SIDE_CHAIN
REMARK 500
REMARK 500 REMARK: NULL
REMARK 500
REMARK 500 GEOMETRY AND STEREOCHEMISTRY
REMARK 500 SUBTOPIC: MAIN CHAIN PLANARITY
REMARK 500
REMARK 500 THE FOLLOWING RESIDUES HAVE A PSEUDO PLANARITY
REMARK 500 TORSION, C(I) - CA(I) - N(I+1) - O(I), GREATER
REMARK 500 10.0 DEGREES. (M=MODEL NUMBER; RES=RESIDUE NAME;
REMARK 500 C=CHAIN IDENTIFIER; SSEQ=SEQUENCE NUMBER;
REMARK 500 I=INSERTION CODE).
REMARK 500
REMARK 500 M RES CSSEQI ANGLE
REMARK 500 HIS A 1 19.48
REMARK 500 GLN A 3 -15.78
REMARK 500 GLY A 4 -17.23
REMARK 500 THR A 5 -10.38
REMARK 500 PHE A 6 -12.06
REMARK 500 THR A 7 -14.66
REMARK 500 SER A 11 -15.10
REMARK 500 LYS A 12 14.46
REMARK 500 ALA A 19 -10.92
REMARK 500 GLN A 20 -13.40
REMARK 500 VAL A 23 -15.87
REMARK 500 LEU A 26 -14.56
REMARK 500 MET A 27 -16.22
REMARK 500
REMARK 500 REMARK: NULL
DBREF 1GCN A 1 29 UNP P01274 GLUC_PIG 33 61
SEQRES 1 A 29 HIS SER GLN GLY THR PHE THR SER ASP TYR SER LYS TYR
SEQRES 2 A 29 LEU ASP SER ARG ARG ALA GLN ASP PHE VAL GLN TRP LEU
SEQRES 3 A 29 MET ASN THR
HELIX 1 A PHE A 6 LEU A 26 1 21
CRYST1 47.100 47.100 47.100 90.00 90.00 90.00 P 21 3 12
ORIGX1 0.021231 0.000000 0.000000 0.00000
ORIGX2 0.000000 0.021231 0.000000 0.00000
ORIGX3 0.000000 0.000000 0.021231 0.00000
SCALE1 0.021231 0.000000 0.000000 0.00000
SCALE2 0.000000 0.021231 0.000000 0.00000
SCALE3 0.000000 0.000000 0.021231 0.00000
ATOM 1 N HIS A 1 49.668 24.248 10.436 1.00 25.00 N
ATOM 2 CA HIS A 1 50.197 25.578 10.784 1.00 16.00 C
ATOM 3 C HIS A 1 49.169 26.701 10.917 1.00 16.00 C
ATOM 4 O HIS A 1 48.241 26.524 11.749 1.00 16.00 O
ATOM 5 CB HIS A 1 51.312 26.048 9.843 1.00 16.00 C
ATOM 6 CG HIS A 1 50.958 26.068 8.340 1.00 16.00 C
ATOM 7 ND1 HIS A 1 49.636 26.144 7.860 1.00 16.00 N
ATOM 8 CD2 HIS A 1 51.797 26.043 7.286 1.00 16.00 C
ATOM 9 CE1 HIS A 1 49.691 26.152 6.454 1.00 17.00 C
ATOM 10 NE2 HIS A 1 51.046 26.090 6.098 1.00 17.00 N
ATOM 11 N SER A 2 49.788 27.850 10.784 1.00 16.00 N
ATOM 12 CA SER A 2 49.138 29.147 10.620 1.00 15.00 C
ATOM 13 C SER A 2 47.713 29.006 10.110 1.00 15.00 C
ATOM 14 O SER A 2 46.740 29.251 10.864 1.00 15.00 O
ATOM 15 CB SER A 2 49.875 29.930 9.569 1.00 16.00 C
ATOM 16 OG SER A 2 49.145 31.057 9.176 1.00 19.00 O
ATOM 17 N GLN A 3 47.620 28.367 8.973 1.00 15.00 N
ATOM 18 CA GLN A 3 46.287 28.193 8.308 1.00 14.00 C
ATOM 19 C GLN A 3 45.406 27.172 8.963 1.00 14.00 C
ATOM 20 O GLN A 3 44.198 27.508 9.014 1.00 14.00 O
ATOM 21 CB GLN A 3 46.489 27.963 6.806 1.00 18.00 C
ATOM 22 CG GLN A 3 45.138 27.800 6.111 1.00 21.00 C
ATOM 23 CD GLN A 3 45.304 27.952 4.603 1.00 24.00 C
ATOM 24 OE1 GLN A 3 46.432 28.202 4.112 1.00 24.00 O
ATOM 25 NE2 GLN A 3 44.233 27.647 3.897 1.00 26.00 N
ATOM 26 N GLY A 4 46.014 26.394 9.871 1.00 14.00 N
ATOM 27 CA GLY A 4 45.422 25.287 10.680 1.00 14.00 C
ATOM 28 C GLY A 4 43.892 25.215 10.719 1.00 14.00 C
ATOM 29 O GLY A 4 43.287 26.155 11.288 1.00 14.00 O
ATOM 30 N THR A 5 43.406 23.993 10.767 1.00 14.00 N
ATOM 31 CA THR A 5 42.004 23.642 10.443 1.00 12.00 C
ATOM 32 C THR A 5 40.788 24.146 11.252 1.00 12.00 C
ATOM 33 O THR A 5 39.804 23.384 11.410 1.00 12.00 O
ATOM 34 CB THR A 5 41.934 22.202 9.889 1.00 14.00 C
ATOM 35 OG1 THR A 5 41.080 21.317 10.609 1.00 15.00 O
ATOM 36 CG2 THR A 5 43.317 21.556 9.849 1.00 15.00 C
ATOM 37 N PHE A 6 40.628 25.463 11.441 1.00 12.00 N
ATOM 38 CA PHE A 6 39.381 25.950 12.104 1.00 12.00 C
ATOM 39 C PHE A 6 38.156 25.684 11.232 1.00 12.00 C
ATOM 40 O PHE A 6 37.231 25.002 11.719 1.00 12.00 O
ATOM 41 CB PHE A 6 39.407 27.425 12.584 1.00 12.00 C
ATOM 42 CG PHE A 6 38.187 27.923 13.430 1.00 12.00 C
ATOM 43 CD1 PHE A 6 36.889 27.518 13.163 1.00 12.00 C
ATOM 44 CD2 PHE A 6 38.386 28.862 14.419 1.00 12.00 C
ATOM 45 CE1 PHE A 6 35.813 27.967 13.909 1.00 12.00 C
ATOM 46 CE2 PHE A 6 37.306 29.328 15.177 1.00 12.00 C
ATOM 47 CZ PHE A 6 36.019 28.871 14.928 1.00 12.00 C
ATOM 48 N THR A 7 38.341 25.794 9.956 1.00 12.00 N
ATOM 49 CA THR A 7 37.249 25.666 8.991 1.00 12.00 C
ATOM 50 C THR A 7 36.324 24.452 9.101 1.00 12.00 C
ATOM 51 O THR A 7 35.111 24.637 9.387 1.00 12.00 O
ATOM 52 CB THR A 7 37.884 25.743 7.628 1.00 13.00 C
ATOM 53 OG1 THR A 7 37.940 27.122 7.317 1.00 14.00 O
ATOM 54 CG2 THR A 7 37.073 25.003 6.585 1.00 14.00 C
ATOM 55 N SER A 8 36.964 23.356 9.442 1.00 12.00 N
ATOM 56 CA SER A 8 36.286 22.063 9.486 1.00 12.00 C
ATOM 57 C SER A 8 35.575 21.813 10.813 1.00 11.00 C
ATOM 58 O SER A 8 35.203 20.650 11.111 1.00 10.00 O
ATOM 59 CB SER A 8 37.291 20.958 9.189 1.00 16.00 C
ATOM 60 OG SER A 8 37.917 21.247 7.943 1.00 20.00 O
ATOM 61 N ASP A 9 35.723 22.783 11.694 1.00 10.00 N
ATOM 62 CA ASP A 9 35.004 22.803 12.977 1.00 10.00 C
ATOM 63 C ASP A 9 33.532 23.121 12.749 1.00 10.00 C
ATOM 64 O ASP A 9 32.645 22.360 13.210 1.00 10.00 O
ATOM 65 CB ASP A 9 35.556 23.874 13.919 1.00 11.00 C
ATOM 66 CG ASP A 9 36.280 23.230 15.096 1.00 13.00 C
ATOM 67 OD1 ASP A 9 36.088 22.010 15.324 1.00 16.00 O
ATOM 68 OD2 ASP A 9 36.821 23.974 15.951 1.00 16.00 O
ATOM 69 N TYR A 10 33.316 24.220 12.040 1.00 10.00 N
ATOM 70 CA TYR A 10 31.967 24.742 11.748 1.00 10.00 C
ATOM 71 C TYR A 10 31.203 23.973 10.685 1.00 10.00 C
ATOM 72 O TYR A 10 29.980 23.772 10.885 1.00 10.00 O
ATOM 73 CB TYR A 10 31.951 26.230 11.367 1.00 10.00 C
ATOM 74 CG TYR A 10 30.613 26.678 10.713 1.00 10.00 C
ATOM 75 CD1 TYR A 10 30.563 26.886 9.350 1.00 10.00 C
ATOM 76 CD2 TYR A 10 29.463 26.824 11.461 1.00 10.00 C
ATOM 77 CE1 TYR A 10 29.377 27.275 8.733 1.00 10.00 C
ATOM 78 CE2 TYR A 10 28.272 27.214 10.848 1.00 10.00 C
ATOM 79 CZ TYR A 10 28.226 27.452 9.483 1.00 10.00 C
ATOM 80 OH TYR A 10 27.365 27.683 9.060 1.00 11.00 O
ATOM 81 N SER A 11 31.796 23.909 9.491 1.00 10.00 N
ATOM 82 CA SER A 11 31.146 23.418 8.250 1.00 10.00 C
ATOM 83 C SER A 11 30.463 22.048 8.303 1.00 10.00 C
ATOM 84 O SER A 11 29.615 21.759 7.422 1.00 10.00 O
ATOM 85 CB SER A 11 32.004 23.615 6.998 1.00 14.00 C
ATOM 86 OG SER A 11 32.013 24.995 6.632 1.00 19.00 O
ATOM 87 N LYS A 12 30.402 21.619 9.544 1.00 10.00 N
ATOM 88 CA LYS A 12 29.792 20.460 10.189 1.00 9.00 C
ATOM 89 C LYS A 12 28.494 20.817 10.932 1.00 9.00 C
ATOM 90 O LYS A 12 27.597 19.943 10.980 1.00 9.00 O
ATOM 91 CB LYS A 12 30.811 20.013 11.224 1.00 10.00 C
ATOM 92 CG LYS A 12 30.482 18.661 11.833 1.00 14.00 C
ATOM 93 CD LYS A 12 31.413 18.365 12.999 1.00 18.00 C
ATOM 94 CE LYS A 12 31.243 16.937 13.498 1.00 22.00 C
ATOM 95 NZ LYS A 12 32.121 16.717 14.652 1.00 26.00 N
ATOM 96 N TYR A 13 28.583 21.742 11.894 1.00 9.00 N
ATOM 97 CA TYR A 13 27.396 22.283 12.612 1.00 8.00 C
ATOM 98 C TYR A 13 26.214 22.497 11.670 1.00 8.00 C
ATOM 99 O TYR A 13 25.037 22.245 12.029 1.00 8.00 O
ATOM 100 CB TYR A 13 27.730 23.578 13.385 1.00 8.00 C
ATOM 101 CG TYR A 13 26.516 24.500 13.692 1.00 8.00 C
ATOM 102 CD1 TYR A 13 25.798 24.377 14.867 1.00 8.00 C
ATOM 103 CD2 TYR A 13 26.185 25.498 12.796 1.00 8.00 C
ATOM 104 CE1 TYR A 13 24.713 25.228 15.120 1.00 8.00 C
ATOM 105 CE2 TYR A 13 25.108 26.342 13.035 1.00 8.00 C
ATOM 106 CZ TYR A 13 24.370 26.210 14.196 1.00 8.00 C
ATOM 107 OH TYR A 13 23.202 26.933 14.347 1.00 10.00 O
ATOM 108 N LEU A 14 26.522 22.993 10.494 1.00 8.00 N
ATOM 109 CA LEU A 14 25.461 23.263 9.523 1.00 8.00 C
ATOM 110 C LEU A 14 24.912 21.978 8.907 1.00 8.00 C
ATOM 111 O LEU A 14 24.122 22.025 7.933 1.00 8.00 O
ATOM 112 CB LEU A 14 25.923 24.242 8.447 1.00 13.00 C
ATOM 113 CG LEU A 14 25.064 25.509 8.412 1.00 19.00 C
ATOM 114 CD1 LEU A 14 25.564 26.496 7.505 1.00 25.00 C
ATOM 115 CD2 LEU A 14 23.582 25.209 8.199 1.00 25.00 C
ATOM 116 N ASP A 15 25.556 20.886 9.263 1.00 8.00 N
ATOM 117 CA ASP A 15 25.075 19.552 8.885 1.00 8.00 C
ATOM 118 C ASP A 15 24.208 19.002 10.009 1.00 8.00 C
ATOM 119 O ASP A 15 23.550 17.940 9.861 1.00 8.00 O
ATOM 120 CB ASP A 15 26.246 18.601 8.644 1.00 11.00 C
ATOM 121 CG ASP A 15 26.260 18.121 7.196 1.00 16.00 C
ATOM 122 OD1 ASP A 15 26.021 18.946 6.280 1.00 21.00 O
ATOM 123 OD2 ASP A 15 26.732 16.984 6.946 1.00 21.00 O
ATOM 124 N SER A 16 24.015 19.861 10.986 1.00 8.00 N
ATOM 125 CA SER A 16 23.180 19.548 12.149 1.00 7.00 C
ATOM 126 C SER A 16 21.923 20.414 12.167 1.00 7.00 C
ATOM 127 O SER A 16 20.841 19.941 12.598 1.00 7.00 O
ATOM 128 CB SER A 16 23.981 19.746 13.437 1.00 9.00 C
ATOM 129 OG SER A 16 23.327 19.102 14.524 1.00 11.00 O
ATOM 130 N ARG A 17 22.037 21.605 11.597 1.00 7.00 N
ATOM 131 CA ARG A 17 20.875 22.504 11.583 1.00 6.00 C
ATOM 132 C ARG A 17 19.868 22.156 10.491 1.00 6.00 C
ATOM 133 O ARG A 17 18.665 22.015 10.809 1.00 6.00 O
ATOM 134 CB ARG A 17 21.214 23.997 11.557 1.00 7.00 C
ATOM 135 CG ARG A 17 20.010 24.800 12.063 1.00 9.00 C
ATOM 136 CD ARG A 17 19.570 25.929 11.132 1.00 11.00 C
ATOM 137 NE ARG A 17 20.149 27.218 11.537 1.00 12.00 N
ATOM 138 CZ ARG A 17 19.828 28.351 10.936 1.00 13.00 C
ATOM 139 NH1 ARG A 17 19.319 28.304 9.720 1.00 14.00 N
ATOM 140 NH2 ARG A 17 20.351 29.485 11.362 1.00 14.00 N
ATOM 141 N ARG A 18 20.378 21.725 9.348 1.00 6.00 N
ATOM 142 CA ARG A 18 19.530 21.258 8.235 1.00 5.00 C
ATOM 143 C ARG A 18 19.148 19.796 8.478 1.00 5.00 C
ATOM 144 O ARG A 18 18.326 19.189 7.741 1.00 5.00 O
ATOM 145 CB ARG A 18 20.237 21.481 6.888 1.00 8.00 C
ATOM 146 CG ARG A 18 19.384 21.236 5.634 1.00 9.00 C
ATOM 147 CD ARG A 18 19.623 19.860 5.005 1.00 11.00 C
ATOM 148 NE ARG A 18 20.029 19.997 3.600 1.00 12.00 N
ATOM 149 CZ ARG A 18 19.398 19.415 2.597 1.00 13.00 C
ATOM 150 NH1 ARG A 18 18.483 18.493 2.835 1.00 14.00 N
ATOM 151 NH2 ARG A 18 19.831 19.597 1.364 1.00 14.00 N
ATOM 152 N ALA A 19 19.560 19.319 9.623 1.00 6.00 N
ATOM 153 CA ALA A 19 19.126 17.991 10.053 1.00 6.00 C
ATOM 154 C ALA A 19 18.002 18.136 11.071 1.00 6.00 C
ATOM 155 O ALA A 19 16.933 17.494 10.922 1.00 7.00 O
ATOM 156 CB ALA A 19 20.285 17.187 10.629 1.00 15.00 C
ATOM 157 N GLN A 20 18.094 19.241 11.783 1.00 7.00 N
ATOM 158 CA GLN A 20 17.013 19.632 12.689 1.00 7.00 C
ATOM 159 C GLN A 20 15.897 20.314 11.905 1.00 7.00 C
ATOM 160 O GLN A 20 14.701 20.031 12.162 1.00 7.00 O
ATOM 161 CB GLN A 20 17.513 20.538 13.821 1.00 11.00 C
ATOM 162 CG GLN A 20 16.699 21.829 13.936 1.00 16.00 C
ATOM 163 CD GLN A 20 16.591 22.277 15.393 1.00 22.00 C
ATOM 164 OE1 GLN A 20 17.533 22.060 16.194 1.00 24.00 O
ATOM 165 NE2 GLN A 20 15.356 22.544 15.773 1.00 24.00 N
ATOM 166 N ASP A 21 16.292 20.724 10.714 1.00 7.00 N
ATOM 167 CA ASP A 21 15.405 21.490 9.835 1.00 7.00 C
ATOM 168 C ASP A 21 14.451 20.565 9.120 1.00 7.00 C
ATOM 169 O ASP A 21 13.245 20.850 8.962 1.00 7.00 O
ATOM 170 CB ASP A 21 16.212 22.278 8.809 1.00 14.00 C
ATOM 171 CG ASP A 21 15.427 23.525 8.413 1.00 21.00 C
ATOM 172 OD1 ASP A 21 15.031 24.298 9.321 1.00 28.00 O
ATOM 173 OD2 ASP A 21 15.316 23.827 7.200 1.00 28.00 O
ATOM 174 N PHE A 22 14.987 19.373 8.843 1.00 7.00 N
ATOM 175 CA PHE A 22 14.216 18.253 8.289 1.00 7.00 C
ATOM 176 C PHE A 22 13.098 17.860 9.246 1.00 7.00 C
ATOM 177 O PHE A 22 11.956 17.556 8.818 1.00 7.00 O
ATOM 178 CB PHE A 22 15.134 17.038 8.105 1.00 8.00 C
ATOM 179 CG PHE A 22 14.349 15.761 7.724 1.00 10.00 C
ATOM 180 CD1 PHE A 22 14.022 15.527 6.410 1.00 12.00 C
ATOM 181 CD2 PHE A 22 13.992 14.842 8.689 1.00 12.00 C
ATOM 182 CE1 PHE A 22 13.302 14.391 6.050 1.00 14.00 C
ATOM 183 CE2 PHE A 22 13.269 13.708 8.340 1.00 14.00 C
ATOM 184 CZ PHE A 22 12.917 13.483 7.018 1.00 16.00 C
ATOM 185 N VAL A 23 13.455 17.883 10.517 1.00 7.00 N
ATOM 186 CA VAL A 23 12.574 17.403 11.589 1.00 7.00 C
ATOM 187 C VAL A 23 11.283 18.205 11.729 1.00 7.00 C
ATOM 188 O VAL A 23 10.233 17.600 12.052 1.00 7.00 O
ATOM 189 CB VAL A 23 13.339 17.278 12.906 1.00 10.00 C
ATOM 190 CG1 VAL A 23 12.441 17.004 14.108 1.00 13.00 C
ATOM 191 CG2 VAL A 23 14.455 16.248 12.794 1.00 13.00 C
ATOM 192 N GLN A 24 11.255 19.253 10.941 1.00 8.00 N
ATOM 193 CA GLN A 24 10.082 20.114 10.818 1.00 8.00 C
ATOM 194 C GLN A 24 9.158 19.638 9.692 1.00 8.00 C
ATOM 195 O GLN A 24 7.959 19.990 9.663 1.00 8.00 O
ATOM 196 CB GLN A 24 10.575 21.521 10.498 1.00 14.00 C
ATOM 197 CG GLN A 24 9.505 22.591 10.661 1.00 20.00 C
ATOM 198 CD GLN A 24 9.964 23.862 9.956 1.00 26.00 C
ATOM 199 OE1 GLN A 24 10.079 24.941 10.587 1.00 32.00 O
ATOM 200 NE2 GLN A 24 10.086 23.739 8.649 1.00 32.00 N
ATOM 201 N TRP A 25 9.723 19.074 8.651 1.00 8.00 N
ATOM 202 CA TRP A 25 8.899 18.676 7.495 1.00 9.00 C
ATOM 203 C TRP A 25 8.118 17.395 7.751 1.00 9.00 C
ATOM 204 O TRP A 25 6.860 17.395 7.725 1.00 9.00 O
ATOM 205 CB TRP A 25 9.761 18.442 6.262 1.00 11.00 C
ATOM 206 CG TRP A 25 8.871 18.331 5.004 1.00 12.00 C
ATOM 207 CD1 TRP A 25 8.097 19.279 4.442 1.00 12.00 C
ATOM 208 CD2 TRP A 25 8.640 17.180 4.249 1.00 12.00 C
ATOM 209 NE1 TRP A 25 7.041 18.780 3.259 1.00 12.00 N
ATOM 210 CE2 TRP A 25 7.873 17.564 3.121 1.00 12.00 C
ATOM 211 CE3 TRP A 25 9.124 15.884 4.378 1.00 12.00 C
ATOM 212 CZ2 TRP A 25 7.726 16.765 2.003 1.00 12.00 C
ATOM 213 CZ3 TRP A 25 8.870 15.038 3.296 1.00 12.00 C
ATOM 214 CH2 TRP A 25 8.216 15.469 2.140 1.00 12.00 C
ATOM 215 N LEU A 26 8.857 16.484 8.346 1.00 9.00 N
ATOM 216 CA LEU A 26 8.377 15.159 8.741 1.00 10.00 C
ATOM 217 C LEU A 26 7.534 15.279 10.012 1.00 11.00 C
ATOM 218 O LEU A 26 6.755 14.347 10.331 1.00 11.00 O
ATOM 219 CB LEU A 26 9.611 14.267 8.924 1.00 10.00 C
ATOM 220 CG LEU A 26 9.342 12.810 9.303 1.00 10.00 C
ATOM 221 CD1 LEU A 26 8.223 12.149 8.505 1.00 10.00 C
ATOM 222 CD2 LEU A 26 10.637 11.982 9.250 1.00 10.00 C
ATOM 223 N MET A 27 7.281 16.544 10.320 1.00 11.00 N
ATOM 224 CA MET A 27 6.446 16.959 11.451 1.00 11.00 C
ATOM 225 C MET A 27 5.607 18.227 11.219 1.00 13.00 C
ATOM 226 O MET A 27 4.823 18.240 10.244 1.00 13.00 O
ATOM 227 CB MET A 27 7.327 17.118 12.679 1.00 11.00 C
ATOM 228 CG MET A 27 6.518 17.289 13.953 1.00 11.00 C
ATOM 229 SD MET A 27 7.301 18.326 15.196 1.00 11.00 S
ATOM 230 CE MET A 27 5.833 18.677 16.178 1.00 11.00 C
ATOM 231 N ASN A 28 6.147 19.366 11.620 1.00 14.00 N
ATOM 232 CA ASN A 28 5.399 20.637 11.728 1.00 14.00 C
ATOM 233 C ASN A 28 3.878 20.587 11.716 1.00 17.00 C
ATOM 234 O ASN A 28 3.252 21.114 10.763 1.00 19.00 O
ATOM 235 CB ASN A 28 5.874 21.774 10.843 1.00 14.00 C
ATOM 236 CG ASN A 28 6.246 22.905 11.791 1.00 14.00 C
ATOM 237 OD1 ASN A 28 6.929 22.629 12.807 1.00 14.00 O
ATOM 238 ND2 ASN A 28 6.271 24.085 11.229 1.00 14.00 N
ATOM 239 N THR A 29 3.391 19.940 12.762 1.00 21.00 N
ATOM 240 CA THR A 29 2.014 19.761 13.283 1.00 21.00 C
ATOM 241 C THR A 29 0.826 19.943 12.332 1.00 23.00 C
ATOM 242 O THR A 29 0.932 19.600 11.133 1.00 30.00 O
ATOM 243 CB THR A 29 1.845 20.667 14.505 1.00 21.00 C
ATOM 244 OG1 THR A 29 1.214 21.893 14.153 1.00 21.00 O
ATOM 245 CG2 THR A 29 3.180 20.968 15.185 1.00 21.00 C
ATOM 246 OXT THR A 29 -0.317 20.109 12.824 1.00 25.00 O
TER 247 THR A 29
MASTER 344 1 0 1 0 0 0 6 246 1 0 3
END
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data pipeline for model features."""
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data pipeline."""
from typing import Iterable, MutableMapping, List
from alphafold.common import residue_constants
from alphafold.data import msa_pairing
from alphafold.data import pipeline
import numpy as np
REQUIRED_FEATURES = frozenset({
'aatype', 'all_atom_mask', 'all_atom_positions', 'all_chains_entity_ids',
'all_crops_all_chains_mask', 'all_crops_all_chains_positions',
'all_crops_all_chains_residue_ids', 'assembly_num_chains', 'asym_id',
'bert_mask', 'cluster_bias_mask', 'deletion_matrix', 'deletion_mean',
'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments',
'num_templates', 'queue_size', 'residue_index', 'resolution',
'seq_length', 'seq_mask', 'sym_id', 'template_aatype',
'template_all_atom_mask', 'template_all_atom_positions'
})
MAX_TEMPLATES = 4
MSA_CROP_SIZE = 2048
def _is_homomer_or_monomer(chains: Iterable[pipeline.FeatureDict]) -> bool:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains = len(np.unique(np.concatenate(
[np.unique(chain['entity_id'][chain['entity_id'] > 0]) for
chain in chains])))
return num_unique_chains == 1
def pair_and_merge(
all_chain_features: MutableMapping[str, pipeline.FeatureDict]
) -> pipeline.FeatureDict:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
Returns:
A dictionary of features.
"""
process_unmerged_features(all_chain_features)
np_chains_list = list(all_chain_features.values())
pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)
if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = crop_chains(
np_chains_list,
msa_crop_size=MSA_CROP_SIZE,
pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES)
np_example = msa_pairing.merge_chain_features(
np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES)
np_example = process_final(np_example)
return np_example
def crop_chains(
chains_list: List[pipeline.FeatureDict],
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int) -> List[pipeline.FeatureDict]:
"""Crops the MSAs for a set of chains.
Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.
Returns:
The chains cropped.
"""
# Apply the cropping.
cropped_chains = []
for chain in chains_list:
cropped_chain = _crop_single_chain(
chain,
msa_crop_size=msa_crop_size,
pair_msa_sequences=pair_msa_sequences,
max_templates=max_templates)
cropped_chains.append(cropped_chain)
return cropped_chains
def _crop_single_chain(chain: pipeline.FeatureDict,
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int) -> pipeline.FeatureDict:
"""Crops msa sequences to `msa_crop_size`."""
msa_size = chain['num_alignments']
if pair_msa_sequences:
msa_size_all_seq = chain['num_alignments_all_seq']
msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :]
num_non_gapped_pairs = np.sum(
np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1))
num_non_gapped_pairs = np.minimum(num_non_gapped_pairs,
msa_crop_size_all_seq)
# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0)
msa_crop_size = np.minimum(msa_size, max_msa_crop_size)
else:
msa_crop_size = np.minimum(msa_size, msa_crop_size)
include_templates = 'template_aatype' in chain and max_templates
if include_templates:
num_templates = chain['template_aatype'].shape[0]
templates_crop_size = np.minimum(num_templates, max_templates)
for k in chain:
k_split = k.split('_all_seq')[0]
if k_split in msa_pairing.TEMPLATE_FEATURES:
chain[k] = chain[k][:templates_crop_size, :]
elif k_split in msa_pairing.MSA_FEATURES:
if '_all_seq' in k and pair_msa_sequences:
chain[k] = chain[k][:msa_crop_size_all_seq, :]
else:
chain[k] = chain[k][:msa_crop_size, :]
chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32)
if include_templates:
chain['num_templates'] = np.asarray(templates_crop_size, dtype=np.int32)
if pair_msa_sequences:
chain['num_alignments_all_seq'] = np.asarray(
msa_crop_size_all_seq, dtype=np.int32)
return chain
def process_final(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example = _correct_msa_restypes(np_example)
np_example = _make_seq_mask(np_example)
np_example = _make_msa_mask(np_example)
np_example = _filter_features(np_example)
return np_example
def _correct_msa_restypes(np_example):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
np_example['msa'] = np.take(new_order_list, np_example['msa'], axis=0)
np_example['msa'] = np_example['msa'].astype(np.int32)
return np_example
def _make_seq_mask(np_example):
np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32)
return np_example
def _make_msa_mask(np_example):
"""Mask features are all ones, but will later be zero-padded."""
np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.float32)
seq_mask = (np_example['entity_id'] > 0).astype(np.float32)
np_example['msa_mask'] *= seq_mask[None]
return np_example
def _filter_features(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict:
"""Filters features of example to only those requested."""
return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}
def process_unmerged_features(
all_chain_features: MutableMapping[str, pipeline.FeatureDict]):
"""Postprocessing stage for per-chain features before merging."""
num_chains = len(all_chain_features)
for chain_features in all_chain_features.values():
# Convert deletion matrices to float.
chain_features['deletion_matrix'] = np.asarray(
chain_features.pop('deletion_matrix_int'), dtype=np.float32)
if 'deletion_matrix_int_all_seq' in chain_features:
chain_features['deletion_matrix_all_seq'] = np.asarray(
chain_features.pop('deletion_matrix_int_all_seq'), dtype=np.float32)
chain_features['deletion_mean'] = np.mean(
chain_features['deletion_matrix'], axis=0)
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['aatype']]
chain_features['all_atom_mask'] = all_atom_mask
chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])
# Add assembly_num_chains.
chain_features['assembly_num_chains'] = np.asarray(num_chains)
# Add entity_mask.
for chain_features in all_chain_features.values():
chain_features['entity_mask'] = (
chain_features['entity_id'] != 0).astype(np.int32)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parses the mmCIF file format."""
import collections
import dataclasses
import functools
import io
from typing import Any, Mapping, Optional, Sequence, Tuple
from absl import logging
from Bio import PDB
from Bio.Data import SCOPData
# Type aliases:
ChainId = str
PdbHeader = Mapping[str, Any]
PdbStructure = PDB.Structure.Structure
SeqRes = str
MmCIFDict = Mapping[str, Sequence[str]]
@dataclasses.dataclass(frozen=True)
class Monomer:
id: str
num: int
# Note - mmCIF format provides no guarantees on the type of author-assigned
# sequence numbers. They need not be integers.
@dataclasses.dataclass(frozen=True)
class AtomSite:
residue_name: str
author_chain_id: str
mmcif_chain_id: str
author_seq_num: str
mmcif_seq_num: int
insertion_code: str
hetatm_atom: str
model_num: int
# Used to map SEQRES index to a residue in the structure.
@dataclasses.dataclass(frozen=True)
class ResiduePosition:
chain_id: str
residue_number: int
insertion_code: str
@dataclasses.dataclass(frozen=True)
class ResidueAtPosition:
position: Optional[ResiduePosition]
name: str
is_missing: bool
hetflag: str
@dataclasses.dataclass(frozen=True)
class MmcifObject:
"""Representation of a parsed mmCIF file.
Contains:
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
files being processed.
header: Biopython header.
structure: Biopython structure.
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
{'A': 'ABCDEFG'}
seqres_to_structure: Dict; for each chain_id contains a mapping between
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
1: ResidueAtPosition,
...}}
raw_string: The raw string used to construct the MmcifObject.
"""
file_id: str
header: PdbHeader
structure: PdbStructure
chain_to_seqres: Mapping[ChainId, SeqRes]
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
raw_string: Any
@dataclasses.dataclass(frozen=True)
class ParsingResult:
"""Returned by the parse function.
Contains:
mmcif_object: A MmcifObject, may be None if no chain could be successfully
parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated.
"""
mmcif_object: Optional[MmcifObject]
errors: Mapping[Tuple[str, str], Any]
class ParseError(Exception):
"""An error indicating that an mmCIF file could not be parsed."""
def mmcif_loop_to_list(prefix: str,
parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a list.
Reference for loop_ in mmCIF:
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
Returns:
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
"""
cols = []
data = []
for key, value in parsed_info.items():
if key.startswith(prefix):
cols.append(key)
data.append(value)
assert all([len(xs) == len(data[0]) for xs in data]), (
'mmCIF error: Not all loops are the same length: %s' % cols)
return [dict(zip(cols, xs)) for xs in zip(*data)]
def mmcif_loop_to_dict(prefix: str,
index: str,
parsed_info: MmCIFDict,
) -> Mapping[str, Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
index: Which item of loop data should serve as the key.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
Returns:
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
indexed by the index column.
"""
entries = mmcif_loop_to_list(prefix, parsed_info)
return {entry[index]: entry for entry in entries}
@functools.lru_cache(16, typed=False)
def parse(*,
file_id: str,
mmcif_string: str,
catch_all_errors: bool = True) -> ParsingResult:
"""Entry point, parses an mmcif_string.
Args:
file_id: A string identifier for this file. Should be unique within the
collection of files being processed.
mmcif_string: Contents of an mmCIF file.
catch_all_errors: If True, all exceptions are caught and error messages are
returned as part of the ParsingResult. If False exceptions will be allowed
to propagate.
Returns:
A ParsingResult.
"""
errors = {}
try:
parser = PDB.MMCIFParser(QUIET=True)
handle = io.StringIO(mmcif_string)
full_structure = parser.get_structure('', handle)
first_model_structure = _get_first_model(full_structure)
# Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure.
parsed_info = parser._mmcif_dict # pylint:disable=protected-access
# Ensure all values are lists, even if singletons.
for key, value in parsed_info.items():
if not isinstance(value, list):
parsed_info[key] = [value]
header = _get_header(parsed_info)
# Determine the protein chains, and their start numbers according to the
# internal mmCIF numbering scheme (likely but not guaranteed to be 1).
valid_chains = _get_protein_chains(parsed_info=parsed_info)
if not valid_chains:
return ParsingResult(
None, {(file_id, ''): 'No protein chains found in this file.'})
seq_start_num = {chain_id: min([monomer.num for monomer in seq])
for chain_id, seq in valid_chains.items()}
# Loop over the atoms for which we have coordinates. Populate two mappings:
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
# the authors / Biopython).
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
mmcif_to_author_chain_id = {}
seq_to_structure_mappings = {}
for atom in _get_atom_site_list(parsed_info):
if atom.model_num != '1':
# We only process the first model at the moment.
continue
mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
if atom.mmcif_chain_id in valid_chains:
hetflag = ' '
if atom.hetatm_atom == 'HETATM':
# Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id.
if atom.residue_name in ('HOH', 'WAT'):
hetflag = 'W'
else:
hetflag = 'H_' + atom.residue_name
insertion_code = atom.insertion_code
if not _is_set(atom.insertion_code):
insertion_code = ' '
position = ResiduePosition(chain_id=atom.author_chain_id,
residue_number=int(atom.author_seq_num),
insertion_code=insertion_code)
seq_idx = int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
current = seq_to_structure_mappings.get(atom.author_chain_id, {})
current[seq_idx] = ResidueAtPosition(position=position,
name=atom.residue_name,
is_missing=False,
hetflag=hetflag)
seq_to_structure_mappings[atom.author_chain_id] = current
# Add missing residue information to seq_to_structure_mappings.
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
current_mapping = seq_to_structure_mappings[author_chain]
for idx, monomer in enumerate(seq_info):
if idx not in current_mapping:
current_mapping[idx] = ResidueAtPosition(position=None,
name=monomer.id,
is_missing=True,
hetflag=' ')
author_chain_to_sequence = {}
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
seq = []
for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, 'X')
seq.append(code if len(code) == 1 else 'X')
seq = ''.join(seq)
author_chain_to_sequence[author_chain] = seq
mmcif_object = MmcifObject(
file_id=file_id,
header=header,
structure=first_model_structure,
chain_to_seqres=author_chain_to_sequence,
seqres_to_structure=seq_to_structure_mappings,
raw_string=parsed_info)
return ParsingResult(mmcif_object=mmcif_object, errors=errors)
except Exception as e: # pylint:disable=broad-except
errors[(file_id, '')] = e
if not catch_all_errors:
raise
return ParsingResult(mmcif_object=None, errors=errors)
def _get_first_model(structure: PdbStructure) -> PdbStructure:
"""Returns the first model in a Biopython structure."""
return next(structure.get_models())
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
def get_release_date(parsed_info: MmCIFDict) -> str:
"""Returns the oldest revision date."""
revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date']
return min(revision_dates)
def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
"""Returns a basic header containing method, release date and resolution."""
header = {}
experiments = mmcif_loop_to_list('_exptl.', parsed_info)
header['structure_method'] = ','.join([
experiment['_exptl.method'].lower() for experiment in experiments])
# Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date.
if '_pdbx_audit_revision_history.revision_date' in parsed_info:
header['release_date'] = get_release_date(parsed_info)
else:
logging.warning('Could not determine release_date: %s',
parsed_info['_entry.id'])
header['resolution'] = 0.00
for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution',
'_reflns.d_resolution_high'):
if res_key in parsed_info:
try:
raw_resolution = parsed_info[res_key][0]
header['resolution'] = float(raw_resolution)
break
except ValueError:
logging.debug('Invalid resolution format: %s', parsed_info[res_key])
return header
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
"""Returns list of atom sites; contains data not present in the structure."""
return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension
parsed_info['_atom_site.label_comp_id'],
parsed_info['_atom_site.auth_asym_id'],
parsed_info['_atom_site.label_asym_id'],
parsed_info['_atom_site.auth_seq_id'],
parsed_info['_atom_site.label_seq_id'],
parsed_info['_atom_site.pdbx_PDB_ins_code'],
parsed_info['_atom_site.group_PDB'],
parsed_info['_atom_site.pdbx_PDB_model_num'],
)]
def _get_protein_chains(
*, parsed_info: Mapping[str, Any]) -> Mapping[ChainId, Sequence[Monomer]]:
"""Extracts polymer information for protein chains only.
Args:
parsed_info: _mmcif_dict produced by the Biopython parser.
Returns:
A dict mapping mmcif chain id to a list of Monomers.
"""
# Get polymer information for each entity in the structure.
entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info)
polymers = collections.defaultdict(list)
for entity_poly_seq in entity_poly_seqs:
polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append(
Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'],
num=int(entity_poly_seq['_entity_poly_seq.num'])))
# Get chemical compositions. Will allow us to identify which of these polymers
# are proteins.
chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', parsed_info)
# Get chains information for each entity. Necessary so that we can return a
# dict keyed on chain id rather than entity.
struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info)
entity_to_mmcif_chains = collections.defaultdict(list)
for struct_asym in struct_asyms:
chain_id = struct_asym['_struct_asym.id']
entity_id = struct_asym['_struct_asym.entity_id']
entity_to_mmcif_chains[entity_id].append(chain_id)
# Identify and return the valid protein chains.
valid_chains = {}
for entity_id, seq_info in polymers.items():
chain_ids = entity_to_mmcif_chains[entity_id]
# Reject polymers without any peptide-like components, such as DNA/RNA.
if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type'].lower()
for monomer in seq_info]):
for chain_id in chain_ids:
valid_chains[chain_id] = seq_info
return valid_chains
def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in ('.', '?')
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for extracting identifiers from MSA sequence descriptions."""
import dataclasses
import re
from typing import Optional
# Sequences coming from UniProtKB database come in the
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
_UNIPROT_PATTERN = re.compile(
r"""
^
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
(?:tr|sp)
\|
# A primary accession number of the UniProtKB entry.
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
(?:_\d)?
\|
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
# protein ID code.
(?:[A-Za-z0-9]+)
_
# A mnemonic species identification code.
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
# Small BFD uses a final value after an underscore, which we ignore.
(?:_\d+)?
$
""",
re.VERBOSE)
@dataclasses.dataclass(frozen=True)
class Identifiers:
species_id: str = ''
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
"""Gets species from an msa sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
Args:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with species_id. These
can be empty in the case where no identifier was found.
"""
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches:
return Identifiers(
species_id=matches.group('SpeciesIdentifier'))
return Identifiers()
def _extract_sequence_identifier(description: str) -> Optional[str]:
"""Extracts sequence identifier from description. Returns None if no match."""
split_description = description.split()
if split_description:
return split_description[0].partition('/')[0]
else:
return None
def get_identifiers(description: str) -> Identifiers:
"""Computes extra MSA features from the description."""
sequence_identifier = _extract_sequence_identifier(description)
if sequence_identifier is None:
return Identifiers()
else:
return _parse_sequence_identifier(sequence_identifier)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pairing logic for multimer data pipeline."""
import collections
from typing import cast, Dict, Iterable, List, Sequence
from alphafold.common import residue_constants
from alphafold.data import pipeline
import numpy as np
import pandas as pd
import scipy.linalg
MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-')
SEQUENCE_GAP_CUTOFF = 0.5
SEQUENCE_SIMILARITY_CUTOFF = 0.9
MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX,
'msa_mask_all_seq': 1,
'deletion_matrix_all_seq': 0,
'deletion_matrix_int_all_seq': 0,
'msa': MSA_GAP_IDX,
'msa_mask': 1,
'deletion_matrix': 0,
'deletion_matrix_int': 0}
MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int')
SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions',
'all_atom_mask', 'seq_mask', 'between_segment_residues',
'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id',
'sym_id', 'entity_mask', 'deletion_mean',
'prediction_atom_mask',
'literature_positions', 'atom_indices_to_group_indices',
'rigid_group_default_frame')
TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
'template_all_atom_mask')
CHAIN_FEATURES = ('num_alignments', 'seq_length')
def create_paired_features(
chains: Iterable[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
"""Returns the original chains with paired NUM_SEQ features.
Args:
chains: A list of feature dictionaries for each chain.
Returns:
A list of feature dictionaries with sequence features including only
rows to be paired.
"""
chains = list(chains)
chain_keys = chains[0].keys()
if len(chains) < 2:
return chains
else:
updated_chains = []
paired_chains_to_paired_row_indices = pair_sequences(chains)
paired_rows = reorder_paired_rows(
paired_chains_to_paired_row_indices)
for chain_num, chain in enumerate(chains):
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
for feature_name in chain_keys:
if feature_name.endswith('_all_seq'):
feats_padded = pad_features(chain[feature_name], feature_name)
new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
new_chain['num_alignments_all_seq'] = np.asarray(
len(paired_rows[:, chain_num]))
updated_chains.append(new_chain)
return updated_chains
def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
"""Add a 'padding' row at the end of the features list.
The padding row will be selected as a 'paired' row in the case of partial
alignment - for the chain that doesn't have paired alignment.
Args:
feature: The feature to be padded.
feature_name: The name of the feature to be padded.
Returns:
The feature with an additional padding row.
"""
assert feature.dtype != np.dtype(np.string_)
if feature_name in ('msa_all_seq', 'msa_mask_all_seq',
'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'):
num_res = feature.shape[1]
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
feature.dtype)
elif feature_name == 'msa_species_identifiers_all_seq':
padding = [b'']
else:
return feature
feats_padded = np.concatenate([feature, padding], axis=0)
return feats_padded
def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame:
"""Makes dataframe with msa features needed for msa pairing."""
chain_msa = chain_features['msa_all_seq']
query_seq = chain_msa[0]
per_seq_similarity = np.sum(
query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq))
msa_df = pd.DataFrame({
'msa_species_identifiers':
chain_features['msa_species_identifiers_all_seq'],
'msa_row':
np.arange(len(
chain_features['msa_species_identifiers_all_seq'])),
'msa_similarity': per_seq_similarity,
'gap': per_seq_gap
})
return msa_df
def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
"""Creates mapping from species to msa dataframe of that species."""
species_lookup = {}
for species, species_df in msa_df.groupby('msa_species_identifiers'):
species_lookup[cast(bytes, species)] = species_df
return species_lookup
def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
) -> List[List[int]]:
"""Finds MSA sequence pairings across chains based on sequence similarity.
Each chain's MSA sequences are first sorted by their sequence similarity to
their respective target sequence. The sequences are then paired, starting
from the sequences most similar to their target sequence.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
all_paired_msa_rows = []
num_seqs = [len(species_df) for species_df in this_species_msa_dfs
if species_df is not None]
take_num_seqs = np.min(num_seqs)
sort_by_similarity = (
lambda x: x.sort_values('msa_similarity', axis=0, ascending=False))
for species_df in this_species_msa_dfs:
if species_df is not None:
species_df_sorted = sort_by_similarity(species_df)
msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
else:
msa_rows = [-1] * take_num_seqs # take the last 'padding' row
all_paired_msa_rows.append(msa_rows)
all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
return all_paired_msa_rows
def pair_sequences(examples: List[pipeline.FeatureDict]
) -> Dict[int, np.ndarray]:
"""Returns indices for paired MSA sequences across chains."""
num_examples = len(examples)
all_chain_species_dict = []
common_species = set()
for chain_features in examples:
msa_df = _make_msa_df(chain_features)
species_dict = _create_species_dict(msa_df)
all_chain_species_dict.append(species_dict)
common_species.update(set(species_dict))
common_species = sorted(common_species)
common_species.remove(b'') # Remove target sequence species.
all_paired_msa_rows = [np.zeros(len(examples), int)]
all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]
for species in common_species:
if not species:
continue
this_species_msa_dfs = []
species_dfs_present = 0
for species_dict in all_chain_species_dict:
if species in species_dict:
this_species_msa_dfs.append(species_dict[species])
species_dfs_present += 1
else:
this_species_msa_dfs.append(None)
# Skip species that are present in only one chain.
if species_dfs_present <= 1:
continue
if np.any(
np.array([len(species_df) for species_df in
this_species_msa_dfs if
isinstance(species_df, pd.DataFrame)]) > 600):
continue
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
all_paired_msa_rows.extend(paired_msa_rows)
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
all_paired_msa_rows_dict = {
num_examples: np.array(paired_msa_rows) for
num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
}
return all_paired_msa_rows_dict
def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]
) -> np.ndarray:
"""Creates a list of indices of paired MSA rows across chains.
Args:
all_paired_msa_rows_dict: a mapping from the number of paired chains to the
paired indices.
Returns:
a list of lists, each containing indices of paired MSA rows across chains.
The paired-index lists are ordered by:
1) the number of chains in the paired alignment, i.e, all-chain pairings
will come first.
2) e-values
"""
all_paired_msa_rows = []
for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
paired_rows = all_paired_msa_rows_dict[num_pairings]
paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows]))
paired_rows_sort_index = np.argsort(paired_rows_product)
all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])
return np.array(all_paired_msa_rows)
def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
"""Like scipy.linalg.block_diag but with an optional padding value."""
ones_arrs = [np.ones_like(x) for x in arrs]
off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs)
diag = scipy.linalg.block_diag(*arrs)
diag += (off_diag_mask * pad_value).astype(diag.dtype)
return diag
def _correct_post_merged_feats(
np_example: pipeline.FeatureDict,
np_chains_list: Sequence[pipeline.FeatureDict],
pair_msa_sequences: bool) -> pipeline.FeatureDict:
"""Adds features that need to be computed/recomputed post merging."""
np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0],
dtype=np.int32)
np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0],
dtype=np.int32)
if not pair_msa_sequences:
# Generate a bias that is 1 for the first row of every block in the
# block diagonal MSA - i.e. make sure the cluster stack always includes
# the query sequences for each chain (since the first row is the query
# sequence).
cluster_bias_masks = []
for chain in np_chains_list:
mask = np.zeros(chain['msa'].shape[0])
mask[0] = 1
cluster_bias_masks.append(mask)
np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)
# Initialize Bert mask with masked out off diagonals.
msa_masks = [np.ones(x['msa'].shape, dtype=np.float32)
for x in np_chains_list]
np_example['bert_mask'] = block_diag(
*msa_masks, pad_value=0)
else:
np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
np_example['cluster_bias_mask'][0] = 1
# Initialize Bert mask with masked out off diagonals.
msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for
x in np_chains_list]
msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) for
x in np_chains_list]
msa_mask_block_diag = block_diag(
*msa_masks, pad_value=0)
msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
np_example['bert_mask'] = np.concatenate(
[msa_mask_all_seq, msa_mask_block_diag], axis=0)
return np_example
def _pad_templates(chains: Sequence[pipeline.FeatureDict],
max_templates: int) -> Sequence[pipeline.FeatureDict]:
"""For each chain pad the number of templates to a fixed size.
Args:
chains: A list of protein chains.
max_templates: Each chain will be padded to have this many templates.
Returns:
The list of chains, updated to have template features padded to
max_templates.
"""
for chain in chains:
for k, v in chain.items():
if k in TEMPLATE_FEATURES:
padding = np.zeros_like(v.shape)
padding[0] = max_templates - v.shape[0]
padding = [(0, p) for p in padding]
chain[k] = np.pad(v, padding, mode='constant')
return chains
def _merge_features_from_multiple_chains(
chains: Sequence[pipeline.FeatureDict],
pair_msa_sequences: bool) -> pipeline.FeatureDict:
"""Merge features from multiple chains.
Args:
chains: A list of feature dictionaries that we want to merge.
pair_msa_sequences: Whether to concatenate MSA features along the
num_res dimension (if True), or to block diagonalize them (if False).
Returns:
A feature dictionary for the merged example.
"""
merged_example = {}
for feature_name in chains[0]:
feats = [x[feature_name] for x in chains]
feature_name_split = feature_name.split('_all_seq')[0]
if feature_name_split in MSA_FEATURES:
if pair_msa_sequences or '_all_seq' in feature_name:
merged_example[feature_name] = np.concatenate(feats, axis=1)
else:
merged_example[feature_name] = block_diag(
*feats, pad_value=MSA_PAD_VALUES[feature_name])
elif feature_name_split in SEQ_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=0)
elif feature_name_split in TEMPLATE_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=1)
elif feature_name_split in CHAIN_FEATURES:
merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32)
else:
merged_example[feature_name] = feats[0]
return merged_example
def _merge_homomers_dense_msa(
chains: Iterable[pipeline.FeatureDict]) -> Sequence[pipeline.FeatureDict]:
"""Merge all identical chains, making the resulting MSA dense.
Args:
chains: An iterable of features for each chain.
Returns:
A list of feature dictionaries. All features with the same entity_id
will be merged - MSA features will be concatenated along the num_res
dimension - making them dense.
"""
entity_chains = collections.defaultdict(list)
for chain in chains:
entity_id = chain['entity_id'][0]
entity_chains[entity_id].append(chain)
grouped_chains = []
for entity_id in sorted(entity_chains):
chains = entity_chains[entity_id]
grouped_chains.append(chains)
chains = [
_merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
for chains in grouped_chains]
return chains
def _concatenate_paired_and_unpaired_features(
example: pipeline.FeatureDict) -> pipeline.FeatureDict:
"""Merges paired and block-diagonalised features."""
features = MSA_FEATURES
for feature_name in features:
if feature_name in example:
feat = example[feature_name]
feat_all_seq = example[feature_name + '_all_seq']
merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
example[feature_name] = merged_feat
example['num_alignments'] = np.array(example['msa'].shape[0],
dtype=np.int32)
return example
def merge_chain_features(np_chains_list: List[pipeline.FeatureDict],
pair_msa_sequences: bool,
max_templates: int) -> pipeline.FeatureDict:
"""Merges features for multiple chains to single FeatureDict.
Args:
np_chains_list: List of FeatureDicts for each chain.
pair_msa_sequences: Whether to merge paired MSAs.
max_templates: The maximum number of templates to include.
Returns:
Single FeatureDict for entire complex.
"""
np_chains_list = _pad_templates(
np_chains_list, max_templates=max_templates)
np_chains_list = _merge_homomers_dense_msa(np_chains_list)
# Unpaired MSA features will be always block-diagonalised; paired MSA
# features will be concatenated.
np_example = _merge_features_from_multiple_chains(
np_chains_list, pair_msa_sequences=False)
if pair_msa_sequences:
np_example = _concatenate_paired_and_unpaired_features(np_example)
np_example = _correct_post_merged_feats(
np_example=np_example,
np_chains_list=np_chains_list,
pair_msa_sequences=pair_msa_sequences)
return np_example
def deduplicate_unpaired_sequences(
np_chains: List[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
"""Removes unpaired sequences which duplicate a paired sequence."""
feature_names = np_chains[0].keys()
msa_features = MSA_FEATURES
for chain in np_chains:
# Convert the msa_all_seq numpy array to a tuple for hashing.
sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
keep_rows = []
# Go through unpaired MSA seqs and remove any rows that correspond to the
# sequences that are already present in the paired MSA.
for row_num, seq in enumerate(chain['msa']):
if tuple(seq) not in sequence_set:
keep_rows.append(row_num)
for feature_name in feature_names:
if feature_name in msa_features:
chain[feature_name] = chain[feature_name][keep_rows]
chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
return np_chains
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for parsing various file formats."""
import collections
import dataclasses
import itertools
import re
import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
# Internal import (7716).
DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True)
class Msa:
"""Class representing a parsed MSA file."""
sequences: Sequence[str]
deletion_matrix: DeletionMatrix
descriptions: Sequence[str]
def __post_init__(self):
if not (len(self.sequences) ==
len(self.deletion_matrix) ==
len(self.descriptions)):
raise ValueError(
'All fields for an MSA must have the same length. '
f'Got {len(self.sequences)} sequences, '
f'{len(self.deletion_matrix)} rows in the deletion matrix and '
f'{len(self.descriptions)} descriptions.')
def __len__(self):
return len(self.sequences)
def truncate(self, max_seqs: int):
return Msa(sequences=self.sequences[:max_seqs],
deletion_matrix=self.deletion_matrix[:max_seqs],
descriptions=self.descriptions[:max_seqs])
@dataclasses.dataclass(frozen=True)
class TemplateHit:
"""Class representing a template hit."""
index: int
name: str
aligned_cols: int
sum_probs: Optional[float]
query: str
hit_sequence: str
indices_query: List[int]
indices_hit: List[int]
def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
"""Parses FASTA string and returns list of strings with amino-acid sequences.
Arguments:
fasta_string: The string contents of a FASTA file.
Returns:
A tuple of two lists:
* A list of sequences.
* A list of sequence descriptions taken from the comment lines. In the
same order as the sequences.
"""
sequences = []
descriptions = []
index = -1
for line in fasta_string.splitlines():
line = line.strip()
if line.startswith('>'):
index += 1
descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append('')
continue
elif not line:
continue # Skip blank lines.
sequences[index] += line
return sequences, descriptions
def parse_stockholm(stockholm_string: str) -> Msa:
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
stockholm_string: The string contents of a stockholm file. The first
sequence in the file should be the query sequence.
Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
* The names of the targets matched, including the jackhmmer subsequence
suffix.
"""
name_to_sequence = collections.OrderedDict()
for line in stockholm_string.splitlines():
line = line.strip()
if not line or line.startswith(('#', '//')):
continue
name, sequence = line.split()
if name not in name_to_sequence:
name_to_sequence[name] = ''
name_to_sequence[name] += sequence
msa = []
deletion_matrix = []
query = ''
keep_columns = []
for seq_index, sequence in enumerate(name_to_sequence.values()):
if seq_index == 0:
# Gather the columns with gaps from the query
query = sequence
keep_columns = [i for i, res in enumerate(query) if res != '-']
# Remove the columns with gaps in the query from all sequences.
aligned_sequence = ''.join([sequence[c] for c in keep_columns])
msa.append(aligned_sequence)
# Count the number of deletions w.r.t. query.
deletion_vec = []
deletion_count = 0
for seq_res, query_res in zip(sequence, query):
if seq_res != '-' or query_res != '-':
if query_res == '-':
deletion_count += 1
else:
deletion_vec.append(deletion_count)
deletion_count = 0
deletion_matrix.append(deletion_vec)
return Msa(sequences=msa,
deletion_matrix=deletion_matrix,
descriptions=list(name_to_sequence.keys()))
def parse_a3m(a3m_string: str) -> Msa:
"""Parses sequences and deletion matrix from a3m format alignment.
Args:
a3m_string: The string contents of a a3m file. The first sequence in the
file should be the query sequence.
Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
* A list of descriptions, one per sequence, from the a3m file.
"""
sequences, descriptions = parse_fasta(a3m_string)
deletion_matrix = []
for msa_sequence in sequences:
deletion_vec = []
deletion_count = 0
for j in msa_sequence:
if j.islower():
deletion_count += 1
else:
deletion_vec.append(deletion_count)
deletion_count = 0
deletion_matrix.append(deletion_vec)
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans('', '', string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences]
return Msa(sequences=aligned_sequences,
deletion_matrix=deletion_matrix,
descriptions=descriptions)
def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]:
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
if is_query_res_non_gap:
yield sequence_res
elif sequence_res != '-':
yield sequence_res.lower()
def convert_stockholm_to_a3m(stockholm_format: str,
max_sequences: Optional[int] = None,
remove_first_row_gaps: bool = True) -> str:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions = {}
sequences = {}
reached_max_sequences = False
for line in stockholm_format.splitlines():
reached_max_sequences = max_sequences and len(sequences) >= max_sequences
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname, aligned_seq = line.split(maxsplit=1)
if seqname not in sequences:
if reached_max_sequences:
continue
sequences[seqname] = ''
sequences[seqname] += aligned_seq
for line in stockholm_format.splitlines():
if line[:4] == '#=GS':
# Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns = line.split(maxsplit=3)
seqname, feature = columns[1:3]
value = columns[3] if len(columns) == 4 else ''
if feature != 'DE':
continue
if reached_max_sequences and seqname not in sequences:
continue
descriptions[seqname] = value
if len(descriptions) == len(sequences):
break
# Convert sto format to a3m line by line
a3m_sequences = {}
if remove_first_row_gaps:
# query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != '-' for res in query_sequence]
for seqname, sto_sequence in sequences.items():
# Dots are optional in a3m format and are commonly removed.
out_sequence = sto_sequence.replace('.', '')
if remove_first_row_gaps:
out_sequence = ''.join(
_convert_sto_seq_to_a3m(query_non_gaps, out_sequence))
a3m_sequences[seqname] = out_sequence
fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
for k in a3m_sequences)
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline.
def _keep_line(line: str, seqnames: Set[str]) -> bool:
"""Function to decide which lines to keep."""
if not line.strip():
return True
if line.strip() == '//': # End tag
return True
if line.startswith('# STOCKHOLM'): # Start tag
return True
if line.startswith('#=GC RF'): # Reference Annotation Line
return True
if line[:4] == '#=GS': # Description lines - keep if sequence in list.
_, seqname, _ = line.split(maxsplit=2)
return seqname in seqnames
elif line.startswith('#'): # Other markup - filter out
return False
else: # Alignment data - keep if sequence in list.
seqname = line.partition(' ')[0]
return seqname in seqnames
def truncate_stockholm_msa(stockholm_msa_path: str, max_sequences: int) -> str:
"""Reads + truncates a Stockholm file while preventing excessive RAM usage."""
seqnames = set()
filtered_lines = []
with open(stockholm_msa_path) as f:
for line in f:
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname = line.partition(' ')[0]
seqnames.add(seqname)
if len(seqnames) >= max_sequences:
break
f.seek(0)
for line in f:
if _keep_line(line, seqnames):
filtered_lines.append(line)
return ''.join(filtered_lines)
def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
processed_lines = {}
unprocessed_lines = {}
for i, line in enumerate(stockholm_msa.splitlines()):
if line.startswith('#=GC RF'):
reference_annotation_i = i
reference_annotation_line = line
# Reached the end of this chunk of the alignment. Process chunk.
_, _, first_alignment = line.rpartition(' ')
mask = []
for j in range(len(first_alignment)):
for _, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
if alignment[j] != '-':
mask.append(True)
break
else: # Every row contained a hyphen - empty column.
mask.append(False)
# Add reference annotation for processing with mask.
unprocessed_lines[reference_annotation_i] = reference_annotation_line
if not any(mask): # All columns were empty. Output empty lines for chunk.
for line_index in unprocessed_lines:
processed_lines[line_index] = ''
else:
for line_index, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
masked_alignment = ''.join(itertools.compress(alignment, mask))
processed_lines[line_index] = f'{prefix} {masked_alignment}'
# Clear raw_alignments.
unprocessed_lines = {}
elif line.strip() and not line.startswith(('#', '//')):
unprocessed_lines[i] = line
else:
processed_lines[i] = line
return '\n'.join((processed_lines[i] for i in range(len(processed_lines))))
def deduplicate_stockholm_msa(stockholm_msa: str) -> str:
"""Remove duplicate sequences (ignoring insertions wrt query)."""
sequence_dict = collections.defaultdict(str)
# First we must extract all sequences from the MSA.
for line in stockholm_msa.splitlines():
# Only consider the alignments - ignore reference annotation, empty lines,
# descriptions or markup.
if line.strip() and not line.startswith(('#', '//')):
line = line.strip()
seqname, alignment = line.split()
sequence_dict[seqname] += alignment
seen_sequences = set()
seqnames = set()
# First alignment is the query.
query_align = next(iter(sequence_dict.values()))
mask = [c != '-' for c in query_align] # Mask is False for insertions.
for seqname, alignment in sequence_dict.items():
# Apply mask to remove all insertions from the string.
masked_alignment = ''.join(itertools.compress(alignment, mask))
if masked_alignment in seen_sequences:
continue
else:
seen_sequences.add(masked_alignment)
seqnames.add(seqname)
filtered_lines = []
for line in stockholm_msa.splitlines():
if _keep_line(line, seqnames):
filtered_lines.append(line)
return '\n'.join(filtered_lines) + '\n'
def _get_hhr_line_regex_groups(
regex_pattern: str, line: str) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line)
if match is None:
raise RuntimeError(f'Could not parse query line {line}')
return match.groups()
def _update_hhr_residue_indices_list(
sequence: str, start_index: int, indices_list: List[int]):
"""Computes the relative indices for each residue with respect to the original sequence."""
counter = start_index
for symbol in sequence:
if symbol == '-':
indices_list.append(-1)
else:
indices_list.append(counter)
counter += 1
def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
"""Parses the detailed HMM HMM comparison section for a single Hit.
This works on .hhr files generated from both HHBlits and HHSearch.
Args:
detailed_lines: A list of lines from a single comparison section between 2
sequences (which each have their own HMM's)
Returns:
A dictionary with the information from that detailed comparison section
Raises:
RuntimeError: If a certain line cannot be processed
"""
# Parse first 2 lines.
number_of_hit = int(detailed_lines[0].split()[-1])
name_hit = detailed_lines[1][1:]
# Parse the summary line.
pattern = (
'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t'
' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t '
']*Template_Neff=(.*)')
match = re.match(pattern, detailed_lines[2])
if match is None:
raise RuntimeError(
'Could not parse section: %s. Expected this: \n%s to contain summary.' %
(detailed_lines, detailed_lines[2]))
(_, _, _, aligned_cols, _, _, sum_probs, _) = [float(x)
for x in match.groups()]
# The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block.
query = ''
hit_sequence = ''
indices_query = []
indices_hit = []
length_block = None
for line in detailed_lines[3:]:
# Parse the query sequence line
if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and
not line.startswith('Q ss_pred') and
not line.startswith('Q Consensus')):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that.
# start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)'
groups = _get_hhr_line_regex_groups(patt, line[17:])
# Get the length of the parsed block using the start and finish indices,
# and ensure it is the same as the actual block length.
start = int(groups[0]) - 1 # Make index zero based.
delta_query = groups[1]
end = int(groups[2])
num_insertions = len([x for x in delta_query if x == '-'])
length_block = end - start + num_insertions
assert length_block == len(delta_query)
# Update the query sequence and indices list.
query += delta_query
_update_hhr_residue_indices_list(delta_query, start, indices_query)
elif line.startswith('T '):
# Parse the hit sequence.
if (not line.startswith('T ss_dssp') and
not line.startswith('T ss_pred') and
not line.startswith('T Consensus')):
# Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that.
# start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)'
groups = _get_hhr_line_regex_groups(patt, line[17:])
start = int(groups[0]) - 1 # Make index zero based.
delta_hit_sequence = groups[1]
assert length_block == len(delta_hit_sequence)
# Update the hit sequence and indices list.
hit_sequence += delta_hit_sequence
_update_hhr_residue_indices_list(delta_hit_sequence, start, indices_hit)
return TemplateHit(
index=number_of_hit,
name=name_hit,
aligned_cols=int(aligned_cols),
sum_probs=sum_probs,
query=query,
hit_sequence=hit_sequence,
indices_query=indices_query,
indices_hit=indices_hit,
)
def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
"""Parses the content of an entire HHR file."""
lines = hhr_string.splitlines()
# Each .hhr file starts with a results table, then has a sequence of hit
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit.
block_starts = [i for i, line in enumerate(lines) if line.startswith('No ')]
hits = []
if block_starts:
block_starts.append(len(lines)) # Add the end of the final block.
for i in range(len(block_starts) - 1):
hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]]))
return hits
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values = {'query': 0}
lines = [line for line in tblout.splitlines() if line[0] != '#']
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1).
for line in lines:
fields = line.split()
e_value = fields[4]
target_name = fields[0]
e_values[target_name] = float(e_value)
return e_values
def _get_indices(sequence: str, start: int) -> List[int]:
"""Returns indices for non-gap/insert residues starting at the given index."""
indices = []
counter = start
for symbol in sequence:
# Skip gaps but add a placeholder so that the alignment is preserved.
if symbol == '-':
indices.append(-1)
# Skip deleted residues, but increase the counter.
elif symbol.islower():
counter += 1
# Normal aligned residue. Increase the counter and append to indices.
else:
indices.append(counter)
counter += 1
return indices
@dataclasses.dataclass(frozen=True)
class HitMetadata:
pdb_id: str
chain: str
start: int
end: int
length: int
text: str
def _parse_hmmsearch_description(description: str) -> HitMetadata:
"""Parses the hmmsearch A3M sequence description line."""
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
match = re.match(
r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$',
description.strip())
if not match:
raise ValueError(f'Could not parse description: "{description}".')
return HitMetadata(
pdb_id=match[1],
chain=match[2],
start=int(match[3]),
end=int(match[4]),
length=int(match[5]),
text=match[6])
def parse_hmmsearch_a3m(query_sequence: str,
a3m_string: str,
skip_first: bool = True) -> Sequence[TemplateHit]:
"""Parses an a3m string produced by hmmsearch.
Args:
query_sequence: The query sequence.
a3m_string: The a3m string produced by hmmsearch.
skip_first: Whether to skip the first sequence in the a3m string.
Returns:
A sequence of `TemplateHit` results.
"""
# Zip the descriptions and MSAs together, skip the first query sequence.
parsed_a3m = list(zip(*parse_fasta(a3m_string)))
if skip_first:
parsed_a3m = parsed_a3m[1:]
indices_query = _get_indices(query_sequence, start=0)
hits = []
for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1):
if 'mol:protein' not in hit_description:
continue # Skip non-protein chains.
metadata = _parse_hmmsearch_description(hit_description)
# Aligned columns are only the match states.
aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence])
indices_hit = _get_indices(hit_sequence, start=metadata.start - 1)
hit = TemplateHit(
index=i,
name=f'{metadata.pdb_id}_{metadata.chain}',
aligned_cols=aligned_cols,
sum_probs=None,
query=query_sequence,
hit_sequence=hit_sequence.upper(),
indices_query=indices_query,
indices_hit=indices_hit,
)
hits.append(hit)
return hits
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for building the input features for the AlphaFold model."""
import os
from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union
from absl import logging
from alphafold.common import residue_constants
from alphafold.data import msa_identifiers
from alphafold.data import parsers
from alphafold.data import templates
from alphafold.data.tools import hhblits
from alphafold.data.tools import hhsearch
from alphafold.data.tools import hmmsearch
from alphafold.data.tools import jackhmmer
import numpy as np
# Internal import (7716).
FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
def make_sequence_features(
sequence: str, description: str, num_res: int) -> FeatureDict:
"""Constructs a feature dict of sequence features."""
features = {}
features['aatype'] = residue_constants.sequence_to_onehot(
sequence=sequence,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True)
features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
features['domain_name'] = np.array([description.encode('utf-8')],
dtype=np.object_)
features['residue_index'] = np.array(range(num_res), dtype=np.int32)
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_)
return features
def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError('At least one MSA must be provided.')
int_msa = []
deletion_matrix = []
species_ids = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
for sequence_index, sequence in enumerate(msa.sequences):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers(
msa.descriptions[sequence_index])
species_ids.append(identifiers.species_id.encode('utf-8'))
num_res = len(msas[0].sequences[0])
num_alignments = len(int_msa)
features = {}
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
features['msa'] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array(
[num_alignments] * num_res, dtype=np.int32)
features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_)
return features
def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str,
msa_format: str, use_precomputed_msas: bool,
max_sto_sequences: Optional[int] = None
) -> Mapping[str, Any]:
"""Runs an MSA tool, checking if output already exists first."""
if not use_precomputed_msas or not os.path.exists(msa_out_path):
if msa_format == 'sto' and max_sto_sequences is not None:
result = msa_runner.query(input_fasta_path, max_sto_sequences)[0] # pytype: disable=wrong-arg-count
else:
result = msa_runner.query(input_fasta_path)[0]
with open(msa_out_path, 'w') as f:
f.write(result[msa_format])
else:
logging.warning('Reading MSA from file %s', msa_out_path)
if msa_format == 'sto' and max_sto_sequences is not None:
precomputed_msa = parsers.truncate_stockholm_msa(
msa_out_path, max_sto_sequences)
result = {'sto': precomputed_msa}
else:
with open(msa_out_path, 'r') as f:
result = {msa_format: f.read()}
return result
class DataPipeline:
"""Runs the alignment tools and assembles the input features."""
def __init__(self,
jackhmmer_binary_path: str,
hhblits_binary_path: str,
uniref90_database_path: str,
mgnify_database_path: str,
bfd_database_path: Optional[str],
uniref30_database_path: Optional[str],
small_bfd_database_path: Optional[str],
template_searcher: TemplateSearcher,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
mgnify_max_hits: int = 501,
uniref_max_hits: int = 10000,
use_precomputed_msas: bool = False):
"""Initializes the data pipeline."""
self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path)
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path)
else:
self.hhblits_bfd_uniref_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniref30_database_path])
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path)
self.template_searcher = template_searcher
self.template_featurizer = template_featurizer
self.mgnify_max_hits = mgnify_max_hits
self.uniref_max_hits = uniref_max_hits
self.use_precomputed_msas = use_precomputed_msas
def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {input_fasta_path}.')
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
jackhmmer_uniref90_result = run_msa_tool(
msa_runner=self.jackhmmer_uniref90_runner,
input_fasta_path=input_fasta_path,
msa_out_path=uniref90_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas,
max_sto_sequences=self.uniref_max_hits)
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
jackhmmer_mgnify_result = run_msa_tool(
msa_runner=self.jackhmmer_mgnify_runner,
input_fasta_path=input_fasta_path,
msa_out_path=mgnify_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas,
max_sto_sequences=self.mgnify_max_hits)
msa_for_templates = jackhmmer_uniref90_result['sto']
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
msa_for_templates)
if self.template_searcher.input_format == 'sto':
pdb_templates_result = self.template_searcher.query(msa_for_templates)
elif self.template_searcher.input_format == 'a3m':
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)
pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)
else:
raise ValueError('Unrecognized template input format: '
f'{self.template_searcher.input_format}')
pdb_hits_out_path = os.path.join(
msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}')
with open(pdb_hits_out_path, 'w') as f:
f.write(pdb_templates_result)
uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto'])
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
pdb_template_hits = self.template_searcher.get_template_hits(
output_string=pdb_templates_result, input_sequence=input_sequence)
if self._use_small_bfd:
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
jackhmmer_small_bfd_result = run_msa_tool(
msa_runner=self.jackhmmer_small_bfd_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
else:
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
hhblits_bfd_uniref_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniref_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
hits=pdb_template_hits)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res)
msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa))
logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa))
logging.info('BFD MSA size: %d sequences.', len(bfd_msa))
logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa))
logging.info('Final (deduplicated) MSA size: %d sequences.',
msa_features['num_alignments'][0])
logging.info('Total number of templates (NB: this can include bad '
'templates and is later filtered to top 4): %d.',
templates_result.features['template_domain_names'].shape[0])
return {**sequence_features, **msa_features, **templates_result.features}
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for building the features for the AlphaFold multimer model."""
import collections
import contextlib
import copy
import dataclasses
import json
import os
import tempfile
from typing import Mapping, MutableMapping, Sequence
from absl import logging
from alphafold.common import protein
from alphafold.common import residue_constants
from alphafold.data import feature_processing
from alphafold.data import msa_pairing
from alphafold.data import parsers
from alphafold.data import pipeline
from alphafold.data.tools import jackhmmer
import numpy as np
# Internal import (7716).
@dataclasses.dataclass(frozen=True)
class _FastaChain:
sequence: str
description: str
def _make_chain_id_map(*,
sequences: Sequence[str],
descriptions: Sequence[str],
) -> Mapping[str, _FastaChain]:
"""Makes a mapping from PDB-format chain ID to sequence and description."""
if len(sequences) != len(descriptions):
raise ValueError('sequences and descriptions must have equal length. '
f'Got {len(sequences)} != {len(descriptions)}.')
if len(sequences) > protein.PDB_MAX_CHAINS:
raise ValueError('Cannot process more chains than the PDB format supports. '
f'Got {len(sequences)} chains.')
chain_id_map = {}
for chain_id, sequence, description in zip(
protein.PDB_CHAIN_IDS, sequences, descriptions):
chain_id_map[chain_id] = _FastaChain(
sequence=sequence, description=description)
return chain_id_map
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
fasta_file.write(fasta_str)
fasta_file.seek(0)
yield fasta_file.name
def convert_monomer_features(
monomer_features: pipeline.FeatureDict,
chain_id: str) -> pipeline.FeatureDict:
"""Reshapes and modifies monomer features for multimer models."""
converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_)
unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length'}
for feature_name, feature in monomer_features.items():
if feature_name in unnecessary_leading_dim_feats:
# asarray ensures it's a np.ndarray.
feature = np.asarray(feature[0], dtype=feature.dtype)
elif feature_name == 'aatype':
# The multimer model performs the one-hot operation itself.
feature = np.argmax(feature, axis=-1).astype(np.int32)
elif feature_name == 'template_aatype':
feature = np.argmax(feature, axis=-1).astype(np.int32)
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
elif feature_name == 'template_all_atom_masks':
feature_name = 'template_all_atom_mask'
converted[feature_name] = feature
return converted
def int_id_to_str_id(num: int) -> str:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if num <= 0:
raise ValueError(f'Only positive integers allowed, got {num}.')
num = num - 1 # 1-based indexing.
output = []
while num >= 0:
output.append(chr(num % 26 + ord('A')))
num = num // 26 - 1
return ''.join(output)
def add_assembly_features(
all_chain_features: MutableMapping[str, pipeline.FeatureDict],
) -> MutableMapping[str, pipeline.FeatureDict]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id = {}
grouped_chains = collections.defaultdict(list)
for chain_id, chain_features in all_chain_features.items():
seq = str(chain_features['sequence'])
if seq not in seq_to_entity_id:
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
grouped_chains[seq_to_entity_id[seq]].append(chain_features)
new_all_chain_features = {}
chain_id = 1
for entity_id, group_chain_features in grouped_chains.items():
for sym_id, chain_features in enumerate(group_chain_features, start=1):
new_all_chain_features[
f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features
seq_length = chain_features['seq_length']
chain_features['asym_id'] = chain_id * np.ones(seq_length)
chain_features['sym_id'] = sym_id * np.ones(seq_length)
chain_features['entity_id'] = entity_id * np.ones(seq_length)
chain_id += 1
return new_all_chain_features
def pad_msa(np_example, min_num_seq):
np_example = dict(np_example)
num_seq = np_example['msa'].shape[0]
if num_seq < min_num_seq:
for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'):
np_example[feat] = np.pad(
np_example[feat], ((0, min_num_seq - num_seq), (0, 0)))
np_example['cluster_bias_mask'] = np.pad(
np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),))
return np_example
class DataPipeline:
"""Runs the alignment tools and assembles the input features."""
def __init__(self,
monomer_data_pipeline: pipeline.DataPipeline,
jackhmmer_binary_path: str,
uniprot_database_path: str,
max_uniprot_hits: int = 50000,
use_precomputed_msas: bool = False):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
self._monomer_data_pipeline = monomer_data_pipeline
self._uniprot_msa_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniprot_database_path)
self._max_uniprot_hits = max_uniprot_hits
self.use_precomputed_msas = use_precomputed_msas
def _process_single_chain(
self,
chain_id: str,
sequence: str,
description: str,
msa_output_dir: str,
is_homomer_or_monomer: bool) -> pipeline.FeatureDict:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str = f'>chain_{chain_id}\n{sequence}\n'
chain_msa_output_dir = os.path.join(msa_output_dir, chain_id)
if not os.path.exists(chain_msa_output_dir):
os.makedirs(chain_msa_output_dir)
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
logging.info('Running monomer pipeline on chain %s: %s',
chain_id, description)
chain_features = self._monomer_data_pipeline.process(
input_fasta_path=chain_fasta_path,
msa_output_dir=chain_msa_output_dir)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path,
chain_msa_output_dir)
chain_features.update(all_seq_msa_features)
return chain_features
def _all_seq_msa_features(self, input_fasta_path, msa_output_dir):
"""Get MSA features for unclustered uniprot, for pairing."""
out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto')
result = pipeline.run_msa_tool(
self._uniprot_msa_runner, input_fasta_path, out_path, 'sto',
self.use_precomputed_msas)
msa = parsers.parse_stockholm(result['sto'])
msa = msa.truncate(max_seqs=self._max_uniprot_hits)
all_seq_features = pipeline.make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers',
)
feats = {f'{k}_all_seq': v for k, v in all_seq_features.items()
if k in valid_feats}
return feats
def process(self,
input_fasta_path: str,
msa_output_dir: str) -> pipeline.FeatureDict:
"""Runs alignment tools on the input sequences and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
chain_id_map = _make_chain_id_map(sequences=input_seqs,
descriptions=input_descs)
chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json')
with open(chain_id_map_path, 'w') as f:
chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain)
for chain_id, fasta_chain in chain_id_map.items()}
json.dump(chain_id_map_dict, f, indent=4, sort_keys=True)
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
for chain_id, fasta_chain in chain_id_map.items():
if fasta_chain.sequence in sequence_features:
all_chain_features[chain_id] = copy.deepcopy(
sequence_features[fasta_chain.sequence])
continue
chain_features = self._process_single_chain(
chain_id=chain_id,
sequence=fasta_chain.sequence,
description=fasta_chain.description,
msa_output_dir=msa_output_dir,
is_homomer_or_monomer=is_homomer_or_monomer)
chain_features = convert_monomer_features(chain_features,
chain_id=chain_id)
all_chain_features[chain_id] = chain_features
sequence_features[fasta_chain.sequence] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing.pair_and_merge(
all_chain_features=all_chain_features)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for getting templates and calculating template features."""
import abc
import dataclasses
import datetime
import functools
import glob
import os
import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
from absl import logging
from alphafold.common import residue_constants
from alphafold.data import mmcif_parsing
from alphafold.data import parsers
from alphafold.data.tools import kalign
import numpy as np
# Internal import (7716).
class Error(Exception):
"""Base class for exceptions."""
class NoChainsError(Error):
"""An error indicating that template mmCIF didn't have any chains."""
class SequenceNotInTemplateError(Error):
"""An error indicating that template mmCIF didn't contain the sequence."""
class NoAtomDataInTemplateError(Error):
"""An error indicating that template mmCIF didn't contain atom positions."""
class TemplateAtomMaskAllZerosError(Error):
"""An error indicating that template mmCIF had all atom positions masked."""
class QueryToTemplateAlignError(Error):
"""An error indicating that the query can't be aligned to the template."""
class CaDistanceError(Error):
"""An error indicating that a CA atom distance exceeds a threshold."""
class MultipleChainsError(Error):
"""An error indicating that multiple chains were found for a given ID."""
# Prefilter exceptions.
class PrefilterError(Exception):
"""A base class for template prefilter exceptions."""
class DateError(PrefilterError):
"""An error indicating that the hit date was after the max allowed date."""
class AlignRatioError(PrefilterError):
"""An error indicating that the hit align ratio to the query was too small."""
class DuplicateError(PrefilterError):
"""An error indicating that the hit was an exact subsequence of the query."""
class LengthError(PrefilterError):
"""An error indicating that the hit was too short."""
TEMPLATE_FEATURES = {
'template_aatype': np.float32,
'template_all_atom_masks': np.float32,
'template_all_atom_positions': np.float32,
'template_domain_names': object,
'template_sequence': object,
'template_sum_probs': np.float32,
}
def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
"""Returns PDB id and chain id for an HHSearch Hit."""
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name)
if not id_match:
raise ValueError(f'hit.name did not start with PDBID_chain: {hit.name}')
pdb_id, chain_id = id_match.group(0).split('_')
return pdb_id.lower(), chain_id
def _is_after_cutoff(
pdb_id: str,
release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: Optional[datetime.datetime]) -> bool:
"""Checks if the template date is after the release date cutoff.
Args:
pdb_id: 4 letter pdb code.
release_dates: Dictionary mapping PDB ids to their structure release dates.
release_date_cutoff: Max release date that is valid for this query.
Returns:
True if the template release date is after the cutoff, False otherwise.
"""
if release_date_cutoff is None:
raise ValueError('The release_date_cutoff must not be None.')
if pdb_id in release_dates:
return release_dates[pdb_id] > release_date_cutoff
else:
# Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here.
return False
def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, Optional[str]]:
"""Parses the data file from PDB that lists which pdb_ids are obsolete."""
with open(obsolete_file_path) as f:
result = {}
for line in f:
line = line.strip()
# Format: Date From To
# 'OBSLTE 06-NOV-19 6G9Y' - Removed, rare
# 'OBSLTE 31-JUL-94 116L 216L' - Replaced, common
# 'OBSLTE 26-SEP-06 2H33 2JM5 2OWI' - Replaced by multiple, rare
if line.startswith('OBSLTE'):
if len(line) > 30:
# Replaced by at least one structure.
from_id = line[20:24].lower()
to_id = line[29:33].lower()
result[from_id] = to_id
elif len(line) == 24:
# Removed.
from_id = line[20:24].lower()
result[from_id] = None
return result
def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
"""Parses release dates file, returns a mapping from PDBs to release dates."""
if path.endswith('txt'):
release_dates = {}
with open(path, 'r') as f:
for line in f:
pdb_id, date = line.split(':')
date = date.strip()
# Python 3.6 doesn't have datetime.date.fromisoformat() which is about
# 90x faster than strptime. However, splitting the string manually is
# about 10x faster than strptime.
release_dates[pdb_id.strip()] = datetime.datetime(
year=int(date[:4]), month=int(date[5:7]), day=int(date[8:10]))
return release_dates
else:
raise ValueError('Invalid format of the release date file %s.' % path)
def _assess_hhsearch_hit(
hit: parsers.TemplateHit,
hit_pdb_code: str,
query_sequence: str,
release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: datetime.datetime,
max_subsequence_ratio: float = 0.95,
min_align_ratio: float = 0.1) -> bool:
"""Determines if template is valid (without parsing the template mmcif file).
Args:
hit: HhrHit for the template.
hit_pdb_code: The 4 letter pdb code of the template hit. This might be
different from the value in the actual hit since the original pdb might
have become obsolete.
query_sequence: Amino acid sequence of the query.
release_dates: Dictionary mapping pdb codes to their structure release
dates.
release_date_cutoff: Max release date that is valid for this query.
max_subsequence_ratio: Exclude any exact matches with this much overlap.
min_align_ratio: Minimum overlap between the template and query.
Returns:
True if the hit passed the prefilter. Raises an exception otherwise.
Raises:
DateError: If the hit date was after the max allowed date.
AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short.
"""
aligned_cols = hit.aligned_cols
align_ratio = aligned_cols / len(query_sequence)
template_sequence = hit.hit_sequence.replace('-', '')
length_ratio = float(len(template_sequence)) / len(query_sequence)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate = (template_sequence in query_sequence and
length_ratio > max_subsequence_ratio)
if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
raise DateError(f'Date ({release_dates[hit_pdb_code]}) > max template date '
f'({release_date_cutoff}).')
if align_ratio <= min_align_ratio:
raise AlignRatioError('Proportion of residues aligned to query too small. '
f'Align ratio: {align_ratio}.')
if duplicate:
raise DuplicateError('Template is an exact subsequence of query with large '
f'coverage. Length ratio: {length_ratio}.')
if len(template_sequence) < 10:
raise LengthError(f'Template too short. Length: {len(template_sequence)}.')
return True
def _find_template_in_pdb(
template_chain_id: str,
template_sequence: str,
mmcif_object: mmcif_parsing.MmcifObject) -> Tuple[str, str, int]:
"""Tries to find the template chain in the given pdb file.
This method tries the three following things in order:
1. Tries if there is an exact match in both the chain ID and the sequence.
If yes, the chain sequence is returned. Otherwise:
2. Tries if there is an exact match only in the sequence.
If yes, the chain sequence is returned. Otherwise:
3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
If yes, the chain sequence is returned.
If none of these succeed, a SequenceNotInTemplateError is thrown.
Args:
template_chain_id: The template chain ID.
template_sequence: The template chain sequence.
mmcif_object: The PDB object to search for the template in.
Returns:
A tuple with:
* The chain sequence that was found to match the template in the PDB object.
* The ID of the chain that is being returned.
* The offset where the template sequence starts in the chain sequence.
Raises:
SequenceNotInTemplateError: If no match is found after the steps described
above.
"""
# Try if there is an exact match in both the chain ID and the (sub)sequence.
pdb_id = mmcif_object.file_id
chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
if chain_sequence and (template_sequence in chain_sequence):
logging.info(
'Found an exact template match %s_%s.', pdb_id, template_chain_id)
mapping_offset = chain_sequence.find(template_sequence)
return chain_sequence, template_chain_id, mapping_offset
# Try if there is an exact match in the (sub)sequence only.
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
if chain_sequence and (template_sequence in chain_sequence):
logging.info('Found a sequence-only match %s_%s.', pdb_id, chain_id)
mapping_offset = chain_sequence.find(template_sequence)
return chain_sequence, chain_id, mapping_offset
# Return a chain sequence that fuzzy matches (X = wildcard) the template.
# Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence]
regex = re.compile(''.join(regex))
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
match = re.search(regex, chain_sequence)
if match:
logging.info('Found a fuzzy sequence-only match %s_%s.', pdb_id, chain_id)
mapping_offset = match.start()
return chain_sequence, chain_id, mapping_offset
# No hits, raise an error.
raise SequenceNotInTemplateError(
'Could not find the template sequence in %s_%s. Template sequence: %s, '
'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence,
mmcif_object.chain_to_seqres))
def _realign_pdb_template_to_query(
old_template_sequence: str,
template_chain_id: str,
mmcif_object: mmcif_parsing.MmcifObject,
old_mapping: Mapping[int, int],
kalign_binary_path: str) -> Tuple[str, Mapping[int, int]]:
"""Aligns template from the mmcif_object to the query.
In case PDB70 contains a different version of the template sequence, we need
to perform a realignment to the actual sequence that is in the mmCIF file.
This method performs such realignment, but returns the new sequence and
mapping only if the sequence in the mmCIF file is 90% identical to the old
sequence.
Note that the old_template_sequence comes from the hit, and contains only that
part of the chain that matches with the query while the new_template_sequence
is the full chain.
Args:
old_template_sequence: The template sequence that was returned by the PDB
template search (typically done using HHSearch).
template_chain_id: The template chain id was returned by the PDB template
search (typically done using HHSearch). This is used to find the right
chain in the mmcif_object chain_to_seqres mapping.
mmcif_object: A mmcif_object which holds the actual template data.
old_mapping: A mapping from the query sequence to the template sequence.
This mapping will be used to compute the new mapping from the query
sequence to the actual mmcif_object template sequence by aligning the
old_template_sequence and the actual template sequence.
kalign_binary_path: The path to a kalign executable.
Returns:
A tuple (new_template_sequence, new_query_to_template_mapping) where:
* new_template_sequence is the actual template sequence that was found in
the mmcif_object.
* new_query_to_template_mapping is the new mapping from the query to the
actual template found in the mmcif_object.
Raises:
QueryToTemplateAlignError:
* If there was an error thrown by the alignment tool.
* Or if the actual template sequence differs by more than 10% from the
old_template_sequence.
"""
aligner = kalign.Kalign(binary_path=kalign_binary_path)
new_template_sequence = mmcif_object.chain_to_seqres.get(
template_chain_id, '')
# Sometimes the template chain id is unknown. But if there is only a single
# sequence within the mmcif_object, it is safe to assume it is that one.
if not new_template_sequence:
if len(mmcif_object.chain_to_seqres) == 1:
logging.info('Could not find %s in %s, but there is only 1 sequence, so '
'using that one.',
template_chain_id,
mmcif_object.file_id)
new_template_sequence = list(mmcif_object.chain_to_seqres.values())[0]
else:
raise QueryToTemplateAlignError(
f'Could not find chain {template_chain_id} in {mmcif_object.file_id}. '
'If there are no mmCIF parsing errors, it is possible it was not a '
'protein chain.')
try:
parsed_a3m = parsers.parse_a3m(
aligner.align([old_template_sequence, new_template_sequence]))
old_aligned_template, new_aligned_template = parsed_a3m.sequences
except Exception as e:
raise QueryToTemplateAlignError(
'Could not align old template %s to template %s (%s_%s). Error: %s' %
(old_template_sequence, new_template_sequence, mmcif_object.file_id,
template_chain_id, str(e)))
logging.info('Old aligned template: %s\nNew aligned template: %s',
old_aligned_template, new_aligned_template)
old_to_new_template_mapping = {}
old_template_index = -1
new_template_index = -1
num_same = 0
for old_template_aa, new_template_aa in zip(
old_aligned_template, new_aligned_template):
if old_template_aa != '-':
old_template_index += 1
if new_template_aa != '-':
new_template_index += 1
if old_template_aa != '-' and new_template_aa != '-':
old_to_new_template_mapping[old_template_index] = new_template_index
if old_template_aa == new_template_aa:
num_same += 1
# Require at least 90 % sequence identity wrt to the shorter of the sequences.
if float(num_same) / min(
len(old_template_sequence), len(new_template_sequence)) < 0.9:
raise QueryToTemplateAlignError(
'Insufficient similarity of the sequence in the database: %s to the '
'actual sequence in the mmCIF file %s_%s: %s. We require at least '
'90 %% similarity wrt to the shorter of the sequences. This is not a '
'problem unless you think this is a template that should be included.' %
(old_template_sequence, mmcif_object.file_id, template_chain_id,
new_template_sequence))
new_query_to_template_mapping = {}
for query_index, old_template_index in old_mapping.items():
new_query_to_template_mapping[query_index] = (
old_to_new_template_mapping.get(old_template_index, -1))
new_template_sequence = new_template_sequence.replace('-', '')
return new_template_sequence, new_query_to_template_mapping
def _check_residue_distances(all_positions: np.ndarray,
all_positions_mask: np.ndarray,
max_ca_ca_distance: float):
"""Checks if the distance between unmasked neighbor residues is ok."""
ca_position = residue_constants.atom_order['CA']
prev_is_unmasked = False
prev_calpha = None
for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
this_is_unmasked = bool(mask[ca_position])
if this_is_unmasked:
this_calpha = coords[ca_position]
if prev_is_unmasked:
distance = np.linalg.norm(this_calpha - prev_calpha)
if distance > max_ca_ca_distance:
raise CaDistanceError(
'The distance between residues %d and %d is %f > limit %f.' % (
i, i + 1, distance, max_ca_ca_distance))
prev_calpha = this_calpha
prev_is_unmasked = this_is_unmasked
def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str,
max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues."""
num_res = len(mmcif_object.chain_to_seqres[auth_chain_id])
relevant_chains = [c for c in mmcif_object.structure.get_chains()
if c.id == auth_chain_id]
if len(relevant_chains) != 1:
raise MultipleChainsError(
f'Expected exactly one chain in structure with id {auth_chain_id}.')
chain = relevant_chains[0]
all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3])
all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num],
dtype=np.int64)
for res_index in range(num_res):
pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][res_index]
if not res_at_position.is_missing:
assert res_at_position.position is not None
res = chain[(res_at_position.hetflag,
res_at_position.position.residue_number,
res_at_position.position.insertion_code)]
for atom in res.get_atoms():
atom_name = atom.get_name()
x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
# Put the coordinates of the selenium atom in the sulphur column.
pos[residue_constants.atom_order['SD']] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1.
cd = residue_constants.atom_order['CD']
nh1 = residue_constants.atom_order['NH1']
nh2 = residue_constants.atom_order['NH2']
if (res.get_resname() == 'ARG' and
all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and
(np.linalg.norm(pos[nh1] - pos[cd]) >
np.linalg.norm(pos[nh2] - pos[cd]))):
pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()
all_positions[res_index] = pos
all_positions_mask[res_index] = mask
_check_residue_distances(
all_positions, all_positions_mask, max_ca_ca_distance)
return all_positions, all_positions_mask
def _extract_template_features(
mmcif_object: mmcif_parsing.MmcifObject,
pdb_id: str,
mapping: Mapping[int, int],
template_sequence: str,
query_sequence: str,
template_chain_id: str,
kalign_binary_path: str) -> Tuple[Dict[str, Any], Optional[str]]:
"""Parses atom positions in the target structure and aligns with the query.
Atoms for each residue in the template structure are indexed to coincide
with their corresponding residue in the query sequence, according to the
alignment mapping provided.
Args:
mmcif_object: mmcif_parsing.MmcifObject representing the template.
pdb_id: PDB code for the template.
mapping: Dictionary mapping indices in the query sequence to indices in
the template sequence.
template_sequence: String describing the amino acid sequence for the
template protein.
query_sequence: String describing the amino acid sequence for the query
protein.
template_chain_id: String ID describing which chain in the structure proto
should be used.
kalign_binary_path: The path to a kalign executable used for template
realignment.
Returns:
A tuple with:
* A dictionary containing the extra features derived from the template
protein structure.
* A warning message if the hit was realigned to the actual mmCIF sequence.
Otherwise None.
Raises:
NoChainsError: If the mmcif object doesn't contain any chains.
SequenceNotInTemplateError: If the given chain id / sequence can't
be found in the mmcif object.
QueryToTemplateAlignError: If the actual template in the mmCIF file
can't be aligned to the query.
NoAtomDataInTemplateError: If the mmcif object doesn't contain
atom positions.
TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
unmasked residues.
"""
if mmcif_object is None or not mmcif_object.chain_to_seqres:
raise NoChainsError('No chains in PDB: %s_%s' % (pdb_id, template_chain_id))
warning = None
try:
seqres, chain_id, mapping_offset = _find_template_in_pdb(
template_chain_id=template_chain_id,
template_sequence=template_sequence,
mmcif_object=mmcif_object)
except SequenceNotInTemplateError:
# If PDB70 contains a different version of the template, we use the sequence
# from the mmcif_object.
chain_id = template_chain_id
warning = (
f'The exact sequence {template_sequence} was not found in '
f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.')
logging.warning(warning)
# This throws an exception if it fails to realign the hit.
seqres, mapping = _realign_pdb_template_to_query(
old_template_sequence=template_sequence,
template_chain_id=template_chain_id,
mmcif_object=mmcif_object,
old_mapping=mapping,
kalign_binary_path=kalign_binary_path)
logging.info('Sequence in %s_%s: %s successfully realigned to %s',
pdb_id, chain_id, template_sequence, seqres)
# The template sequence changed.
template_sequence = seqres
# No mapping offset, the query is aligned to the actual sequence.
mapping_offset = 0
try:
# Essentially set to infinity - we don't want to reject templates unless
# they're really really bad.
all_atom_positions, all_atom_mask = _get_atom_positions(
mmcif_object, chain_id, max_ca_ca_distance=150.0)
except (CaDistanceError, KeyError) as ex:
raise NoAtomDataInTemplateError(
'Could not get atom data (%s_%s): %s' % (pdb_id, chain_id, str(ex))
) from ex
all_atom_positions = np.split(all_atom_positions, all_atom_positions.shape[0])
all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
output_templates_sequence = []
templates_all_atom_positions = []
templates_all_atom_masks = []
for _ in query_sequence:
# Residues in the query_sequence that are not in the template_sequence:
templates_all_atom_positions.append(
np.zeros((residue_constants.atom_type_num, 3)))
templates_all_atom_masks.append(np.zeros(residue_constants.atom_type_num))
output_templates_sequence.append('-')
for k, v in mapping.items():
template_index = v + mapping_offset
templates_all_atom_positions[k] = all_atom_positions[template_index][0]
templates_all_atom_masks[k] = all_atom_masks[template_index][0]
output_templates_sequence[k] = template_sequence[v]
# Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
if np.sum(templates_all_atom_masks) < 5:
raise TemplateAtomMaskAllZerosError(
'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d' %
(pdb_id, chain_id, min(mapping.values()) + mapping_offset,
max(mapping.values()) + mapping_offset))
output_templates_sequence = ''.join(output_templates_sequence)
templates_aatype = residue_constants.sequence_to_onehot(
output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID)
return (
{
'template_all_atom_positions': np.array(templates_all_atom_positions),
'template_all_atom_masks': np.array(templates_all_atom_masks),
'template_sequence': output_templates_sequence.encode(),
'template_aatype': np.array(templates_aatype),
'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(),
},
warning)
def _build_query_to_hit_index_mapping(
hit_query_sequence: str,
hit_sequence: str,
indices_hit: Sequence[int],
indices_query: Sequence[int],
original_query_sequence: str) -> Mapping[int, int]:
"""Gets mapping from indices in original query sequence to indices in the hit.
hit_query_sequence and hit_sequence are two aligned sequences containing gap
characters. hit_query_sequence contains only the part of the original query
sequence that matched the hit. When interpreting the indices from the .hhr, we
need to correct for this to recover a mapping from original query sequence to
the hit sequence.
Args:
hit_query_sequence: The portion of the query sequence that is in the .hhr
hit
hit_sequence: The portion of the hit sequence that is in the .hhr
indices_hit: The indices for each aminoacid relative to the hit sequence
indices_query: The indices for each aminoacid relative to the original query
sequence
original_query_sequence: String describing the original query sequence.
Returns:
Dictionary with indices in the original query sequence as keys and indices
in the hit sequence as values.
"""
# If the hit is empty (no aligned residues), return empty mapping
if not hit_query_sequence:
return {}
# Remove gaps and find the offset of hit.query relative to original query.
hhsearch_query_sequence = hit_query_sequence.replace('-', '')
hit_sequence = hit_sequence.replace('-', '')
hhsearch_query_offset = original_query_sequence.find(hhsearch_query_sequence)
# Index of -1 used for gap characters. Subtract the min index ignoring gaps.
min_idx = min(x for x in indices_hit if x > -1)
fixed_indices_hit = [
x - min_idx if x > -1 else -1 for x in indices_hit
]
min_idx = min(x for x in indices_query if x > -1)
fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query]
# Zip the corrected indices, ignore case where both seqs have gap characters.
mapping = {}
for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
if q_t != -1 and q_i != -1:
if (q_t >= len(hit_sequence) or
q_i + hhsearch_query_offset >= len(original_query_sequence)):
continue
mapping[q_i + hhsearch_query_offset] = q_t
return mapping
@dataclasses.dataclass(frozen=True)
class SingleHitResult:
features: Optional[Mapping[str, Any]]
error: Optional[str]
warning: Optional[str]
@functools.lru_cache(16, typed=False)
def _read_file(path):
with open(path, 'r') as f:
file_data = f.read()
return file_data
def _process_single_hit(
query_sequence: str,
hit: parsers.TemplateHit,
mmcif_dir: str,
max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime],
obsolete_pdbs: Mapping[str, Optional[str]],
kalign_binary_path: str,
strict_error_check: bool = False) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
# This hit has been removed (obsoleted) from PDB, skip it.
if hit_pdb_code in obsolete_pdbs and obsolete_pdbs[hit_pdb_code] is None:
return SingleHitResult(
features=None, error=None, warning=f'Hit {hit_pdb_code} is obsolete.')
if hit_pdb_code not in release_dates:
if hit_pdb_code in obsolete_pdbs:
hit_pdb_code = obsolete_pdbs[hit_pdb_code]
# Pass hit_pdb_code since it might have changed due to the pdb being obsolete.
try:
_assess_hhsearch_hit(
hit=hit,
hit_pdb_code=hit_pdb_code,
query_sequence=query_sequence,
release_dates=release_dates,
release_date_cutoff=max_template_date)
except PrefilterError as e:
msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}'
logging.info(msg)
if strict_error_check and isinstance(e, (DateError, DuplicateError)):
# In strict mode we treat some prefilter cases as errors.
return SingleHitResult(features=None, error=msg, warning=None)
return SingleHitResult(features=None, error=None, warning=None)
mapping = _build_query_to_hit_index_mapping(
hit.query, hit.hit_sequence, hit.indices_hit, hit.indices_query,
query_sequence)
# The mapping is from the query to the actual hit sequence, so we need to
# remove gaps (which regardless have a missing confidence score).
template_sequence = hit.hit_sequence.replace('-', '')
cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif')
logging.debug('Reading PDB entry from %s. Query: %s, template: %s', cif_path,
query_sequence, template_sequence)
# Fail if we can't find the mmCIF file.
cif_string = _read_file(cif_path)
parsing_result = mmcif_parsing.parse(
file_id=hit_pdb_code, mmcif_string=cif_string)
if parsing_result.mmcif_object is not None:
hit_release_date = datetime.datetime.strptime(
parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d')
if hit_release_date > max_template_date:
error = ('Template %s date (%s) > max template date (%s).' %
(hit_pdb_code, hit_release_date, max_template_date))
if strict_error_check:
return SingleHitResult(features=None, error=error, warning=None)
else:
logging.debug(error)
return SingleHitResult(features=None, error=None, warning=None)
try:
features, realign_warning = _extract_template_features(
mmcif_object=parsing_result.mmcif_object,
pdb_id=hit_pdb_code,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=hit_chain_id,
kalign_binary_path=kalign_binary_path)
if hit.sum_probs is None:
features['template_sum_probs'] = [0]
else:
features['template_sum_probs'] = [hit.sum_probs]
# It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still
# computed. In such case the mmCIF parsing errors are not relevant.
return SingleHitResult(
features=features, error=None, warning=realign_warning)
except (NoChainsError, NoAtomDataInTemplateError,
TemplateAtomMaskAllZerosError) as e:
# These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings.
warning = ('%s_%s (sum_probs: %s, rank: %s): feature extracting errors: '
'%s, mmCIF parsing errors: %s'
% (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index,
str(e), parsing_result.errors))
if strict_error_check:
return SingleHitResult(features=None, error=warning, warning=None)
else:
return SingleHitResult(features=None, error=None, warning=warning)
except Error as e:
error = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
'%s, mmCIF parsing errors: %s'
% (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index,
str(e), parsing_result.errors))
return SingleHitResult(features=None, error=error, warning=None)
@dataclasses.dataclass(frozen=True)
class TemplateSearchResult:
features: Mapping[str, Any]
errors: Sequence[str]
warnings: Sequence[str]
class TemplateHitFeaturizer(abc.ABC):
"""An abstract base class for turning template hits to template features."""
def __init__(
self,
mmcif_dir: str,
max_template_date: str,
max_hits: int,
kalign_binary_path: str,
release_dates_path: Optional[str],
obsolete_pdbs_path: Optional[str],
strict_error_check: bool = False):
"""Initializes the Template Search.
Args:
mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
is found by HHSearch, this directory is used to retrieve the template
data.
max_template_date: The maximum date permitted for template structures. No
template with date higher than this date will be returned. In ISO8601
date format, YYYY-MM-DD.
max_hits: The maximum number of templates that will be returned.
kalign_binary_path: The path to a kalign executable used for template
realignment.
release_dates_path: An optional path to a file with a mapping from PDB IDs
to their release dates. Thanks to this we don't have to redundantly
parse mmCIF files to get that information.
obsolete_pdbs_path: An optional path to a file containing a mapping from
obsolete PDB IDs to the PDB IDs of their replacements.
strict_error_check: If True, then the following will be treated as errors:
* If any template date is after the max_template_date.
* If any template has identical PDB ID to the query.
* If any template is a duplicate of the query.
* Any feature computation errors.
"""
self._mmcif_dir = mmcif_dir
if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')):
logging.error('Could not find CIFs in %s', self._mmcif_dir)
raise ValueError(f'Could not find CIFs in {self._mmcif_dir}')
try:
self._max_template_date = datetime.datetime.strptime(
max_template_date, '%Y-%m-%d')
except ValueError:
raise ValueError(
'max_template_date must be set and have format YYYY-MM-DD.')
self._max_hits = max_hits
self._kalign_binary_path = kalign_binary_path
self._strict_error_check = strict_error_check
if release_dates_path:
logging.info('Using precomputed release dates %s.', release_dates_path)
self._release_dates = _parse_release_dates(release_dates_path)
else:
self._release_dates = {}
if obsolete_pdbs_path:
logging.info('Using precomputed obsolete pdbs %s.', obsolete_pdbs_path)
self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
else:
self._obsolete_pdbs = {}
@abc.abstractmethod
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
"""Computes the templates for given query sequence."""
class HhsearchHitFeaturizer(TemplateHitFeaturizer):
"""A class for turning a3m hits from hhsearch to template features."""
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info('Searching for template for: %s', query_sequence)
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
num_hits = 0
errors = []
warnings = []
for hit in sorted(hits, key=lambda x: x.sum_probs, reverse=True):
# We got all the templates we wanted, stop processing hits.
if num_hits >= self._max_hits:
break
result = _process_single_hit(
query_sequence=query_sequence,
hit=hit,
mmcif_dir=self._mmcif_dir,
max_template_date=self._max_template_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path)
if result.error:
errors.append(result.error)
# There could be an error even if there are some results, e.g. thrown by
# other unparsable chains in the same mmCIF file.
if result.warning:
warnings.append(result.warning)
if result.features is None:
logging.info('Skipped invalid hit %s, error: %s, warning: %s',
hit.name, result.error, result.warning)
else:
# Increment the hit counter, since we got features out of this hit.
num_hits += 1
for k in template_features:
template_features[k].append(result.features[k])
for name in template_features:
if num_hits > 0:
template_features[name] = np.stack(
template_features[name], axis=0).astype(TEMPLATE_FEATURES[name])
else:
# Make sure the feature has correct dtype even if empty.
template_features[name] = np.array([], dtype=TEMPLATE_FEATURES[name])
return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings)
class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
"""A class for turning a3m hits from hmmsearch to template features."""
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info('Searching for template for: %s', query_sequence)
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
already_seen = set()
errors = []
warnings = []
if not hits or hits[0].sum_probs is None:
sorted_hits = hits
else:
sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True)
for hit in sorted_hits:
# We got all the templates we wanted, stop processing hits.
if len(already_seen) >= self._max_hits:
break
result = _process_single_hit(
query_sequence=query_sequence,
hit=hit,
mmcif_dir=self._mmcif_dir,
max_template_date=self._max_template_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path)
if result.error:
errors.append(result.error)
# There could be an error even if there are some results, e.g. thrown by
# other unparsable chains in the same mmCIF file.
if result.warning:
warnings.append(result.warning)
if result.features is None:
logging.debug('Skipped invalid hit %s, error: %s, warning: %s',
hit.name, result.error, result.warning)
else:
already_seen_key = result.features['template_sequence']
if already_seen_key in already_seen:
continue
# Increment the hit counter, since we got features out of this hit.
already_seen.add(already_seen_key)
for k in template_features:
template_features[k].append(result.features[k])
if already_seen:
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0).astype(TEMPLATE_FEATURES[name])
else:
num_res = len(query_sequence)
# Construct a default template with all zeros.
template_features = {
'template_aatype': np.zeros(
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32),
'template_all_atom_masks': np.zeros(
(1, num_res, residue_constants.atom_type_num), np.float32),
'template_all_atom_positions': np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32),
'template_domain_names': np.array([''.encode()], dtype=object),
'template_sequence': np.array([''.encode()], dtype=object),
'template_sum_probs': np.array([0], dtype=np.float32)
}
return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python wrappers for third party tools."""
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run HHblits from Python."""
import glob
import os
import subprocess
from typing import Any, List, Mapping, Optional, Sequence
from absl import logging
from alphafold.data.tools import utils
# Internal import (7716).
_HHBLITS_DEFAULT_P = 20
_HHBLITS_DEFAULT_Z = 500
class HHBlits:
"""Python wrapper of the HHblits binary."""
def __init__(self,
*,
binary_path: str,
databases: Sequence[str],
n_cpu: int = 4,
n_iter: int = 3,
e_value: float = 0.001,
maxseq: int = 1_000_000,
realign_max: int = 100_000,
maxfilt: int = 100_000,
min_prefilter_hits: int = 1000,
all_seqs: bool = False,
alt: Optional[int] = None,
p: int = _HHBLITS_DEFAULT_P,
z: int = _HHBLITS_DEFAULT_Z):
"""Initializes the Python HHblits wrapper.
Args:
binary_path: The path to the HHblits executable.
databases: A sequence of HHblits database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to give HHblits.
n_iter: The number of HHblits iterations.
e_value: The E-value, see HHblits docs for more details.
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
maxfilt: Max number of hits allowed to pass the 2nd prefilter.
HHblits default: 20000.
min_prefilter_hits: Min number of hits to pass prefilter.
HHblits default: 100.
all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
HHblits default: False.
alt: Show up to this many alternative alignments.
p: Minimum Prob for a hit to be included in the output hhr file.
HHblits default: 20.
z: Hard cap on number of hits reported in the hhr file.
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
Raises:
RuntimeError: If HHblits binary not found within the path.
"""
self.binary_path = binary_path
self.databases = databases
for database_path in self.databases:
if not glob.glob(database_path + '_*'):
logging.error('Could not find HHBlits database %s', database_path)
raise ValueError(f'Could not find HHBlits database {database_path}')
self.n_cpu = n_cpu
self.n_iter = n_iter
self.e_value = e_value
self.maxseq = maxseq
self.realign_max = realign_max
self.maxfilt = maxfilt
self.min_prefilter_hits = min_prefilter_hits
self.all_seqs = all_seqs
self.alt = alt
self.p = p
self.z = z
def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]:
"""Queries the database using HHblits."""
with utils.tmpdir_manager() as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
db_cmd = []
for db_path in self.databases:
db_cmd.append('-d')
db_cmd.append(db_path)
cmd = [
self.binary_path,
'-i', input_fasta_path,
'-cpu', str(self.n_cpu),
'-oa3m', a3m_path,
'-o', '/dev/null',
'-n', str(self.n_iter),
'-e', str(self.e_value),
'-maxseq', str(self.maxseq),
'-realign_max', str(self.realign_max),
'-maxfilt', str(self.maxfilt),
'-min_prefilter_hits', str(self.min_prefilter_hits)]
if self.all_seqs:
cmd += ['-all']
if self.alt:
cmd += ['-alt', str(self.alt)]
if self.p != _HHBLITS_DEFAULT_P:
cmd += ['-p', str(self.p)]
if self.z != _HHBLITS_DEFAULT_Z:
cmd += ['-Z', str(self.z)]
cmd += db_cmd
logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing('HHblits query'):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
# Logs have a 15k character limit, so log HHblits error line by line.
logging.error('HHblits failed. HHblits stderr begin:')
for error_line in stderr.decode('utf-8').splitlines():
if error_line.strip():
logging.error(error_line.strip())
logging.error('HHblits stderr end')
raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))
with open(a3m_path) as f:
a3m = f.read()
raw_output = dict(
a3m=a3m,
output=stdout,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value)
return [raw_output]
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run HHsearch from Python."""
import glob
import os
import subprocess
from typing import Sequence
from absl import logging
from alphafold.data import parsers
from alphafold.data.tools import utils
# Internal import (7716).
class HHSearch:
"""Python wrapper of the HHsearch binary."""
def __init__(self,
*,
binary_path: str,
databases: Sequence[str],
maxseq: int = 1_000_000):
"""Initializes the Python HHsearch wrapper.
Args:
binary_path: The path to the HHsearch executable.
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
Raises:
RuntimeError: If HHsearch binary not found within the path.
"""
self.binary_path = binary_path
self.databases = databases
self.maxseq = maxseq
for database_path in self.databases:
if not glob.glob(database_path + '_*'):
logging.error('Could not find HHsearch database %s', database_path)
raise ValueError(f'Could not find HHsearch database {database_path}')
@property
def output_format(self) -> str:
return 'hhr'
@property
def input_format(self) -> str:
return 'a3m'
def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager() as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, 'query.a3m')
hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
with open(input_path, 'w') as f:
f.write(a3m)
db_cmd = []
for db_path in self.databases:
db_cmd.append('-d')
db_cmd.append(db_path)
cmd = [self.binary_path,
'-i', input_path,
'-o', hhr_path,
'-maxseq', str(self.maxseq)
] + db_cmd
logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing('HHsearch query'):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
# Stderr is truncated to prevent proto size errors in Beam.
raise RuntimeError(
'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr[:100_000].decode('utf-8')))
with open(hhr_path) as f:
hhr = f.read()
return hhr
def get_template_hits(self,
output_string: str,
input_sequence: str) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
del input_sequence # Used by hmmseach but not needed for hhsearch.
return parsers.parse_hhr(output_string)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import os
import re
import subprocess
from absl import logging
from alphafold.data.tools import utils
# Internal import (7716).
class Hmmbuild(object):
"""Python wrapper of the hmmbuild binary."""
def __init__(self,
*,
binary_path: str,
singlemx: bool = False):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self.binary_path = binary_path
self.singlemx = singlemx
def build_profile_from_sto(self, sto: str, model_construction='fast') -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return self._build_profile(sto, model_construction=model_construction)
def build_profile_from_a3m(self, a3m: str) -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines = []
for line in a3m.splitlines():
if not line.startswith('>'):
line = re.sub('[a-z]+', '', line) # Remove inserted residues.
lines.append(line + '\n')
msa = ''.join(lines)
return self._build_profile(msa, model_construction='fast')
def _build_profile(self, msa: str, model_construction: str = 'fast') -> str:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if model_construction not in {'hand', 'fast'}:
raise ValueError(f'Invalid model_construction {model_construction} - only'
'hand and fast supported.')
with utils.tmpdir_manager() as query_tmp_dir:
input_query = os.path.join(query_tmp_dir, 'query.msa')
output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
with open(input_query, 'w') as f:
f.write(msa)
cmd = [self.binary_path]
# If adding flags, we have to do so before the output and input:
if model_construction == 'hand':
cmd.append(f'--{model_construction}')
if self.singlemx:
cmd.append('--singlemx')
cmd.extend([
'--amino',
output_hmm_path,
input_query,
])
logging.info('Launching subprocess %s', cmd)
process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
with utils.timing('hmmbuild query'):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
stdout.decode('utf-8'), stderr.decode('utf-8'))
if retcode:
raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n'
% (stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(output_hmm_path, encoding='utf-8') as f:
hmm = f.read()
return hmm
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import os
import subprocess
from typing import Optional, Sequence
from absl import logging
from alphafold.data import parsers
from alphafold.data.tools import hmmbuild
from alphafold.data.tools import utils
# Internal import (7716).
class Hmmsearch(object):
"""Python wrapper of the hmmsearch binary."""
def __init__(self,
*,
binary_path: str,
hmmbuild_binary_path: str,
database_path: str,
flags: Optional[Sequence[str]] = None):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self.binary_path = binary_path
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
self.database_path = database_path
if flags is None:
# Default hmmsearch run settings.
flags = ['--F1', '0.1',
'--F2', '0.1',
'--F3', '0.1',
'--incE', '100',
'-E', '100',
'--domE', '100',
'--incdomE', '100']
self.flags = flags
if not os.path.exists(self.database_path):
logging.error('Could not find hmmsearch database %s', database_path)
raise ValueError(f'Could not find hmmsearch database {database_path}')
@property
def output_format(self) -> str:
return 'sto'
@property
def input_format(self) -> str:
return 'sto'
def query(self, msa_sto: str) -> str:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm = self.hmmbuild_runner.build_profile_from_sto(msa_sto,
model_construction='hand')
return self.query_with_hmm(hmm)
def query_with_hmm(self, hmm: str) -> str:
"""Queries the database using hmmsearch using a given hmm."""
with utils.tmpdir_manager() as query_tmp_dir:
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
out_path = os.path.join(query_tmp_dir, 'output.sto')
with open(hmm_input_path, 'w') as f:
f.write(hmm)
cmd = [
self.binary_path,
'--noali', # Don't include the alignment in stdout.
'--cpu', '8'
]
# If adding flags, we have to do so before the output and input:
if self.flags:
cmd.extend(self.flags)
cmd.extend([
'-A', out_path,
hmm_input_path,
self.database_path,
])
logging.info('Launching sub-process %s', cmd)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing(
f'hmmsearch ({os.path.basename(self.database_path)}) query'):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
raise RuntimeError(
'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(out_path) as f:
out_msa = f.read()
return out_msa
def get_template_hits(self,
output_string: str,
input_sequence: str) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string = parsers.convert_stockholm_to_a3m(output_string,
remove_first_row_gaps=False)
template_hits = parsers.parse_hmmsearch_a3m(
query_sequence=input_sequence,
a3m_string=a3m_string,
skip_first=False)
return template_hits
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run Jackhmmer from Python."""
from concurrent import futures
import glob
import os
import subprocess
from typing import Any, Callable, Mapping, Optional, Sequence
from urllib import request
from absl import logging
from alphafold.data import parsers
from alphafold.data.tools import utils
# Internal import (7716).
class Jackhmmer:
"""Python wrapper of the Jackhmmer binary."""
def __init__(self,
*,
binary_path: str,
database_path: str,
n_cpu: int = 8,
n_iter: int = 1,
e_value: float = 0.0001,
z_value: Optional[int] = None,
get_tblout: bool = False,
filter_f1: float = 0.0005,
filter_f2: float = 0.00005,
filter_f3: float = 0.0000005,
incdom_e: Optional[float] = None,
dom_e: Optional[float] = None,
num_streamed_chunks: Optional[int] = None,
streaming_callback: Optional[Callable[[int], None]] = None):
"""Initializes the Python Jackhmmer wrapper.
Args:
binary_path: The path to the jackhmmer executable.
database_path: The path to the jackhmmer database (FASTA format).
n_cpu: The number of CPUs to give Jackhmmer.
n_iter: The number of Jackhmmer iterations.
e_value: The E-value, see Jackhmmer docs for more details.
z_value: The Z-value, see Jackhmmer docs for more details.
get_tblout: Whether to save tblout string.
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
filter_f3: Forward pre-filter, set to >1.0 to turn off.
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
round.
dom_e: Domain e-value criteria for inclusion in tblout.
num_streamed_chunks: Number of database chunks to stream over.
streaming_callback: Callback function run after each chunk iteration with
the iteration number as argument.
"""
self.binary_path = binary_path
self.database_path = database_path
self.num_streamed_chunks = num_streamed_chunks
if not os.path.exists(self.database_path) and num_streamed_chunks is None:
logging.error('Could not find Jackhmmer database %s', database_path)
raise ValueError(f'Could not find Jackhmmer database {database_path}')
self.n_cpu = n_cpu
self.n_iter = n_iter
self.e_value = e_value
self.z_value = z_value
self.filter_f1 = filter_f1
self.filter_f2 = filter_f2
self.filter_f3 = filter_f3
self.incdom_e = incdom_e
self.dom_e = dom_e
self.get_tblout = get_tblout
self.streaming_callback = streaming_callback
def _query_chunk(self,
input_fasta_path: str,
database_path: str,
max_sequences: Optional[int] = None) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager() as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, 'output.sto')
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these
# speeds up the pipeline at the expensive of sensitivity. They are
# currently set very low to make querying Mgnify run in a reasonable
# amount of time.
cmd_flags = [
# Don't pollute stdout with Jackhmmer output.
'-o', '/dev/null',
'-A', sto_path,
'--noali',
'--F1', str(self.filter_f1),
'--F2', str(self.filter_f2),
'--F3', str(self.filter_f3),
'--incE', str(self.e_value),
# Report only sequences with E-values <= x in per-sequence output.
'-E', str(self.e_value),
'--cpu', str(self.n_cpu),
'-N', str(self.n_iter)
]
if self.get_tblout:
tblout_path = os.path.join(query_tmp_dir, 'tblout.txt')
cmd_flags.extend(['--tblout', tblout_path])
if self.z_value:
cmd_flags.extend(['-Z', str(self.z_value)])
if self.dom_e is not None:
cmd_flags.extend(['--domE', str(self.dom_e)])
if self.incdom_e is not None:
cmd_flags.extend(['--incdomE', str(self.incdom_e)])
cmd = [self.binary_path] + cmd_flags + [input_fasta_path,
database_path]
logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing(
f'Jackhmmer ({os.path.basename(database_path)}) query'):
_, stderr = process.communicate()
retcode = process.wait()
if retcode:
raise RuntimeError(
'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8'))
# Get e-values for each target name
tbl = ''
if self.get_tblout:
with open(tblout_path) as f:
tbl = f.read()
if max_sequences is None:
with open(sto_path) as f:
sto = f.read()
else:
sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
raw_output = dict(
sto=sto,
tbl=tbl,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value)
return raw_output
def query(self,
input_fasta_path: str,
max_sequences: Optional[int] = None) -> Sequence[Mapping[str, Any]]:
"""Queries the database using Jackhmmer."""
return self.query_multiple([input_fasta_path], max_sequences)[0]
def query_multiple(
self,
input_fasta_paths: Sequence[str],
max_sequences: Optional[int] = None,
) -> Sequence[Sequence[Mapping[str, Any]]]:
"""Queries the database for multiple queries using Jackhmmer."""
if self.num_streamed_chunks is None:
single_chunk_results = []
for input_fasta_path in input_fasta_paths:
single_chunk_results.append([self._query_chunk(
input_fasta_path, self.database_path, max_sequences)])
return single_chunk_results
db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}'
# Remove existing files to prevent OOM
for f in glob.glob(db_local_chunk('[0-9]*')):
try:
os.remove(f)
except OSError:
print(f'OSError while deleting {f}')
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with futures.ThreadPoolExecutor(max_workers=2) as executor:
chunked_outputs = [[] for _ in range(len(input_fasta_paths))]
for i in range(1, self.num_streamed_chunks + 1):
# Copy the chunk locally
if i == 1:
future = executor.submit(
request.urlretrieve, db_remote_chunk(i), db_local_chunk(i))
if i < self.num_streamed_chunks:
next_future = executor.submit(
request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1))
# Run Jackhmmer with the chunk
future.result()
for fasta_index, input_fasta_path in enumerate(input_fasta_paths):
chunked_outputs[fasta_index].append(self._query_chunk(
input_fasta_path, db_local_chunk(i), max_sequences))
# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
# Do not set next_future for the last chunk so that this works even for
# databases with only 1 chunk.
if i < self.num_streamed_chunks:
future = next_future
if self.streaming_callback:
self.streaming_callback(i)
return chunked_outputs
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for Kalign."""
import os
import subprocess
from typing import Sequence
from absl import logging
from alphafold.data.tools import utils
# Internal import (7716).
def _to_a3m(sequences: Sequence[str]) -> str:
"""Converts sequences to an a3m file."""
names = ['sequence %d' % i for i in range(1, len(sequences) + 1)]
a3m = []
for sequence, name in zip(sequences, names):
a3m.append(u'>' + name + u'\n')
a3m.append(sequence + u'\n')
return ''.join(a3m)
class Kalign:
"""Python wrapper of the Kalign binary."""
def __init__(self, *, binary_path: str):
"""Initializes the Python Kalign wrapper.
Args:
binary_path: The path to the Kalign binary.
Raises:
RuntimeError: If Kalign binary not found within the path.
"""
self.binary_path = binary_path
def align(self, sequences: Sequence[str]) -> str:
"""Aligns the sequences and returns the alignment in A3M string.
Args:
sequences: A list of query sequence strings. The sequences have to be at
least 6 residues long (Kalign requires this). Note that the order in
which you give the sequences might alter the output slightly as
different alignment tree might get constructed.
Returns:
A string with the alignment in a3m format.
Raises:
RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long.
"""
logging.info('Aligning %d sequences', len(sequences))
for s in sequences:
if len(s) < 6:
raise ValueError('Kalign requires all sequences to be at least 6 '
'residues long. Got %s (%d residues).' % (s, len(s)))
with utils.tmpdir_manager() as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
with open(input_fasta_path, 'w') as f:
f.write(_to_a3m(sequences))
cmd = [
self.binary_path,
'-i', input_fasta_path,
'-o', output_a3m_path,
'-format', 'fasta',
]
logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
with utils.timing('Kalign query'):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n',
stdout.decode('utf-8'), stderr.decode('utf-8'))
if retcode:
raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n'
% (stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(output_a3m_path) as f:
a3m = f.read()
return a3m
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for data pipeline tools."""
import contextlib
import shutil
import tempfile
import time
from typing import Optional
from absl import logging
@contextlib.contextmanager
def tmpdir_manager(base_dir: Optional[str] = None):
"""Context manager that deletes a temporary directory on exit."""
tmpdir = tempfile.mkdtemp(dir=base_dir)
try:
yield tmpdir
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
@contextlib.contextmanager
def timing(msg: str):
logging.info('Started %s', msg)
tic = time.time()
yield
toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alphafold model."""
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ops for all atom representations.
Generally we employ two different representations for all atom coordinates,
one is atom37 where each heavy atom corresponds to a given position in a 37
dimensional array, This mapping is non amino acid specific, but each slot
corresponds to an atom of a given name, for example slot 12 always corresponds
to 'C delta 1', positions that are not present for a given amino acid are
zeroed out and denoted by a mask.
The other representation we employ is called atom14, this is a more dense way
of representing atoms with 14 slots. Here a given slot will correspond to a
different kind of atom depending on amino acid type, for example slot 5
corresponds to 'N delta 2' for Aspargine, but to 'C delta 1' for Isoleucine.
14 is chosen because it is the maximum number of heavy atoms for any standard
amino acid.
The order of slots can be found in 'residue_constants.residue_atoms'.
Internally the model uses the atom14 representation because it is
computationally more efficient.
The internal atom14 representation is turned into the atom37 at the output of
the network to facilitate easier conversion to existing protein datastructures.
"""
from typing import Dict, Optional
from alphafold.common import residue_constants
from alphafold.model import r3
from alphafold.model import utils
import jax
import jax.numpy as jnp
import numpy as np
def squared_difference(x, y):
return jnp.square(x - y)
def get_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in residue_constants.restypes:
residue_name = residue_constants.restype_1to3[residue_name]
residue_chi_angles = residue_constants.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[residue_constants.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return jnp.asarray(chi_atom_indices)
def atom14_to_atom37(atom14_data: jnp.ndarray, # (N, 14, ...)
batch: Dict[str, jnp.ndarray]
) -> jnp.ndarray: # (N, 37, ...)
"""Convert atom14 to atom37 representation."""
assert len(atom14_data.shape) in [2, 3]
assert 'residx_atom37_to_atom14' in batch
assert 'atom37_atom_exists' in batch
atom37_data = utils.batched_gather(atom14_data,
batch['residx_atom37_to_atom14'],
batch_dims=1)
if len(atom14_data.shape) == 2:
atom37_data *= batch['atom37_atom_exists']
elif len(atom14_data.shape) == 3:
atom37_data *= batch['atom37_atom_exists'][:, :,
None].astype(atom37_data.dtype)
return atom37_data
def atom37_to_atom14(
atom37_data: jnp.ndarray, # (N, 37, ...)
batch: Dict[str, jnp.ndarray]) -> jnp.ndarray: # (N, 14, ...)
"""Convert atom14 to atom37 representation."""
assert len(atom37_data.shape) in [2, 3]
assert 'residx_atom14_to_atom37' in batch
assert 'atom14_atom_exists' in batch
atom14_data = utils.batched_gather(atom37_data,
batch['residx_atom14_to_atom37'],
batch_dims=1)
if len(atom37_data.shape) == 2:
atom14_data *= batch['atom14_atom_exists'].astype(atom14_data.dtype)
elif len(atom37_data.shape) == 3:
atom14_data *= batch['atom14_atom_exists'][:, :,
None].astype(atom14_data.dtype)
return atom14_data
def atom37_to_frames(
aatype: jnp.ndarray, # (...)
all_atom_positions: jnp.ndarray, # (..., 37, 3)
all_atom_mask: jnp.ndarray, # (..., 37)
) -> Dict[str, jnp.ndarray]:
"""Computes the frames for the up to 8 rigid groups for each residue.
The rigid groups are defined by the possible torsions in a given amino acid.
We group the atoms according to their dependence on the torsion angles into
"rigid groups". E.g., the position of atoms in the chi2-group depend on
chi1 and chi2, but do not depend on chi3 or chi4.
Jumper et al. (2021) Suppl. Table 2 and corresponding text.
Args:
aatype: Amino acid type, given as array with integers.
all_atom_positions: atom37 representation of all atom coordinates.
all_atom_mask: atom37 representation of mask on all atom coordinates.
Returns:
Dictionary containing:
* 'rigidgroups_gt_frames': 8 Frames corresponding to 'all_atom_positions'
represented as flat 12 dimensional array.
* 'rigidgroups_gt_exists': Mask denoting whether the atom positions for
the given frame are available in the ground truth, e.g. if they were
resolved in the experiment.
* 'rigidgroups_group_exists': Mask denoting whether given group is in
principle present for given amino acid type.
* 'rigidgroups_group_is_ambiguous': Mask denoting whether frame is
affected by naming ambiguity.
* 'rigidgroups_alt_gt_frames': 8 Frames with alternative atom renaming
corresponding to 'all_atom_positions' represented as flat
12 dimensional array.
"""
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
aatype_in_shape = aatype.shape
# If there is a batch axis, just flatten it away, and reshape everything
# back at the end of the function.
aatype = jnp.reshape(aatype, [-1])
all_atom_positions = jnp.reshape(all_atom_positions, [-1, 37, 3])
all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37])
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object)
# 0: backbone frame
restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N']
# 3: 'psi-group'
restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O']
# 4,5,6,7: 'chi1,2,3,4-group'
for restype, restype_letter in enumerate(residue_constants.restypes):
resname = residue_constants.restype_1to3[restype_letter]
for chi_idx in range(4):
if residue_constants.chi_angles_mask[restype][chi_idx]:
atom_names = residue_constants.chi_angles_atoms[resname][chi_idx]
restype_rigidgroup_base_atom_names[
restype, chi_idx + 4, :] = atom_names[1:]
# Create mask for existing rigid groups.
restype_rigidgroup_mask = np.zeros([21, 8], dtype=np.float32)
restype_rigidgroup_mask[:, 0] = 1
restype_rigidgroup_mask[:, 3] = 1
restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask
# Translate atom names into atom37 indices.
lookuptable = residue_constants.atom_order.copy()
lookuptable[''] = 0
restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])(
restype_rigidgroup_base_atom_names)
# Compute the gather indices for all residues in the chain.
# shape (N, 8, 3)
residx_rigidgroup_base_atom37_idx = utils.batched_gather(
restype_rigidgroup_base_atom37_idx, aatype)
# Gather the base atom positions for each rigid group.
base_atom_pos = utils.batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
batch_dims=1)
# Compute the Rigids.
gt_frames = r3.rigids_from_3_points(
point_on_neg_x_axis=r3.vecs_from_tensor(base_atom_pos[:, :, 0, :]),
origin=r3.vecs_from_tensor(base_atom_pos[:, :, 1, :]),
point_on_xy_plane=r3.vecs_from_tensor(base_atom_pos[:, :, 2, :])
)
# Compute a mask whether the group exists.
# (N, 8)
group_exists = utils.batched_gather(restype_rigidgroup_mask, aatype)
# Compute a mask whether ground truth exists for the group
gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3)
all_atom_mask.astype(jnp.float32),
residx_rigidgroup_base_atom37_idx,
batch_dims=1)
gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1])
rots[0, 0, 0] = -1
rots[0, 2, 2] = -1
gt_frames = r3.rigids_mul_rots(gt_frames, r3.rots_from_tensor3x3(rots))
# The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32)
restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1])
for resname, _ in residue_constants.residue_atom_renaming_swaps.items():
restype = residue_constants.restype_order[
residue_constants.restype_3to1[resname]]
chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1
restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1
# Gather the ambiguity information for each residue.
residx_rigidgroup_is_ambiguous = utils.batched_gather(
restype_rigidgroup_is_ambiguous, aatype)
residx_rigidgroup_ambiguity_rot = utils.batched_gather(
restype_rigidgroup_rots, aatype)
# Create the alternative ground truth frames.
alt_gt_frames = r3.rigids_mul_rots(
gt_frames, r3.rots_from_tensor3x3(residx_rigidgroup_ambiguity_rot))
gt_frames_flat12 = r3.rigids_to_tensor_flat12(gt_frames)
alt_gt_frames_flat12 = r3.rigids_to_tensor_flat12(alt_gt_frames)
# reshape back to original residue layout
gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12))
gt_exists = jnp.reshape(gt_exists, aatype_in_shape + (8,))
group_exists = jnp.reshape(group_exists, aatype_in_shape + (8,))
gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12))
residx_rigidgroup_is_ambiguous = jnp.reshape(residx_rigidgroup_is_ambiguous,
aatype_in_shape + (8,))
alt_gt_frames_flat12 = jnp.reshape(alt_gt_frames_flat12,
aatype_in_shape + (8, 12,))
return {
'rigidgroups_gt_frames': gt_frames_flat12, # (..., 8, 12)
'rigidgroups_gt_exists': gt_exists, # (..., 8)
'rigidgroups_group_exists': group_exists, # (..., 8)
'rigidgroups_group_is_ambiguous':
residx_rigidgroup_is_ambiguous, # (..., 8)
'rigidgroups_alt_gt_frames': alt_gt_frames_flat12, # (..., 8, 12)
}
def atom37_to_torsion_angles(
aatype: jnp.ndarray, # (B, N)
all_atom_pos: jnp.ndarray, # (B, N, 37, 3)
all_atom_mask: jnp.ndarray, # (B, N, 37)
placeholder_for_undefined=False,
) -> Dict[str, jnp.ndarray]:
"""Computes the 7 torsion angles (in sin, cos encoding) for each residue.
The 7 torsion angles are in the order
'[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]',
here pre_omega denotes the omega torsion angle between the given amino acid
and the previous amino acid.
Args:
aatype: Amino acid type, given as array with integers.
all_atom_pos: atom37 representation of all atom coordinates.
all_atom_mask: atom37 representation of mask on all atom coordinates.
placeholder_for_undefined: flag denoting whether to set masked torsion
angles to zero.
Returns:
Dict containing:
* 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final
2 dimensions denote sin and cos respectively
* 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but
with the angle shifted by pi for all chi angles affected by the naming
ambiguities.
* 'torsion_angles_mask': Mask for which chi angles are present.
"""
# Map aatype > 20 to 'Unknown' (20).
aatype = jnp.minimum(aatype, 20)
# Compute the backbone angles.
num_batch, num_res = aatype.shape
pad = jnp.zeros([num_batch, 1, 37, 3], jnp.float32)
prev_all_atom_pos = jnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1)
pad = jnp.zeros([num_batch, 1, 37], jnp.float32)
prev_all_atom_mask = jnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1)
# For each torsion angle collect the 4 atom positions that define this angle.
# shape (B, N, atoms=4, xyz=3)
pre_omega_atom_pos = jnp.concatenate(
[prev_all_atom_pos[:, :, 1:3, :], # prev CA, C
all_atom_pos[:, :, 0:2, :] # this N, CA
], axis=-2)
phi_atom_pos = jnp.concatenate(
[prev_all_atom_pos[:, :, 2:3, :], # prev C
all_atom_pos[:, :, 0:3, :] # this N, CA, C
], axis=-2)
psi_atom_pos = jnp.concatenate(
[all_atom_pos[:, :, 0:3, :], # this N, CA, C
all_atom_pos[:, :, 4:5, :] # this O
], axis=-2)
# Collect the masks from these atoms.
# Shape [batch, num_res]
pre_omega_mask = (
jnp.prod(prev_all_atom_mask[:, :, 1:3], axis=-1) # prev CA, C
* jnp.prod(all_atom_mask[:, :, 0:2], axis=-1)) # this N, CA
phi_mask = (
prev_all_atom_mask[:, :, 2] # prev C
* jnp.prod(all_atom_mask[:, :, 0:3], axis=-1)) # this N, CA, C
psi_mask = (
jnp.prod(all_atom_mask[:, :, 0:3], axis=-1) * # this N, CA, C
all_atom_mask[:, :, 4]) # this O
# Collect the atoms for the chi-angles.
# Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
chi_atom_indices = get_chi_atom_indices()
# Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4].
atom_indices = utils.batched_gather(
params=chi_atom_indices, indices=aatype, axis=0, batch_dims=0)
# Gather atom positions. Shape: [batch, num_res, chis=4, atoms=4, xyz=3].
chis_atom_pos = utils.batched_gather(
params=all_atom_pos, indices=atom_indices, axis=-2,
batch_dims=2)
# Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
chi_angles_mask = list(residue_constants.chi_angles_mask)
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
chi_angles_mask = jnp.asarray(chi_angles_mask)
# Compute the chi angle mask. I.e. which chis angles exist according to the
# aatype. Shape [batch, num_res, chis=4].
chis_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype,
axis=0, batch_dims=0)
# Constrain the chis_mask to those chis, where the ground truth coordinates of
# all defining four atoms are available.
# Gather the chi angle atoms mask. Shape: [batch, num_res, chis=4, atoms=4].
chi_angle_atoms_mask = utils.batched_gather(
params=all_atom_mask, indices=atom_indices, axis=-1,
batch_dims=2)
# Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4].
chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1])
chis_mask = chis_mask * (chi_angle_atoms_mask).astype(jnp.float32)
# Stack all torsion angle atom positions.
# Shape (B, N, torsions=7, atoms=4, xyz=3)
torsions_atom_pos = jnp.concatenate(
[pre_omega_atom_pos[:, :, None, :, :],
phi_atom_pos[:, :, None, :, :],
psi_atom_pos[:, :, None, :, :],
chis_atom_pos
], axis=2)
# Stack up masks for all torsion angles.
# shape (B, N, torsions=7)
torsion_angles_mask = jnp.concatenate(
[pre_omega_mask[:, :, None],
phi_mask[:, :, None],
psi_mask[:, :, None],
chis_mask
], axis=2)
# Create a frame from the first three atoms:
# First atom: point on x-y-plane
# Second atom: point on negative x-axis
# Third atom: origin
# r3.Rigids (B, N, torsions=7)
torsion_frames = r3.rigids_from_3_points(
point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]),
origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]),
point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :]))
# Compute the position of the forth atom in this frame (y and z coordinate
# define the chi angle)
# r3.Vecs (B, N, torsions=7)
forth_atom_rel_pos = r3.rigids_mul_vecs(
r3.invert_rigids(torsion_frames),
r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :]))
# Normalize to have the sin and cos of the torsion angle.
# jnp.ndarray (B, N, torsions=7, sincos=2)
torsion_angles_sin_cos = jnp.stack(
[forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1)
torsion_angles_sin_cos /= jnp.sqrt(
jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True)
+ 1e-8)
# Mirror psi, because we computed it from the Oxygen-atom.
torsion_angles_sin_cos *= jnp.asarray(
[1., 1., -1., 1., 1., 1., 1.])[None, None, :, None]
# Create alternative angles for ambiguous atom names.
chi_is_ambiguous = utils.batched_gather(
jnp.asarray(residue_constants.chi_pi_periodic), aatype)
mirror_torsion_angles = jnp.concatenate(
[jnp.ones([num_batch, num_res, 3]),
1.0 - 2.0 * chi_is_ambiguous], axis=-1)
alt_torsion_angles_sin_cos = (
torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None])
if placeholder_for_undefined:
# Add placeholder torsions in place of undefined torsion angles
# (e.g. N-terminus pre-omega)
placeholder_torsions = jnp.stack([
jnp.ones(torsion_angles_sin_cos.shape[:-1]),
jnp.zeros(torsion_angles_sin_cos.shape[:-1])
], axis=-1)
torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[
..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[
..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
return {
'torsion_angles_sin_cos': torsion_angles_sin_cos, # (B, N, 7, 2)
'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos, # (B, N, 7, 2)
'torsion_angles_mask': torsion_angles_mask # (B, N, 7)
}
def torsion_angles_to_frames(
aatype: jnp.ndarray, # (N)
backb_to_global: r3.Rigids, # (N)
torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2)
) -> r3.Rigids: # (N, 8)
"""Compute rigid group frames from torsion angles.
Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" lines 2-10
Jumper et al. (2021) Suppl. Alg. 25 "makeRotX"
Args:
aatype: aatype for each residue
backb_to_global: Rigid transformations describing transformation from
backbone frame to global frame.
torsion_angles_sin_cos: sin and cosine of the 7 torsion angles
Returns:
Frames corresponding to all the Sidechain Rigid Transforms
"""
assert len(aatype.shape) == 1
assert len(backb_to_global.rot.xx.shape) == 1
assert len(torsion_angles_sin_cos.shape) == 3
assert torsion_angles_sin_cos.shape[1] == 7
assert torsion_angles_sin_cos.shape[2] == 2
# Gather the default frames for all rigid groups.
# r3.Rigids with shape (N, 8)
m = utils.batched_gather(residue_constants.restype_rigid_group_default_frame,
aatype)
default_frames = r3.rigids_from_tensor4x4(m)
# Create the rotation matrices according to the given angles (each frame is
# defined such that its rotation is around the x-axis).
sin_angles = torsion_angles_sin_cos[..., 0]
cos_angles = torsion_angles_sin_cos[..., 1]
# insert zero rotation for backbone group.
num_residues, = aatype.shape
sin_angles = jnp.concatenate([jnp.zeros([num_residues, 1]), sin_angles],
axis=-1)
cos_angles = jnp.concatenate([jnp.ones([num_residues, 1]), cos_angles],
axis=-1)
zeros = jnp.zeros_like(sin_angles)
ones = jnp.ones_like(sin_angles)
# all_rots are r3.Rots with shape (N, 8)
all_rots = r3.Rots(ones, zeros, zeros,
zeros, cos_angles, -sin_angles,
zeros, sin_angles, cos_angles)
# Apply rotations to the frames.
all_frames = r3.rigids_mul_rots(default_frames, all_rots)
# chi2, chi3, and chi4 frames do not transform to the backbone frame but to
# the previous frame. So chain them up accordingly.
chi2_frame_to_frame = jax.tree_map(lambda x: x[:, 5], all_frames)
chi3_frame_to_frame = jax.tree_map(lambda x: x[:, 6], all_frames)
chi4_frame_to_frame = jax.tree_map(lambda x: x[:, 7], all_frames)
chi1_frame_to_backb = jax.tree_map(lambda x: x[:, 4], all_frames)
chi2_frame_to_backb = r3.rigids_mul_rigids(chi1_frame_to_backb,
chi2_frame_to_frame)
chi3_frame_to_backb = r3.rigids_mul_rigids(chi2_frame_to_backb,
chi3_frame_to_frame)
chi4_frame_to_backb = r3.rigids_mul_rigids(chi3_frame_to_backb,
chi4_frame_to_frame)
# Recombine them to a r3.Rigids with shape (N, 8).
def _concat_frames(xall, x5, x6, x7):
return jnp.concatenate(
[xall[:, 0:5], x5[:, None], x6[:, None], x7[:, None]], axis=-1)
all_frames_to_backb = jax.tree_map(
_concat_frames,
all_frames,
chi2_frame_to_backb,
chi3_frame_to_backb,
chi4_frame_to_backb)
# Create the global frames.
# shape (N, 8)
all_frames_to_global = r3.rigids_mul_rigids(
jax.tree_map(lambda x: x[:, None], backb_to_global),
all_frames_to_backb)
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
aatype: jnp.ndarray, # (N)
all_frames_to_global: r3.Rigids # (N, 8)
) -> r3.Vecs: # (N, 14)
"""Put atom literature positions (atom14 encoding) in each rigid group.
Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11
Args:
aatype: aatype for each residue.
all_frames_to_global: All per residue coordinate frames.
Returns:
Positions of all atom coordinates in global frame.
"""
# Pick the appropriate transform for every atom.
residx_to_group_idx = utils.batched_gather(
residue_constants.restype_atom14_to_rigid_group, aatype)
group_mask = jax.nn.one_hot(
residx_to_group_idx, num_classes=8) # shape (N, 14, 8)
# r3.Rigids with shape (N, 14)
map_atoms_to_global = jax.tree_map(
lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1),
all_frames_to_global)
# Gather the literature atom positions for each residue.
# r3.Vecs with shape (N, 14)
lit_positions = r3.vecs_from_tensor(
utils.batched_gather(
residue_constants.restype_atom14_rigid_group_positions, aatype))
# Transform each atom from its local frame to the global frame.
# r3.Vecs with shape (N, 14)
pred_positions = r3.rigids_mul_vecs(map_atoms_to_global, lit_positions)
# Mask out non-existing atoms.
mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype)
pred_positions = jax.tree_map(lambda x: x * mask, pred_positions)
return pred_positions
def extreme_ca_ca_distance_violations(
pred_atom_positions: jnp.ndarray, # (N, 37(14), 3)
pred_atom_mask: jnp.ndarray, # (N, 37(14))
residue_index: jnp.ndarray, # (N)
max_angstrom_tolerance=1.5
) -> jnp.ndarray:
"""Counts residues whose Ca is a large distance from its neighbour.
Measures the fraction of CA-CA pairs between consecutive amino acids that are
more than 'max_angstrom_tolerance' apart.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
residue_index: Residue index for given amino acid, this is assumed to be
monotonically increasing.
max_angstrom_tolerance: Maximum distance allowed to not count as violation.
Returns:
Fraction of consecutive CA-CA pairs with violation.
"""
this_ca_pos = pred_atom_positions[:-1, 1, :] # (N - 1, 3)
this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1)
next_ca_pos = pred_atom_positions[1:, 1, :] # (N - 1, 3)
next_ca_mask = pred_atom_mask[1:, 1] # (N - 1)
has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype(
jnp.float32)
ca_ca_distance = jnp.sqrt(
1e-6 + jnp.sum(squared_difference(this_ca_pos, next_ca_pos), axis=-1))
violations = (ca_ca_distance -
residue_constants.ca_ca) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
return utils.mask_mean(mask=mask, value=violations)
def between_residue_bond_loss(
pred_atom_positions: jnp.ndarray, # (N, 37(14), 3)
pred_atom_mask: jnp.ndarray, # (N, 37(14))
residue_index: jnp.ndarray, # (N)
aatype: jnp.ndarray, # (N)
tolerance_factor_soft=12.0,
tolerance_factor_hard=12.0
) -> Dict[str, jnp.ndarray]:
"""Flat-bottom loss to penalize structural violations between residues.
This is a loss penalizing any violation of the geometry around the peptide
bond between consecutive amino acids. This loss corresponds to
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
residue_index: Residue index for given amino acid, this is assumed to be
monotonically increasing.
aatype: Amino acid type of given residue
tolerance_factor_soft: soft tolerance factor measured in standard deviations
of pdb distributions
tolerance_factor_hard: hard tolerance factor measured in standard deviations
of pdb distributions
Returns:
Dict containing:
* 'c_n_loss_mean': Loss for peptide bond length violations
* 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned
by CA, C, N
* 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned
by C, N, CA
* 'per_residue_loss_sum': sum of all losses for each residue
* 'per_residue_violation_mask': mask denoting all residues with violation
present.
"""
assert len(pred_atom_positions.shape) == 3
assert len(pred_atom_mask.shape) == 2
assert len(residue_index.shape) == 1
assert len(aatype.shape) == 1
# Get the positions of the relevant backbone atoms.
this_ca_pos = pred_atom_positions[:-1, 1, :] # (N - 1, 3)
this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1)
this_c_pos = pred_atom_positions[:-1, 2, :] # (N - 1, 3)
this_c_mask = pred_atom_mask[:-1, 2] # (N - 1)
next_n_pos = pred_atom_positions[1:, 0, :] # (N - 1, 3)
next_n_mask = pred_atom_mask[1:, 0] # (N - 1)
next_ca_pos = pred_atom_positions[1:, 1, :] # (N - 1, 3)
next_ca_mask = pred_atom_mask[1:, 1] # (N - 1)
has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype(
jnp.float32)
# Compute loss for the C--N bond.
c_n_bond_length = jnp.sqrt(
1e-6 + jnp.sum(squared_difference(this_c_pos, next_n_pos), axis=-1))
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline = (
aatype[1:] == residue_constants.resname_to_idx['PRO']).astype(jnp.float32)
gt_length = (
(1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0]
+ next_is_proline * residue_constants.between_res_bond_length_c_n[1])
gt_stddev = (
(1. - next_is_proline) *
residue_constants.between_res_bond_length_stddev_c_n[0] +
next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1])
c_n_bond_length_error = jnp.sqrt(1e-6 +
jnp.square(c_n_bond_length - gt_length))
c_n_loss_per_residue = jax.nn.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev)
mask = this_c_mask * next_n_mask * has_no_gap_mask
c_n_loss = jnp.sum(mask * c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6)
c_n_violation_mask = mask * (
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev))
# Compute loss for the angles.
ca_c_bond_length = jnp.sqrt(1e-6 + jnp.sum(
squared_difference(this_ca_pos, this_c_pos), axis=-1))
n_ca_bond_length = jnp.sqrt(1e-6 + jnp.sum(
squared_difference(next_n_pos, next_ca_pos), axis=-1))
c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[:, None]
c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[:, None]
n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[:, None]
ca_c_n_cos_angle = jnp.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1)
gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
ca_c_n_cos_angle_error = jnp.sqrt(
1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle))
ca_c_n_loss_per_residue = jax.nn.relu(
ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev)
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
ca_c_n_loss = jnp.sum(mask * ca_c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6)
ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error >
(tolerance_factor_hard * gt_stddev))
c_n_ca_cos_angle = jnp.sum((-c_n_unit_vec) * n_ca_unit_vec, axis=-1)
gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
c_n_ca_cos_angle_error = jnp.sqrt(
1e-6 + jnp.square(c_n_ca_cos_angle - gt_angle))
c_n_ca_loss_per_residue = jax.nn.relu(
c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev)
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
c_n_ca_loss = jnp.sum(mask * c_n_ca_loss_per_residue) / (jnp.sum(mask) + 1e-6)
c_n_ca_violation_mask = mask * (
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev))
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum = (c_n_loss_per_residue +
ca_c_n_loss_per_residue +
c_n_ca_loss_per_residue)
per_residue_loss_sum = 0.5 * (jnp.pad(per_residue_loss_sum, [[0, 1]]) +
jnp.pad(per_residue_loss_sum, [[1, 0]]))
# Compute hard violations.
violation_mask = jnp.max(
jnp.stack([c_n_violation_mask,
ca_c_n_violation_mask,
c_n_ca_violation_mask]), axis=0)
violation_mask = jnp.maximum(
jnp.pad(violation_mask, [[0, 1]]),
jnp.pad(violation_mask, [[1, 0]]))
return {'c_n_loss_mean': c_n_loss, # shape ()
'ca_c_n_loss_mean': ca_c_n_loss, # shape ()
'c_n_ca_loss_mean': c_n_ca_loss, # shape ()
'per_residue_loss_sum': per_residue_loss_sum, # shape (N)
'per_residue_violation_mask': violation_mask # shape (N)
}
def between_residue_clash_loss(
atom14_pred_positions: jnp.ndarray, # (N, 14, 3)
atom14_atom_exists: jnp.ndarray, # (N, 14)
atom14_atom_radius: jnp.ndarray, # (N, 14)
residue_index: jnp.ndarray, # (N)
overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5
) -> Dict[str, jnp.ndarray]:
"""Loss to penalize steric clashes between residues.
This is a loss penalizing any steric clashes due to non bonded atoms in
different peptides coming too close. This loss corresponds to the part with
different residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
atom14_atom_exists: Mask denoting whether atom at positions exists for given
amino acid type
atom14_atom_radius: Van der Waals radius for each atom.
residue_index: Residue index for given amino acid.
overlap_tolerance_soft: Soft tolerance factor.
overlap_tolerance_hard: Hard tolerance factor.
Returns:
Dict containing:
* 'mean_loss': average clash loss
* 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14)
* 'per_atom_clash_mask': mask whether atom clashes with any other atom
shape (N, 14)
"""
assert len(atom14_pred_positions.shape) == 3
assert len(atom14_atom_exists.shape) == 2
assert len(atom14_atom_radius.shape) == 2
assert len(residue_index.shape) == 1
# Create the distance matrix.
# (N, N, 14, 14)
dists = jnp.sqrt(1e-10 + jnp.sum(
squared_difference(
atom14_pred_positions[:, None, :, None, :],
atom14_pred_positions[None, :, None, :, :]),
axis=-1))
# Create the mask for valid distances.
# shape (N, N, 14, 14)
dists_mask = (atom14_atom_exists[:, None, :, None] *
atom14_atom_exists[None, :, None, :])
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask *= (
residue_index[:, None, None, None] < residue_index[None, :, None, None])
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot = jax.nn.one_hot(2, num_classes=14)
n_one_hot = jax.nn.one_hot(0, num_classes=14)
neighbour_mask = ((residue_index[:, None, None, None] +
1) == residue_index[None, :, None, None])
c_n_bonds = neighbour_mask * c_one_hot[None, None, :,
None] * n_one_hot[None, None, None, :]
dists_mask *= (1. - c_n_bonds)
# Disulfide bridge between two cysteines is no clash.
cys_sg_idx = residue_constants.restype_name_to_atom14_names['CYS'].index('SG')
cys_sg_one_hot = jax.nn.one_hot(cys_sg_idx, num_classes=14)
disulfide_bonds = (cys_sg_one_hot[None, None, :, None] *
cys_sg_one_hot[None, None, None, :])
dists_mask *= (1. - disulfide_bonds)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
dists_lower_bound = dists_mask * (atom14_atom_radius[:, None, :, None] +
atom14_atom_radius[None, :, None, :])
# Compute the error.
# shape (N, N, 14, 14)
dists_to_low_error = dists_mask * jax.nn.relu(
dists_lower_bound - overlap_tolerance_soft - dists)
# Compute the mean loss.
# shape ()
mean_loss = (jnp.sum(dists_to_low_error)
/ (1e-6 + jnp.sum(dists_mask)))
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = (jnp.sum(dists_to_low_error, axis=[0, 2]) +
jnp.sum(dists_to_low_error, axis=[1, 3]))
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask = dists_mask * (
dists < (dists_lower_bound - overlap_tolerance_hard))
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask = jnp.maximum(
jnp.max(clash_mask, axis=[0, 2]),
jnp.max(clash_mask, axis=[1, 3]))
return {'mean_loss': mean_loss, # shape ()
'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14)
'per_atom_clash_mask': per_atom_clash_mask # shape (N, 14)
}
def within_residue_violations(
atom14_pred_positions: jnp.ndarray, # (N, 14, 3)
atom14_atom_exists: jnp.ndarray, # (N, 14)
atom14_dists_lower_bound: jnp.ndarray, # (N, 14, 14)
atom14_dists_upper_bound: jnp.ndarray, # (N, 14, 14)
tighten_bounds_for_loss=0.0,
) -> Dict[str, jnp.ndarray]:
"""Loss to penalize steric clashes within residues.
This is a loss penalizing any steric violations or clashes of non-bonded atoms
in a given peptide. This loss corresponds to the part with
the same residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
atom14_atom_exists: Mask denoting whether atom at positions exists for given
amino acid type
atom14_dists_lower_bound: Lower bound on allowed distances.
atom14_dists_upper_bound: Upper bound on allowed distances
tighten_bounds_for_loss: Extra factor to tighten loss
Returns:
Dict containing:
* 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14)
* 'per_atom_clash_mask': mask whether atom clashes with any other atom
shape (N, 14)
"""
assert len(atom14_pred_positions.shape) == 3
assert len(atom14_atom_exists.shape) == 2
assert len(atom14_dists_lower_bound.shape) == 3
assert len(atom14_dists_upper_bound.shape) == 3
# Compute the mask for each residue.
# shape (N, 14, 14)
dists_masks = (1. - jnp.eye(14, 14)[None])
dists_masks *= (atom14_atom_exists[:, :, None] *
atom14_atom_exists[:, None, :])
# Distance matrix
# shape (N, 14, 14)
dists = jnp.sqrt(1e-10 + jnp.sum(
squared_difference(
atom14_pred_positions[:, :, None, :],
atom14_pred_positions[:, None, :, :]),
axis=-1))
# Compute the loss.
# shape (N, 14, 14)
dists_to_low_error = jax.nn.relu(
atom14_dists_lower_bound + tighten_bounds_for_loss - dists)
dists_to_high_error = jax.nn.relu(
dists - (atom14_dists_upper_bound - tighten_bounds_for_loss))
loss = dists_masks * (dists_to_low_error + dists_to_high_error)
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = (jnp.sum(loss, axis=1) +
jnp.sum(loss, axis=2))
# Compute the violations mask.
# shape (N, 14, 14)
violations = dists_masks * ((dists < atom14_dists_lower_bound) |
(dists > atom14_dists_upper_bound))
# Compute the per atom violations.
# shape (N, 14)
per_atom_violations = jnp.maximum(
jnp.max(violations, axis=1), jnp.max(violations, axis=2))
return {'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14)
'per_atom_violations': per_atom_violations # shape (N, 14)
}
def find_optimal_renaming(
atom14_gt_positions: jnp.ndarray, # (N, 14, 3)
atom14_alt_gt_positions: jnp.ndarray, # (N, 14, 3)
atom14_atom_is_ambiguous: jnp.ndarray, # (N, 14)
atom14_gt_exists: jnp.ndarray, # (N, 14)
atom14_pred_positions: jnp.ndarray, # (N, 14, 3)
atom14_atom_exists: jnp.ndarray, # (N, 14)
) -> jnp.ndarray: # (N):
"""Find optimal renaming for ground truth that maximizes LDDT.
Jumper et al. (2021) Suppl. Alg. 26
"renameSymmetricGroundTruthAtoms" lines 1-5
Args:
atom14_gt_positions: Ground truth positions in global frame of ground truth.
atom14_alt_gt_positions: Alternate ground truth positions in global frame of
ground truth with coordinates of ambiguous atoms swapped relative to
'atom14_gt_positions'.
atom14_atom_is_ambiguous: Mask denoting whether atom is among ambiguous
atoms, see Jumper et al. (2021) Suppl. Table 3
atom14_gt_exists: Mask denoting whether atom at positions exists in ground
truth.
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
atom14_atom_exists: Mask denoting whether atom at positions exists for given
amino acid type
Returns:
Float array of shape [N] with 1. where atom14_alt_gt_positions is closer to
prediction and 0. otherwise
"""
assert len(atom14_gt_positions.shape) == 3
assert len(atom14_alt_gt_positions.shape) == 3
assert len(atom14_atom_is_ambiguous.shape) == 2
assert len(atom14_gt_exists.shape) == 2
assert len(atom14_pred_positions.shape) == 3
assert len(atom14_atom_exists.shape) == 2
# Create the pred distance matrix.
# shape (N, N, 14, 14)
pred_dists = jnp.sqrt(1e-10 + jnp.sum(
squared_difference(
atom14_pred_positions[:, None, :, None, :],
atom14_pred_positions[None, :, None, :, :]),
axis=-1))
# Compute distances for ground truth with original and alternative names.
# shape (N, N, 14, 14)
gt_dists = jnp.sqrt(1e-10 + jnp.sum(
squared_difference(
atom14_gt_positions[:, None, :, None, :],
atom14_gt_positions[None, :, None, :, :]),
axis=-1))
alt_gt_dists = jnp.sqrt(1e-10 + jnp.sum(
squared_difference(
atom14_alt_gt_positions[:, None, :, None, :],
atom14_alt_gt_positions[None, :, None, :, :]),
axis=-1))
# Compute LDDT's.
# shape (N, N, 14, 14)
lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, gt_dists))
alt_lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, alt_gt_dists))
# Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms
# in cols.
# shape (N ,N, 14, 14)
mask = (atom14_gt_exists[:, None, :, None] * # rows
atom14_atom_is_ambiguous[:, None, :, None] * # rows
atom14_gt_exists[None, :, None, :] * # cols
(1. - atom14_atom_is_ambiguous[None, :, None, :])) # cols
# Aggregate distances for each residue to the non-amibuguous atoms.
# shape (N)
per_res_lddt = jnp.sum(mask * lddt, axis=[1, 2, 3])
alt_per_res_lddt = jnp.sum(mask * alt_lddt, axis=[1, 2, 3])
# Decide for each residue, whether alternative naming is better.
# shape (N)
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).astype(jnp.float32)
return alt_naming_is_better # shape (N)
def frame_aligned_point_error(
pred_frames: r3.Rigids, # shape (num_frames)
target_frames: r3.Rigids, # shape (num_frames)
frames_mask: jnp.ndarray, # shape (num_frames)
pred_positions: r3.Vecs, # shape (num_positions)
target_positions: r3.Vecs, # shape (num_positions)
positions_mask: jnp.ndarray, # shape (num_positions)
length_scale: float,
l1_clamp_distance: Optional[float] = None,
epsilon=1e-4) -> jnp.ndarray: # shape ()
"""Measure point error under different alignments.
Jumper et al. (2021) Suppl. Alg. 28 "computeFAPE"
Computes error between two structures with B points under A alignments derived
from the given pairs of frames.
Args:
pred_frames: num_frames reference frames for 'pred_positions'.
target_frames: num_frames reference frames for 'target_positions'.
frames_mask: Mask for frame pairs to use.
pred_positions: num_positions predicted positions of the structure.
target_positions: num_positions target positions of the structure.
positions_mask: Mask on which positions to score.
length_scale: length scale to divide loss by.
l1_clamp_distance: Distance cutoff on error beyond which gradients will
be zero.
epsilon: small value used to regularize denominator for masked average.
Returns:
Masked Frame Aligned Point Error.
"""
assert pred_frames.rot.xx.ndim == 1
assert target_frames.rot.xx.ndim == 1
assert frames_mask.ndim == 1, frames_mask.ndim
assert pred_positions.x.ndim == 1
assert target_positions.x.ndim == 1
assert positions_mask.ndim == 1
# Compute array of predicted positions in the predicted frames.
# r3.Vecs (num_frames, num_positions)
local_pred_pos = r3.rigids_mul_vecs(
jax.tree_map(lambda r: r[:, None], r3.invert_rigids(pred_frames)),
jax.tree_map(lambda x: x[None, :], pred_positions))
# Compute array of target positions in the target frames.
# r3.Vecs (num_frames, num_positions)
local_target_pos = r3.rigids_mul_vecs(
jax.tree_map(lambda r: r[:, None], r3.invert_rigids(target_frames)),
jax.tree_map(lambda x: x[None, :], target_positions))
# Compute errors between the structures.
# jnp.ndarray (num_frames, num_positions)
error_dist = jnp.sqrt(
r3.vecs_squared_distance(local_pred_pos, local_target_pos)
+ epsilon)
if l1_clamp_distance:
error_dist = jnp.clip(error_dist, 0, l1_clamp_distance)
normed_error = error_dist / length_scale
normed_error *= jnp.expand_dims(frames_mask, axis=-1)
normed_error *= jnp.expand_dims(positions_mask, axis=-2)
normalization_factor = (
jnp.sum(frames_mask, axis=-1) *
jnp.sum(positions_mask, axis=-1))
return (jnp.sum(normed_error, axis=(-2, -1)) /
(epsilon + normalization_factor))
def _make_renaming_matrices():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3 = [
residue_constants.restype_1to3[res] for res in residue_constants.restypes
]
restype_3 += ['UNK']
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = residue_constants.restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = residue_constants.restype_name_to_atom14_names[
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
return renaming_matrices
RENAMING_MATRICES = _make_renaming_matrices()
def get_alt_atom14(aatype, positions, mask):
"""Get alternative atom14 positions.
Constructs renamed atom positions for ambiguous residues.
Jumper et al. (2021) Suppl. Table 3 "Ambiguous atom names due to 180 degree-
rotation-symmetry"
Args:
aatype: Amino acid at given position
positions: Atom positions as r3.Vecs in atom14 representation, (N, 14)
mask: Atom masks in atom14 representation, (N, 14)
Returns:
renamed atom positions, renamed atom mask
"""
# pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14)
renaming_transform = utils.batched_gather(
jnp.asarray(RENAMING_MATRICES), aatype)
positions = jax.tree_map(lambda x: x[:, :, None], positions)
alternative_positions = jax.tree_map(
lambda x: jnp.sum(x, axis=1), positions * renaming_transform)
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position)
alternative_mask = jnp.sum(mask[..., None] * renaming_transform, axis=1)
return alternative_positions, alternative_mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment