Commit e4119508 authored by zhuww's avatar zhuww
Browse files

fix multimer bug

parent 614e2763
...@@ -23,10 +23,10 @@ FastFold provides a **high-performance implementation of Evoformer** with the fo ...@@ -23,10 +23,10 @@ FastFold provides a **high-performance implementation of Evoformer** with the fo
## Installation ## Installation
To install and use FastFold, you will need: To install FastFold, you will need:
+ Python 3.8 or 3.9. + Python 3.8 or 3.9.
+ [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.1 or above + [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.1 or above
+ PyTorch 1.10 or above + PyTorch 1.12 or above
For now, You can install FastFold: For now, You can install FastFold:
...@@ -45,14 +45,10 @@ python setup.py install ...@@ -45,14 +45,10 @@ python setup.py install
#### Advanced #### Advanced
To leverage the power of FastFold, we recommend you build [Triton]() from source. To leverage the power of FastFold, we recommend you to install [Triton](https://github.com/openai/triton).
**[NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.4 or above is needed.**
```bash ```bash
git clone https://github.com/openai/triton.git ~/triton pip install triton==2.0.0.dev20221005
cd ~/triton/python
pip install -e .
``` ```
......
...@@ -26,6 +26,9 @@ FeatureDict = Mapping[str, np.ndarray] ...@@ -26,6 +26,9 @@ FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict. ModelOutput = Mapping[str, Any] # Is a nested dict.
PICO_TO_ANGSTROM = 0.01 PICO_TO_ANGSTROM = 0.01
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS)
assert(PDB_MAX_CHAINS == 62)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Protein: class Protein:
...@@ -45,11 +48,22 @@ class Protein: ...@@ -45,11 +48,22 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed. # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res] residue_index: np.ndarray # [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units), # B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean # representing the displacement of the residue from its ground truth mean
# value. # value.
b_factors: np.ndarray # [num_res, num_atom_type] b_factors: np.ndarray # [num_res, num_atom_type]
def __post_init__(self):
if(len(np.unique(self.chain_index)) > PDB_MAX_CHAINS):
raise ValueError(
f"Cannot build an instance with more than {PDB_MAX_CHAINS} "
"chains because these cannot be written to PDB format"
)
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
...@@ -60,9 +74,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: ...@@ -60,9 +74,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
Args: Args:
pdb_str: The contents of the pdb file pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which chain_id: If chain_id is specified (e.g. A), then only that chain is
will be parsed). If chain_id is specified (e.g. A), then only that chain parsed. Else, all chains are parsed.
is parsed.
Returns: Returns:
A new `Protein` parsed from the pdb contents. A new `Protein` parsed from the pdb contents.
...@@ -72,62 +85,75 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: ...@@ -72,62 +85,75 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
structure = parser.get_structure("none", pdb_fh) structure = parser.get_structure("none", pdb_fh)
models = list(structure.get_models()) models = list(structure.get_models())
if len(models) != 1: if len(models) != 1:
raise ValueError(f"Only single model PDBs are supported. Found {len(models)} models.") raise ValueError(
f"Only single model PDBs are supported. Found {len(models)} models."
)
model = models[0] model = models[0]
if chain_id is not None:
chain = model[chain_id]
else:
chains = list(model.get_chains())
if len(chains) != 1:
raise ValueError("Only single chain PDBs are supported when chain_id not specified. "
f"Found {len(chains)} chains.")
else:
chain = chains[0]
atom_positions = [] atom_positions = []
aatype = [] aatype = []
atom_mask = [] atom_mask = []
residue_index = [] residue_index = []
chain_ids = []
b_factors = [] b_factors = []
for chain in model:
if(chain_id is not None and chain.id != chain_id):
continue
for res in chain: for res in chain:
if res.id[2] != " ": if res.id[2] != " ":
raise ValueError(f"PDB contains an insertion code at chain {chain.id} and residue " raise ValueError(
f"index {res.id[1]}. These are not supported.") f"PDB contains an insertion code at chain {chain.id} and residue "
res_shortname = residue_constants.restype_3to1.get(res.resname, "X") f"index {res.id[1]}. These are not supported."
restype_idx = residue_constants.restype_order.get(res_shortname, )
residue_constants.restype_num) res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
pos = np.zeros((residue_constants.atom_type_num, 3)) restype_idx = residue_constants.restype_order.get(
mask = np.zeros((residue_constants.atom_type_num,)) res_shortname, residue_constants.restype_num
res_b_factors = np.zeros((residue_constants.atom_type_num,)) )
for atom in res: pos = np.zeros((residue_constants.atom_type_num, 3))
if atom.name not in residue_constants.atom_types: mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[
residue_constants.atom_order[atom.name]
] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0 aatype.append(restype_idx)
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor atom_positions.append(pos)
if np.sum(mask) < 0.5: atom_mask.append(mask)
# If no known atom positions are reported for the residue then skip it. residue_index.append(res.id[1])
continue chain_ids.append(chain.id)
aatype.append(restype_idx) b_factors.append(res_b_factors)
atom_positions.append(pos)
atom_mask.append(mask) # Chain IDs are usually characters so map these to ints
residue_index.append(res.id[1]) unique_chain_ids = np.unique(chain_ids)
b_factors.append(res_b_factors) chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein( return Protein(
atom_positions=np.array(atom_positions), atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask), atom_mask=np.array(atom_mask),
aatype=np.array(aatype), aatype=np.array(aatype),
residue_index=np.array(residue_index), residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors), b_factors=np.array(b_factors),
) )
def from_proteinnet_string(proteinnet_str: str) -> Protein: def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r'(\[[A-Z]+\]\n)' tag_re = r'(\[[A-Z]+\]\n)'
tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0] tags = [
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
]
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]]) groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
atoms = ['N', 'CA', 'C'] atoms = ['N', 'CA', 'C']
...@@ -141,26 +167,28 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein: ...@@ -141,26 +167,28 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
if (seq[i] not in residue_constants.restypes): if (seq[i] not in residue_constants.restypes):
seq[i] = 'X' seq[i] = 'X'
aatype = np.array([ aatype = np.array([
residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) residue_constants.restype_order.get(
for res_symbol in seq res_symbol, residue_constants.restype_num
) for res_symbol in seq
]) ])
elif ("[TERTIARY]" == g[0]): elif("[TERTIARY]" == g[0]):
tertiary = [] tertiary = []
for axis in range(3): for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split()))) tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary) tertiary_np = np.array(tertiary)
atom_positions = np.zeros( atom_positions = np.zeros(
(len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)).astype(np.float32) (len(tertiary[0])//3, residue_constants.atom_type_num, 3)
).astype(np.float32)
for i, atom in enumerate(atoms): for i, atom in enumerate(atoms):
atom_positions[:, residue_constants.atom_order[atom], :] = (np.transpose( atom_positions[:, residue_constants.atom_order[atom], :] = (
tertiary_np[:, i::3])) np.transpose(tertiary_np[:, i::3])
)
atom_positions *= PICO_TO_ANGSTROM atom_positions *= PICO_TO_ANGSTROM
elif ("[MASK]" == g[0]): elif("[MASK]" == g[0]):
mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip()))) mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
atom_mask = np.zeros(( atom_mask = np.zeros(
len(mask), (len(mask), residue_constants.atom_type_num,)
residue_constants.atom_type_num, ).astype(np.float32)
)).astype(np.float32)
for i, atom in enumerate(atoms): for i, atom in enumerate(atoms):
atom_mask[:, residue_constants.atom_order[atom]] = 1 atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None] atom_mask *= mask[..., None]
...@@ -174,6 +202,14 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein: ...@@ -174,6 +202,14 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
) )
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return(
f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
f'{chain_name:>1}{residue_index:>4}'
)
def to_pdb(prot: Protein) -> str: def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string. """Converts a `Protein` instance to a PDB string.
...@@ -193,19 +229,43 @@ def to_pdb(prot: Protein) -> str: ...@@ -193,19 +229,43 @@ def to_pdb(prot: Protein) -> str:
aatype = prot.aatype aatype = prot.aatype
atom_positions = prot.atom_positions atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32) residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num): if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.") raise ValueError("Invalid aatypes.")
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f"The PDB format supports at most {PDB_MAX_CHAINS} chains."
)
chain_ids[i] = PDB_CHAIN_IDS[i]
pdb_lines.append("MODEL 1") pdb_lines.append("MODEL 1")
atom_index = 1 atom_index = 1
chain_id = "A" last_chain_index = chain_index[0]
# Add all atom sites. # Add all atom sites.
for i in range(aatype.shape[0]): for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[i - 1]),
chain_ids[chain_index[i - 1]],
residue_index[i - 1]
)
)
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.
res_name_3 = res_1to3(aatype[i]) res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(atom_types, atom_positions[i], atom_mask[i], for atom_name, pos, mask, b_factor in zip(
b_factors[i]): atom_types, atom_positions[i], atom_mask[i], b_factors[i]
):
if mask < 0.5: if mask < 0.5:
continue continue
...@@ -214,28 +274,38 @@ def to_pdb(prot: Protein) -> str: ...@@ -214,28 +274,38 @@ def to_pdb(prot: Protein) -> str:
alt_loc = "" alt_loc = ""
insertion_code = "" insertion_code = ""
occupancy = 1.00 occupancy = 1.00
element = atom_name[0] # Protein supports only C, N, O, S, this works. element = atom_name[
0
] # Protein supports only C, N, O, S, this works.
charge = "" charge = ""
# PDB is a columnar format, every space matters here! # PDB is a columnar format, every space matters here!
atom_line = (f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" atom_line = (
f"{res_name_3:>3} {chain_id:>1}" f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} " f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" f"{residue_index[i]:>4}{insertion_code:>1} "
f"{occupancy:>6.2f}{b_factor:>6.2f} " f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{element:>2}{charge:>2}") f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}"
)
pdb_lines.append(atom_line) pdb_lines.append(atom_line)
atom_index += 1 atom_index += 1
# Close the chain. # Close the final chain.
chain_end = "TER" pdb_lines.append(
chain_termination_line = (f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} " _chain_end(
f"{chain_id:>1}{residue_index[-1]:>4}") atom_index,
pdb_lines.append(chain_termination_line) res_1to3(aatype[-1]),
chain_ids[chain_index[-1]],
residue_index[-1]
)
)
pdb_lines.append("ENDMDL") pdb_lines.append("ENDMDL")
pdb_lines.append("END") pdb_lines.append("END")
pdb_lines.append("")
return "\n".join(pdb_lines) # Pad all lines to 80 characters
pdb_lines = [line.ljust(80) for line in pdb_lines]
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
def ideal_atom_mask(prot: Protein) -> np.ndarray: def ideal_atom_mask(prot: Protein) -> np.ndarray:
...@@ -258,24 +328,36 @@ def from_prediction( ...@@ -258,24 +328,36 @@ def from_prediction(
features: FeatureDict, features: FeatureDict,
result: ModelOutput, result: ModelOutput,
b_factors: Optional[np.ndarray] = None, b_factors: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = False,
) -> Protein: ) -> Protein:
"""Assembles a protein from a prediction. """Assembles a protein from a prediction.
Args: Args:
features: Dictionary holding model inputs. features: Dictionary holding model inputs.
result: Dictionary holding model outputs. result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein. b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
Returns: Returns:
A protein instance. A protein instance.
""" """
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
return arr[0] if remove_leading_feature_dimension else arr
if 'asym_id' in features:
chain_index = _maybe_remove_leading_dim(features["asym_id"])
else:
chain_index = np.zeros_like(
_maybe_remove_leading_dim(features["aatype"])
)
if b_factors is None: if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"]) b_factors = np.zeros_like(result["final_atom_mask"])
return Protein( return Protein(
aatype=features["aatype"], aatype=_maybe_remove_leading_dim(features["aatype"]),
atom_positions=result["final_atom_positions"], atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"], atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1, residue_index=_maybe_remove_leading_dim(features["residue_index"]) + 1,
chain_index=chain_index,
b_factors=b_factors, b_factors=b_factors,
) )
\ No newline at end of file
...@@ -249,7 +249,7 @@ def run_msa_tool( ...@@ -249,7 +249,7 @@ def run_msa_tool(
max_sto_sequences: Optional[int] = None, max_sto_sequences: Optional[int] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
"""Runs an MSA tool, checking if output already exists first.""" """Runs an MSA tool, checking if output already exists first."""
if(msa_format == "sto" and max_sto_sequences is not None): if(msa_format == "sto"):
result = msa_runner.query(fasta_path, max_sto_sequences)[0] result = msa_runner.query(fasta_path, max_sto_sequences)[0]
else: else:
result = msa_runner.query(fasta_path) result = msa_runner.query(fasta_path)
......
...@@ -866,7 +866,10 @@ def _process_single_hit( ...@@ -866,7 +866,10 @@ def _process_single_hit(
kalign_binary_path=kalign_binary_path, kalign_binary_path=kalign_binary_path,
_zero_center_positions=_zero_center_positions, _zero_center_positions=_zero_center_positions,
) )
features["template_sum_probs"] = [hit.sum_probs] if hit.sum_probs is None:
features['template_sum_probs'] = [0]
else:
features["template_sum_probs"] = [hit.sum_probs]
# It is possible there were some errors when parsing the other chains in the # It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still # mmCIF file, but the template features for the chain we want were still
......
...@@ -70,7 +70,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor: ...@@ -70,7 +70,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
return output return output
def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor: def _chunk_gather(tensor: Tensor, dim=-1, chunk_size=1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1: if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor return tensor
...@@ -82,12 +82,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor: ...@@ -82,12 +82,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
world_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1) world_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1)
tensor_list = [] tensor_list = []
for t in world_list: for t in world_list:
tensor_list.extend(t.chunk(chunks, dim=1)) tensor_list.extend(t.chunk(chunk_size, dim=1))
chunk_tensor = tensor.chunk(chunks, dim=1) chunk_tensor = tensor.chunk(chunk_size, dim=1)
for i in range(chunks): for i in range(chunk_size):
_chunk_list = [tensor_list[j*chunks+i] for j in range(gpc.get_world_size(ParallelMode.TENSOR))] _chunk_list = [tensor_list[j*chunk_size+i] for j in range(gpc.get_world_size(ParallelMode.TENSOR))]
_chunk_tensor = chunk_tensor[i] _chunk_tensor = chunk_tensor[i]
dist.all_gather(list(_chunk_list), dist.all_gather(list(_chunk_list),
...@@ -103,12 +103,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor: ...@@ -103,12 +103,12 @@ def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
world_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=0) world_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=0)
tensor_list = [] tensor_list = []
for t in world_list: for t in world_list:
tensor_list.extend(t.chunk(chunks, dim=0)) tensor_list.extend(t.chunk(chunk_size, dim=0))
chunk_tensor = tensor.chunk(chunks, dim=0) chunk_tensor = tensor.chunk(chunk_size, dim=0)
for i in range(chunks): for i in range(chunk_size):
_chunk_list = [tensor_list[j*chunks+i] for j in range(gpc.get_world_size(ParallelMode.TENSOR))] _chunk_list = [tensor_list[j*chunk_size+i] for j in range(gpc.get_world_size(ParallelMode.TENSOR))]
_chunk_tensor = chunk_tensor[i] _chunk_tensor = chunk_tensor[i]
dist.all_gather(list(_chunk_list), dist.all_gather(list(_chunk_list),
...@@ -176,14 +176,14 @@ class Reduce(torch.autograd.Function): ...@@ -176,14 +176,14 @@ class Reduce(torch.autograd.Function):
return grad_output return grad_output
def gather(input: Tensor, dim: int = -1, chunks: int = None) -> Tensor: def gather(input: Tensor, dim: int = -1, chunk_size: int = None) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad: if torch.is_grad_enabled() and input.requires_grad:
input = Gather.apply(input, dim) input = Gather.apply(input, dim)
else: else:
if chunks is None: if chunk_size is None:
input = _gather(input, dim=dim) input = _gather(input, dim=dim)
else: else:
input = _chunk_gather(input, dim=dim, chunks=chunks) input = _chunk_gather(input, dim=dim, chunk_size=chunk_size)
return input return input
......
...@@ -300,6 +300,7 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -300,6 +300,7 @@ class TemplateEmbedderMultimer(nn.Module):
): ):
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
template_pair_embeddings = torch.zeros((z.shape[0], z.shape[1], 64), dtype=z.dtype, device=z.device)
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
...@@ -336,7 +337,7 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -336,7 +337,7 @@ class TemplateEmbedderMultimer(nn.Module):
rigid_vec = rigid[..., None].inverse().apply_to_point(points) rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized() unit_vector = rigid_vec.normalized()
pair_act = self.template_pair_embedder( pair_embedding = self.template_pair_embedder(
template_dgram, template_dgram,
aatype_one_hot, aatype_one_hot,
z, z,
...@@ -346,7 +347,23 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -346,7 +347,23 @@ class TemplateEmbedderMultimer(nn.Module):
unit_vector, unit_vector,
) )
single_template_embeds["template_pair_embedding"] = pair_act if not inplace:
# [*, S_t, N, N, C_z]
template_pair_embeddings = template_pair_embeddings + self.template_pair_stack(
pair_embedding,
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
).squeeze(0)
else:
# [*, S_t, N, N, C_z]
template_pair_embeddings += self.template_pair_stack.inplace(
[pair_embedding],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)[0].squeeze(0)
single_template_embeds.update( single_template_embeds.update(
self.template_single_embedder( self.template_single_embedder(
single_template_feats, single_template_feats,
...@@ -361,27 +378,11 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -361,27 +378,11 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds, template_embeds,
) )
if not inplace:
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)
else:
template_embeds["template_pair_embedding"] = [template_embeds["template_pair_embedding"]]
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack.inplace(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)[0].to(z.device)
# [*, N, N, C_z] # [*, N, N, C_z]
template_embeds["template_pair_embedding"] = torch.sum(template_embeds["template_pair_embedding"], dim=-4) / n_templ template_pair_embeddings = template_pair_embeddings / n_templ
template_embeds["template_pair_embedding"] = torch.nn.functional.relu(template_embeds["template_pair_embedding"]) template_pair_embeddings = torch.nn.functional.relu(template_pair_embeddings)
template_embeds["template_pair_embedding"] = self.linear_t(template_embeds["template_pair_embedding"]) template_pair_embeddings = self.linear_t(template_pair_embeddings)
template_embeds["template_pair_embedding"] = template_pair_embeddings
return template_embeds return template_embeds
...@@ -147,8 +147,8 @@ class Evoformer(nn.Module): ...@@ -147,8 +147,8 @@ class Evoformer(nn.Module):
if self.is_multimer: if self.is_multimer:
m[0] = gather(m[0], dim=1) m[0] = gather(m[0], dim=1)
else: else:
m[0] = gather(m[0], dim=0, chunks=4) m[0] = gather(m[0], dim=0, chunk_size=chunk_size)
z[0] = gather(z[0], dim=0, chunks=4) z[0] = gather(z[0], dim=0, chunk_size=chunk_size)
m[0] = m[0][:, :-padding_size, :] m[0] = m[0][:, :-padding_size, :]
z[0] = z[0][:-padding_size, :-padding_size, :] z[0] = z[0][:-padding_size, :-padding_size, :]
......
...@@ -34,6 +34,24 @@ class FusedLayerNorm(torch.nn.Module): ...@@ -34,6 +34,24 @@ class FusedLayerNorm(torch.nn.Module):
torch.nn.init.zeros_(self.bias) torch.nn.init.zeros_(self.bias)
def forward(self, input): def forward(self, input):
if len(input.shape) >= 3 and input.shape[-3] > 4000:
out = torch.empty_like(input)
# set max chunk_size = dim / 2, to max compute efficiency
chunk_size = min(4000 * 4000 // input.shape[-3], (input.shape[-3] + 1) // 2)
if len(input.shape) == 3:
for i in range(input.shape[-3]):
out[i:i + chunk_size] = self.kernel_forward(input[i:i + chunk_size])
elif len(input.shape) == 4:
for j in range(input.shape[-4]):
for i in range(0, input.shape[-3], chunk_size):
out[j, i:i + chunk_size] = self.kernel_forward(input[j, i:i + chunk_size])
else:
raise RuntimeError("Shape" + input.shape + "not implemented for layernorm yet!")
return out
else:
return self.kernel_forward(input)
def kernel_forward(self, input):
if _triton_available: if _triton_available:
return LayerNormTritonFunc.apply(input, self.normalized_shape, self.weight, self.bias, return LayerNormTritonFunc.apply(input, self.normalized_shape, self.weight, self.bias,
self.eps) self.eps)
......
...@@ -303,7 +303,7 @@ class ExtraMSABlock(nn.Module): ...@@ -303,7 +303,7 @@ class ExtraMSABlock(nn.Module):
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
m[0] = scatter(m[0], dim=1, drop_unused=True) if not self.is_multimer else scatter(m[0], dim=2) m[0] = scatter(m[0], dim=1, drop_unused=True) if not self.is_multimer else scatter(m[0], dim=2, drop_unused=True)
torch.cuda.empty_cache() torch.cuda.empty_cache()
z[0] = scatter(z[0], dim=1, drop_unused=True) z[0] = scatter(z[0], dim=1, drop_unused=True)
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -339,9 +339,9 @@ class ExtraMSABlock(nn.Module): ...@@ -339,9 +339,9 @@ class ExtraMSABlock(nn.Module):
if self.last_block: if self.last_block:
m[0] = gather(m[0], dim=1, chunks=4) if not self.is_multimer else gather(m[0], dim=2) m[0] = gather(m[0], dim=1, chunk_size=chunk_size) if not self.is_multimer else gather(m[0], dim=2)
torch.cuda.empty_cache() torch.cuda.empty_cache()
z[0] = gather(z[0], dim=1, chunks=4) z[0] = gather(z[0], dim=1, chunk_size=chunk_size)
m[0] = m[0][:, :-seq_cnt_padding_size, :-seq_len_padding_size, :] m[0] = m[0][:, :-seq_cnt_padding_size, :-seq_len_padding_size, :]
z[0] = z[0][:, :-seq_len_padding_size, :-seq_len_padding_size, :] z[0] = z[0][:, :-seq_len_padding_size, :-seq_len_padding_size, :]
......
...@@ -91,20 +91,19 @@ class ChunkTransition(nn.Module): ...@@ -91,20 +91,19 @@ class ChunkTransition(nn.Module):
self.linear2 = Linear(n * d, d, initializer='zeros') self.linear2 = Linear(n * d, d, initializer='zeros')
def forward(self, src): def forward(self, src):
para_dim = src.shape[1]
chunk_size = 48
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
chunk_size = para_dim out = self.norm(src)
out = self.linear2(F.relu(self.linear1(out)))
else: else:
chunk_size = CHUNK_SIZE * 48 chunk_size = CHUNK_SIZE * 48
para_dim = src.shape[1]
out = torch.empty_like(src) out = torch.empty_like(src)
for ax in range(0, para_dim, chunk_size): for ax in range(0, para_dim, chunk_size):
if DEBUG and ax > 10: if DEBUG and ax > 10:
break break
x = self.norm(src[:, ax:ax + chunk_size, :, :]) x = self.norm(src[:, ax:ax + chunk_size, :, :])
x = self.linear2(F.relu(self.linear1(x))) x = self.linear2(F.relu(self.linear1(x)))
out[:, ax:ax + chunk_size, :, :] = x out[:, ax:ax + chunk_size, :, :] = x
out.add_(src) out.add_(src)
return out return out
...@@ -155,18 +154,21 @@ class OutProductMean(nn.Module): ...@@ -155,18 +154,21 @@ class OutProductMean(nn.Module):
right_act_all = gather_async_opp(right_act_all, work, dim=2) right_act_all = gather_async_opp(right_act_all, work, dim=2)
right_act_all = M_mask * right_act_all right_act_all = M_mask * right_act_all
para_dim = left_act.shape[2]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
chunk_size = para_dim out = torch.einsum('bsid, bsje->bijde', left_act, right_act_all)
out = rearrange(out, 'b i j d e -> b i j (d e)')
for ax in range(0, para_dim, chunk_size): out = self.o_linear(out)
left_act_part = left_act[:, :, ax:ax + chunk_size, :] Z = out / norm
O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all) else:
O = rearrange(O, 'b i j d e -> b i j (d e)') para_dim = left_act.shape[2]
O = self.o_linear(O) chunk_size = CHUNK_SIZE
norm0 = norm[:, ax:ax + chunk_size, :, :] for ax in range(0, para_dim, chunk_size):
Z[:, ax:ax + chunk_size, :, :] = O / norm0 left_act_part = left_act[:, :, ax:ax + chunk_size, :]
O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all)
O = rearrange(O, 'b i j d e -> b i j (d e)')
O = self.o_linear(O)
norm0 = norm[:, ax:ax + chunk_size, :, :]
Z[:, ax:ax + chunk_size, :, :] = O / norm0
return Z + Z_raw return Z + Z_raw
...@@ -293,11 +295,6 @@ class SelfAttention(nn.Module): ...@@ -293,11 +295,6 @@ class SelfAttention(nn.Module):
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
""" """
para_dim = in_data.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
if nonbatched_bias is not None: if nonbatched_bias is not None:
if nonbatched_bias[-1] == -1: if nonbatched_bias[-1] == -1:
bias = nonbatched_bias[0] bias = nonbatched_bias[0]
...@@ -306,7 +303,33 @@ class SelfAttention(nn.Module): ...@@ -306,7 +303,33 @@ class SelfAttention(nn.Module):
bias = gather_async_opp(*nonbatched_bias, dim=1) bias = gather_async_opp(*nonbatched_bias, dim=1)
bias = rearrange(bias, 'b q k h -> b h q k') bias = rearrange(bias, 'b q k h -> b h q k')
output = [] if CHUNK_SIZE == None:
qkv = self.to_qkv(in_data).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None:
weights = fused_softmax(logits, mask, bias.unsqueeze(1))
else:
weights = fused_softmax(logits, mask)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
if self.gating:
gate_values = self.gating_linear(in_data)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output = self.o_linear(weighted_avg)
else:
para_dim = in_data.shape[1]
chunk_size = CHUNK_SIZE
output = []
for ax in range(0, para_dim, chunk_size): for ax in range(0, para_dim, chunk_size):
in_data_part = in_data[:, ax:ax + chunk_size, :, :] in_data_part = in_data[:, ax:ax + chunk_size, :, :]
...@@ -983,16 +1006,17 @@ class ChunkMSAColumnGlobalAttention(nn.Module): ...@@ -983,16 +1006,17 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
) )
def forward(self, M_raw, M_mask): def forward(self, M_raw, M_mask):
para_dim = M_raw.shape[2]
if CHUNK_SIZE is None: if CHUNK_SIZE is None:
chunk_size = para_dim m = self.layernormM(M_raw.transpose(-2, -3))
m = self.global_attention(m, M_mask.transpose(-1, -2))
m = m.transpose(-2, -3)
M_raw = M_raw + m
else: else:
chunk_size = CHUNK_SIZE chunk_size = CHUNK_SIZE
para_dim = M_raw.shape[2]
for i in range(0, para_dim, chunk_size): for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
m = M_raw[:, :, i:i + chunk_size, :].transpose(-2, -3) m = M_raw[:, :, i:i + chunk_size, :].transpose(-2, -3)
m = self.layernormM(m) m = self.layernormM(m)
m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2) m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2)
...@@ -1111,12 +1135,12 @@ class RecyclingEmbedder(nn.Module): ...@@ -1111,12 +1135,12 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, no_bins] # [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype) d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
para_dim = d.shape[1]
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
chunk_size = para_dim d = self.linear(d)
z = d + self.layer_norm_z(z)
else: else:
chunk_size = CHUNK_SIZE * 48 chunk_size = CHUNK_SIZE * 48
para_dim = d.shape[1]
for i in range(0, para_dim, chunk_size): for i in range(0, para_dim, chunk_size):
di = self.linear(d[i:i + chunk_size, :, :]) di = self.linear(d[i:i + chunk_size, :, :])
...@@ -1154,41 +1178,64 @@ class GlobalAttention(nn.Module): ...@@ -1154,41 +1178,64 @@ class GlobalAttention(nn.Module):
def forward(self, m, mask): def forward(self, m, mask):
para_dim = m.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
chunk_size = para_dim q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1)[..., None] + self.eps
output = []
for ax in range(0, para_dim, chunk_size):
m_part = m[:, ax : ax + chunk_size, :, :]
mask_part = mask[:, ax : ax + chunk_size, :]
q = torch.sum(m_part * mask_part.unsqueeze(-1), dim=-2) / (
torch.sum(mask_part, dim=-1)[..., None] + self.eps
) )
q = q * self.scaling q = q * self.scaling
q = self.to_q(q) q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1)) q = q.view(q.shape[:-1] + (self.n_head, -1))
k, v = self.to_kv(m_part).chunk(2, dim=-1) k, v = self.to_kv(m).chunk(2, dim=-1)
logits = torch.matmul(q, k.transpose(-1, -2)) logits = torch.matmul(q, k.transpose(-1, -2))
weights = fused_softmax(logits, mask_part) weights = fused_softmax(logits, mask)
weighted_avg = torch.matmul(weights, v) weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)") weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
gate_values = self.gating_linear(m_part) gate_values = self.gating_linear(m)
weighted_avg = bias_sigmod_ele( weighted_avg = bias_sigmod_ele(
gate_values, self.gating_bias, weighted_avg.unsqueeze(-2) gate_values, self.gating_bias, weighted_avg.unsqueeze(-2)
) )
output.append(self.o_linear(weighted_avg)) m = self.o_linear(weighted_avg)
else:
para_dim = m.shape[1]
chunk_size = CHUNK_SIZE
output = []
for ax in range(0, para_dim, chunk_size):
m_part = m[:, ax : ax + chunk_size, :, :]
mask_part = mask[:, ax : ax + chunk_size, :]
q = torch.sum(m_part * mask_part.unsqueeze(-1), dim=-2) / (
torch.sum(mask_part, dim=-1)[..., None] + self.eps
)
q = q * self.scaling
q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1))
k, v = self.to_kv(m_part).chunk(2, dim=-1)
logits = torch.matmul(q, k.transpose(-1, -2))
weights = fused_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
gate_values = self.gating_linear(m_part)
weighted_avg = bias_sigmod_ele(
gate_values, self.gating_bias, weighted_avg.unsqueeze(-2)
)
output.append(self.o_linear(weighted_avg))
m = torch.cat(output, dim=1) m = torch.cat(output, dim=1)
return m return m
......
...@@ -241,25 +241,18 @@ class TemplatePairBlock(nn.Module): ...@@ -241,25 +241,18 @@ class TemplatePairBlock(nn.Module):
mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size)) mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size))
# single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)] single_mask_row = scatter(mask, dim=1)
# single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)] single_mask_col = scatter(mask, dim=2)
for i in range(z.shape[0]):
single = z[i].unsqueeze(-4) z = self.TriangleAttentionStartingNode(z, single_mask_row)
single_mask = mask[i].unsqueeze(-3) z = row_to_col(z)
z = self.TriangleAttentionEndingNode(z, single_mask_col)
single_mask_row = scatter(single_mask, dim=1) z = col_to_row(z)
single_mask_col = scatter(single_mask, dim=2) z = self.TriangleMultiplicationOutgoing(z, single_mask_row)
z = row_to_col(z)
single = self.TriangleAttentionStartingNode(single, single_mask_row) z = self.TriangleMultiplicationIncoming(z, single_mask_col)
single = row_to_col(single) z = self.PairTransition(z)
single = self.TriangleAttentionEndingNode(single, single_mask_col) z = col_to_row(z)
single = col_to_row(single)
single = self.TriangleMultiplicationOutgoing(single, single_mask_row)
single = row_to_col(single)
single = self.TriangleMultiplicationIncoming(single, single_mask_col)
single = self.PairTransition(single)
single = col_to_row(single)
z[i] = single
# z = torch.cat(single_templates, dim=-4) # z = torch.cat(single_templates, dim=-4)
if self.last_block: if self.last_block:
...@@ -275,8 +268,6 @@ class TemplatePairBlock(nn.Module): ...@@ -275,8 +268,6 @@ class TemplatePairBlock(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
_mask_trans: bool = True, _mask_trans: bool = True,
): ):
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4:
z[0] = z[0].cpu()
dap_size = gpc.get_world_size(ParallelMode.TENSOR) dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = mask.size(-1) seq_length = mask.size(-1)
...@@ -290,32 +281,24 @@ class TemplatePairBlock(nn.Module): ...@@ -290,32 +281,24 @@ class TemplatePairBlock(nn.Module):
mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size)) mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size))
# single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)] single_mask_row = scatter(mask, dim=1, drop_unused=True)
# single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)] single_mask_col = scatter(mask, dim=2, drop_unused=True)
for i in range(z[0].shape[0]): torch.cuda.empty_cache()
single = z[0][i].unsqueeze(-4).to(mask.device)
single_mask = mask[i].unsqueeze(-3) z = self.TriangleAttentionStartingNode.inplace(z, single_mask_row)
z[0] = row_to_col(z[0])
single_mask_row = scatter(single_mask, dim=1, drop_unused=True) z = self.TriangleAttentionEndingNode.inplace(z, single_mask_col)
single_mask_col = scatter(single_mask, dim=2, drop_unused=True) z[0] = col_to_row(z[0])
z[0] = self.TriangleMultiplicationOutgoing(z[0], single_mask_row)
single = self.TriangleAttentionStartingNode(single, single_mask_row) z[0] = row_to_col(z[0])
single = row_to_col(single) z[0] = self.TriangleMultiplicationIncoming(z[0], single_mask_col)
single = self.TriangleAttentionEndingNode(single, single_mask_col) z = self.PairTransition.inplace(z)
single = col_to_row(single) z[0] = col_to_row(z[0])
single = self.TriangleMultiplicationOutgoing(single, single_mask_row)
single = row_to_col(single)
single = self.TriangleMultiplicationIncoming(single, single_mask_col)
single = self.PairTransition(single)
single = col_to_row(single)
z[0][i] = single.to(z[0].device)
# z = torch.cat(single_templates, dim=-4) # z = torch.cat(single_templates, dim=-4)
if self.last_block: if self.last_block:
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4: z[0] = gather(z[0], dim=1, chunk_size=chunk_size)
z[0] = z[0].to(mask.device)
z[0] = gather(z[0], dim=1)
z[0] = z[0][:, :-padding_size, :-padding_size, :] z[0] = z[0][:, :-padding_size, :-padding_size, :]
return z return z
...@@ -411,15 +394,8 @@ class TemplatePairStack(nn.Module): ...@@ -411,15 +394,8 @@ class TemplatePairStack(nn.Module):
args=(t,), args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
if chunk_size is None: for i in range(0, t.shape[0]):
chunk_size = t.shape[0] t[i] = self.layer_norm(t[i])
for i in range(0, t.shape[0], chunk_size):
if t.shape[1] > 4000:
chunk_new = int(4000 * 4000 / t.shape[1])
for j in range(0, t.shape[1], chunk_new):
t[i:i + chunk_size, j:j + chunk_new] = self.layer_norm(t[i:i + chunk_size, j:j + chunk_new])
else:
t[i:i + chunk_size] = self.layer_norm(t[i:i + chunk_size])
return t return t
def inplace( def inplace(
...@@ -456,13 +432,6 @@ class TemplatePairStack(nn.Module): ...@@ -456,13 +432,6 @@ class TemplatePairStack(nn.Module):
args=(t,), args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
if chunk_size is None: for i in range(0, t[0].shape[0]):
chunk_size = t[0].shape[0] t[0][i] = self.layer_norm(t[0][i].to(mask.device)).to(t[0].device)
for i in range(0, t[0].shape[0], chunk_size):
if t[0].shape[1] > 4000:
chunk_new = int(4000 * 4000 / t[0].shape[1])
for j in range(0, t[0].shape[1], chunk_new):
t[0][i:i + chunk_size, j:j + chunk_new] = self.layer_norm(t[0][i:i + chunk_size, j:j + chunk_new].to(mask.device)).to(t[0].device)
else:
t[0][i:i + chunk_size] = self.layer_norm(t[0][i:i + chunk_size].to(mask.device)).to(t[0].device)
return t return t
...@@ -56,7 +56,7 @@ class Dropout(nn.Module): ...@@ -56,7 +56,7 @@ class Dropout(nn.Module):
shape[bd] = 1 shape[bd] = 1
mask = x.new_ones(shape) mask = x.new_ones(shape)
mask = self.dropout(mask) mask = self.dropout(mask)
x *= mask x = x * mask
return x return x
......
...@@ -264,11 +264,6 @@ class EvoformerBlock(nn.Module): ...@@ -264,11 +264,6 @@ class EvoformerBlock(nn.Module):
eps=eps, eps=eps,
) )
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
self.is_multimer = is_multimer self.is_multimer = is_multimer
def forward(self, def forward(self,
......
...@@ -227,9 +227,11 @@ def inference_multimer_model(args): ...@@ -227,9 +227,11 @@ def inference_multimer_model(args):
) )
output_dir_base = args.output_dir output_dir_base = args.output_dir
random_seed = args.data_random_seed random_seed = args.data_random_seed
if random_seed is None: if random_seed is None:
random_seed = random.randrange(sys.maxsize) random_seed = random.randrange(sys.maxsize)
# seed_torch(seed=1029)
feature_processor = feature_pipeline.FeaturePipeline( feature_processor = feature_pipeline.FeaturePipeline(
config.data config.data
......
...@@ -15,8 +15,8 @@ this_dir = os.path.dirname(os.path.abspath(__file__)) ...@@ -15,8 +15,8 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
def check_cuda_torch_binary_vs_bare_metal(cuda_dir): def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
torch_binary_major = torch.version.cuda.split(".")[0] torch_binary_major = torch.version.hip.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1] torch_binary_minor = torch.version.hip.split(".")[1]
print("\nCompiling cuda extensions with") print("\nCompiling cuda extensions with")
......
...@@ -73,4 +73,4 @@ def _test_msa_att_col(rank, world_size, chunk_size, get_openfold_module_and_data ...@@ -73,4 +73,4 @@ def _test_msa_att_col(rank, world_size, chunk_size, get_openfold_module_and_data
m_fast = m_fast[:, :-padding_size, :] m_fast = m_fast[:, :-padding_size, :]
error = torch.max(torch.abs(m_out.cuda() - m_fast)) error = torch.max(torch.abs(m_out.cuda() - m_fast))
assert error < 5e-5, f"Test m failed at chunk size: {chunk_size}. The position dif is {error}" assert error < 1e-4, f"Test m failed at chunk size: {chunk_size}. The position dif is {error}"
...@@ -46,7 +46,7 @@ def get_openfold_module_and_data(): ...@@ -46,7 +46,7 @@ def get_openfold_module_and_data():
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32]) @pytest.mark.parametrize('chunk_size', [None, 4]) # should set 4 to test offload
@pytest.mark.parametrize('inplace', [False, True]) @pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data): def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data):
run_func = partial(_test_template_embedder, world_size=world_size, chunk_size=chunk_size, run_func = partial(_test_template_embedder, world_size=world_size, chunk_size=chunk_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment