"tools/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "707a861a572dc465e2ff2bc1c84c75bec6bdb3db"
Commit a17a9777 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix precision bug

parent 48670cfc
...@@ -150,15 +150,15 @@ class DistogramHead(nn.Module): ...@@ -150,15 +150,15 @@ class DistogramHead(nn.Module):
logits = logits + logits.transpose(-2, -3) logits = logits + logits.transpose(-2, -3)
return logits return logits
def forward(self, z): def forward(self, z):
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16) float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled(): if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float()) return self._forward(z.float())
else: else:
return self._forward(z) return self._forward(z)
class TMScoreHead(nn.Module): class TMScoreHead(nn.Module):
""" """
For use in computation of TM-score, subsection 1.9.7 For use in computation of TM-score, subsection 1.9.7
......
...@@ -480,8 +480,9 @@ class Attention(nn.Module): ...@@ -480,8 +480,9 @@ class Attention(nn.Module):
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16) float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled: if float16_enabled and torch.is_autocast_enabled():
use_memory_efficient_kernel = False use_memory_efficient_kernel = False
if(use_memory_efficient_kernel): if(use_memory_efficient_kernel):
if(len(biases) > 2): if(len(biases) > 2):
raise ValueError( raise ValueError(
......
...@@ -324,6 +324,7 @@ class InvariantPointAttention(nn.Module): ...@@ -324,6 +324,7 @@ class InvariantPointAttention(nn.Module):
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
) )
a *= math.sqrt(1.0 / (3 * self.c_hidden)) a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
......
...@@ -391,12 +391,14 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -391,12 +391,14 @@ class TriangleMultiplicativeUpdate(nn.Module):
b = mask b = mask
b = b * self.sigmoid(self.linear_b_g(z)) b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z) b = b * self.linear_b_p(z)
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16) float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled(): if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float()) x = self._combine_projections(a.float(), b.float())
else: else:
x = self._combine_projections(a, b) x = self._combine_projections(a, b)
del a, b del a, b
x = self.layer_norm_out(x) x = self.layer_norm_out(x)
x = self.linear_z(x) x = self.linear_z(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment