Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
9b4c68c7
Commit
9b4c68c7
authored
Dec 08, 2020
by
lcskrishna
Browse files
updated hipify changes for apex contrib
parent
ef209a74
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
13 deletions
+51
-13
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
+5
-0
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
+4
-1
setup.py
setup.py
+42
-12
No files found.
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
View file @
9b4c68c7
...
...
@@ -9,7 +9,12 @@
// #include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#if HIP_VERSION >= 310
#include "multi_tensor_apply_hip.cuh"
#else
#include "multi_tensor_apply.cuh"
#endif
#define BLOCK_SIZE 512
#define ILP 4
...
...
apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
View file @
9b4c68c7
...
...
@@ -8,8 +8,11 @@
#include <assert.h>
#include "type_shim.h"
#if HIP_VERSION >= 310
#include "multi_tensor_apply_hip.cuh"
#else
#include "multi_tensor_apply.cuh"
#endif
#define BLOCK_SIZE 512
#define ILP 4
...
...
setup.py
View file @
9b4c68c7
...
...
@@ -333,12 +333,20 @@ if "--xentropy" in sys.argv:
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:[
'-O3'
]
+
version_dependent_macros
}))
else
:
xentropy_sources_v1_8
=
[
'apex/contrib/csrc/xentropy/interface.cpp'
,
'apex/contrib/csrc/xentropy/xentropy_kernel.hip'
]
xentropy_sources_other
=
[
'apex/contrib/csrc/xentropy/interface.cpp'
,
'apex/contrib/csrc/xentropy/hip/xentropy_kernel.hip'
]
ext_modules
.
append
(
CUDAExtension
(
name
=
'xentropy_cuda'
,
sources
=
[
'apex/contrib/csrc/xentropy/interface.cpp'
,
'apex/contrib/csrc/xentropy/hip/xentropy_kernel.hip'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc/hip'
)],
extra_compile_args
=
[
'-O3'
]
+
version_dependent_macros
))
CUDAExtension
(
name
=
'xentropy_cuda'
,
sources
=
xentropy_sources_v1_8
if
torch
.
__version__
>=
'1.8'
else
xentropy_sources_other
,
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
)
if
torch
.
__version__
>=
'1.8'
else
os
.
path
.
join
(
this_dir
,
'csrc/hip'
)],
extra_compile_args
=
[
'-O3'
]
+
version_dependent_macros
))
#ext_modules.append(
# CUDAExtension(name='xentropy_cuda',
# sources=['apex/contrib/csrc/xentropy/interface.cpp',
# 'apex/contrib/csrc/xentropy/hip/xentropy_kernel.hip'],
# include_dirs=[os.path.join(this_dir, 'csrc/hip')],
# extra_compile_args=['-O3'] + version_dependent_macros))
if
"--deprecated_fused_adam"
in
sys
.
argv
:
...
...
@@ -364,12 +372,23 @@ if "--deprecated_fused_adam" in sys.argv:
'--use_fast_math'
]
+
version_dependent_macros
}))
else
:
print
(
"INFO: Building deprecated fused adam."
)
fused_adam_sources_v1_8
=
[
'apex/contrib/csrc/optimizers/fused_adam_cuda.cpp'
,
'apex/contrib/csrc/optimizers/fused_adam_hip_kernel.hip'
]
fused_adam_sources_other
=
[
'apex/contrib/csrc/optimizers/fused_adam_cuda.cpp'
,
'apex/contrib/csrc/optimizers/hip/fused_adam_hip_kernel.hip'
]
ext_modules
.
append
(
CUDAExtension
(
name
=
'fused_adam_cuda'
,
sources
=
[
'apex/contrib/csrc/optimizers/fused_adam_cuda.cpp'
,
'apex/contrib/csrc/optimizers/hip/fused_adam_hip_kernel.hip'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc/hip'
)],
sources
=
fused_adam_sources_v1_8
if
torch
.
__version__
>=
'1.8'
else
fused_adam_sources_other
,
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
)
if
torch
.
__version__
>=
'1.8'
else
os
.
path
.
join
(
this_dir
,
'csrc/hip'
)],
extra_compile_args
=
[
'-O3'
]
+
version_dependent_macros
))
#ext_modules.append(
# CUDAExtension(name='fused_adam_cuda',
# sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
# 'apex/contrib/csrc/optimizers/hip/fused_adam_hip_kernel.hip'],
# include_dirs=[os.path.join(this_dir, 'csrc/hip')],
# extra_compile_args=['-O3'] + version_dependent_macros))
if
"--deprecated_fused_lamb"
in
sys
.
argv
:
from
torch.utils.cpp_extension
import
CUDAExtension
...
...
@@ -395,13 +414,24 @@ if "--deprecated_fused_lamb" in sys.argv:
'--use_fast_math'
]
+
version_dependent_macros
}))
else
:
print
(
"INFO: Building deprecated fused lamb."
)
fused_lamb_sources_v1_8
=
[
'apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp'
,
'apex/contrib/csrc/optimizers/fused_lamb_hip_kernel.hip'
]
fused_lamb_sources_other
=
[
'apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp'
,
'apex/contrib/csrc/optimizers/hip/fused_lamb_hip_kernel.hip'
]
ext_modules
.
append
(
CUDAExtension
(
name
=
'fused_lamb_cuda'
,
sources
=
[
'apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp'
,
'apex/contrib/csrc/optimizers/hip/fused_lamb_hip_kernel.hip'
,
'csrc/hip/multi_tensor_l2norm_kernel.hip'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc/hip'
)],
sources
=
fused_lamb_sources_v1_8
if
torch
.
__version__
>=
'1.8'
else
fused_lamb_sources_other
,
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
)
if
torch
.
__version__
>=
'1.8'
else
os
.
path
.
join
(
this_dir
,
'csrc/hip'
)],
extra_compile_args
=
[
'-O3'
]
+
version_dependent_macros
))
#ext_modules.append(
# CUDAExtension(name='fused_lamb_cuda',
# sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
# 'apex/contrib/csrc/optimizers/hip/fused_lamb_hip_kernel.hip',
# 'csrc/hip/multi_tensor_l2norm_kernel.hip'],
# include_dirs=[os.path.join(this_dir, 'csrc/hip')],
# extra_compile_args=['-O3'] + version_dependent_macros))
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
generator_flag
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment