Commit f84eb0a1 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Refactor FP16 check

parent 2fc72a5f
......@@ -22,6 +22,7 @@ from openfold.utils.loss import (
compute_tm,
compute_predicted_aligned_error,
)
from openfold.utils.precision_utils import is_fp16_enabled
class AuxiliaryHeads(nn.Module):
......@@ -151,8 +152,7 @@ class DistogramHead(nn.Module):
return logits
def forward(self, z):
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float())
else:
......
......@@ -21,6 +21,7 @@ import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
class OuterProductMean(nn.Module):
......@@ -150,9 +151,7 @@ class OuterProductMean(nn.Module):
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(m.float(), mask, chunk_size, inplace_safe)
else:
......
......@@ -33,6 +33,7 @@ from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames,
)
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
dict_multimap,
......@@ -312,8 +313,7 @@ class InvariantPointAttention(nn.Module):
z[0] = z[0].cpu()
# [*, H, N_res, N_res]
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
a = torch.matmul(
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
......
......@@ -21,6 +21,7 @@ import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import add, permute_final_dims
......@@ -392,8 +393,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z)
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
......
......@@ -13,10 +13,6 @@
# limitations under the License.
import importlib
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
import deepspeed
import torch
def is_fp16_enabled():
......@@ -24,11 +20,4 @@ def is_fp16_enabled():
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
# DeepSpeed world
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
print(dir(deepspeed))
return fp16_enabled
......@@ -567,6 +567,9 @@ if __name__ == "__main__":
(args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified")
if(args.precision == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment