Unverified Commit 9082c254 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #222 from beiwang2003/main

FP16 fixes 
parents 499b9a84 4d5fa31c
...@@ -137,7 +137,7 @@ class DistogramHead(nn.Module): ...@@ -137,7 +137,7 @@ class DistogramHead(nn.Module):
self.linear = Linear(self.c_z, self.no_bins, init="final") self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, z): # [*, N, N, C_z] def _forward(self, z): # [*, N, N, C_z]
""" """
Args: Args:
z: z:
...@@ -150,6 +150,14 @@ class DistogramHead(nn.Module): ...@@ -150,6 +150,14 @@ class DistogramHead(nn.Module):
logits = logits + logits.transpose(-2, -3) logits = logits + logits.transpose(-2, -3)
return logits return logits
def forward(self, z):
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float())
else:
return self._forward(z)
class TMScoreHead(nn.Module): class TMScoreHead(nn.Module):
""" """
......
...@@ -93,7 +93,7 @@ class OuterProductMean(nn.Module): ...@@ -93,7 +93,7 @@ class OuterProductMean(nn.Module):
return outer return outer
def forward(self, def _forward(self,
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
...@@ -143,3 +143,18 @@ class OuterProductMean(nn.Module): ...@@ -143,3 +143,18 @@ class OuterProductMean(nn.Module):
outer = outer / norm outer = outer / norm
return outer return outer
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False):
return self._forward(m.float(), mask, chunk_size, inplace_safe)
else:
return self._forward(m, mask, chunk_size, inplace_safe)
...@@ -479,6 +479,9 @@ class Attention(nn.Module): ...@@ -479,6 +479,9 @@ class Attention(nn.Module):
q, k, v = self._prep_qkv(q_x, kv_x) q, k, v = self._prep_qkv(q_x, kv_x)
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled:
use_memory_efficient_kernel = False
if(use_memory_efficient_kernel): if(use_memory_efficient_kernel):
if(len(biases) > 2): if(len(biases) > 2):
raise ValueError( raise ValueError(
......
...@@ -312,6 +312,14 @@ class InvariantPointAttention(nn.Module): ...@@ -312,6 +312,14 @@ class InvariantPointAttention(nn.Module):
z[0] = z[0].cpu() z[0] = z[0].cpu()
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False):
a = torch.matmul(
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
)
else:
a = torch.matmul( a = torch.matmul(
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
......
...@@ -391,6 +391,11 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -391,6 +391,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
b = mask b = mask
b = b * self.sigmoid(self.linear_b_g(z)) b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z) b = b * self.linear_b_p(z)
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
x = self._combine_projections(a, b) x = self._combine_projections(a, b)
del a, b del a, b
x = self.layer_norm_out(x) x = self.layer_norm_out(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment