Commit 6f7a8b39 authored by lcskrishna's avatar lcskrishna
Browse files

Merge remote-tracking branch 'rocm_upstream/master' into ifu_07272020

parents 459de22d 9c80f6d3
...@@ -105,6 +105,66 @@ ...@@ -105,6 +105,66 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
// TODO: We might have come up with an optimal set of dispatch macros by
// changing the signature to have an integer suffix of number of types
// to dispatch for as defined in upstream (e.g AT_DISPATCH_FLOATING_TYPES_AND2)
// Refactor once all the extension ops are enabled.
#define DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T> template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes __device__ __forceinline__ T reduce_block_into_lanes
...@@ -141,8 +201,13 @@ __device__ __forceinline__ T reduce_block_into_lanes ...@@ -141,8 +201,13 @@ __device__ __forceinline__ T reduce_block_into_lanes
// __SYNCWARP(); // __SYNCWARP();
#pragma unroll #pragma unroll
for(int i = 16; i >= lanes; i >>= 1) for(int i = 16; i >= lanes; i >>= 1) {
#ifdef __HIP_PLATFORM_HCC__
final = final + __shfl_down(0xffffffff, final, i);
#else
final = final + __shfl_down_sync(0xffffffff, final, i); final = final + __shfl_down_sync(0xffffffff, final, i);
#endif
}
} }
if(share_result) if(share_result)
...@@ -191,8 +256,13 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op ...@@ -191,8 +256,13 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op
// __SYNCWARP(); // __SYNCWARP();
#pragma unroll #pragma unroll
for(int i = 16; i >= lanes; i >>= 1) for(int i = 16; i >= lanes; i >>= 1) {
#ifdef __HIP_PLATFORM_HCC__
final = fmaxf(fabsf(final), fabsf(__shfl_down(0xffffffff, final, i)));
#else
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
#endif
}
} }
if(share_result) if(share_result)
......
...@@ -11,6 +11,11 @@ ...@@ -11,6 +11,11 @@
#include "type_shim.h" #include "type_shim.h"
#include "compat.h" #include "compat.h"
#if defined __HIP_PLATFORM_HCC__
#define SHFL_DOWN __shfl_down
#else
#define SHFL_DOWN __shfl_down_sync
#endif
__device__ __forceinline__ int lastpow2(int n) __device__ __forceinline__ int lastpow2(int n)
{ {
...@@ -47,7 +52,7 @@ __device__ __forceinline__ T warp_reduce_sum(T val) ...@@ -47,7 +52,7 @@ __device__ __forceinline__ T warp_reduce_sum(T val)
{ {
#pragma unroll #pragma unroll
for(int i = WARP_SIZE/2; i > 0; i >>= 1) for(int i = WARP_SIZE/2; i > 0; i >>= 1)
val = val + __shfl_down_sync(0xffffffff, val, i); val = val + SHFL_DOWN(0xffffffff, val, i);
return val; return val;
} }
...@@ -129,10 +134,14 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num) ...@@ -129,10 +134,14 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
{ {
#pragma unroll #pragma unroll
for(int i = WARP_SIZE/2; i > 0; i >>= 1) { for(int i = WARP_SIZE/2; i > 0; i >>= 1) {
auto num_new = __shfl_down_sync(0xffffffff, num, i); auto num_new = SHFL_DOWN(0xffffffff, num, i);
auto mean_new = __shfl_down_sync(0xffffffff, mean, i); auto mean_new = SHFL_DOWN(0xffffffff, mean, i);
auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i); auto m2n_new = SHFL_DOWN(0xffffffff, m2n, i);
#if defined __HIP_PLATFORM_HCC__
welford_merge_element<T, int>(num, mean, m2n, num_new, mean_new, m2n_new);
#else
welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new); welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new);
#endif
} }
} }
......
...@@ -6,6 +6,8 @@ import sys ...@@ -6,6 +6,8 @@ import sys
import warnings import warnings
import os import os
from torch.utils.hipify import hipify_python
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
...@@ -124,51 +126,104 @@ if "--cuda_ext" in sys.argv: ...@@ -124,51 +126,104 @@ if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--cuda_ext") sys.argv.remove("--cuda_ext")
if torch.utils.cpp_extension.CUDA_HOME is None: is_rocm_pytorch = False
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME) if not is_rocm_pytorch:
check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)
ext_modules.append(
CUDAExtension(name='amp_C', if is_rocm_pytorch:
sources=['csrc/amp_C_frontend.cpp', import shutil
'csrc/multi_tensor_sgd_kernel.cu', with hipify_python.GeneratedFileCleaner(keep_intermediates=True) as clean_ctx:
'csrc/multi_tensor_scale_kernel.cu', hipify_python.hipify(project_directory=this_dir, output_directory=this_dir, includes="csrc/*",
'csrc/multi_tensor_axpby_kernel.cu', show_detailed=True, is_pytorch_extension=True, clean_ctx=clean_ctx)
'csrc/multi_tensor_l2norm_kernel.cu', shutil.copy("csrc/compat.h", "csrc/hip/compat.h")
'csrc/multi_tensor_lamb_stage_1.cu', shutil.copy("csrc/type_shim.h", "csrc/hip/type_shim.h")
'csrc/multi_tensor_lamb_stage_2.cu',
'csrc/multi_tensor_adam.cu', if not is_rocm_pytorch:
'csrc/multi_tensor_adagrad.cu', ext_modules.append(
'csrc/multi_tensor_novograd.cu', CUDAExtension(name='amp_C',
'csrc/multi_tensor_lamb.cu'], sources=['csrc/amp_C_frontend.cpp',
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, 'csrc/multi_tensor_sgd_kernel.cu',
'nvcc':['-lineinfo', 'csrc/multi_tensor_scale_kernel.cu',
'-O3', 'csrc/multi_tensor_axpby_kernel.cu',
# '--resource-usage', 'csrc/multi_tensor_l2norm_kernel.cu',
'--use_fast_math'] + version_dependent_macros})) 'csrc/multi_tensor_lamb_stage_1.cu',
ext_modules.append( 'csrc/multi_tensor_lamb_stage_2.cu',
CUDAExtension(name='syncbn', 'csrc/multi_tensor_adam.cu',
sources=['csrc/syncbn.cpp', 'csrc/multi_tensor_adagrad.cu',
'csrc/welford.cu'], 'csrc/multi_tensor_novograd.cu',
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, 'csrc/multi_tensor_lamb.cu'],
'nvcc':['-O3'] + version_dependent_macros})) extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-lineinfo',
ext_modules.append( '-O3',
CUDAExtension(name='fused_layer_norm_cuda', # '--resource-usage',
sources=['csrc/layer_norm_cuda.cpp', '--use_fast_math'] + version_dependent_macros}))
'csrc/layer_norm_cuda_kernel.cu'], else:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, print ("INFO: Building Multitensor apply extension")
'nvcc':['-maxrregcount=50', ext_modules.append(
'-O3', CUDAExtension(name='amp_C',
'--use_fast_math'] + version_dependent_macros})) sources=['csrc/amp_C_frontend.cpp',
'csrc/hip/multi_tensor_sgd_kernel.hip',
ext_modules.append( 'csrc/hip/multi_tensor_scale_kernel.hip',
CUDAExtension(name='mlp_cuda', 'csrc/hip/multi_tensor_axpby_kernel.hip',
sources=['csrc/mlp.cpp', 'csrc/hip/multi_tensor_l2norm_kernel.hip',
'csrc/mlp_cuda.cu'], 'csrc/hip/multi_tensor_lamb_stage_1.hip',
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, 'csrc/hip/multi_tensor_lamb_stage_2.hip',
'nvcc':['-O3'] + version_dependent_macros})) 'csrc/hip/multi_tensor_adam.hip',
'csrc/hip/multi_tensor_adagrad.hip',
'csrc/hip/multi_tensor_novograd.hip',
'csrc/hip/multi_tensor_lamb.hip'],
extra_compile_args=['-O3'] + version_dependent_macros))
if not is_rocm_pytorch:
ext_modules.append(
CUDAExtension(name='syncbn',
sources=['csrc/syncbn.cpp',
'csrc/welford.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
else:
print ("INFO: Building syncbn extension.")
ext_modules.append(
CUDAExtension(name='syncbn',
sources=['csrc/syncbn.cpp',
'csrc/hip/welford.hip'],
extra_compile_args=['-O3'] + version_dependent_macros))
if not is_rocm_pytorch:
ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda',
sources=['csrc/layer_norm_cuda.cpp',
'csrc/layer_norm_cuda_kernel.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-maxrregcount=50',
'-O3',
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Building FusedLayerNorm extension.")
ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda',
sources=['csrc/layer_norm_cuda.cpp',
'csrc/hip/layer_norm_hip_kernel.hip'],
extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros,
'nvcc' : []}))
if not is_rocm_pytorch:
ext_modules.append(
CUDAExtension(name='mlp_cuda',
sources=['csrc/mlp.cpp',
'csrc/mlp_cuda.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
else:
print ("INFO: Skipping MLP extension")
if "--bnp" in sys.argv: if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
...@@ -250,7 +305,7 @@ if "--deprecated_fused_lamb" in sys.argv: ...@@ -250,7 +305,7 @@ if "--deprecated_fused_lamb" in sys.argv:
'nvcc':['-O3', 'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros})) '--use_fast_math'] + version_dependent_macros}))
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026 # Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
generator_flag = [] generator_flag = []
torch_dir = torch.__path__[0] torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')): if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
......
...@@ -14,11 +14,11 @@ from utils import common_init, HALF, FLOAT,\ ...@@ -14,11 +14,11 @@ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
class MyModel(torch.nn.Module): class MyModel(torch.nn.Module):
def __init__(self, unique): def __init__(self, unique, dtype=torch.float16):
super(MyModel, self).__init__() super(MyModel, self).__init__()
self.weight0 = Parameter(unique + self.weight0 = Parameter(unique +
torch.arange(2, device='cuda', dtype=torch.float32)) torch.arange(2, device='cuda', dtype=torch.float32))
self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16)) self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=dtype))
@staticmethod @staticmethod
def ops(input, weight0, weight1): def ops(input, weight0, weight1):
...@@ -51,11 +51,15 @@ class TestAddParamGroup(unittest.TestCase): ...@@ -51,11 +51,15 @@ class TestAddParamGroup(unittest.TestCase):
optimizer.zero_grad() optimizer.zero_grad()
def test_add_param_group(self): def test_add_param_group(self):
for opt_level in ("O0", "O1", "O2", "O3"): for opt_level in ("O0", "O1", "O2", "O3", "O4", "O5"):
for zero_before_add in (True, False): for zero_before_add in (True, False):
for try_accumulation in (True, False): for try_accumulation in (True, False):
model0 = MyModel(1) if opt_level in {"O4", "O5"}:
model1 = MyModel(2) model0 = MyModel(1, torch.bfloat16)
model1 = MyModel(2, torch.bfloat16)
else:
model0 = MyModel(1)
model1 = MyModel(2)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125) momentum=0.125)
...@@ -89,8 +93,12 @@ class TestAddParamGroup(unittest.TestCase): ...@@ -89,8 +93,12 @@ class TestAddParamGroup(unittest.TestCase):
[param.data.clone() for param in model1.parameters()] [param.data.clone() for param in model1.parameters()]
for how_to_zero in "none", "model", "optimizer": for how_to_zero in "none", "model", "optimizer":
model0 = MyModel(1) if opt_level in {"O4", "O5"}:
model1 = MyModel(2) model0 = MyModel(1, torch.bfloat16)
model1 = MyModel(2, torch.bfloat16)
else:
model0 = MyModel(1)
model1 = MyModel(2)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125) momentum=0.125)
...@@ -139,6 +147,9 @@ class TestAddParamGroup(unittest.TestCase): ...@@ -139,6 +147,9 @@ class TestAddParamGroup(unittest.TestCase):
[param.data.clone() for param in model1.parameters()] [param.data.clone() for param in model1.parameters()]
for reference, final in zip(reference_params, final_params): for reference, final in zip(reference_params, final_params):
# TODO: remove the conversion once allclose supports bfloat16 type.
if final.dtype == torch.bfloat16:
final = final.float()
self.assertTrue(torch.allclose(reference.to(final.dtype), final), self.assertTrue(torch.allclose(reference.to(final.dtype), final),
"opt_level = {}, how_to_zero = {}, zero_before_add = {}".format( "opt_level = {}, how_to_zero = {}, zero_before_add = {}".format(
opt_level, how_to_zero, zero_before_add)) opt_level, how_to_zero, zero_before_add))
......
...@@ -9,7 +9,9 @@ from torch import nn ...@@ -9,7 +9,9 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT ALWAYS_HALF, ALWAYS_BFLOAT16, ALWAYS_FLOAT, MATCH_INPUT
from apex.testing.common_utils import skipIfRocm
def run_layer_test(test_case, fns, expected, input_shape, test_backward=True): def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
for fn, typ in it.product(fns, expected.keys()): for fn, typ in it.product(fns, expected.keys()):
...@@ -20,124 +22,237 @@ def run_layer_test(test_case, fns, expected, input_shape, test_backward=True): ...@@ -20,124 +22,237 @@ def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
y.float().sum().backward() y.float().sum().backward()
test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ]) test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])
class TestBasicCasts(unittest.TestCase): class _TestBasicCasts(unittest.TestCase):
def setUp(self): def _test_linear(self, expected):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_linear_is_half(self):
m = nn.Linear(self.h, self.h) m = nn.Linear(self.h, self.h)
f = ft.partial(F.linear, weight=m.weight, bias=m.bias) f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h)) run_layer_test(self, [m, f], expected, (self.b, self.h))
def test_conv2d_is_half(self): def _test_conv2d(self, expected):
m = nn.Conv2d(self.c, self.c, self.k) m = nn.Conv2d(self.c, self.c, self.k)
f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias) f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h)) run_layer_test(self, [m, f], expected, (self.b, self.c, self.h, self.h))
def test_softmax_is_float(self): def _test_softmax(self, expected):
m = nn.Softmax(dim=1) m = nn.Softmax(dim=1)
f = ft.partial(F.softmax, dim=1) f = ft.partial(F.softmax, dim=1)
run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h)) run_layer_test(self, [m, f], expected, (self.b, self.h))
def test_group_norm_is_float(self): def _test_group_norm(self, expected):
m = nn.GroupNorm(num_groups=4, num_channels=self.c) m = nn.GroupNorm(num_groups=4, num_channels=self.c)
run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h)) run_layer_test(self, [m], expected, (self.b, self.c, self.h, self.h))
def test_mse_loss_is_float(self): def _test_mse_loss(self, expected):
shape = (self.b, self.h) shape = (self.b, self.h)
target = torch.randn(shape) target = torch.randn(shape)
mod = nn.MSELoss() mod = nn.MSELoss()
m = lambda x: mod(x, target) m = lambda x: mod(x, target)
f = ft.partial(F.mse_loss, target=target) f = ft.partial(F.mse_loss, target=target)
run_layer_test(self, [m], ALWAYS_FLOAT, shape) run_layer_test(self, [m], expected, shape)
def test_relu_is_match(self): def _test_relu(self, expected):
run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h)) run_layer_test(self, [nn.ReLU(), F.relu], expected, (self.b, self.h))
def test_batch_norm_is_match(self): def _test_batch_norm(self, expected):
m = nn.BatchNorm2d(num_features=self.c) m = nn.BatchNorm2d(num_features=self.c)
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var, f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
weight=m.weight, bias=m.bias, training=True) weight=m.weight, bias=m.bias, training=True)
run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h)) run_layer_test(self, [m], expected, (self.b, self.c, self.h, self.h))
# Test forward-only for BN inference # Test forward-only for BN inference
m.eval() m.eval()
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var, f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
weight=m.weight, bias=m.bias, training=False) weight=m.weight, bias=m.bias, training=False)
run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h), run_layer_test(self, [m, f], expected, (self.b, self.c, self.h, self.h),
test_backward=False) test_backward=False)
class TestBasicCastsHalf(_TestBasicCasts):
def setUp(self):
self.handle = amp.init(enabled=True, patch_type=torch.half)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_linear_is_half(self):
self._test_linear(ALWAYS_HALF)
def test_conv2d_is_half(self):
self._test_conv2d(ALWAYS_HALF)
def test_softmax_is_float(self):
self._test_softmax(ALWAYS_FLOAT)
def test_group_norm_is_float(self):
self._test_group_norm(ALWAYS_FLOAT)
def test_mse_loss_is_float(self):
self._test_mse_loss(ALWAYS_FLOAT)
def test_relu_is_match(self):
self._test_relu(MATCH_INPUT)
def test_batch_norm_is_match(self):
self._test_batch_norm(MATCH_INPUT)
class TestBasicCastsBFloat16(_TestBasicCasts):
def setUp(self):
self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
common_init(self)
def tearDown(self):
self.handle._deactivate()
@skipIfRocm
def test_linear_is_bfloat16(self):
self._test_linear(ALWAYS_BFLOAT16)
@skipIfRocm
def test_conv2d_is_bfloat16(self):
self._test_conv2d(ALWAYS_BFLOAT16)
def test_softmax_is_float(self):
self._test_softmax(ALWAYS_FLOAT)
def test_group_norm_is_float(self):
self._test_group_norm(ALWAYS_FLOAT)
def test_mse_loss_is_float(self):
self._test_mse_loss(ALWAYS_FLOAT)
def test_relu_is_match(self):
self._test_relu(MATCH_INPUT)
def test_batch_norm_is_match(self):
self._test_batch_norm(MATCH_INPUT)
class TestBannedMethods(unittest.TestCase): class TestBannedMethods(unittest.TestCase):
def setUp(self): def setUp(self):
self.handle = amp.init(enabled=True) self.handle = amp.init(enabled=True, patch_type=torch.half)
common_init(self) common_init(self)
def tearDown(self): def tearDown(self):
self.handle._deactivate() self.handle._deactivate()
def bce_common(self, assertion): def bce_common(self, assertion, dtype=torch.half):
shape = (self.b, self.h) shape = (self.b, self.h)
target = torch.rand(shape) target = torch.rand(shape)
mod = nn.BCELoss() mod = nn.BCELoss()
m = lambda x: mod(x, target) m = lambda x: mod(x, target)
f = ft.partial(F.binary_cross_entropy, target=target) f = ft.partial(F.binary_cross_entropy, target=target)
for fn in [m, f]: for fn in [m, f]:
x = torch.rand(shape, dtype=torch.half) x = torch.rand(shape, dtype=dtype)
assertion(fn, x) assertion(fn, x)
def test_bce_raises_by_default(self): def test_bce_raises_by_default(self):
assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x) assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x)
self.bce_common(assertion) self.bce_common(assertion, dtype=torch.half)
# handle with bfloat16 as patch_type
self.handle._deactivate()
self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
self.bce_common(assertion, dtype=torch.bfloat16)
def test_bce_is_float_with_allow_banned(self): def test_bce_is_float_with_allow_banned(self):
self.handle._deactivate() self.handle._deactivate()
self.handle = amp.init(enabled=True, allow_banned=True) self.handle = amp.init(enabled=True, allow_banned=True, patch_type=torch.half)
assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT) assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
self.bce_common(assertion) self.bce_common(assertion, dtype=torch.half)
class TestTensorCasts(unittest.TestCase): # handle with bfloat16 as patch_type
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate() self.handle._deactivate()
self.handle = amp.init(enabled=True, allow_banned=True, patch_type=torch.bfloat16)
self.bce_common(assertion, dtype=torch.bfloat16)
def test_matmul_method_is_half(self): class _TestTensorCasts(unittest.TestCase):
def _test_matmul_method(self, expected):
other = torch.randn(self.h, self.h) other = torch.randn(self.h, self.h)
lhs = lambda x: x.matmul(other) lhs = lambda x: x.matmul(other)
rhs = lambda x: other.matmul(x) rhs = lambda x: other.matmul(x)
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h)) run_layer_test(self, [lhs, rhs], expected, (self.h, self.h))
def test_matmul_op_is_half(self): def _test_matmul_op(self, expected):
other = torch.randn(self.h, self.h) other = torch.randn(self.h, self.h)
lhs = lambda x: x @ other lhs = lambda x: x @ other
rhs = lambda x: other @ x rhs = lambda x: other @ x
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h)) run_layer_test(self, [lhs, rhs], expected, (self.h, self.h))
def test_pow_method_is_float(self): def _test_pow_method(self, expected):
fn = lambda x: x.pow(2.) fn = lambda x: x.pow(2.)
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) run_layer_test(self, [fn], expected, (self.b, self.h))
def test_pow_op_is_float(self): def _test_pow_op(self, expected):
fn = lambda x: x ** 2. fn = lambda x: x ** 2.
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) run_layer_test(self, [fn], expected, (self.b, self.h))
def test_cpu_is_float(self): def _test_cpu(self, expected):
fn = lambda x: x.cpu() fn = lambda x: x.cpu()
run_layer_test(self, [fn], expected, (self.b, self.h))
def _test_sum(self, expected):
fn = lambda x: x.sum()
run_layer_test(self, [fn], expected, (self.b, self.h))
# TODO: maybe more tests on disabled casting?
class TestTensorCastsHalf(_TestTensorCasts):
def setUp(self):
self.handle = amp.init(enabled=True, patch_type=torch.half)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_matmul_method_is_half(self):
self._test_matmul_method(ALWAYS_HALF)
def test_matmul_op_is_half(self):
self._test_matmul_op(ALWAYS_HALF)
def test_pow_method_is_float(self):
self._test_pow_method(ALWAYS_FLOAT)
def test_pow_op_is_float(self):
self._test_pow_op(ALWAYS_FLOAT)
def test_cpu_is_float(self):
always_cpu_float = {torch.float: 'torch.FloatTensor', always_cpu_float = {torch.float: 'torch.FloatTensor',
torch.half: 'torch.FloatTensor'} torch.half: 'torch.FloatTensor'}
run_layer_test(self, [fn], always_cpu_float, (self.b, self.h)) self._test_cpu(always_cpu_float)
def test_sum_is_float(self): def test_sum_is_float(self):
fn = lambda x: x.sum() self._test_sum(ALWAYS_FLOAT)
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
class TestTensorCastsBFloat16(_TestTensorCasts):
def setUp(self):
self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
common_init(self)
def tearDown(self):
self.handle._deactivate()
@skipIfRocm
def test_matmul_method_is_bfloat16(self):
self._test_matmul_method(ALWAYS_BFLOAT16)
@skipIfRocm
def test_matmul_op_is_bfloat16(self):
self._test_matmul_op(ALWAYS_BFLOAT16)
def test_pow_method_is_float(self):
self._test_pow_method(ALWAYS_FLOAT)
def test_pow_op_is_float(self):
self._test_pow_op(ALWAYS_FLOAT)
def test_cpu_is_float(self):
always_cpu_float = {torch.float: 'torch.FloatTensor',
torch.bfloat16: 'torch.FloatTensor'}
self._test_cpu(always_cpu_float)
def test_sum_is_float(self):
self._test_sum(ALWAYS_FLOAT)
# TODO: maybe more tests on disabled casting?
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -67,12 +67,12 @@ class TestCache(unittest.TestCase): ...@@ -67,12 +67,12 @@ class TestCache(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def train_eval_train_test(self, module, t): def train_eval_train_test(self, module, t, opt_level):
model = module(t).cuda() model = module(t).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1.0) optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
_amp_state.allow_incoming_model_not_fp32 = True _amp_state.allow_incoming_model_not_fp32 = True
model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level, verbosity=0)
_amp_state.allow_incoming_model_not_fp32 = False _amp_state.allow_incoming_model_not_fp32 = False
def training_step(): def training_step():
...@@ -93,6 +93,8 @@ class TestCache(unittest.TestCase): ...@@ -93,6 +93,8 @@ class TestCache(unittest.TestCase):
# but I'm keeping this in case we want different tolerances for fp16 and fp32 checks. # but I'm keeping this in case we want different tolerances for fp16 and fp32 checks.
if model.weight.grad.type() == "torch.cuda.HalfTensor": if model.weight.grad.type() == "torch.cuda.HalfTensor":
self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad)) self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
elif model.weight.grad.type() == "torch.cuda.BFloat16Tensor":
self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
elif model.weight.grad.type() == "torch.cuda.FloatTensor": elif model.weight.grad.type() == "torch.cuda.FloatTensor":
self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad)) self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
else: else:
...@@ -115,22 +117,41 @@ class TestCache(unittest.TestCase): ...@@ -115,22 +117,41 @@ class TestCache(unittest.TestCase):
# I could easily have these as a set of for loops in a single test, # I could easily have these as a set of for loops in a single test,
# instead of going for granularity. # instead of going for granularity.
def test_whitelist_module_fp16_weight(self): def test_whitelist_module_fp16_weight(self):
self.train_eval_train_test(WhitelistModule, torch.float16) self.train_eval_train_test(WhitelistModule, torch.float16, "O1")
def test_whitelist_module_fp32_weight(self): def test_whitelist_module_fp32_weight(self):
self.train_eval_train_test(WhitelistModule, torch.float32) self.train_eval_train_test(WhitelistModule, torch.float32, "O1")
def test_blacklist_module_fp16_weight(self): def test_blacklist_module_fp16_weight(self):
self.train_eval_train_test(BlacklistModule, torch.float16) self.train_eval_train_test(BlacklistModule, torch.float16, "O1")
def test_blacklist_module_fp32_weight(self): def test_blacklist_module_fp32_weight(self):
self.train_eval_train_test(BlacklistModule, torch.float32) self.train_eval_train_test(BlacklistModule, torch.float32, "O1")
def test_promote_module_fp16_weight(self): def test_promote_module_fp16_weight(self):
self.train_eval_train_test(PromoteModule, torch.float16) self.train_eval_train_test(PromoteModule, torch.float16, "O1")
def test_promote_module_fp32_weight(self):
self.train_eval_train_test(PromoteModule, torch.float32, "O1")
# opt_level = O4
def test_whitelist_module_bfp16_weight(self):
self.train_eval_train_test(WhitelistModule, torch.bfloat16, "O4")
def test_whitelist_module_fp32_weight(self):
self.train_eval_train_test(WhitelistModule, torch.float32, "O4")
def test_blacklist_module_bfp16_weight(self):
self.train_eval_train_test(BlacklistModule, torch.bfloat16, "O4")
def test_blacklist_module_fp32_weight(self):
self.train_eval_train_test(BlacklistModule, torch.float32, "O4")
def test_promote_module_bfp16_weight(self):
self.train_eval_train_test(PromoteModule, torch.bfloat16, "O4")
def test_promote_module_fp32_weight(self): def test_promote_module_fp32_weight(self):
self.train_eval_train_test(PromoteModule, torch.float32) self.train_eval_train_test(PromoteModule, torch.float32, "O4")
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -6,7 +6,7 @@ import torch.nn.functional as F ...@@ -6,7 +6,7 @@ import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from apex import amp from apex import amp
from apex.testing.common_utils import skipIfRocm
from utils import common_init, FLOAT from utils import common_init, FLOAT
...@@ -28,7 +28,7 @@ class MyModel(torch.nn.Module): ...@@ -28,7 +28,7 @@ class MyModel(torch.nn.Module):
class TestCheckpointing(unittest.TestCase): class TestCheckpointing(unittest.TestCase):
def setUp(self): def setUp(self):
self.initial_lr = 1e-3 self.initial_lr = 1e-3
self.test_opt_levels = ("O0", "O1", "O2", "O3") self.test_opt_levels = ("O0", "O1", "O2", "O3", "O4", "O5")
def seed(self): def seed(self):
torch.manual_seed(2809) torch.manual_seed(2809)
...@@ -161,6 +161,7 @@ class TestCheckpointing(unittest.TestCase): ...@@ -161,6 +161,7 @@ class TestCheckpointing(unittest.TestCase):
# skip tests for different opt_levels # skip tests for different opt_levels
continue continue
@skipIfRocm
def test_loss_scale_decrease(self): def test_loss_scale_decrease(self):
num_losses = 3 num_losses = 3
nb_decrease_loss_scales = [0, 1, 2] nb_decrease_loss_scales = [0, 1, 2]
...@@ -236,6 +237,7 @@ class TestCheckpointing(unittest.TestCase): ...@@ -236,6 +237,7 @@ class TestCheckpointing(unittest.TestCase):
state_dict = model.state_dict() state_dict = model.state_dict()
for key in state_dict: for key in state_dict:
self.assertFalse('Half' in state_dict[key].type()) self.assertFalse('Half' in state_dict[key].type())
self.assertFalse('BFloat16' in state_dict[key].type())
# Check, if model is still trainable # Check, if model is still trainable
# Create dummy data # Create dummy data
......
...@@ -13,6 +13,7 @@ from torch.nn import Parameter ...@@ -13,6 +13,7 @@ from torch.nn import Parameter
from utils import common_init, HALF, FLOAT,\ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
from apex.testing.common_utils import skipIfRocm
try: try:
import amp_C import amp_C
......
...@@ -12,6 +12,8 @@ from math import floor ...@@ -12,6 +12,8 @@ from math import floor
from utils import common_init, HALF, FLOAT,\ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
from apex.testing.common_utils import skipIfRocm
try: try:
import amp_C import amp_C
from amp_C import multi_tensor_axpby from amp_C import multi_tensor_axpby
...@@ -69,7 +71,10 @@ class TestMultiTensorAxpby(unittest.TestCase): ...@@ -69,7 +71,10 @@ class TestMultiTensorAxpby(unittest.TestCase):
applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b, -1) applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b, -1)
self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]), # TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16
if out_type == torch.bfloat16:
out_list = [out.float() for out in out_list]
self.assertTrue(all([torch.allclose(out, self.ref.to(out.dtype)) for out in out_list]),
msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors, msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors,
x_type, y_type, out_type, inplace)) x_type, y_type, out_type, inplace))
self.assertTrue(self.overflow_buf.item() == 0, self.assertTrue(self.overflow_buf.item() == 0,
...@@ -119,9 +124,9 @@ class TestMultiTensorAxpby(unittest.TestCase): ...@@ -119,9 +124,9 @@ class TestMultiTensorAxpby(unittest.TestCase):
for sizea, sizeb in input_size_pairs: for sizea, sizeb in input_size_pairs:
for applier in appliers: for applier in appliers:
for repeat in repeat_tensors: for repeat in repeat_tensors:
for x_type in (torch.float32, torch.float16): for x_type in (torch.float32, torch.float16, torch.bfloat16):
for y_type in (torch.float32, torch.float16): for y_type in (torch.float32, torch.float16, torch.bfloat16):
for out_type in (torch.float32, torch.float16): for out_type in (torch.float32, torch.float16, torch.bfloat16):
for inplace in (True, False): for inplace in (True, False):
if inplace is True and (y_type is not out_type): if inplace is True and (y_type is not out_type):
continue continue
...@@ -137,6 +142,7 @@ class TestMultiTensorAxpby(unittest.TestCase): ...@@ -137,6 +142,7 @@ class TestMultiTensorAxpby(unittest.TestCase):
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@unittest.skipIf(not try_nhwc, "torch version is 1.4 or earlier, may not support nhwc") @unittest.skipIf(not try_nhwc, "torch version is 1.4 or earlier, may not support nhwc")
@skipIfRocm
def test_fuzz_nhwc(self): def test_fuzz_nhwc(self):
input_size_pairs = ( input_size_pairs = (
((7, 77, 7, 77), (5, 55, 5, 55)), ((7, 77, 7, 77), (5, 55, 5, 55)),
......
...@@ -11,6 +11,8 @@ import torch.nn.functional as F ...@@ -11,6 +11,8 @@ import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
from apex.testing.common_utils import skipIfRocm
try: try:
import amp_C import amp_C
from amp_C import multi_tensor_l2norm from amp_C import multi_tensor_l2norm
...@@ -56,6 +58,7 @@ class TestMultiTensorL2Norm(unittest.TestCase): ...@@ -56,6 +58,7 @@ class TestMultiTensorL2Norm(unittest.TestCase):
self.assertTrue(self.overflow_buf.item() == 0) self.assertTrue(self.overflow_buf.item() == 0)
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_fuzz(self): def test_fuzz(self):
input_size_pairs = ( input_size_pairs = (
(7777*77, 555*555), (7777*77, 555*555),
......
...@@ -49,7 +49,10 @@ class TestMultiTensorScale(unittest.TestCase): ...@@ -49,7 +49,10 @@ class TestMultiTensorScale(unittest.TestCase):
applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale) applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list])) # TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16
if out_type == torch.bfloat16:
out_list = [out.float() for out in out_list]
self.assertTrue(all([torch.allclose(out, self.ref.to(out.dtype)) for out in out_list]))
self.assertTrue(self.overflow_buf.item() == 0) self.assertTrue(self.overflow_buf.item() == 0)
def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False): def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False):
...@@ -106,8 +109,8 @@ class TestMultiTensorScale(unittest.TestCase): ...@@ -106,8 +109,8 @@ class TestMultiTensorScale(unittest.TestCase):
for sizea, sizeb in input_size_pairs: for sizea, sizeb in input_size_pairs:
for applier in appliers: for applier in appliers:
for repeat in repeat_tensors: for repeat in repeat_tensors:
for in_type in (torch.float32, torch.float16): for in_type in (torch.float32, torch.float16, torch.bfloat16):
for out_type in (torch.float32, torch.float16): for out_type in (torch.float32, torch.float16, torch.bfloat16):
for inplace in (True, False): for inplace in (True, False):
if inplace is True and (out_type is not in_type): if inplace is True and (out_type is not in_type):
continue continue
......
...@@ -13,6 +13,8 @@ from torch.nn import Parameter ...@@ -13,6 +13,8 @@ from torch.nn import Parameter
from utils import common_init, HALF, FLOAT,\ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
from apex.testing.common_utils import skipIfRocm
class MyModel(torch.nn.Module): class MyModel(torch.nn.Module):
def __init__(self, unique): def __init__(self, unique):
super(MyModel, self).__init__() super(MyModel, self).__init__()
...@@ -41,7 +43,7 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase): ...@@ -41,7 +43,7 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def test_2models2losses1optimizer(self): def test_2models2losses1optimizer(self):
model0 = MyModel(1) model0 = MyModel(1)
model1 = MyModel(2) model1 = MyModel(2)
......
...@@ -7,18 +7,18 @@ import torch ...@@ -7,18 +7,18 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from utils import common_init, HALF, FLOAT, DTYPES from utils import common_init, HALF, FLOAT, DTYPES, DTYPES2, MATCH_INPUT
class TestPromotion(unittest.TestCase): class _TestPromotion(unittest.TestCase):
def setUp(self): def run_binary_promote_test(self, fns, input_shape, lp_type, x_inplace=False):
self.handle = amp.init(enabled=True) if lp_type == torch.half:
common_init(self) dtypes = DTYPES
elif lp_type == torch.bfloat16:
def tearDown(self): dtypes = DTYPES2
self.handle._deactivate() else:
raise RuntimeError("Creating test class with invalid low_precision type. \
def run_binary_promote_test(self, fns, input_shape, x_inplace=False): Supported types are torch.half and torch.bfloat16")
type_pairs = it.product(DTYPES, DTYPES) type_pairs = it.product(dtypes, dtypes)
for fn, (xtype, ytype) in it.product(fns, type_pairs): for fn, (xtype, ytype) in it.product(fns, type_pairs):
x = torch.randn(input_shape, dtype=xtype).requires_grad_() x = torch.randn(input_shape, dtype=xtype).requires_grad_()
x_leaf = x x_leaf = x
...@@ -35,41 +35,78 @@ class TestPromotion(unittest.TestCase): ...@@ -35,41 +35,78 @@ class TestPromotion(unittest.TestCase):
if xtype == torch.float or ytype == torch.float: if xtype == torch.float or ytype == torch.float:
self.assertEqual(out.type(), FLOAT) self.assertEqual(out.type(), FLOAT)
else: else:
self.assertEqual(out.type(), HALF) self.assertEqual(out.type(), MATCH_INPUT[lp_type])
out.float().sum().backward() out.float().sum().backward()
self.assertEqual(x_leaf.grad.dtype, xtype) self.assertEqual(x_leaf.grad.dtype, xtype)
def _test_cat_matches_widest(self, lp_type):
shape = self.b
ys = [torch.randn(shape, dtype=lp_type) for _ in range(5)]
x_float = torch.randn(shape)
out = torch.cat(ys + [x_float])
self.assertEqual(out.type(), FLOAT)
x_lp = torch.randn(shape, dtype=lp_type)
out = torch.cat(ys + [x_lp])
self.assertEqual(out.type(), MATCH_INPUT[lp_type])
def _test_inplace_exp_is_error_for_lp(self, lp_type):
xs = torch.randn(self.b)
xs.exp_()
self.assertEqual(xs.type(), FLOAT)
xs = torch.randn(self.b, dtype=lp_type)
with self.assertRaises(NotImplementedError):
xs.exp_()
class TestPromotionHalf(_TestPromotion):
def setUp(self):
self.handle = amp.init(enabled=True, patch_type=torch.half)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_atan2_matches_widest(self): def test_atan2_matches_widest(self):
fns = [lambda x, y : torch.atan2(x, y), fns = [lambda x, y : torch.atan2(x, y),
lambda x, y : x.atan2(y)] lambda x, y : x.atan2(y)]
self.run_binary_promote_test(fns, (self.b,)) self.run_binary_promote_test(fns, (self.b,), torch.half)
def test_mul_matches_widest(self): def test_mul_matches_widest(self):
fns = [lambda x, y : torch.mul(x, y), fns = [lambda x, y : torch.mul(x, y),
lambda x, y: x.mul(y)] lambda x, y: x.mul(y)]
self.run_binary_promote_test(fns, (self.b,)) self.run_binary_promote_test(fns, (self.b,), torch.half)
def test_cat_matches_widest(self): def test_cat_matches_widest(self):
shape = self.b self._test_cat_matches_widest(torch.half)
ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)]
x_float = torch.randn(shape)
out = torch.cat(ys + [x_float])
self.assertEqual(out.type(), FLOAT)
x_half = torch.randn(shape, dtype=torch.half)
out = torch.cat(ys + [x_half])
self.assertEqual(out.type(), HALF)
def test_inplace_exp_is_error_for_half(self): def test_inplace_exp_is_error_for_half(self):
xs = torch.randn(self.b) self._test_inplace_exp_is_error_for_lp(torch.half)
xs.exp_()
self.assertEqual(xs.type(), FLOAT) def test_inplace_add_matches_self(self):
xs = torch.randn(self.b, dtype=torch.half) fn = lambda x, y: x.add_(y)
with self.assertRaises(NotImplementedError): self.run_binary_promote_test([fn], (self.b,), torch.half, x_inplace=True)
xs.exp_()
class TestPromotionBFloat16(_TestPromotion):
def setUp(self):
self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_mul_matches_widest(self):
fns = [lambda x, y : torch.mul(x, y),
lambda x, y: x.mul(y)]
self.run_binary_promote_test(fns, (self.b,), torch.bfloat16)
def test_cat_matches_widest(self):
self._test_cat_matches_widest(torch.bfloat16)
def test_inplace_exp_is_error_for_bfloat16(self):
self._test_inplace_exp_is_error_for_lp(torch.bfloat16)
def test_inplace_add_matches_self(self): def test_inplace_add_matches_self(self):
fn = lambda x, y: x.add_(y) fn = lambda x, y: x.add_(y)
self.run_binary_promote_test([fn], (self.b,), x_inplace=True) self.run_binary_promote_test([fn], (self.b,), torch.bfloat16, x_inplace=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from torch import nn from torch import nn
from utils import common_init, HALF from utils import common_init, HALF
from apex.testing.common_utils import skipIfRocm
class TestRnnCells(unittest.TestCase): class TestRnnCells(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -73,6 +74,7 @@ class TestRnns(unittest.TestCase): ...@@ -73,6 +74,7 @@ class TestRnns(unittest.TestCase):
output[-1, :, :].float().sum().backward() output[-1, :, :].float().sum().backward()
self.assertEqual(x.grad.dtype, x.dtype) self.assertEqual(x.grad.dtype, x.dtype)
@skipIfRocm
def test_rnn_is_half(self): def test_rnn_is_half(self):
configs = [(1, False), (2, False), (2, True)] configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs: for layers, bidir in configs:
...@@ -80,6 +82,7 @@ class TestRnns(unittest.TestCase): ...@@ -80,6 +82,7 @@ class TestRnns(unittest.TestCase):
nonlinearity='relu', bidirectional=bidir) nonlinearity='relu', bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir) self.run_rnn_test(rnn, layers, bidir)
@skipIfRocm
def test_gru_is_half(self): def test_gru_is_half(self):
configs = [(1, False), (2, False), (2, True)] configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs: for layers, bidir in configs:
...@@ -87,6 +90,7 @@ class TestRnns(unittest.TestCase): ...@@ -87,6 +90,7 @@ class TestRnns(unittest.TestCase):
bidirectional=bidir) bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir) self.run_rnn_test(rnn, layers, bidir)
@skipIfRocm
def test_lstm_is_half(self): def test_lstm_is_half(self):
configs = [(1, False), (2, False), (2, True)] configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs: for layers, bidir in configs:
...@@ -94,6 +98,7 @@ class TestRnns(unittest.TestCase): ...@@ -94,6 +98,7 @@ class TestRnns(unittest.TestCase):
bidirectional=bidir) bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir, state_tuple=True) self.run_rnn_test(rnn, layers, bidir, state_tuple=True)
@skipIfRocm
def test_rnn_packed_sequence(self): def test_rnn_packed_sequence(self):
num_layers = 2 num_layers = 2
rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers) rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers)
......
...@@ -2,15 +2,21 @@ import torch ...@@ -2,15 +2,21 @@ import torch
HALF = 'torch.cuda.HalfTensor' HALF = 'torch.cuda.HalfTensor'
FLOAT = 'torch.cuda.FloatTensor' FLOAT = 'torch.cuda.FloatTensor'
BFLOAT16 = 'torch.cuda.BFloat16Tensor'
DTYPES = [torch.half, torch.float] DTYPES = [torch.half, torch.float]
DTYPES2 = [torch.bfloat16, torch.float]
ALWAYS_HALF = {torch.float: HALF, ALWAYS_HALF = {torch.float: HALF,
torch.half: HALF} torch.half: HALF}
ALWAYS_BFLOAT16 = {torch.bfloat16: BFLOAT16,
torch.float: BFLOAT16}
ALWAYS_FLOAT = {torch.float: FLOAT, ALWAYS_FLOAT = {torch.float: FLOAT,
torch.half: FLOAT} torch.half: FLOAT}
MATCH_INPUT = {torch.float: FLOAT, MATCH_INPUT = {torch.float: FLOAT,
torch.half: HALF} torch.half: HALF,
torch.bfloat16: BFLOAT16}
def common_init(test_case): def common_init(test_case):
test_case.h = 64 test_case.h = 64
......
...@@ -14,22 +14,28 @@ class TestFusedAdagrad(unittest.TestCase): ...@@ -14,22 +14,28 @@ class TestFusedAdagrad(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def gen_param_optim(self, tensors, adagrad_option): def gen_param_optim(self, tensors, adagrad_option, apex_only=False):
ref_param = [] ref_param = []
tst_param = [] tst_param = []
for tensor in tensors: for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone())) if apex_only:
ref_param.append(torch.nn.Parameter(tensor.clone().float()))
else:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone())) tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = torch.optim.Adagrad(ref_param, **adagrad_option) if apex_only:
ref_optim = apex.optimizers.FusedAdagrad(ref_param, **adagrad_option)
else:
ref_optim = torch.optim.Adagrad(ref_param, **adagrad_option)
tst_optim = apex.optimizers.FusedAdagrad(tst_param, **adagrad_option) tst_optim = apex.optimizers.FusedAdagrad(tst_param, **adagrad_option)
return (ref_param, tst_param, ref_optim, tst_optim) return (ref_param, tst_param, ref_optim, tst_optim)
def gen_grad(self, ref_param, tst_param): def gen_grad(self, ref_param, tst_param, apex_only=False):
for p_ref, p_tst in zip(ref_param, tst_param): for p_ref, p_tst in zip(ref_param, tst_param):
p_ref.grad = torch.rand_like(p_ref) p_tst.grad = torch.rand_like(p_tst)
p_tst.grad = p_ref.grad p_ref.grad = p_tst.grad.detach().float() if apex_only else p_tst.grad
def gen_mixed_grad(self, ref_param, tst_param, scale=1.0): def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
half_grads = [] half_grads = []
...@@ -38,9 +44,11 @@ class TestFusedAdagrad(unittest.TestCase): ...@@ -38,9 +44,11 @@ class TestFusedAdagrad(unittest.TestCase):
p_ref.grad = half_grads[-1].float() / scale p_ref.grad = half_grads[-1].float() / scale
return half_grads return half_grads
def get_max_diff(self, ref_param, tst_param): def get_max_diff(self, ref_param, tst_param, apex_only=False):
max_abs_diff = max_rel_diff = 0 max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param): for p_ref, p_tst in zip(ref_param, tst_param):
if apex_only:
p_tst = p_tst.float()
max_abs_diff_p = (p_ref - p_tst).abs().max().item() max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
...@@ -51,23 +59,24 @@ class TestFusedAdagrad(unittest.TestCase): ...@@ -51,23 +59,24 @@ class TestFusedAdagrad(unittest.TestCase):
return max_abs_diff, max_rel_diff return max_abs_diff, max_rel_diff
def gen_single_type_test(self, param_type=torch.float): def gen_single_type_test(self, param_type=torch.float, apex_only=False):
nelem = 278011 nelem = 278011
adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 1.0e-5} adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 1.0e-5}
tensor = torch.rand(nelem, dtype=param_type, device="cuda") tensor = torch.rand(nelem, dtype=param_type, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], adagrad_option [tensor], adagrad_option, apex_only=apex_only
) )
for _ in range(self.iters): for _ in range(self.iters):
self.gen_grad(ref_param, tst_param) self.gen_grad(ref_param, tst_param, apex_only=apex_only)
ref_optim.step() ref_optim.step()
tst_optim.step() tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param, apex_only=apex_only)
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff) if not apex_only:
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_float(self): def test_float(self):
self.gen_single_type_test(param_type=torch.float) self.gen_single_type_test(param_type=torch.float)
...@@ -76,6 +85,14 @@ class TestFusedAdagrad(unittest.TestCase): ...@@ -76,6 +85,14 @@ class TestFusedAdagrad(unittest.TestCase):
def test_half(self): def test_half(self):
self.gen_single_type_test(param_type=torch.float16) self.gen_single_type_test(param_type=torch.float16)
# Compares bfloat16 computation against float32 as gold standard.
# Uses apex optimizers(controlled by apex_only flag) for both types.
# Doesn't use upstream optimizer like other tests as they seem to be
# numerically unstable for half types(see skip note for test above).
def test_bfloat16(self):
self.max_abs_diff = 1e-2
self.gen_single_type_test(param_type=torch.bfloat16, apex_only=True)
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0} adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0}
......
...@@ -15,22 +15,28 @@ class TestFusedAdam(unittest.TestCase): ...@@ -15,22 +15,28 @@ class TestFusedAdam(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def gen_param_optim(self, tensors, adam_option): def gen_param_optim(self, tensors, adam_option, apex_only=False):
ref_param = [] ref_param = []
tst_param = [] tst_param = []
for tensor in tensors: for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone())) if apex_only:
ref_param.append(torch.nn.Parameter(tensor.clone().float()))
else:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone())) tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = torch.optim.Adam(ref_param, **adam_option) if apex_only:
ref_optim = apex.optimizers.FusedAdam(ref_param, **adam_option)
else:
ref_optim = torch.optim.Adam(ref_param, **adam_option)
tst_optim = apex.optimizers.FusedAdam(tst_param, **adam_option) tst_optim = apex.optimizers.FusedAdam(tst_param, **adam_option)
return (ref_param, tst_param, ref_optim, tst_optim) return (ref_param, tst_param, ref_optim, tst_optim)
def gen_grad(self, ref_param, tst_param): def gen_grad(self, ref_param, tst_param, apex_only=False):
for p_ref, p_tst in zip(ref_param, tst_param): for p_ref, p_tst in zip(ref_param, tst_param):
p_ref.grad = torch.rand_like(p_ref) p_tst.grad = torch.rand_like(p_tst)
p_tst.grad = p_ref.grad p_ref.grad = p_tst.grad.detach().float() if apex_only else p_tst.grad
def gen_mixed_grad(self, ref_param, tst_param, scale=1.0): def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
half_grads = [] half_grads = []
...@@ -39,9 +45,11 @@ class TestFusedAdam(unittest.TestCase): ...@@ -39,9 +45,11 @@ class TestFusedAdam(unittest.TestCase):
p_ref.grad = half_grads[-1].float() / scale p_ref.grad = half_grads[-1].float() / scale
return half_grads return half_grads
def get_max_diff(self, ref_param, tst_param): def get_max_diff(self, ref_param, tst_param, apex_only=False):
max_abs_diff = max_rel_diff = 0 max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param): for p_ref, p_tst in zip(ref_param, tst_param):
if apex_only:
p_tst = p_tst.float()
max_abs_diff_p = (p_ref - p_tst).abs().max().item() max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
...@@ -50,23 +58,24 @@ class TestFusedAdam(unittest.TestCase): ...@@ -50,23 +58,24 @@ class TestFusedAdam(unittest.TestCase):
return max_abs_diff, max_rel_diff return max_abs_diff, max_rel_diff
def gen_single_type_test(self, param_type=torch.float): def gen_single_type_test(self, param_type=torch.float, apex_only=False):
nelem = 278011 nelem = 278011
adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay':0, 'amsgrad':False} 'weight_decay':0, 'amsgrad':False}
tensor = torch.rand(nelem, dtype=param_type, device='cuda') tensor = torch.rand(nelem, dtype=param_type, device='cuda')
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], adam_option) self.gen_param_optim([tensor], adam_option, apex_only=apex_only)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_param, tst_param) self.gen_grad(ref_param, tst_param, apex_only=apex_only)
ref_optim.step() ref_optim.step()
tst_optim.step() tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param, apex_only=apex_only)
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff) if not apex_only:
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_float(self): def test_float(self):
self.gen_single_type_test(param_type=torch.float) self.gen_single_type_test(param_type=torch.float)
...@@ -74,6 +83,14 @@ class TestFusedAdam(unittest.TestCase): ...@@ -74,6 +83,14 @@ class TestFusedAdam(unittest.TestCase):
def test_half(self): def test_half(self):
self.gen_single_type_test(param_type=torch.float16) self.gen_single_type_test(param_type=torch.float16)
# Compares bfloat16 computation against float32 as gold standard.
# Uses apex optimizers(controlled by apex_only flag) for both types.
# Doesn't use upstream optimizer like other tests as they seem to be
# numerically unstable for half types
def test_bfloat16(self):
self.max_abs_diff = 1e-2
self.gen_single_type_test(param_type=torch.bfloat16, apex_only=True)
@unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked') @unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked')
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from torch.optim import Optimizer from torch.optim import Optimizer
import apex import apex
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
from apex.testing.common_utils import skipIfRocm
class RefLAMB(Optimizer): class RefLAMB(Optimizer):
r"""Implements Lamb algorithm. r"""Implements Lamb algorithm.
...@@ -207,6 +208,7 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -207,6 +208,7 @@ class TestFusedLAMB(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff) self.assertLessEqual(max_rel_diff, self.max_rel_diff)
@skipIfRocm
def test_float(self): def test_float(self):
self.gen_single_type_test(param_type=torch.float) self.gen_single_type_test(param_type=torch.float)
...@@ -214,6 +216,7 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -214,6 +216,7 @@ class TestFusedLAMB(unittest.TestCase):
def test_half(self): def test_half(self):
self.gen_single_type_test(param_type=torch.float16) self.gen_single_type_test(param_type=torch.float16)
@skipIfRocm
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
weight_decay = [0, 0.01] weight_decay = [0, 0.01]
...@@ -234,6 +237,7 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -234,6 +237,7 @@ class TestFusedLAMB(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff) self.assertLessEqual(max_rel_diff, self.max_rel_diff)
@skipIfRocm
def test_lamb_option(self): def test_lamb_option(self):
nelem = 1 nelem = 1
tensor = torch.rand(nelem, dtype=torch.float, device='cuda') tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
......
#!/bin/bash
APEX_TEST_WITH_ROCM=1 python3.6 run_test.py
import unittest import unittest
import sys import sys
from apex.testing.common_utils import TEST_WITH_ROCM, skipIfRocm
test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"] test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"]
ROCM_BLACKLIST = [
'run_fused_layer_norm',
'run_pyprof_nvtx',
'run_pyprof_data',
'run_mlp'
]
runner = unittest.TextTestRunner(verbosity=2) runner = unittest.TextTestRunner(verbosity=2)
errcode = 0 errcode = 0
for test_dir in test_dirs: for test_dir in test_dirs:
if (test_dir in ROCM_BLACKLIST) and TEST_WITH_ROCM:
continue
suite = unittest.TestLoader().discover(test_dir) suite = unittest.TestLoader().discover(test_dir)
print("\nExecuting tests from " + test_dir) print("\nExecuting tests from " + test_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment