"docs/vscode:/vscode.git/clone" did not exist on "a315988baeddd64da3eb4f030ca804ef92a73d1f"
Unverified Commit 68187c46 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix arg names for our models (#20166)

* Fix arg names for our models

* Clean out the other uses of "residx" in infer()

* make fixup
parent 6dda14dc
...@@ -2248,8 +2248,7 @@ class EsmForProteinFolding(EsmPreTrainedModel): ...@@ -2248,8 +2248,7 @@ class EsmForProteinFolding(EsmPreTrainedModel):
def infer( def infer(
self, self,
seqs: Union[str, List[str]], seqs: Union[str, List[str]],
residx=None, position_ids=None,
with_mask: Optional[torch.Tensor] = None,
): ):
if type(seqs) is str: if type(seqs) is str:
lst = [seqs] lst = [seqs]
...@@ -2272,17 +2271,17 @@ class EsmForProteinFolding(EsmPreTrainedModel): ...@@ -2272,17 +2271,17 @@ class EsmForProteinFolding(EsmPreTrainedModel):
] ]
) # B=1 x L ) # B=1 x L
mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst]) mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
residx = ( position_ids = (
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1) if residx is None else residx.to(device) torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
if position_ids is None
else position_ids.to(device)
) )
if residx.ndim == 1: if position_ids.ndim == 1:
residx = residx.unsqueeze(0) position_ids = position_ids.unsqueeze(0)
return self.forward( return self.forward(
aatype, aatype,
mask, mask,
mask_aa=with_mask is not None, position_ids=position_ids,
masking_pattern=with_mask,
residx=residx,
) )
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment