Unverified Commit de6378f5 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

NHWC support for multi tensor apply (#732)

* NHWC support for multi tensor apply

* compilation fix for version<=1.4
parent 92b3b9a9
...@@ -56,7 +56,11 @@ void multi_tensor_apply( ...@@ -56,7 +56,11 @@ void multi_tensor_apply(
for(int t = 0; t < tensor_lists[l].size(); t++) for(int t = 0; t < tensor_lists[l].size(); t++)
{ {
// TODO: Print which tensor fails. // TODO: Print which tensor fails.
TORCH_CHECK(tensor_lists[l][t].is_contiguous(), "A tensor was not contiguous."); bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda."); TORCH_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda.");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
} }
......
...@@ -91,7 +91,10 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): ...@@ -91,7 +91,10 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_3 = [] version_ge_1_3 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
version_ge_1_3 = ['-DVERSION_GE_1_3'] version_ge_1_3 = ['-DVERSION_GE_1_3']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 version_ge_1_5 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--cuda_ext" in sys.argv: if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
......
...@@ -7,6 +7,7 @@ from apex import amp ...@@ -7,6 +7,7 @@ from apex import amp
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from math import floor
from utils import common_init, HALF, FLOAT,\ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
...@@ -20,6 +21,10 @@ except ImportError as err: ...@@ -20,6 +21,10 @@ except ImportError as err:
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err) print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
disabled = True disabled = True
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
try_nhwc = (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4)
class TestMultiTensorAxpby(unittest.TestCase): class TestMultiTensorAxpby(unittest.TestCase):
...@@ -31,28 +36,36 @@ class TestMultiTensorAxpby(unittest.TestCase): ...@@ -31,28 +36,36 @@ class TestMultiTensorAxpby(unittest.TestCase):
self.xval = 4.0 self.xval = 4.0
self.yval = 16.0 self.yval = 16.0
self.overflow_buf = torch.cuda.IntTensor(1).zero_() self.overflow_buf = torch.cuda.IntTensor(1).zero_()
self.ref = torch.cuda.FloatTensor([136.0]) self.ref = torch.full((1,), 136.0, device="cuda", dtype=torch.float32)
def tearDown(self): def tearDown(self):
pass pass
# The tensor creation here is written for convenience, not speed. # The tensor creation here is written for convenience, not speed.
def axpby(self, sizea, sizeb, applier, repeat_tensors, def axpby(self, sizea, sizeb, applier, repeat_tensors,
x_type, y_type, out_type, inplace=False): x_type, y_type, out_type, inplace=False, nhwc=False):
self.overflow_buf.zero_() self.overflow_buf.zero_()
t1 = torch.cuda.FloatTensor(sizea).fill_(1.0) sizea = sizea if isinstance(sizea, tuple) else (sizea,)
t2 = torch.cuda.FloatTensor(sizeb).fill_(1.0) sizeb = sizeb if isinstance(sizeb, tuple) else (sizeb,)
t1 = torch.full(sizea, 1.0, device="cuda", dtype=torch.float32)
t2 = torch.full(sizeb, 1.0, device="cuda", dtype=torch.float32)
def to_fmt(t, tp):
if nhwc:
return t.clone().to(tp, memory_format=torch.channels_last)
else:
return t.clone().to(tp)
y_list = [] y_list = []
for i in range(repeat_tensors): for i in range(repeat_tensors):
y_list += [t1.clone().to(y_type)*self.yval, t2.clone().to(y_type)*self.yval] y_list += [to_fmt(t1, y_type)*self.yval, to_fmt(t2, y_type)*self.yval]
x_list = [x.clone().to(x_type)*(self.xval/self.yval) for x in y_list] x_list = [to_fmt(x, x_type)*(self.xval/self.yval) for x in y_list]
if inplace: if inplace:
out_list = y_list out_list = y_list
else: else:
out_list = [out.clone().to(out_type)*3.0 for out in y_list] out_list = [to_fmt(out, out_type)*3.0 for out in y_list]
applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b, -1) applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b, -1)
...@@ -122,6 +135,45 @@ class TestMultiTensorAxpby(unittest.TestCase): ...@@ -122,6 +135,45 @@ class TestMultiTensorAxpby(unittest.TestCase):
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*(repeat//2), sizea//2, float('inf'), inplace=inplace) # 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
@unittest.skipIf(disabled, "amp_C is unavailable")
@unittest.skipIf(not try_nhwc, "torch version is 1.4 or earlier, may not support nhwc")
def test_fuzz_nhwc(self):
input_size_pairs = (
((7, 77, 7, 77), (5, 55, 5, 55)),
((1, 1, 777, 1), (1, 1, 555, 1)),
((5, 47, 5, 55), (1, 1, 1, 2048*32 + 1)),
((1, 1, 1, 2048*32 + 1), (55, 47, 5, 55)),
((555, 1, 1, 1), (32, 8, 32, 8)),
((32, 8, 32, 8), (55, 47, 5, 55)),
((1, 1, 33333, 1), (55, 47, 55, 5)),
((55, 47, 55, 5), (1, 1, 33333, 1)))
appliers = (
MultiTensorApply(2048*32),
MultiTensorApply(333),
MultiTensorApply(33333))
repeat_tensors = (
1,
55)
for sizea, sizeb in input_size_pairs:
for applier in appliers:
for repeat in repeat_tensors:
for x_type in (torch.float32, torch.float16):
for y_type in (torch.float32, torch.float16):
for out_type in (torch.float32, torch.float16):
for inplace in (True, False):
if inplace is True and (y_type is not out_type):
continue
else:
self.axpby(sizea, sizeb, applier, repeat,
x_type, y_type, out_type, inplace=inplace, nhwc=True)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 0, 0, float('nan'), inplace=inplace)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment