Commit 340e71a4 authored by Michael Carilli's avatar Michael Carilli
Browse files

Tests for the fused downscale kernel

parent 8818ba9e
import unittest
import functools as ft
import itertools as it
from apex import amp
import torch
from torch import nn
import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
try:
import amp_C
scale_check_overflow = amp_C.scale_check_overflow
disabled = False
except ImportError as err:
print("amp_C fused kernel unavailable, disabling TestScale. ImportError was ", err)
disabled = True
class TestScale(unittest.TestCase):
def setUp(self):
self.scale = 128.0
self.nx = 999
self.ny = 888
self.overflow_buf = torch.cuda.IntTensor([0])
self.fp16 = torch.ones((self.ny, self.nx), device='cuda', dtype=torch.float16)
self.fp32 = torch.ones((self.ny, self.nx), device='cuda', dtype=torch.float32)
self.fp16_ref = torch.ones((1, 1), device='cuda', dtype=torch.float16)
self.fp32_ref = torch.ones((1, 1), device='cuda', dtype=torch.float32)
common_init(self)
def tearDown(self):
pass
def downscale_test(self, input, output, ref):
self.overflow_buf.zero_()
input.fill_(1.0)
if input is not output:
output.fill_(3.0)
input.mul_(self.scale)
scale_check_overflow(input, 1./self.scale, self.overflow_buf, output)
self.assertTrue(torch.allclose(output, ref))
self.assertTrue(self.overflow_buf.item() == 0)
def find_inf_test(self, input, output, ref, x, y, val):
self.overflow_buf.zero_()
input.fill_(1.0)
if input is not output:
output.fill_(3.0)
input[x,y] = val
scale_check_overflow(input, 1./self.scale, self.overflow_buf, output)
self.assertTrue(self.overflow_buf.item())
# Currently, the fused kernel gives a hard error if you attempt to downscale
# into fp16 output, which imo is the desired behavior. Maybe someday we
# will learn otherwise.
# @unittest.skipIf(disabled, "amp_C is unavailable")
# def test_fp16_to_fp16(self):
# self.downscale_test(self.fp16, self.fp16, self.fp16_ref)
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fp16_to_fp32(self):
self.downscale_test(self.fp16, self.fp32, self.fp32_ref)
# @unittest.skipIf(disabled, "amp_C is unavailable")
# def test_fp32_to_fp16(self):
# self.downscale_test(self.fp32, self.fp16, self.fp16_ref)
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fp32_to_fp32(self):
self.downscale_test(self.fp32, self.fp32, self.fp32_ref)
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fp16_to_fp32_find_inf_nan(self):
self.find_inf_test(self.fp16, self.fp32, self.fp32_ref, 0, 0, float('nan'))
self.find_inf_test(self.fp16, self.fp32, self.fp32_ref, self.ny//2, self.nx//2, float('inf'))
self.find_inf_test(self.fp16, self.fp32, self.fp32_ref, self.ny-1, self.nx-1, float('nan'))
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fp32_to_fp32_find_inf_nan(self):
self.find_inf_test(self.fp32, self.fp32, self.fp32_ref, 0, 0, float('inf'))
self.find_inf_test(self.fp32, self.fp32, self.fp32_ref, self.ny//2, self.nx//2, float('nan'))
self.find_inf_test(self.fp32, self.fp32, self.fp32_ref, self.ny-1, self.nx-1, float('inf'))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment