Commit f44557ed authored by shenggan's avatar shenggan
Browse files

refactor softmax kernel

parent a65d5009
......@@ -87,7 +87,7 @@ if CUDA_HOME is None:
"Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc."
)
else:
check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)
# check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)
def cuda_ext_helper(name, sources, extra_cuda_flags):
return CUDAExtension(
......
......@@ -5,11 +5,11 @@ from fastfold.model.fastnn.kernel import softmax
def test_softmax():
# [batch, dim]
test_shape = [[64, 64], [64, 128], [64, 129], [64, 1024]]
test_shape = [[64, 64], [64, 128], [64, 129], [64, 2000]]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
tolerance_eps = {torch.float32: 10e-5, torch.float16: 10e-2, torch.bfloat16: 10e-2}
tolerance_eps = {torch.float32: 10e-4, torch.float16: 10e-2, torch.bfloat16: 10e-2}
for shape in test_shape:
for dtype in test_dtype:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment