Unverified Commit 9ebc53e5 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Consider both contiguous and channels_last tensors for FusedSGD (#97)

* Consider both contiguous and channel_last tensors for FusedSGD

* Consider all the memory formats in fused_sgd

* Add an unit test script for nhwc fused_sgd
parent 719215bd
......@@ -175,15 +175,33 @@ class FusedSGD(Optimizer):
if self.materialize_master_grads:
fp16_model_params = [p for i, p in enumerate(
stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]
fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_grads = []
for p in fp32_from_fp16_params:
if p.is_contiguous(memory_format=torch.contiguous_format):
fp32_from_fp16_grads.append(p.grad)
elif p.is_contiguous(memory_format=torch.channels_last):
fp32_from_fp16_grads.append(p.grad.to(memory_format=torch.channels_last))
elif p.is_contiguous(memory_format=torch.channel_last_3d):
fp32_from_fp16_grads.append(p.grad.to(memory_format=torch.channel_last_3d))
else:
assert(False), "Unsupported memory format. Supports only contiguous_format, channels_last, or channel_last_3d."
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
else:
fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = []
for p in fp16_model_params:
if p.is_contiguous(memory_format=torch.contiguous_format):
fp16_model_grads.append(p.grad)
elif p.is_contiguous(memory_format=torch.channels_last):
fp16_model_grads.append(p.grad.to(memory_format=torch.channels_last))
elif p.is_contiguous(memory_format=torch.channel_last_3d):
fp16_model_grads.append(p.grad.to(memory_format=torch.channel_last_3d))
else:
assert(False), "Unsupported memory format. Supports only contiguous_format, channels_last, or channel_last_3d."
fp32_from_fp16_params = [p for i, p in enumerate(
stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
......@@ -194,11 +212,29 @@ class FusedSGD(Optimizer):
launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]
else:
fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_grads = []
for p in fp16_params:
if p.is_contiguous(memory_format=torch.contiguous_format):
fp16_grads.append(p.grad)
elif p.is_contiguous(memory_format=torch.channels_last):
fp16_grads.append(p.grad.to(memory_format=torch.channels_last))
elif p.is_contiguous(memory_format=torch.channel_last_3d):
fp16_grads.append(p.grad.to(memory_format=torch.channel_last_3d))
else:
assert(False), "Unsupported memory format. Supports only contiguous_format, channels_last, or channel_last_3d."
fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)
fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_grads = []
for p in fp32_params:
if p.is_contiguous(memory_format=torch.contiguous_format):
fp32_grads.append(p.grad)
elif p.is_contiguous(memory_format=torch.channels_last):
fp32_grads.append(p.grad.to(memory_format=torch.channels_last))
elif p.is_contiguous(memory_format=torch.channel_last_3d):
fp32_grads.append(p.grad.to(memory_format=torch.channel_last_3d))
else:
assert(False), "Unsupported memory format. Supports only contiguous_format, channels_last, or channel_last_3d."
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
......@@ -208,6 +244,7 @@ class FusedSGD(Optimizer):
assert len(launch_set[0]) == len(launch_set[1])
assert len(launch_set[0]) == len(launch_set[2])
if len(launch_set[0]) > 0:
# multi_tensor_applier has nhwc support: https://github.com/NVIDIA/apex/pull/732
multi_tensor_applier(
self.multi_tensor_sgd,
self._dummy_overflow_buf,
......
from itertools import product
import random
import unittest
import torch
import apex
# NHWC
class TestFusedOptimizerChannelsLast(unittest.TestCase):
def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
self.max_abs_diff = max_abs_diff
self.max_rel_diff = max_rel_diff
self.iters = iters
torch.manual_seed(9876)
def tearDown(self):
pass
def gen_param_optim(self, tensors, options, device, tst_options=None):
# Adding this to make backward compatible with existing tests. Just in
# case "tst_options" are not provided, it gets a copy of options
# which contains the parameters for the reference optimizer
if tst_options == None:
tst_options = options
ref_param = []
tst_param = []
for tensor in tensors:
input = tensor.clone().contiguous(memory_format=torch.channels_last).to(device) # channels_last
ref_input = tensor.clone().contiguous().to(device)
self.assertTrue(input.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(ref_input.is_contiguous(memory_format=torch.contiguous_format))
tst_param.append(torch.nn.Parameter(input))
ref_param.append(torch.nn.Parameter(ref_input))
ref_optim = self.ref_optim(ref_param, **options)
tst_optim = self.fused_optim(tst_param, **tst_options)
return (ref_param, tst_param, ref_optim, tst_optim)
def gen_grad(self, ref_param, tst_param):
for p_ref, p_tst in zip(ref_param, tst_param):
p_ref.grad = torch.rand_like(p_ref)
p_tst.grad = p_ref.grad.clone() #### p_tst is =torch.channels_last but p_tst.grad is torch.contiguous_format
self.assertTrue(p_tst.grad.is_contiguous(memory_format=torch.contiguous_format))
self.assertTrue(p_ref.grad.is_contiguous(memory_format=torch.contiguous_format))
def get_max_diff(self, ref_param, tst_param):
max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param):
self.assertTrue(p_ref.is_contiguous(memory_format=torch.contiguous_format))
self.assertTrue(p_tst.is_contiguous(memory_format=torch.channels_last))
max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p
if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p
return max_abs_diff, max_rel_diff
def gen_single_type_test(self, param_type=torch.float, device='cuda', *, skip_assert: bool = False):
# nelem = 278011
# Some ref and test optimizers may require different set of options.
# This is a quick workaround to add that functionality while making
# minimum changes in existing code.
# If there is no "tst_options" field provided, safe to initialize
# the test optimizer with the parameters of reference optimizer.
if not hasattr(self, 'tst_options'):
self.tst_options = self.options
tensor = torch.rand([3,4,2,3], dtype=param_type, device=device)
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], self.options, device, self.tst_options)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
if skip_assert:
return
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
class TestFusedSGDChannelLast(TestFusedOptimizerChannelsLast):
def __init__(self, *args, **kwargs):
super(TestFusedSGDChannelLast, self).__init__(*args, **kwargs)
self.options = {"lr": .25, "momentum": .125}
self.ref_optim = torch.optim.SGD
self.fused_optim = apex.optimizers.FusedSGD
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment