"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "b0655a3465904ff265bdc1e1ccdcff009a448bb0"
Commit 64493e08 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

reduce asym_id in chain_center_of_mass_loss by 1 to avoid torch one_hot error

parent 2c4d4183
...@@ -1667,7 +1667,7 @@ def chain_center_of_mass_loss( ...@@ -1667,7 +1667,7 @@ def chain_center_of_mass_loss(
all_atom_positions = all_atom_positions[..., ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains, _ = asym_id.unique(return_counts=True) chains, _ = asym_id.unique(return_counts=True)
one_hot = torch.nn.functional.one_hot(asym_id.to(torch.int64), one_hot = torch.nn.functional.one_hot(asym_id.to(torch.int64)-1, # have to reduce asym_id by one because class values must be smaller than num_classes
num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) # make sure asym_id dtype is int num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) # make sure asym_id dtype is int
one_hot = one_hot * all_atom_mask one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1) chain_pos_mask = one_hot.transpose(-2, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment