Commit 50078a62 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add type check to attention kernel

parent aada2a46
...@@ -20,11 +20,16 @@ import torch ...@@ -20,11 +20,16 @@ import torch
attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda") attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
SUPPORTED_DTYPES = [torch.float32, torch.bfloat16]
class AttentionCoreFunction(torch.autograd.Function): class AttentionCoreFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, bias_1=None, bias_2=None): def forward(ctx, q, k, v, bias_1=None, bias_2=None):
if(bias_1 is None and bias_2 is not None): if(bias_1 is None and bias_2 is not None):
raise ValueError("bias_1 must be specified before bias_2") raise ValueError("bias_1 must be specified before bias_2")
if(q.dtype not in SUPPORTED_DTYPES):
raise ValueError("Unsupported datatype")
q = q.contiguous() q = q.contiguous()
k = k.contiguous() k = k.contiguous()
......
...@@ -671,7 +671,7 @@ class TestLoss(unittest.TestCase): ...@@ -671,7 +671,7 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_backbone_loss(self): def test_backbone_loss_compare(self):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
c_sm = config.model.heads.structure_module c_sm = config.model.heads.structure_module
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment