Commit f84eb0a1 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Refactor FP16 check

parent 2fc72a5f
...@@ -22,6 +22,7 @@ from openfold.utils.loss import ( ...@@ -22,6 +22,7 @@ from openfold.utils.loss import (
compute_tm, compute_tm,
compute_predicted_aligned_error, compute_predicted_aligned_error,
) )
from openfold.utils.precision_utils import is_fp16_enabled
class AuxiliaryHeads(nn.Module): class AuxiliaryHeads(nn.Module):
...@@ -151,8 +152,7 @@ class DistogramHead(nn.Module): ...@@ -151,8 +152,7 @@ class DistogramHead(nn.Module):
return logits return logits
def forward(self, z): def forward(self, z):
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16) if(is_fp16_enabled()):
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float()) return self._forward(z.float())
else: else:
......
...@@ -21,6 +21,7 @@ import torch.nn as nn ...@@ -21,6 +21,7 @@ import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear
from openfold.utils.chunk_utils import chunk_layer from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
class OuterProductMean(nn.Module): class OuterProductMean(nn.Module):
...@@ -150,9 +151,7 @@ class OuterProductMean(nn.Module): ...@@ -150,9 +151,7 @@ class OuterProductMean(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
inplace_safe: bool = False, inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if(is_fp16_enabled()):
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return self._forward(m.float(), mask, chunk_size, inplace_safe) return self._forward(m.float(), mask, chunk_size, inplace_safe)
else: else:
......
...@@ -33,6 +33,7 @@ from openfold.utils.feats import ( ...@@ -33,6 +33,7 @@ from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos, frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames, torsion_angles_to_frames,
) )
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
dict_multimap, dict_multimap,
...@@ -312,8 +313,7 @@ class InvariantPointAttention(nn.Module): ...@@ -312,8 +313,7 @@ class InvariantPointAttention(nn.Module):
z[0] = z[0].cpu() z[0] = z[0].cpu()
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16) if(is_fp16_enabled()):
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
a = torch.matmul( a = torch.matmul(
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
......
...@@ -21,6 +21,7 @@ import torch.nn as nn ...@@ -21,6 +21,7 @@ import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.chunk_utils import chunk_layer from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import add, permute_final_dims from openfold.utils.tensor_utils import add, permute_final_dims
...@@ -392,8 +393,7 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -392,8 +393,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
b = b * self.sigmoid(self.linear_b_g(z)) b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z) b = b * self.linear_b_p(z)
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16) if(is_fp16_enabled()):
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float()) x = self._combine_projections(a.float(), b.float())
else: else:
......
...@@ -13,10 +13,6 @@ ...@@ -13,10 +13,6 @@
# limitations under the License. # limitations under the License.
import importlib import importlib
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
import deepspeed
import torch import torch
def is_fp16_enabled(): def is_fp16_enabled():
...@@ -24,11 +20,4 @@ def is_fp16_enabled(): ...@@ -24,11 +20,4 @@ def is_fp16_enabled():
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
fp16_enabled = fp16_enabled and torch.is_autocast_enabled() fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
# DeepSpeed world
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
print(dir(deepspeed))
return fp16_enabled return fp16_enabled
...@@ -567,6 +567,9 @@ if __name__ == "__main__": ...@@ -567,6 +567,9 @@ if __name__ == "__main__":
(args.num_nodes is not None and args.num_nodes > 1))): (args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified") raise ValueError("For distributed training, --seed must be specified")
if(args.precision == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible")
# This re-applies the training-time filters at the beginning of every epoch # This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1 args.reload_dataloaders_every_n_epochs = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment