Commit 4f662f83 authored by Jose Duarte's avatar Jose Duarte
Browse files

Now should work for multi-chain. But it is untested.

Also fixed important bug: atom names were padded
parent ba8de687
...@@ -403,7 +403,6 @@ def to_modelcif(prot: Protein) -> str: ...@@ -403,7 +403,6 @@ def to_modelcif(prot: Protein) -> str:
""" """
restypes = residue_constants.restypes + ["X"] restypes = residue_constants.restypes + ["X"]
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types atom_types = residue_constants.atom_types
atom_mask = prot.atom_mask atom_mask = prot.atom_mask
...@@ -413,71 +412,67 @@ def to_modelcif(prot: Protein) -> str: ...@@ -413,71 +412,67 @@ def to_modelcif(prot: Protein) -> str:
b_factors = prot.b_factors b_factors = prot.b_factors
chain_index = prot.chain_index chain_index = prot.chain_index
n = aatype.shape[0]
if chain_index is None:
chain_index = [0 for i in range(n)]
system = modelcif.System(title='OpenFold prediction') system = modelcif.System(title='OpenFold prediction')
# sequence into entity # Finding chains and creating entities
n = aatype.shape[0] seqs = {}
seq = [restypes[aatype[i]] for i in range(n)] seq = []
model_e = modelcif.Entity("".join(seq), description='Model subunit') last_chain_idx = None
for i in range(n):
if last_chain_idx is not None and last_chain_idx != chain_index[i]:
seqs[last_chain_idx] = seq
seq = []
seq.append(restypes[aatype[i]])
last_chain_idx = chain_index[i]
# finally add the last chain
if last_chain_idx not in seqs:
seqs[last_chain_idx] = seq
# now reduce sequences to unique ones (note this won't work if different asyms have different unmodelled regions)
unique_seqs = {}
for chain_idx in seqs.keys():
seq = "".join(seqs[chain_idx])
if seq in unique_seqs:
unique_seqs[seq].append(chain_idx)
else:
unique_seqs[seq] = [chain_idx]
# adding 1 entity per unique sequence
entities_map = {}
for key, value in unique_seqs.items():
model_e = modelcif.Entity("".join(key), description='Model subunit')
for chain_idx in value:
entities_map[chain_idx] = model_e
chain_tags = string.ascii_uppercase
asym_unit_map = {}
for chain_idx in set(chain_index):
# Define the model assembly # Define the model assembly
asymA = modelcif.AsymUnit(model_e, details='Model subunit A', id='A') chain_id = chain_tags[chain_idx]
modeled_assembly = modelcif.Assembly((asymA,), name='Modeled assembly') asym = modelcif.AsymUnit(entities_map[chain_idx], details='Model subunit %s' % chain_id, id=chain_id)
asym_unit_map[chain_idx] = asym
modeled_assembly = modelcif.Assembly(asym_unit_map.values(), name='Modeled assembly')
class MyModel(modelcif.model.HomologyModel): class MyModel(modelcif.model.HomologyModel):
asym_unit_map = {'A': asymA}
def get_atoms(self): def get_atoms(self):
prev_chain_index = 0
chain_tags = string.ascii_uppercase
# Add all atom sites. # Add all atom sites.
for i in range(n): for i in range(n):
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip( for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i] atom_types, atom_positions[i], atom_mask[i], b_factors[i]
): ):
if mask < 0.5: if mask < 0.5:
continue continue
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
element = atom_name[0] # Protein supports only C, N, O, S, this works. element = atom_name[0] # Protein supports only C, N, O, S, this works.
chain_tag = "A"
if chain_index is not None:
chain_tag = chain_tags[chain_index[i]]
# TODO check that residue indices are ok
# TODO how to set residue types? do they come from the entity sequence set above?
yield modelcif.model.Atom( yield modelcif.model.Atom(
asym_unit=asymA, type_symbol=element, asym_unit=asym_unit_map[chain_index[i]], type_symbol=element,
seq_id=residue_index[i], atom_id=name, seq_id=residue_index[i], atom_id=atom_name,
x=pos[0], y=pos[1], z=pos[2], x=pos[0], y=pos[1], z=pos[2],
het=False, biso=b_factor, occupancy=1.00) het=False, biso=b_factor, occupancy=1.00)
# TODO multiple chains
# should_terminate = (i == n - 1)
# if(chain_index is not None):
# if(i != n - 1 and chain_index[i + 1] != prev_chain_index):
# should_terminate = True
# prev_chain_index = chain_index[i + 1]
#
# if(should_terminate):
# # Close the chain.
# chain_end = "TER"
# chain_termination_line = (
# f"{chain_end:<6}{atom_index:>5} "
# f"{res_1to3(aatype[i]):>3} "
# f"{chain_tag:>1}{residue_index[i]:>4}"
# )
# pdb_lines.append(chain_termination_line)
# atom_index += 1
#
# if(i != n - 1):
# # "prev" is a misnomer here. This happens at the beginning of
# # each new chain.
# pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
# Add the model and modeling protocol to the file and write them out: # Add the model and modeling protocol to the file and write them out:
model = MyModel(assembly=modeled_assembly, name='Best scoring model') model = MyModel(assembly=modeled_assembly, name='Best scoring model')
...@@ -548,9 +543,15 @@ def from_prediction( ...@@ -548,9 +543,15 @@ def from_prediction(
if __name__ == "__main__": if __name__ == "__main__":
with open('/home/jose/Downloads/171l.pdb', 'r') as file: pdb_file = '/home/jose/Downloads/171l.pdb'
# pdb_file = '/home/jose/Downloads/2trx.pdb'
cif_file = '/home/jose/test.cif'
with open(pdb_file, 'r') as file:
pdbstr = file.read() pdbstr = file.read()
prot = from_pdb_string(pdbstr) prot = from_pdb_string(pdbstr, "A")
cifstr = to_modelcif(prot) cifstr = to_modelcif(prot)
print(cifstr) print(cifstr)
with open(cif_file, 'w') as fw:
fw.write(cifstr)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment