Commit 2fc72a5f authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Refactor FP16 code a little (WIP)

parent a17a9777
...@@ -35,6 +35,7 @@ from scipy.stats import truncnorm ...@@ -35,6 +35,7 @@ from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import _chunk_slice from openfold.utils.chunk_utils import _chunk_slice
from openfold.utils.kernel.attention_core import attention_core from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
...@@ -479,8 +480,7 @@ class Attention(nn.Module): ...@@ -479,8 +480,7 @@ class Attention(nn.Module):
q, k, v = self._prep_qkv(q_x, kv_x) q, k, v = self._prep_qkv(q_x, kv_x)
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16) if is_fp16_enabled():
if float16_enabled and torch.is_autocast_enabled():
use_memory_efficient_kernel = False use_memory_efficient_kernel = False
if(use_memory_efficient_kernel): if(use_memory_efficient_kernel):
......
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
import deepspeed
import torch
def is_fp16_enabled():
# Autocast world
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
# DeepSpeed world
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
print(dir(deepspeed))
return fp16_enabled
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment