Commit 8a32e428 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merging in master

parents d9c887c2 18f2eaee
...@@ -75,7 +75,9 @@ if "--cuda_ext" in sys.argv: ...@@ -75,7 +75,9 @@ if "--cuda_ext" in sys.argv:
'csrc/multi_tensor_sgd_kernel.cu', 'csrc/multi_tensor_sgd_kernel.cu',
'csrc/multi_tensor_scale_kernel.cu', 'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu', 'csrc/multi_tensor_axpby_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu'], 'csrc/multi_tensor_l2norm_kernel.cu',
'csrc/multi_tensor_lamb_stage_1.cu',
'csrc/multi_tensor_lamb_stage_2.cu'],
extra_compile_args={'cxx': ['-O3'], extra_compile_args={'cxx': ['-O3'],
'nvcc':['-lineinfo', 'nvcc':['-lineinfo',
'-O3', '-O3',
......
...@@ -32,7 +32,7 @@ class TestMultiTensorL2Norm(unittest.TestCase): ...@@ -32,7 +32,7 @@ class TestMultiTensorL2Norm(unittest.TestCase):
pass pass
# The tensor creation here is written for convenience, not speed. # The tensor creation here is written for convenience, not speed.
def l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type): def l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type, per_tensor):
self.overflow_buf.zero_() self.overflow_buf.zero_()
a = torch.cuda.FloatTensor(sizea).fill_(self.val) a = torch.cuda.FloatTensor(sizea).fill_(self.val)
b = torch.cuda.FloatTensor(sizeb).fill_(self.val) b = torch.cuda.FloatTensor(sizeb).fill_(self.val)
...@@ -41,12 +41,18 @@ class TestMultiTensorL2Norm(unittest.TestCase): ...@@ -41,12 +41,18 @@ class TestMultiTensorL2Norm(unittest.TestCase):
for i in range(repeat_tensors): for i in range(repeat_tensors):
in_list += [a.clone().to(in_type), b.clone().to(in_type)] in_list += [a.clone().to(in_type), b.clone().to(in_type)]
if per_tensor:
norm = applier(multi_tensor_l2norm, self.overflow_buf, [in_list]) norm, norm_per_tensor = applier(multi_tensor_l2norm, self.overflow_buf, [in_list], True)
normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
norm, _ = applier(multi_tensor_l2norm, self.overflow_buf, [in_list], True)
reference = torch.cuda.FloatTensor((sizea + sizeb)*repeat_tensors).fill_(self.val).norm() reference = torch.cuda.FloatTensor((sizea + sizeb)*repeat_tensors).fill_(self.val).norm()
self.assertTrue(torch.allclose(norm, reference)) self.assertTrue(torch.allclose(norm, reference))
if per_tensor:
self.assertTrue(torch.allclose(norm_per_tensor, normab))
self.assertTrue(self.overflow_buf.item() == 0) self.assertTrue(self.overflow_buf.item() == 0)
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
...@@ -72,7 +78,8 @@ class TestMultiTensorL2Norm(unittest.TestCase): ...@@ -72,7 +78,8 @@ class TestMultiTensorL2Norm(unittest.TestCase):
for applier in appliers: for applier in appliers:
for repeat in repeat_tensors: for repeat in repeat_tensors:
for in_type in (torch.float32, torch.float16): for in_type in (torch.float32, torch.float16):
self.l2norm(sizea, sizeb, applier, repeat, in_type, ) for per_tensor in (False, True):
self.l2norm(sizea, sizeb, applier, repeat, in_type, per_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment