"pcdet/ops/pointnet2/vscode:/vscode.git/clone" did not exist on "183d353a05005a0eb625b8ee69d3aea503a46b81"
Commit 5e1af36e authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update the compue_tm function

parent c69053e5
......@@ -694,8 +694,11 @@ def compute_tm(
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
n = residue_weights.shape[-1]
pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32)
if interface and (asym_id is not None):
if len(asym_id.shape)>1:
assert len(asym_id.shape)<=2
pair_mask = residue_weights.new_ones((1,n, n), dtype=torch.int32)
if interface:
pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype)
predicted_tm_term *= pair_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment