Unverified Commit 1f2aa915 authored by ptrblck's avatar ptrblck Committed by GitHub
Browse files

CUDAGenerator fix for #36026 (#801)



* add CUDAGenerator guard

* fix generator_flag

* add guards for gen pointer/ref issue

* change mutex_ to mutex()

* add check_generator
Co-authored-by: default avatarpbialecki <pbialecki@nvidia.com>
parent 71511faf
#include <ATen/ATen.h>
#ifdef OLD_GENERATOR
#include <ATen/CUDAGenerator.h>
#else
#include <ATen/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h>
......@@ -206,8 +212,13 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
#ifdef OLD_GENERATOR
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);
#endif
}
apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, p, rng_engine_inputs);
......@@ -239,8 +250,13 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
#ifdef OLD_GENERATOR
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
#else
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);
#endif
}
apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs);
......
......@@ -225,6 +225,12 @@ if "--deprecated_fused_lamb" in sys.argv:
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
generator_flag = ['-DOLD_GENERATOR']
if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_multihead_attn")
......@@ -240,7 +246,7 @@ if "--fast_multihead_attn" in sys.argv:
CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
......@@ -248,12 +254,12 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
......@@ -261,12 +267,12 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
......@@ -274,12 +280,12 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
......@@ -287,7 +293,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
'--use_fast_math'] + version_dependent_macros + generator_flag}))
setup(
name='apex',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment