"git@developer.sourcefind.cn:OpenDAS/Uni-Core.git" did not exist on "97fd9948f163c0f81390fe80a14d660409013c75"
Commit 023596d2 authored by Nikita Smetanin's avatar Nikita Smetanin
Browse files

Minor updates & optimizations to support ESMFold

parent 998ee79b
......@@ -35,8 +35,30 @@ def rot_matmul(
Returns:
The product ab
"""
with torch.autocast(a.device.type, enabled=False):
return a.to(torch.float32) @ b.to(torch.float32)
def row_mul(i):
return torch.stack(
[
a[..., i, 0] * b[..., 0, 0]
+ a[..., i, 1] * b[..., 1, 0]
+ a[..., i, 2] * b[..., 2, 0],
a[..., i, 0] * b[..., 0, 1]
+ a[..., i, 1] * b[..., 1, 1]
+ a[..., i, 2] * b[..., 2, 1],
a[..., i, 0] * b[..., 0, 2]
+ a[..., i, 1] * b[..., 1, 2]
+ a[..., i, 2] * b[..., 2, 2],
],
dim=-1,
)
return torch.stack(
[
row_mul(0),
row_mul(1),
row_mul(2),
],
dim=-2
)
def rot_vec_mul(
......@@ -53,11 +75,15 @@ def rot_vec_mul(
Returns:
[*, 3] rotated coordinates
"""
with torch.autocast(r.device.type, enabled=False):
r = r.to(torch.float32)
t = t.to(torch.float32)
return (r @ t.unsqueeze(-1)).squeeze(-1)
x, y, z = torch.unbind(t, dim=-1)
return torch.stack(
[
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
],
dim=-1,
)
@lru_cache(maxsize=None)
def identity_rot_mats(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment