Commit f79993d9 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Merge remote-tracking branch 'upstream/master' into IFU-master-2021-10-15

parents 297ab210 1d5f7e55
...@@ -34,6 +34,32 @@ ...@@ -34,6 +34,32 @@
} }
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \ switch(TYPE) \
{ \ { \
...@@ -166,6 +192,160 @@ ...@@ -166,6 +192,160 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_in = double; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_out = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
template<typename T> template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes __device__ __forceinline__ T reduce_block_into_lanes
(T *x, (T *x,
......
...@@ -81,7 +81,7 @@ def parse(): ...@@ -81,7 +81,7 @@ def parse():
help='Only run 10 iterations for profiling.') help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true') parser.add_argument('--deterministic', action='store_true')
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int)
parser.add_argument('--sync_bn', action='store_true', parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.') help='enabling apex sync BN.')
......
import torch import torch
from torch.utils import cpp_extension from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages from setuptools import setup, find_packages
import subprocess import subprocess
...@@ -46,7 +46,7 @@ if not torch.cuda.is_available() and not IS_ROCM_PYTORCH: ...@@ -46,7 +46,7 @@ if not torch.cuda.is_available() and not IS_ROCM_PYTORCH:
'If you wish to cross-compile for a single specific architecture,\n' 'If you wish to cross-compile for a single specific architecture,\n'
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11: if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else: else:
...@@ -85,11 +85,8 @@ if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -85,11 +85,8 @@ if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
if TORCH_MAJOR == 0: if TORCH_MAJOR == 0:
raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, " raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
"found torch.__version__ = {}".format(torch.__version__)) "found torch.__version__ = {}".format(torch.__version__))
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if "--cpp_ext" in sys.argv: if "--cpp_ext" in sys.argv:
from torch.utils.cpp_extension import CppExtension
sys.argv.remove("--cpp_ext") sys.argv.remove("--cpp_ext")
ext_modules.append( ext_modules.append(
CppExtension('apex_C', CppExtension('apex_C',
...@@ -138,13 +135,9 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): ...@@ -138,13 +135,9 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--distributed_adam" in sys.argv: if "--distributed_adam" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--distributed_adam") sys.argv.remove("--distributed_adam")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
ext_modules.append( ext_modules.append(
...@@ -157,7 +150,6 @@ if "--distributed_adam" in sys.argv: ...@@ -157,7 +150,6 @@ if "--distributed_adam" in sys.argv:
'--use_fast_math'] + version_dependent_macros})) '--use_fast_math'] + version_dependent_macros}))
if "--distributed_lamb" in sys.argv: if "--distributed_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--distributed_lamb") sys.argv.remove("--distributed_lamb")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
...@@ -178,7 +170,6 @@ if "--distributed_lamb" in sys.argv: ...@@ -178,7 +170,6 @@ if "--distributed_lamb" in sys.argv:
'nvcc': nvcc_args_distributed_lamb if not IS_ROCM_PYTORCH else hipcc_args_distributed_lamb})) 'nvcc': nvcc_args_distributed_lamb if not IS_ROCM_PYTORCH else hipcc_args_distributed_lamb}))
if "--cuda_ext" in sys.argv: if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--cuda_ext") sys.argv.remove("--cuda_ext")
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
...@@ -197,6 +188,7 @@ if "--cuda_ext" in sys.argv: ...@@ -197,6 +188,7 @@ if "--cuda_ext" in sys.argv:
'csrc/multi_tensor_scale_kernel.cu', 'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu', 'csrc/multi_tensor_axpby_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu', 'csrc/multi_tensor_l2norm_kernel.cu',
'csrc/multi_tensor_l2norm_scale_kernel.cu',
'csrc/multi_tensor_lamb_stage_1.cu', 'csrc/multi_tensor_lamb_stage_1.cu',
'csrc/multi_tensor_lamb_stage_2.cu', 'csrc/multi_tensor_lamb_stage_2.cu',
'csrc/multi_tensor_adam.cu', 'csrc/multi_tensor_adam.cu',
...@@ -235,9 +227,38 @@ if "--cuda_ext" in sys.argv: ...@@ -235,9 +227,38 @@ if "--cuda_ext" in sys.argv:
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fused_dense_cuda',
sources=['csrc/fused_dense.cpp',
'csrc/fused_dense_cuda.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='scaled_upper_triang_masked_softmax_cuda',
sources=['csrc/megatron/scaled_upper_triang_masked_softmax.cpp',
'csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='scaled_masked_softmax_cuda',
sources=['csrc/megatron/scaled_masked_softmax.cpp',
'csrc/megatron/scaled_masked_softmax_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'] + version_dependent_macros}))
if "--bnp" in sys.argv: if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--bnp") sys.argv.remove("--bnp")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
...@@ -261,7 +282,6 @@ if "--bnp" in sys.argv: ...@@ -261,7 +282,6 @@ if "--bnp" in sys.argv:
'-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros})) '-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros}))
if "--xentropy" in sys.argv: if "--xentropy" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--xentropy") sys.argv.remove("--xentropy")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
...@@ -281,7 +301,6 @@ if "--xentropy" in sys.argv: ...@@ -281,7 +301,6 @@ if "--xentropy" in sys.argv:
if "--deprecated_fused_adam" in sys.argv: if "--deprecated_fused_adam" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--deprecated_fused_adam") sys.argv.remove("--deprecated_fused_adam")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
...@@ -302,7 +321,6 @@ if "--deprecated_fused_adam" in sys.argv: ...@@ -302,7 +321,6 @@ if "--deprecated_fused_adam" in sys.argv:
'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam})) 'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))
if "--deprecated_fused_lamb" in sys.argv: if "--deprecated_fused_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--deprecated_fused_lamb") sys.argv.remove("--deprecated_fused_lamb")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
...@@ -329,18 +347,14 @@ if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')) ...@@ -329,18 +347,14 @@ if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h'))
generator_flag = ['-DOLD_GENERATOR'] generator_flag = ['-DOLD_GENERATOR']
if "--fast_layer_norm" in sys.argv: if "--fast_layer_norm" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_layer_norm") sys.argv.remove("--fast_layer_norm")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag.append('-gencode') cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('arch=compute_80,code=sm_80')
...@@ -356,24 +370,57 @@ if "--fast_layer_norm" in sys.argv: ...@@ -356,24 +370,57 @@ if "--fast_layer_norm" in sys.argv:
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'-I./apex/contrib/csrc/layer_norm/',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")]))
if "--fmha" in sys.argv:
sys.argv.remove("--fmha")
if CUDA_HOME is None:
raise RuntimeError("--fmha was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) < 11:
raise RuntimeError("--fmha only supported on SM80")
ext_modules.append(
CUDAExtension(name='fmhalib',
sources=[
'apex/contrib/csrc/fmha/fmha_api.cpp',
'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
],
extra_compile_args={'cxx': ['-O3',
] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_80,code=sm_80',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc"), os.path.join(this_dir, "apex/contrib/csrc/fmha/src")]))
if "--fast_multihead_attn" in sys.argv: if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_multihead_attn") sys.argv.remove("--fast_multihead_attn")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag.append('-gencode') cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('arch=compute_80,code=sm_80')
...@@ -386,12 +433,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -386,12 +433,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_mask_softmax_dropout', CUDAExtension(name='fast_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp', sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp',
...@@ -399,12 +446,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -399,12 +446,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask', CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp',
...@@ -412,12 +459,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -412,12 +459,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias', CUDAExtension(name='fast_self_multihead_attn_bias',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp',
...@@ -425,12 +472,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -425,12 +472,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn', CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
...@@ -438,12 +485,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -438,12 +485,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add', CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
...@@ -451,12 +498,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -451,12 +498,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn', CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
...@@ -464,12 +511,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -464,12 +511,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add', CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
...@@ -477,12 +524,47 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -477,12 +524,47 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
if "--transducer" in sys.argv:
sys.argv.remove("--transducer")
if CUDA_HOME is None:
raise RuntimeError("--transducer was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='transducer_joint_cuda',
sources=['apex/contrib/csrc/transducer/transducer_joint.cpp',
'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc': ['-O3'] + version_dependent_macros},
include_dirs=[os.path.join(this_dir, 'csrc'), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")]))
ext_modules.append(
CUDAExtension(name='transducer_loss_cuda',
sources=['apex/contrib/csrc/transducer/transducer_loss.cpp',
'apex/contrib/csrc/transducer/transducer_loss_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
if "--fast_bottleneck" in sys.argv:
sys.argv.remove("--fast_bottleneck")
if CUDA_HOME is None:
raise RuntimeError("--fast_bottleneck was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
ext_modules.append(
CUDAExtension(name='fast_bottleneck',
sources=['apex/contrib/csrc/bottleneck/bottleneck.cpp'],
include_dirs=[os.path.join(this_dir, 'apex/contrib/csrc/cudnn-frontend/include')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag}))
setup( setup(
name='apex', name='apex',
...@@ -498,6 +580,6 @@ setup( ...@@ -498,6 +580,6 @@ setup(
'apex.egg-info',)), 'apex.egg-info',)),
description='PyTorch Extensions written by NVIDIA', description='PyTorch Extensions written by NVIDIA',
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass=cmdclass, cmdclass={'build_ext': BuildExtension} if ext_modules else {},
extras_require=extras, extras_require=extras,
) )
import itertools
import unittest import unittest
import os
import random
import torch import torch
import apex import apex
from torch.autograd import Variable
class TestFusedLayerNorm(unittest.TestCase): class TestFusedLayerNorm(unittest.TestCase):
dtype = torch.float
elementwise_affine = False
normalized_shape = [32, 16]
rtol, atol = None, None
fwd_thresholds = dict(rtol=None, atol=None)
bwd_thresholds = dict(rtol=None, atol=None)
def setUp(self): def setUp(self):
# bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one # bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one
self.module_cpu_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cpu() self.module_cpu_ = apex.normalization.FusedLayerNorm(
self.module_cuda_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cuda() normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).cpu()
self.module_cuda_ = apex.normalization.FusedLayerNorm(
normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).to(device="cuda", dtype=self.dtype)
def _test_same_output(self, batch_size): def _check_same_output(self, batch_size, contiguous):
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
self.input_ = torch.randn((batch_size, *self.module_cpu_.normalized_shape), device="cpu").requires_grad_(True) if contiguous:
self.input_cuda_ = self.input_.cuda().detach().requires_grad_(True) input_shape = [batch_size] + self.normalized_shape
out_cpu_ = self.module_cpu_(self.input_) input_ = torch.randn(input_shape, device="cpu").requires_grad_(True)
input_cuda_ = input_.to(device="cuda", dtype=self.dtype).detach().requires_grad_(True)
self.assertTrue(input_.is_contiguous())
self.assertTrue(input_cuda_.is_contiguous())
else:
input_shape = [batch_size] + self.normalized_shape
input_shape = [batch_size * 3] + [self.normalized_shape[0] * 5, self.normalized_shape[1] * 3]
input_src_ = torch.randn(input_shape, device="cpu")
input_ = input_src_[::3, ::5, ::3].detach().requires_grad_(True)
input_cuda_ = input_src_.to(device="cuda", dtype=self.dtype)[::3, ::5, ::3].detach().requires_grad_(True)
# make sure that tensors are NOT contiguous.
self.assertFalse(input_.is_contiguous())
self.assertFalse(input_cuda_.is_contiguous())
out_cpu_ = self.module_cpu_(input_)
gO = torch.rand_like(out_cpu_) gO = torch.rand_like(out_cpu_)
out_cpu_.backward(gO) out_cpu_.backward(gO)
out_cuda_ = self.module_cuda_(self.input_cuda_) out_cuda_ = self.module_cuda_(input_cuda_)
gO = gO.cuda() gO = gO.to(device="cuda", dtype=self.dtype)
out_cuda_.backward(gO) out_cuda_.backward(gO)
assert out_cpu_.is_cuda == False self.assertFalse(out_cpu_.is_cuda)
assert out_cuda_.is_cuda == True self.assertTrue(out_cuda_.is_cuda)
torch.testing.assert_allclose(out_cpu_, out_cuda_.cpu()) # TODO (mkozuki): `torch.testing.assert_allclose` is deprecated.
torch.testing.assert_allclose(self.input_.grad, self.input_cuda_.grad.cpu()) # Use `torch.testing.assert_close`.
# See https://github.com/pytorch/pytorch/issues/61844
torch.testing.assert_allclose(
out_cpu_.to(device="cuda", dtype=self.dtype), out_cuda_, **self.fwd_thresholds)
torch.testing.assert_allclose(
input_.grad.to(device="cuda", dtype=self.dtype), input_cuda_.grad, **self.bwd_thresholds)
def _test_same_output(self, batch_size):
for contiguous in (True, False):
with self.subTest(contiguous=contiguous):
self._check_same_output(batch_size, contiguous)
def test_layer_norm(self): def test_layer_norm(self):
self._test_same_output(16) self._test_same_output(16)
...@@ -36,10 +67,105 @@ class TestFusedLayerNorm(unittest.TestCase): ...@@ -36,10 +67,105 @@ class TestFusedLayerNorm(unittest.TestCase):
class TestFusedLayerNormElemWise(TestFusedLayerNorm): class TestFusedLayerNormElemWise(TestFusedLayerNorm):
elementwise_affine = True
class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise):
dtype = torch.half
def test_large_batch(self):
self.skipTest("Skip to save time")
# Megatron style Layer Norm
class TestFusedLayerNormElemWiseMixedDtypes(TestFusedLayerNorm):
def setUp(self): def setUp(self):
self.module_cpu_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=True).cpu() self.module_cpu_ = apex.normalization.MixedFusedLayerNorm(
self.module_cuda_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=True).cuda() normalized_shape=self.normalized_shape, elementwise_affine=True).cpu()
self.module_cuda_ = apex.normalization.MixedFusedLayerNorm(
normalized_shape=self.normalized_shape, elementwise_affine=True).to(device="cuda", dtype=self.dtype)
def test_init_exception(self):
with self.assertRaisesRegex(RuntimeError, "MixedFusedLayerNorm does not support `elementwise_affine = False`"):
apex.normalization.MixedFusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cuda()
class TestFusedLayerNormElemWiseMixedDtypesHalf(TestFusedLayerNormElemWiseMixedDtypes):
dtype = torch.half
def test_large_batch(self):
self.skipTest("Skip to save time")
# NOTE (mkozuki): With the larger threshold values, still flaky.
class TestFusedLayerNormElemWiseMixedDtypesBFloat16(TestFusedLayerNormElemWiseMixedDtypesHalf):
dtype = torch.bfloat16
# NOTE (mkozuki): [BFloat16 Layer Norm flakiness]
# Use thresholds larger than those used in pytorch, see
# https://github.com/pytorch/pytorch/blob/72274e2a2fd55019ec860e1743dbdc5b0c5a5624/torch/testing/_asserts.py#L26
fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise):
dtype = torch.bfloat16
# See [BFloat16 Layer Norm flakiness]
fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
def test_large_batch(self):
self.skipTest("Skip to save time")
def _prep_layers(normalized_shape, elementwise_affine, dtype):
native = torch.nn.LayerNorm(
normalized_shape=normalized_shape, elementwise_affine=elementwise_affine
).to(device="cuda", dtype=dtype)
fused = apex.normalization.FusedLayerNorm(
normalized_shape=normalized_shape, elementwise_affine=elementwise_affine
).cuda()
return native, fused
def _prep_inputs(batch_size, normalized_shape, dtype):
shape = (batch_size, *normalized_shape)
fused = torch.randn(shape).cuda().requires_grad_(True)
with torch.no_grad():
native = fused.clone().to(dtype).requires_grad_(True)
return native, fused
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
class TestAutocastFusedLayerNorm(unittest.TestCase):
bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
def setUp(self):
self.batch_size = 16
self.normalized_shape = [32, 16]
def _run_test(self, dtype, elementwise_affine):
native, fused = _prep_layers(self.normalized_shape, elementwise_affine, dtype)
native_x, fused_x = _prep_inputs(self.batch_size, self.normalized_shape, dtype)
expected = native(native_x)
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused(fused_x)
tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedLayerNorm.bf16_fwd_thresholds
torch.testing.assert_allclose(actual, expected, **tols)
g_native = torch.rand_like(expected)
with torch.no_grad():
g_fused = g_native.clone()
expected.backward(g_native)
actual.backward(g_fused)
tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedLayerNorm.bf16_bwd_thresholds
torch.testing.assert_allclose(native_x.grad, fused_x.grad, **tols)
if __name__ == '__main__': def test_autocast(self):
unittest.main() for (dtype, elementwise_affine) in itertools.product(autocast_dtypes, (True, False)):
with self.subTest(f"{dtype}-{elementwise_affine}"):
self._run_test(dtype, elementwise_affine)
import torch
from torch.optim import Optimizer
import math
import apex
import unittest
from test_fused_optimizer import TestFusedOptimizer
from itertools import product
class Novograd(Optimizer):
"""
Implements Novograd algorithm.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.95, 0))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging: gradient averaging
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
"""
def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8,
weight_decay=0, grad_averaging=False, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
amsgrad=amsgrad)
super(Novograd, self).__init__(params, defaults)
def __setstate__(self, state):
super(Novograd, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Sparse gradients are not supported.')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
norm = torch.sum(torch.pow(grad, 2))
if exp_avg_sq == 0:
exp_avg_sq.copy_(norm)
else:
exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
grad.div_(denom)
if group['weight_decay'] != 0:
grad.add_(p.data, alpha=group['weight_decay'])
if group['grad_averaging']:
grad.mul_(1 - beta1)
exp_avg.mul_(beta1).add_(grad)
p.data.add_(exp_avg, alpha=-group['lr'])
return loss
class TestFusedNovoGrad(TestFusedOptimizer):
def __init__(self, *args, **kwargs):
super(TestFusedNovoGrad, self).__init__(*args, **kwargs)
# The options for NovoGrad and FusedNovoGrad are very specific if they
# are expected to behave the same.
self.options = {'lr':1e-3, 'betas':(0.95, 0), 'eps':1e-8,
'weight_decay':0, 'grad_averaging':False, 'amsgrad':False}
self.tst_options = {'lr':1e-3, 'betas':(0.95, 0), 'eps':1e-8,
'weight_decay':0, 'grad_averaging':False, 'amsgrad':False,
'bias_correction':False, 'reg_inside_moment':True,
'norm_type':2, 'init_zero':False, 'set_grad_none':True}
self.ref_optim = Novograd
self.fused_optim = apex.optimizers.FusedNovoGrad
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:1", "cuda:0")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
torch.cuda.synchronize()
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
tensors, self.options, self.tst_options
)
for _ in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if __name__ == '__main__':
unittest.main()
...@@ -2,9 +2,11 @@ import unittest ...@@ -2,9 +2,11 @@ import unittest
import os import os
import random import random
import math
import torch import torch
import apex import apex
from itertools import product from itertools import product
from torch.optim import Optimizer
class TestFusedOptimizer(unittest.TestCase): class TestFusedOptimizer(unittest.TestCase):
def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7): def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
...@@ -16,7 +18,14 @@ class TestFusedOptimizer(unittest.TestCase): ...@@ -16,7 +18,14 @@ class TestFusedOptimizer(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def gen_param_optim(self, tensors, options, apex_only=False): def gen_param_optim(self, tensors, options, tst_options=None):
# Adding this to make backward compatible with existing tests. Just in
# case "tst_options" are not provided, it gets a copy of options
# which contains the parameters for the reference optimizer
if tst_options == None:
tst_options = options
ref_param = [] ref_param = []
tst_param = [] tst_param = []
for tensor in tensors: for tensor in tensors:
...@@ -26,11 +35,8 @@ class TestFusedOptimizer(unittest.TestCase): ...@@ -26,11 +35,8 @@ class TestFusedOptimizer(unittest.TestCase):
ref_param.append(torch.nn.Parameter(tensor.clone())) ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone())) tst_param.append(torch.nn.Parameter(tensor.clone()))
if apex_only: ref_optim = self.ref_optim(ref_param, **options)
ref_optim = self.fused_optim(ref_param, **options) tst_optim = self.fused_optim(tst_param, **tst_options)
else:
ref_optim = self.ref_optim(ref_param, **options)
tst_optim = self.fused_optim(tst_param, **options)
return (ref_param, tst_param, ref_optim, tst_optim) return (ref_param, tst_param, ref_optim, tst_optim)
...@@ -62,9 +68,18 @@ class TestFusedOptimizer(unittest.TestCase): ...@@ -62,9 +68,18 @@ class TestFusedOptimizer(unittest.TestCase):
def gen_single_type_test(self, param_type=torch.float, apex_only=False, device='cuda'): def gen_single_type_test(self, param_type=torch.float, apex_only=False, device='cuda'):
nelem = 278011 nelem = 278011
# Some ref and test optimizers may require different set of options.
# This is a quick workaround to add that functionality while making
# minimum changes in existing code.
# If there is no "tst_options" field provided, safe to initialize
# the test optimizer with the parameters of reference optimizer.
if not hasattr(self, 'tst_options'):
self.tst_options = self.options
tensor = torch.rand(nelem, dtype=param_type, device=device) tensor = torch.rand(nelem, dtype=param_type, device=device)
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], self.options, apex_only=apex_only) self.gen_param_optim([tensor], self.options, self.tst_options)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_param, tst_param, apex_only=apex_only) self.gen_grad(ref_param, tst_param, apex_only=apex_only)
...@@ -279,8 +294,5 @@ class TestFusedSGD(TestFusedOptimizer): ...@@ -279,8 +294,5 @@ class TestFusedSGD(TestFusedOptimizer):
with torch.cuda.device(current_dev): with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev) self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -26,4 +26,4 @@ for test_dir in test_dirs: ...@@ -26,4 +26,4 @@ for test_dir in test_dirs:
if not result.wasSuccessful(): if not result.wasSuccessful():
errcode = 1 errcode = 1
sys.exit(errcode) sys.exit(errcode)
\ No newline at end of file
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from apex.transformer.tensor_parallel.tests.commons import set_random_seed
from apex.transformer.tensor_parallel.tests.commons import IdentityLayer
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
from apex.transformer.tensor_parallel.tests import global_vars
global_vars.set_global_variables()
def torch_cross_entropy(batch_size, seq_length, vocab_size,
logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda()
logits = identity()
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = F.cross_entropy(logits.view(-1, logits.size()[-1]),
target.view(-1),
reduction='none').view_as(target).mean()
loss.backward()
return loss, identity.weight.grad
def tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda()
logits = identity()
logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
loss.backward()
return loss, identity.weight.grad
def test_cross_entropy(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cross entropy with model parallel size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
batch_size = 13
seq_length = 17
vocab_size_per_partition = 11
logits_scale = 1000.0
vocab_size = vocab_size_per_partition * tensor_model_parallel_size
seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed)
loss_mpu, grad_mpu = tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed)
error = loss_torch.sub_(loss_mpu).abs().max()
print(' max error in loss on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = grad_torch.sub_(grad_mpu).abs().max()
print(' max error in grad on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cross entropy')
test_cross_entropy(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import operator
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import data as data_utils
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_broadcast_data(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing broadcast_data with model parallel size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
torch.manual_seed(1234 + parallel_state.get_data_parallel_rank())
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
key_size_t = {
'key1': [7, 11],
'key2': [8, 2, 1],
'key3': [13],
'key4': [5, 1, 2],
'key5': [5, 12],
}
keys = list(key_size_t.keys())
data = {}
data_t = {}
for key in key_size_t:
data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
data_t[key] = data[key].clone()
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
data_t['keyX'] = data['keyX'].clone()
if parallel_state.get_tensor_model_parallel_rank() != 0:
data = None
data_utils._check_data_types(keys, data_t, torch.int64)
key_size, key_numel, \
total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
for key in keys:
assert key_size[key] == key_size_t[key]
total_numel_t = 0
for key in keys:
target_size = functools.reduce(operator.mul, key_size_t[key], 1)
assert key_numel[key] == target_size
total_numel_t += target_size
assert total_numel == total_numel_t
data_b = data_utils.broadcast_data(keys, data, torch.int64)
for key in keys:
tensor = data_t[key].cuda()
assert data_b[key].sub(tensor).abs().max() == 0
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test test broadcast data')
test_broadcast_data(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_initialize_model_parallel(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing initialize_model_parallel with size {} ...'.format(
tensor_model_parallel_size))
tensor_model_parallel_size_ = min(
tensor_model_parallel_size,
torch.distributed.get_world_size(),
)
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_)
assert parallel_state.model_parallel_is_initialized()
# Checks.
def check(group, world_size, rank):
assert world_size == torch.distributed.get_world_size(group=group)
assert rank == torch.distributed.get_rank(group=group)
# Model parallel.
world_size = tensor_model_parallel_size_
rank = torch.distributed.get_rank() % tensor_model_parallel_size_
assert world_size == parallel_state.get_tensor_model_parallel_world_size()
assert rank == parallel_state.get_tensor_model_parallel_rank()
check(parallel_state.get_tensor_model_parallel_group(), world_size, rank)
# Data parallel.
world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
rank = torch.distributed.get_rank() // tensor_model_parallel_size
assert world_size == parallel_state.get_data_parallel_world_size()
assert rank == parallel_state.get_data_parallel_rank()
check(parallel_state.get_data_parallel_group(), world_size, rank)
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
if torch.distributed.get_rank() == 0:
print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
tensor_model_parallel_size_))
tensor_model_parallel_size = min(
tensor_model_parallel_size_,
torch.distributed.get_world_size(),
)
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
assert parallel_state.model_parallel_is_initialized()
# Checks
src_rank = torch.distributed.get_rank() - parallel_state.get_tensor_model_parallel_rank()
assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test initialize model parallel')
test_initialize_model_parallel(tensor_model_parallel_size)
print_separator('test model parallel source rank')
test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.init as init
from torch.nn.parameter import Parameter
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import set_random_seed
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
class IdentityLayer3D(torch.nn.Module):
def __init__(self, m, n, k):
super(IdentityLayer3D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n, k))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def test_parallel_embedding(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
batch_size = 17
seq_length = 23
vocab_size = 48
hidden_size = 16
seed = 1236
set_random_seed(123)
input_data = torch.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size).cuda()
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
set_random_seed(seed)
embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
output = embedding_original(input_data)
loss_original = torch.mul(output, loss_weight).sum()
loss_original.backward()
set_random_seed(seed)
embedding_parallel = layers.ParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_parallel(input_data)
loss_parallel = torch.mul(output, loss_weight).sum()
loss_parallel.backward()
set_random_seed(seed)
embedding_vocab_parallel = layers.VocabParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_vocab_parallel(input_data)
loss_vocab_parallel = torch.mul(output, loss_weight).sum()
loss_vocab_parallel.backward()
torch.distributed.barrier()
error = loss_parallel.sub(loss_original).abs()
print(' error in loss (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
torch.distributed.barrier()
error = loss_vocab_parallel.sub(loss_original).abs()
print(' error in loss (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // tensor_model_parallel_size,
1)[parallel_state.get_tensor_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // tensor_model_parallel_size,
0)[parallel_state.get_tensor_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_initialize_affine_weight(tensor_model_parallel_size, device):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
# ---------------
# Column parallel
# ---------------
weight = torch.empty(output_size_coeff, input_size)
set_random_seed(seed)
if device == 'cpu':
layers._initialize_affine_weight_cpu(weight, output_size, input_size,
output_size_coeff, 0,
torch.nn.init.normal_,
params_dtype=global_vars.get_args().params_dtype,
)
else:
layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 0)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = parallel_state.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' column parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# ------------
# Row parallel
# ------------
weight = torch.empty(output_size, input_size_coeff)
set_random_seed(seed)
if device == 'cpu':
layers._initialize_affine_weight_cpu(
weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_,
params_dtype=global_vars.get_args().params_dtype)
else:
layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 1)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = parallel_state.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' row parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
class IdentityLayer2D(torch.nn.Module):
def __init__(self, m, n):
super(IdentityLayer2D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def test_column_parallel_linear(tensor_model_parallel_size):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output, _ = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = parallel_state.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
my_dLdb = torch.split(dLdb, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_column_parallel_linear_with_async_allreduce_autocast(tensor_model_parallel_size):
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer3D(batch_size, batch_size, input_size).cuda()
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).cuda()
assert linear_layer.async_tensor_model_parallel_allreduce or tensor_model_parallel_size == 1
# Forward
for dtype in autocast_dtypes:
loss_weight = torch.randn([batch_size, output_size]).cuda()
with torch.cuda.amp.autocast(dtype=dtype):
output, _ = linear_layer(identity_layer())
loss = torch.mul(output, loss_weight).sum()
assert output.dtype == dtype
# Backward
loss.backward()
torch.distributed.barrier()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_column_parallel_linear_with_async_allreduce_custom_amp(tensor_model_parallel_size):
dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
for dtype in dtypes:
# Network
identity_layer = IdentityLayer3D(batch_size, batch_size, input_size).to(device="cuda", dtype=dtype)
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).to(device="cuda", dtype=dtype)
# Forward
loss_weight = torch.randn([batch_size, output_size]).cuda()
output, _ = linear_layer(identity_layer())
loss = torch.mul(output, loss_weight).sum()
loss.backward()
torch.distributed.barrier()
assert output.dtype == dtype
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_row_parallel_linear(tensor_model_parallel_size):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel '
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = layers.RowParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output, _ = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = parallel_state.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
attention_layer = parallel_state.BertParallelSelfAttention(hidden_size, num_att_heads,
dropout_prob).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = attention_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = parallel_state.get_tensor_model_parallel_rank()
parallel_state.destroy_model_parallel()
return rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer
def test_parallel_self_attention(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
dropout_prob = 0.0 # has to be zero
batch_size = 5
sequence_length = 13
rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer = parallel_self_attention(
tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad,
hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max()
torch.distributed.barrier()
print(' weight gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
intermediate_size = 4 * hidden_size
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
transformer_layer = parallel_state.BertParallelTransformerLayer(
hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
torch.nn.functional.relu, 1.0e-5).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = transformer_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = parallel_state.get_tensor_model_parallel_rank()
parallel_state.destroy_model_parallel()
return rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer
def test_parallel_transformer_layer(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
batch_size = 5
sequence_length = 13
rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer(
tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
exceptions = []
print_separator('test initialize affine weight cpu')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
try:
test_initialize_affine_weight(tensor_model_parallel_size, 'cpu')
except Exception as e:
exceptions.append(f"test_initialize_affine_weight-cpu with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
# Reset groups
parallel_state.destroy_model_parallel()
print_separator('test initialize affine weight gpu')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
try:
test_initialize_affine_weight(tensor_model_parallel_size, 'gpu')
except Exception as e:
exceptions.append(f"test_initialize_affine_weight-gpu with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
# Deleted, replaced with vocab parallel embedding?
#tensor_model_parallel_size = 1
#while tensor_model_parallel_size <= world_size:
# print_separator('test parallel embedding')
# test_parallel_embedding(tensor_model_parallel_size)
# tensor_model_parallel_size *= 2
print_separator('test column-parallel linear')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
try:
test_column_parallel_linear(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_column_parallel_linear with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
print_separator('test row-parallel linear')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
try:
test_row_parallel_linear(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_row_parallel_linear with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
print_separator("test ColumnParallelLinearWithAsyncAllreduce - autocast")
tensor_model_parallel_size = 2
while tensor_model_parallel_size <= world_size:
try:
test_column_parallel_linear_with_async_allreduce_autocast(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_column_parallel_linear_with_async_allreduce_autocast with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
print_separator("test ColumnParallelLinearWithAsyncAllreduce - custom AMP")
tensor_model_parallel_size = 2
while tensor_model_parallel_size <= world_size:
try:
test_column_parallel_linear_with_async_allreduce_custom_amp(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_column_parallel_linear_with_async_allreduce_custom_amp with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
if exceptions:
raise RuntimeError("\n".join(exceptions))
# Deleted
#print_separator('test parallel self-attention')
#tensor_model_parallel_size = 1
#while tensor_model_parallel_size <= world_size:
# test_parallel_self_attention(tensor_model_parallel_size)
# tensor_model_parallel_size *= 2
#Deleted because PararallelTransformerLayer no longer exists
# print_separator('test parallel transformer')
# tensor_model_parallel_size = 1
# while tensor_model_parallel_size <= world_size:
# test_parallel_transformer_layer(tensor_model_parallel_size)
# tensor_model_parallel_size *= 2
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel import mappings
from apex.transformer.tensor_parallel.tests import global_vars
global_vars.set_global_variables()
def test__reduce(args, tensor_model_parallel_size):
print("Testing reduction size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
assert torch.equal(
mappings._reduce(torch.full((10, 10, 10, 10), (50))),
torch.full((10, 10, 10, 10), 50 * tensor_model_parallel_size),
)
parallel_state.destroy_model_parallel()
print("Passed!")
def test__split(args, tensor_model_parallel_size):
print("Testing splitting size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
listy = []
for i in range(tensor_model_parallel_size):
listy.append(torch.randn(10, 1))
x = torch.cat(tuple(listy), 1)
out = mappings._split(x)
assert torch.equal(out, listy[parallel_state.get_tensor_model_parallel_rank()])
parallel_state.destroy_model_parallel()
print("Passed!")
def test__gather(args, tensor_model_parallel_size):
print("Testing gathering size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
assert torch.equal(
mappings._gather(torch.tensor([parallel_state.get_tensor_model_parallel_rank()])),
torch.tensor(list(range(tensor_model_parallel_size))),
)
parallel_state.destroy_model_parallel()
print("Passed!")
if __name__ == "__main__":
initialize_distributed()
world_size = torch.distributed.get_world_size()
args = global_vars.get_args()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test__reduce(args, tensor_model_parallel_size)
test__split(args, tensor_model_parallel_size)
test__gather(args, tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print(">> passed the test :-)")
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_set_cuda_rng_state(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
size = 123
seed = 1234
torch.cuda.manual_seed(seed)
tensor = torch.cuda.FloatTensor(size)
# Get the state
rng_state = torch.cuda.get_rng_state()
rng_state_copy = rng_state.clone()
# Do some stuff.
for _ in range(5):
torch.randn(size, out=tensor)
result_1 = tensor.clone()
assert rng_state.sub(rng_state_copy).max() == 0
assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
# State should be different.
new_rng_state = torch.cuda.get_rng_state()
max_diff = new_rng_state.sub(rng_state).max()
print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), max_diff))
assert max_diff > 0
# Reset the rng state and do the same stuff.
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
result_2 = tensor.clone()
# Results should be the same
error = result_2.sub(result_1).abs().max()
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Input state should have remained intact.
error = rng_state.sub(rng_state_copy).max()
print(' max error in rng state (should be zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), error))
assert error == 0
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_cuda_rng_tracker(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed_1 = 1234
seed_2 = 4321
size = [12, 21]
tensor = torch.cuda.FloatTensor(size)
# Set to seed_1 and generate two tensors.
torch.cuda.manual_seed(seed_1)
torch.randn(size, out=tensor)
target_11 = tensor.clone()
torch.randn(size, out=tensor)
target_12 = tensor.clone()
# Set to seed_2 and generate two tensors.
torch.cuda.manual_seed(seed_2)
torch.randn(size, out=tensor)
target_21 = tensor.clone()
torch.randn(size, out=tensor)
target_22 = tensor.clone()
# Now if we interleave seed_1 and seed_2,
# we should still get the same tensors
torch.cuda.manual_seed(seed_1)
tensor_parallel.random.get_cuda_rng_tracker().add('test', seed_2)
torch.randn(size, out=tensor)
result_11 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_21 = tensor.clone()
torch.randn(size, out=tensor)
result_12 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_22 = tensor.clone()
diff = result_11.sub(result_21).abs().max()
diff = min(diff, result_12.sub(result_22).abs().max())
print(' max diff in generated tensors (should be non-zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
assert diff > 1.0e-6
error = max(result_11.sub(target_11).abs().max(),
result_12.sub(target_12).abs().max())
error = max(error, result_21.sub(target_21).abs().max())
error = max(error, result_22.sub(target_22).abs().max())
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset the tracker
tensor_parallel.random.get_cuda_rng_tracker().reset()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print(
'> testing model parallel cuda manual seed with size {} ...'.format(
tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
tensor_parallel.random.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345
with tensor_parallel.random.get_cuda_rng_tracker().fork():
assert (
torch.cuda.initial_seed() ==
12345 + 2718 + parallel_state.get_tensor_model_parallel_rank()
)
# Reset the tracker
tensor_parallel.random.get_cuda_rng_tracker().reset()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test set rng state')
test_set_cuda_rng_state(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cuda rng tracker')
test_cuda_rng_tracker(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
import torch
from apex.transformer.tensor_parallel import utils
def test_divide():
assert utils.divide(8, 4) == 2
def test_split_tensor_along_last_dim():
inputy = torch.randn((100, 100, 100))
splits = utils.split_tensor_along_last_dim(inputy, 10)
last_dim_shapes = torch.tensor([int(split.size()[-1]) for split in splits])
assert torch.equal(last_dim_shapes, torch.full((10,), 10))
if __name__ == "__main__":
test_divide()
test_split_tensor_along_last_dim()
print(">> passed the test :-)")
"""Test for fused softmax functions.
Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py
""" # NOQA
import itertools
import unittest
import torch
from apex.transformer import AttnMaskType
from apex.transformer.functional import FusedScaleMaskSoftmax
def attention_mask_func(attention_scores, attention_mask):
return attention_scores.masked_fill(attention_mask, -10000.0)
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
class TestFusedScaleMaskSoftmax(unittest.TestCase):
def _setup_fused_softmax(self, input_in_fp16, input_in_bf16, scale=None, softmax_in_fp32=False, attn_mask_type=AttnMaskType.padding):
fused_fn = FusedScaleMaskSoftmax(
input_in_fp16=input_in_fp16,
input_in_bf16=input_in_bf16,
mask_func=attention_mask_func,
scale=scale,
softmax_in_fp32=softmax_in_fp32,
attn_mask_type=attn_mask_type,
scaled_masked_softmax_fusion=True,
)
torch_fn = FusedScaleMaskSoftmax(
input_in_fp16=input_in_fp16,
input_in_bf16=input_in_bf16,
mask_func=attention_mask_func,
scale=scale,
softmax_in_fp32=softmax_in_fp32,
attn_mask_type=attn_mask_type,
scaled_masked_softmax_fusion=False,
)
return fused_fn, torch_fn
def test_fused_scale_mask_softmax(self):
"""
attention_scores.shape = [4, 12, 24, 24]
mask.shape = [4, 1, 24, 24]
"""
for (dtype, scale, softmax_in_fp32) in itertools.product(
(torch.half, torch.bfloat16),
(None, 2.0),
(False, True),
):
with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
if not (scale is None or softmax_in_fp32):
with self.assertRaises(RuntimeError):
self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding)
return
fused_fn, torch_fn = self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding)
attention_scores_0 = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True)
with torch.no_grad():
attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
mask = torch.randint(0, 2, (4, 1, 24, 24), device="cuda").bool()
expected = fused_fn(attention_scores_0, mask)
actual = torch_fn(attention_scores_1, mask)
torch.testing.assert_allclose(actual, expected)
g0 = torch.rand_like(actual)
with torch.no_grad():
g1 = g0.clone()
expected.backward(g0)
actual.backward(g1)
def test_autocast_fused_scale_mask_softmax(self):
for dtype in autocast_dtypes:
with self.subTest(f"{dtype}"):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding)
attention_scores_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
with torch.no_grad():
attention_scores_1 = attention_scores_0.clone().to(dtype).requires_grad_(True)
mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda()
expected = torch_fn(attention_scores_1, mask)
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attention_scores_0, mask)
self.assertEqual(actual.dtype, dtype)
torch.testing.assert_allclose(actual, expected)
g0 = torch.rand_like(actual)
with torch.no_grad():
g1 = g0.clone()
expected.backward(g0)
actual.backward(g1)
def test_fused_upper_triangle_mask_softmax(self):
"""
attn_weights.shape: [4, 12, 24, 24]
total_mask.shape: [4, 1, 24, 24]
total_mask[0, 0], a 24x24 matrix is like a lower triangular matrix, but
upper elements are True and lower elements and diagonal are False.
"""
for (dtype, scale, softmax_in_fp32) in itertools.product(
(torch.half, torch.bfloat16),
(None, 2.0),
(False, True),
):
with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
if not (scale is None or softmax_in_fp32):
with self.assertRaises(RuntimeError):
self._setup_fused_softmax(
input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal)
return
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal)
attn_weights_0 = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True)
with torch.no_grad():
attn_weights_1 = attn_weights_0.clone().requires_grad_(True)
total_mask = (~(
torch.tril(torch.randn((24, 24), device="cuda")).bool()
).unsqueeze(0).unsqueeze(0))
total_mask = total_mask.repeat((4, 1, 1, 1))
expected = fused_fn(attn_weights_0, total_mask)
actual = torch_fn(attn_weights_1, total_mask)
torch.testing.assert_allclose(actual, expected)
g0 = torch.randn_like(actual)
with torch.no_grad():
g1 = g0.clone()
actual.backward(g0)
expected.backward(g1)
def test_autocast_fused_upper_triangle_mask_softmax(self):
for dtype in autocast_dtypes:
with self.subTest(f"{dtype}"):
input_in_fp16 = dtype == torch.half
input_in_bf16 = dtype == torch.bfloat16
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal)
attn_weights_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
with torch.no_grad():
attn_weights_1 = attn_weights_0.clone().to(dtype).requires_grad_(True)
total_mask = (~(
torch.tril(torch.randn((24, 24), device="cuda")).bool()
).unsqueeze(0).unsqueeze(0))
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attn_weights_0, total_mask)
self.assertEqual(actual.dtype, dtype)
expected = torch_fn(attn_weights_1, total_mask)
torch.testing.assert_allclose(actual, expected)
g0 = torch.randn_like(actual)
with torch.no_grad():
g1 = g0.clone()
actual.backward(g0)
expected.backward(g1)
import os
import subprocess
import sys
import unittest
def run_mpu_tests():
python_executable_path = sys.executable
# repository_root = os.path.join(os.path.dirname(__file__), "../../../")
# directory = os.path.abspath(os.path.join(repository_root, "tests/mpu"))
directory = os.path.dirname(__file__)
files = [
os.path.join(directory, f) for f in os.listdir(directory)
if f.startswith("run_") and os.path.isfile(os.path.join(directory, f))
]
print("#######################################################")
print(f"# Python executable path: {python_executable_path}")
print(f"# {len(files)} tests: {files}")
print("#######################################################")
errors = []
for i, test_file in enumerate(files, 1):
test_run_cmd = f"NVIDIA_TF32_OVERRIDE=0 {python_executable_path} {test_file} --micro-batch-size 2 --num-layers 1 --hidden-size 256 --num-attention-heads 8 --max-position-embeddings 32 --encoder-seq-length 32 --use-cpu-initialization" # NOQA
print(f"### {i} / {len(files)}: cmd: {test_run_cmd}")
try:
output = subprocess.check_output(
test_run_cmd, shell=True
).decode(sys.stdout.encoding).strip()
except Exception as e:
errors.append((test_file, str(e)))
else:
if '>> passed the test :-)' not in output:
errors.append(test_file, output)
else:
if not errors:
print("### PASSED")
else:
print("### FAILED")
short_msg = f"{len(errors)} out of {len(files)} tests failed"
print(short_msg)
for (filename, log) in errors:
print(f"File: {filename}\nLog: {log}")
raise RuntimeError(short_msg)
class TestMPU(unittest.TestCase):
def test_mpu(self):
run_mpu_tests()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment