Unverified Commit 274cc063 authored by ngimel's avatar ngimel Committed by GitHub
Browse files

set device guard for multi tensor optimizer implementations (#927)

* add device guards to the optimizers

* add untracked file

* set deviceGuard in multi_tensor_apply

* address review comments; fix lamb

* indent

* typo
parent 5b53121a
...@@ -76,7 +76,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -76,7 +76,7 @@ class FusedLAMB(torch.optim.Optimizer):
import amp_C import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_lamb = amp_C.multi_tensor_lamb self.multi_tensor_lamb = amp_C.multi_tensor_lamb
else: else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions') raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
...@@ -117,7 +117,8 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -117,7 +117,8 @@ class FusedLAMB(torch.optim.Optimizer):
else: else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.') raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda') device = self.param_groups[0]["params"][0].device
g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
# compute grad norm for two lists # compute grad norm for two lists
if len(g_all_32) > 0: if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
......
...@@ -98,7 +98,7 @@ class FusedSGD(Optimizer): ...@@ -98,7 +98,7 @@ class FusedSGD(Optimizer):
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_sgd = amp_C.multi_tensor_sgd self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else: else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions') raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include "compat.h" #include "compat.h"
#include <assert.h> #include <assert.h>
...@@ -49,8 +50,9 @@ void multi_tensor_apply( ...@@ -49,8 +50,9 @@ void multi_tensor_apply(
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size(); int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
for(int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
{ {
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++) for(int t = 0; t < tensor_lists[l].size(); t++)
...@@ -61,7 +63,7 @@ void multi_tensor_apply( ...@@ -61,7 +63,7 @@ void multi_tensor_apply(
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif #endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda."); TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
} }
} }
...@@ -70,6 +72,7 @@ void multi_tensor_apply( ...@@ -70,6 +72,7 @@ void multi_tensor_apply(
TensorListMetadata<depth> tl; TensorListMetadata<depth> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0; tl.start_tensor_this_launch = 0;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility: // Another possibility:
// #include <torch/all.h> // #include <torch/all.h>
...@@ -335,13 +336,13 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( ...@@ -335,13 +336,13 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
max_chunks_per_tensor);) max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize()); // AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to end. // This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence // I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now // logic, but keeping it simple for now
auto ret = at::empty({1}, output.options()); auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>( cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.DATA_PTR<float>(), output.DATA_PTR<float>(),
...@@ -369,7 +370,7 @@ void multi_tensor_norm_out_cuda( ...@@ -369,7 +370,7 @@ void multi_tensor_norm_out_cuda(
const int norm_type) const int norm_type)
{ {
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), "noop flag should be on the same device as tensors");
// we don't need global thus uses empty here // we don't need global thus uses empty here
auto output = at::empty({320}, float_options); auto output = at::empty({320}, float_options);
......
...@@ -160,6 +160,8 @@ void multi_tensor_sgd_cuda( ...@@ -160,6 +160,8 @@ void multi_tensor_sgd_cuda(
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16."); "Additional output tensors should always be fp16.");
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of // We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy // grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No // 1. fp16, fp16, fp16, No
......
import unittest
import apex
import torch
class TestFusedAdagrad(unittest.TestCase):
def setUp(self, max_abs_diff=1e-6, max_rel_diff=1, iters=7):
self.max_abs_diff = max_abs_diff
self.max_rel_diff = max_rel_diff
self.iters = iters
torch.cuda.manual_seed(9876)
def tearDown(self):
pass
def gen_param_optim(self, tensors, adagrad_option):
ref_param = []
tst_param = []
for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = torch.optim.Adagrad(ref_param, **adagrad_option)
tst_optim = apex.optimizers.FusedAdagrad(tst_param, **adagrad_option)
return (ref_param, tst_param, ref_optim, tst_optim)
def gen_grad(self, ref_param, tst_param):
for p_ref, p_tst in zip(ref_param, tst_param):
p_ref.grad = torch.rand_like(p_ref)
p_tst.grad = p_ref.grad
def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
half_grads = []
for p_ref, _ in zip(ref_param, tst_param):
half_grads.append(torch.rand_like(p_ref).half())
p_ref.grad = half_grads[-1].float() / scale
return half_grads
def get_max_diff(self, ref_param, tst_param):
max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param):
max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
if max_abs_diff_p > max_abs_diff:
max_abs_diff = max_abs_diff_p
if max_rel_diff_p > max_rel_diff:
max_rel_diff = max_rel_diff_p
return max_abs_diff, max_rel_diff
def gen_single_type_test(self, param_type=torch.float):
nelem = 278011
adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 1.0e-5}
tensor = torch.rand(nelem, dtype=param_type, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], adagrad_option
)
for _ in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
@unittest.skip("PyTorch optimizer is not numerically correct for fp16")
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0}
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
tensors, adagrad_option
)
for _ in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_adagrad_option(self):
nelem = 1
adagrad_option = {"lr": 0.01, "eps": 3e-06, "weight_decay": 0}
tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], adagrad_option
)
for _ in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
...@@ -4,8 +4,9 @@ import random ...@@ -4,8 +4,9 @@ import random
import torch import torch
import apex import apex
from itertools import product
class TestFusedAdam(unittest.TestCase): class TestFusedOptimizer(unittest.TestCase):
def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7): def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
self.max_abs_diff = max_abs_diff self.max_abs_diff = max_abs_diff
self.max_rel_diff = max_rel_diff self.max_rel_diff = max_rel_diff
...@@ -15,15 +16,15 @@ class TestFusedAdam(unittest.TestCase): ...@@ -15,15 +16,15 @@ class TestFusedAdam(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def gen_param_optim(self, tensors, adam_option): def gen_param_optim(self, tensors, options):
ref_param = [] ref_param = []
tst_param = [] tst_param = []
for tensor in tensors: for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone())) ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone())) tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = torch.optim.Adam(ref_param, **adam_option) ref_optim = self.ref_optim(ref_param, **options)
tst_optim = apex.optimizers.FusedAdam(tst_param, **adam_option) tst_optim = self.fused_optim(tst_param, **options)
return (ref_param, tst_param, ref_optim, tst_optim) return (ref_param, tst_param, ref_optim, tst_optim)
...@@ -50,41 +51,54 @@ class TestFusedAdam(unittest.TestCase): ...@@ -50,41 +51,54 @@ class TestFusedAdam(unittest.TestCase):
return max_abs_diff, max_rel_diff return max_abs_diff, max_rel_diff
def gen_single_type_test(self, param_type=torch.float): def gen_single_type_test(self, param_type=torch.float, device='cuda'):
nelem = 278011 nelem = 278011
adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay':0, 'amsgrad':False}
tensor = torch.rand(nelem, dtype=param_type, device='cuda') tensor = torch.rand(nelem, dtype=param_type, device=device)
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], adam_option) self.gen_param_optim([tensor], self.options)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_param, tst_param) self.gen_grad(ref_param, tst_param)
ref_optim.step() ref_optim.step()
tst_optim.step() tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff) self.assertLessEqual(max_rel_diff, self.max_rel_diff)
class TestFusedAdam(TestFusedOptimizer):
def __init__(self, *args, **kwargs):
super(TestFusedAdam, self).__init__(*args, **kwargs)
self.options = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay': 0, 'amsgrad': False}
self.ref_optim = torch.optim.Adam
self.fused_optim = apex.optimizers.FusedAdam
def test_float(self): def test_float(self):
self.gen_single_type_test(param_type=torch.float) self.gen_single_type_test(param_type=torch.float)
def test_half(self): def test_half(self):
self.gen_single_type_test(param_type=torch.float16) self.gen_single_type_test(param_type=torch.float16)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
@unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked') @unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked')
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay':0, 'amsgrad':False}
tensors = [] tensors = []
for size in sizes: for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device='cuda')) tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim(tensors, adam_option) self.gen_param_optim(tensors, self.options)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_param, tst_param) self.gen_grad(ref_param, tst_param)
...@@ -97,12 +111,9 @@ class TestFusedAdam(unittest.TestCase): ...@@ -97,12 +111,9 @@ class TestFusedAdam(unittest.TestCase):
@unittest.skip('No longer support fuse scaling') @unittest.skip('No longer support fuse scaling')
def test_scale(self): def test_scale(self):
nelem = 278011 nelem = 278011
adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay':0, 'amsgrad':False}
tensor = torch.rand(nelem, dtype=torch.float, device='cuda') tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], adam_option) self.gen_param_optim([tensor], self.options)
for i in range(self.iters): for i in range(self.iters):
scale = random.random() * 1000 scale = random.random() * 1000
...@@ -117,12 +128,10 @@ class TestFusedAdam(unittest.TestCase): ...@@ -117,12 +128,10 @@ class TestFusedAdam(unittest.TestCase):
@unittest.skip('No longer support output fp16 param') @unittest.skip('No longer support output fp16 param')
def test_fp16_output(self): def test_fp16_output(self):
nelem = 278011 nelem = 278011
adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay':0, 'amsgrad':False}
tensor = torch.rand(nelem, dtype=torch.float, device='cuda') tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], adam_option) self.gen_param_optim([tensor], self.options)
fp16_param = torch.nn.Parameter(tensor.clone().half()) fp16_param = torch.nn.Parameter(tensor.clone().half())
...@@ -159,6 +168,103 @@ class TestFusedAdam(unittest.TestCase): ...@@ -159,6 +168,103 @@ class TestFusedAdam(unittest.TestCase):
self.assertLessEqual(max_rel_diff, self.max_rel_diff) self.assertLessEqual(max_rel_diff, self.max_rel_diff)
class TestFusedAdagrad(TestFusedOptimizer):
def __init__(self, *args, **kwargs):
super(TestFusedAdagrad, self).__init__(*args, **kwargs)
self.options = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 1.0e-5}
self.ref_optim = torch.optim.Adagrad
self.fused_optim = apex.optimizers.FusedAdagrad
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
@unittest.skip("PyTorch optimizer is not numerically correct for fp16")
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0}
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
tensors, adagrad_option
)
for _ in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
def test_multi_params_different_devices_throws(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0}
tensors = []
for i, size in enumerate(sizes):
tensors.append(torch.rand(size, dtype=torch.float, device="cuda:"+str(i % 2)))
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
tensors, adagrad_option
)
self.gen_grad(ref_param, tst_param)
with self.assertRaisesRegex(RuntimeError, "not on the same device"):
tst_optim.step()
def test_adagrad_option(self):
nelem = 1
adagrad_option = {"lr": 0.01, "eps": 3e-06, "weight_decay": 0}
tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], adagrad_option
)
for _ in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
class TestFusedSGD(TestFusedOptimizer):
def __init__(self, *args, **kwargs):
super(TestFusedSGD, self).__init__(*args, **kwargs)
self.options = {"lr": .25, "momentum": .125}
self.ref_optim = torch.optim.SGD
self.fused_optim = apex.optimizers.FusedSGD
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
if __name__ == '__main__': if __name__ == '__main__':
script_path = os.path.dirname(os.path.realpath(__file__))
unittest.main() unittest.main()
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from torch.optim import Optimizer from torch.optim import Optimizer
import apex import apex
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
from itertools import product
class RefLAMB(Optimizer): class RefLAMB(Optimizer):
r"""Implements Lamb algorithm. r"""Implements Lamb algorithm.
...@@ -40,7 +41,7 @@ class RefLAMB(Optimizer): ...@@ -40,7 +41,7 @@ class RefLAMB(Optimizer):
import amp_C import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_lamb = amp_C.multi_tensor_lamb self.multi_tensor_lamb = amp_C.multi_tensor_lamb
else: else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions') raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
...@@ -68,7 +69,8 @@ class RefLAMB(Optimizer): ...@@ -68,7 +69,8 @@ class RefLAMB(Optimizer):
else: else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.') raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda') device = self.param_groups[0]["params"][0].device
g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
# compute grad norm for two lists # compute grad norm for two lists
if len(g_all_32) > 0: if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
...@@ -188,9 +190,9 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -188,9 +190,9 @@ class TestFusedLAMB(unittest.TestCase):
return max_abs_diff, max_rel_diff return max_abs_diff, max_rel_diff
def gen_single_type_test(self, param_type=torch.float): def gen_single_type_test(self, param_type=torch.float, device="cuda"):
nelem = 278011 nelem = 278011
tensor = torch.rand(nelem, dtype=param_type, device='cuda') tensor = torch.rand(nelem, dtype=param_type, device=device)
weight_decay = [0, 0.01] weight_decay = [0, 0.01]
for wd in weight_decay: for wd in weight_decay:
...@@ -201,7 +203,9 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -201,7 +203,9 @@ class TestFusedLAMB(unittest.TestCase):
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_param, tst_param) self.gen_grad(ref_param, tst_param)
ref_optim.step() ref_optim.step()
torch.cuda.synchronize()
tst_optim.step() tst_optim.step()
torch.cuda.synchronize()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
...@@ -214,6 +218,13 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -214,6 +218,13 @@ class TestFusedLAMB(unittest.TestCase):
def test_half(self): def test_half(self):
self.gen_single_type_test(param_type=torch.float16) self.gen_single_type_test(param_type=torch.float16)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
weight_decay = [0, 0.01] weight_decay = [0, 0.01]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment