Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
a6cdd6c7
Unverified
Commit
a6cdd6c7
authored
Apr 02, 2021
by
Michael Melesse
Committed by
GitHub
Apr 02, 2021
Browse files
[ROCM] Add ROCm support to source build (#1411)
parent
404fa12a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
259 additions
and
2 deletions
+259
-2
CMakeLists.txt
CMakeLists.txt
+8
-0
build_tools/setup_helpers/extension.py
build_tools/setup_helpers/extension.py
+2
-0
cmake/LoadHIP.cmake
cmake/LoadHIP.cmake
+234
-0
test/torchaudio_unittest/backend/soundfile/save_test.py
test/torchaudio_unittest/backend/soundfile/save_test.py
+2
-0
test/torchaudio_unittest/common_utils/__init__.py
test/torchaudio_unittest/common_utils/__init__.py
+3
-2
test/torchaudio_unittest/common_utils/case_utils.py
test/torchaudio_unittest/common_utils/case_utils.py
+2
-0
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
...audio_unittest/functional/torchscript_consistency_impl.py
+4
-0
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
...audio_unittest/transforms/torchscript_consistency_impl.py
+4
-0
No files found.
CMakeLists.txt
View file @
a6cdd6c7
...
...
@@ -18,6 +18,14 @@ endif()
project
(
torchaudio
)
# Find the HIP package, set the HIP paths, load the HIP CMake.
if
(
USE_ROCM
)
include
(
cmake/LoadHIP.cmake
)
if
(
NOT PYTORCH_FOUND_HIP
)
set
(
USE_ROCM OFF
)
endif
()
endif
()
# check and set CMAKE_CXX_STANDARD
string
(
FIND
"
${
CMAKE_CXX_FLAGS
}
"
"-std=c++"
env_cxx_standard
)
if
(
env_cxx_standard GREATER -1
)
...
...
build_tools/setup_helpers/extension.py
View file @
a6cdd6c7
...
...
@@ -37,6 +37,7 @@ def _get_build(var, default=False):
_BUILD_SOX
=
False
if
platform
.
system
()
==
'Windows'
else
_get_build
(
"BUILD_SOX"
)
_BUILD_KALDI
=
False
if
platform
.
system
()
==
'Windows'
else
_get_build
(
"BUILD_KALDI"
,
True
)
_BUILD_TRANSDUCER
=
_get_build
(
"BUILD_TRANSDUCER"
)
_USE_ROCM
=
_get_build
(
"USE_ROCM"
)
def
get_ext_modules
():
...
...
@@ -74,6 +75,7 @@ class CMakeBuild(build_ext):
f
"-DBUILD_TRANSDUCER:BOOL=
{
'ON'
if
_BUILD_TRANSDUCER
else
'OFF'
}
"
,
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON"
,
"-DBUILD_LIBTORCHAUDIO:BOOL=OFF"
,
f
"-DUSE_ROCM:BOOL=
{
'ON'
if
_USE_ROCM
else
'OFF'
}
"
,
]
build_args
=
[
'--target'
,
'install'
...
...
cmake/LoadHIP.cmake
0 → 100644
View file @
a6cdd6c7
set
(
PYTORCH_FOUND_HIP FALSE
)
if
(
NOT DEFINED ENV{ROCM_PATH}
)
set
(
ROCM_PATH /opt/rocm
)
else
()
set
(
ROCM_PATH $ENV{ROCM_PATH}
)
endif
()
# HIP_PATH
if
(
NOT DEFINED ENV{HIP_PATH}
)
set
(
HIP_PATH
${
ROCM_PATH
}
/hip
)
else
()
set
(
HIP_PATH $ENV{HIP_PATH}
)
endif
()
if
(
NOT EXISTS
${
HIP_PATH
}
)
return
()
endif
()
# HCC_PATH
if
(
NOT DEFINED ENV{HCC_PATH}
)
set
(
HCC_PATH
${
ROCM_PATH
}
/hcc
)
else
()
set
(
HCC_PATH $ENV{HCC_PATH}
)
endif
()
# HSA_PATH
if
(
NOT DEFINED ENV{HSA_PATH}
)
set
(
HSA_PATH
${
ROCM_PATH
}
/hsa
)
else
()
set
(
HSA_PATH $ENV{HSA_PATH}
)
endif
()
# ROCBLAS_PATH
if
(
NOT DEFINED ENV{ROCBLAS_PATH}
)
set
(
ROCBLAS_PATH
${
ROCM_PATH
}
/rocblas
)
else
()
set
(
ROCBLAS_PATH $ENV{ROCBLAS_PATH}
)
endif
()
# ROCFFT_PATH
if
(
NOT DEFINED ENV{ROCFFT_PATH}
)
set
(
ROCFFT_PATH
${
ROCM_PATH
}
/rocfft
)
else
()
set
(
ROCFFT_PATH $ENV{ROCFFT_PATH}
)
endif
()
# HIPFFT_PATH
if
(
NOT DEFINED ENV{HIPFFT_PATH}
)
set
(
HIPFFT_PATH
${
ROCM_PATH
}
/hipfft
)
else
()
set
(
HIPFFT_PATH $ENV{HIPFFT_PATH}
)
endif
()
# HIPSPARSE_PATH
if
(
NOT DEFINED ENV{HIPSPARSE_PATH}
)
set
(
HIPSPARSE_PATH
${
ROCM_PATH
}
/hipsparse
)
else
()
set
(
HIPSPARSE_PATH $ENV{HIPSPARSE_PATH}
)
endif
()
# THRUST_PATH
if
(
DEFINED ENV{THRUST_PATH}
)
set
(
THRUST_PATH $ENV{THRUST_PATH}
)
else
()
set
(
THRUST_PATH
${
ROCM_PATH
}
/include
)
endif
()
# HIPRAND_PATH
if
(
NOT DEFINED ENV{HIPRAND_PATH}
)
set
(
HIPRAND_PATH
${
ROCM_PATH
}
/hiprand
)
else
()
set
(
HIPRAND_PATH $ENV{HIPRAND_PATH}
)
endif
()
# ROCRAND_PATH
if
(
NOT DEFINED ENV{ROCRAND_PATH}
)
set
(
ROCRAND_PATH
${
ROCM_PATH
}
/rocrand
)
else
()
set
(
ROCRAND_PATH $ENV{ROCRAND_PATH}
)
endif
()
# MIOPEN_PATH
if
(
NOT DEFINED ENV{MIOPEN_PATH}
)
set
(
MIOPEN_PATH
${
ROCM_PATH
}
/miopen
)
else
()
set
(
MIOPEN_PATH $ENV{MIOPEN_PATH}
)
endif
()
# RCCL_PATH
if
(
NOT DEFINED ENV{RCCL_PATH}
)
set
(
RCCL_PATH
${
ROCM_PATH
}
/rccl
)
else
()
set
(
RCCL_PATH $ENV{RCCL_PATH}
)
endif
()
# ROCPRIM_PATH
if
(
NOT DEFINED ENV{ROCPRIM_PATH}
)
set
(
ROCPRIM_PATH
${
ROCM_PATH
}
/rocprim
)
else
()
set
(
ROCPRIM_PATH $ENV{ROCPRIM_PATH}
)
endif
()
# HIPCUB_PATH
if
(
NOT DEFINED ENV{HIPCUB_PATH}
)
set
(
HIPCUB_PATH
${
ROCM_PATH
}
/hipcub
)
else
()
set
(
HIPCUB_PATH $ENV{HIPCUB_PATH}
)
endif
()
# ROCTHRUST_PATH
if
(
NOT DEFINED ENV{ROCTHRUST_PATH}
)
set
(
ROCTHRUST_PATH
${
ROCM_PATH
}
/rocthrust
)
else
()
set
(
ROCTHRUST_PATH $ENV{ROCTHRUST_PATH}
)
endif
()
# ROCTRACER_PATH
if
(
NOT DEFINED ENV{ROCTRACER_PATH}
)
set
(
ROCTRACER_PATH
${
ROCM_PATH
}
/roctracer
)
else
()
set
(
ROCTRACER_PATH $ENV{ROCTRACER_PATH}
)
endif
()
if
(
NOT DEFINED ENV{PYTORCH_ROCM_ARCH}
)
set
(
PYTORCH_ROCM_ARCH gfx803;gfx900;gfx906;gfx908
)
else
()
set
(
PYTORCH_ROCM_ARCH $ENV{PYTORCH_ROCM_ARCH}
)
endif
()
# Add HIP to the CMAKE Module Path
set
(
CMAKE_MODULE_PATH
${
HIP_PATH
}
/cmake
${
CMAKE_MODULE_PATH
}
)
# Disable Asserts In Code (Can't use asserts on HIP stack.)
add_definitions
(
-DNDEBUG
)
macro
(
find_package_and_print_version PACKAGE_NAME
)
find_package
(
"
${
PACKAGE_NAME
}
"
${
ARGN
}
)
message
(
"
${
PACKAGE_NAME
}
VERSION:
${${
PACKAGE_NAME
}
_VERSION
}
"
)
endmacro
()
# Find the HIP Package
find_package_and_print_version
(
HIP 1.0
)
if
(
HIP_FOUND
)
set
(
PYTORCH_FOUND_HIP TRUE
)
# Find ROCM version for checks
file
(
READ
"
${
ROCM_PATH
}
/.info/version-dev"
ROCM_VERSION_DEV_RAW
)
string
(
REGEX MATCH
"^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$"
ROCM_VERSION_DEV_MATCH
${
ROCM_VERSION_DEV_RAW
}
)
if
(
ROCM_VERSION_DEV_MATCH
)
set
(
ROCM_VERSION_DEV_MAJOR
${
CMAKE_MATCH_1
}
)
set
(
ROCM_VERSION_DEV_MINOR
${
CMAKE_MATCH_2
}
)
set
(
ROCM_VERSION_DEV_PATCH
${
CMAKE_MATCH_3
}
)
set
(
ROCM_VERSION_DEV
"
${
ROCM_VERSION_DEV_MAJOR
}
.
${
ROCM_VERSION_DEV_MINOR
}
.
${
ROCM_VERSION_DEV_PATCH
}
"
)
endif
()
message
(
"
\n
***** ROCm version from
${
ROCM_PATH
}
/.info/version-dev ****
\n
"
)
message
(
"ROCM_VERSION_DEV:
${
ROCM_VERSION_DEV
}
"
)
message
(
"ROCM_VERSION_DEV_MAJOR:
${
ROCM_VERSION_DEV_MAJOR
}
"
)
message
(
"ROCM_VERSION_DEV_MINOR:
${
ROCM_VERSION_DEV_MINOR
}
"
)
message
(
"ROCM_VERSION_DEV_PATCH:
${
ROCM_VERSION_DEV_PATCH
}
"
)
message
(
"
\n
***** Library versions from dpkg *****
\n
"
)
execute_process
(
COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk
"{print $2
\"
VERSION:
\"
$3}"
)
execute_process
(
COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk
"{print $2
\"
VERSION:
\"
$3}"
)
execute_process
(
COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk
"{print $2
\"
VERSION:
\"
$3}"
)
execute_process
(
COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk
"{print $2
\"
VERSION:
\"
$3}"
)
execute_process
(
COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk
"{print $2
\"
VERSION:
\"
$3}"
)
execute_process
(
COMMAND dpkg -l COMMAND grep hip_base COMMAND awk
"{print $2
\"
VERSION:
\"
$3}"
)
execute_process
(
COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk
"{print $2
\"
VERSION:
\"
$3}"
)
message
(
"
\n
***** Library versions from cmake find_package *****
\n
"
)
set
(
CMAKE_HCC_FLAGS_DEBUG
${
CMAKE_CXX_FLAGS_DEBUG
}
)
set
(
CMAKE_HCC_FLAGS_RELEASE
${
CMAKE_CXX_FLAGS_RELEASE
}
)
### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.###
set
(
hip_DIR
${
HIP_PATH
}
/lib/cmake/hip
)
set
(
hsa-runtime64_DIR
${
ROCM_PATH
}
/lib/cmake/hsa-runtime64
)
set
(
AMDDeviceLibs_DIR
${
ROCM_PATH
}
/lib/cmake/AMDDeviceLibs
)
set
(
amd_comgr_DIR
${
ROCM_PATH
}
/lib/cmake/amd_comgr
)
set
(
rocrand_DIR
${
ROCRAND_PATH
}
/lib/cmake/rocrand
)
set
(
hiprand_DIR
${
HIPRAND_PATH
}
/lib/cmake/hiprand
)
set
(
rocblas_DIR
${
ROCBLAS_PATH
}
/lib/cmake/rocblas
)
set
(
miopen_DIR
${
MIOPEN_PATH
}
/lib/cmake/miopen
)
set
(
rocfft_DIR
${
ROCFFT_PATH
}
/lib/cmake/rocfft
)
set
(
hipfft_DIR
${
HIPFFT_PATH
}
/lib/cmake/hipfft
)
set
(
hipsparse_DIR
${
HIPSPARSE_PATH
}
/lib/cmake/hipsparse
)
set
(
rccl_DIR
${
RCCL_PATH
}
/lib/cmake/rccl
)
set
(
rocprim_DIR
${
ROCPRIM_PATH
}
/lib/cmake/rocprim
)
set
(
hipcub_DIR
${
HIPCUB_PATH
}
/lib/cmake/hipcub
)
set
(
rocthrust_DIR
${
ROCTHRUST_PATH
}
/lib/cmake/rocthrust
)
find_package_and_print_version
(
hip REQUIRED
)
find_package_and_print_version
(
hsa-runtime64 REQUIRED
)
find_package_and_print_version
(
amd_comgr REQUIRED
)
find_package_and_print_version
(
rocrand REQUIRED
)
find_package_and_print_version
(
hiprand REQUIRED
)
find_package_and_print_version
(
rocblas REQUIRED
)
find_package_and_print_version
(
miopen REQUIRED
)
find_package_and_print_version
(
rocfft REQUIRED
)
if
(
ROCM_VERSION_DEV VERSION_GREATER_EQUAL
"4.1.0"
)
find_package_and_print_version
(
hipfft REQUIRED
)
endif
()
find_package_and_print_version
(
hipsparse REQUIRED
)
find_package_and_print_version
(
rccl
)
find_package_and_print_version
(
rocprim REQUIRED
)
find_package_and_print_version
(
hipcub REQUIRED
)
find_package_and_print_version
(
rocthrust REQUIRED
)
if
(
HIP_COMPILER STREQUAL clang
)
set
(
hip_library_name amdhip64
)
else
()
set
(
hip_library_name hip_hcc
)
endif
()
message
(
"HIP library name:
${
hip_library_name
}
"
)
# TODO: hip_hcc has an interface include flag "-hc" which is only
# recognizable by hcc, but not gcc and clang. Right now in our
# setup, hcc is only used for linking, but it should be used to
# compile the *_hip.cc files as well.
find_library
(
PYTORCH_HIP_HCC_LIBRARIES
${
hip_library_name
}
HINTS
${
HIP_PATH
}
/lib
)
# TODO: miopen_LIBRARIES should return fullpath to the library file,
# however currently it's just the lib name
find_library
(
PYTORCH_MIOPEN_LIBRARIES
${
miopen_LIBRARIES
}
HINTS
${
MIOPEN_PATH
}
/lib
)
# TODO: rccl_LIBRARIES should return fullpath to the library file,
# however currently it's just the lib name
find_library
(
PYTORCH_RCCL_LIBRARIES
${
rccl_LIBRARIES
}
HINTS
${
RCCL_PATH
}
/lib
)
# hiprtc is part of HIP
find_library
(
ROCM_HIPRTC_LIB
${
hip_library_name
}
HINTS
${
HIP_PATH
}
/lib
)
# roctx is part of roctracer
find_library
(
ROCM_ROCTX_LIB roctx64 HINTS
${
ROCTRACER_PATH
}
/lib
)
set
(
roctracer_INCLUDE_DIRS
${
ROCTRACER_PATH
}
/include
)
endif
()
test/torchaudio_unittest/backend/soundfile/save_test.py
View file @
a6cdd6c7
...
...
@@ -11,6 +11,7 @@ from torchaudio_unittest.common_utils import (
get_wav_data
,
load_wav
,
nested_params
,
skipIfRocm
,
)
from
.common
import
(
fetch_wav_subtype
,
...
...
@@ -280,6 +281,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
self
.
_test_fileobj
(
'wav'
)
@
skipIfFormatNotSupported
(
"FLAC"
)
@
skipIfRocm
def
test_fileobj_flac
(
self
):
"""Saving audio via file-like object works"""
self
.
_test_fileobj
(
'flac'
)
...
...
test/torchaudio_unittest/common_utils/__init__.py
View file @
a6cdd6c7
...
...
@@ -17,6 +17,7 @@ from .case_utils import (
skipIfNoModule
,
skipIfNoKaldi
,
skipIfNoSox
,
skipIfRocm
,
)
from
.wav_utils
import
(
get_wav_data
,
...
...
@@ -32,5 +33,5 @@ from .parameterized_utils import (
__all__
=
[
'get_asset_path'
,
'get_whitenoise'
,
'get_sinusoid'
,
'set_audio_backend'
,
'TempDirMixin'
,
'HttpServerMixin'
,
'TestBaseMixin'
,
'PytorchTestCase'
,
'TorchaudioTestCase'
,
'skipIfNoCuda'
,
'skipIfNoExec'
,
'skipIfNoModule'
,
'skipIfNoKaldi'
,
'skipIfNoSox'
,
'skipIfNoSoxBackend'
,
'get_wav_data'
,
'normalize_wav'
,
'load_wav'
,
'save_wav'
,
'load_params'
,
'nested_params'
]
'skipIfNoSoxBackend'
,
'skipIfRocm'
,
'get_wav_data'
,
'normalize_wav'
,
'load_wav'
,
'save_wav'
,
'load_params'
,
'nested_params'
]
test/torchaudio_unittest/common_utils/case_utils.py
View file @
a6cdd6c7
...
...
@@ -98,3 +98,5 @@ def skipIfNoModule(module, display_name=None):
skipIfNoCuda
=
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
skipIfNoSox
=
unittest
.
skipIf
(
not
is_sox_available
(),
reason
=
'Sox not available'
)
skipIfNoKaldi
=
unittest
.
skipIf
(
not
is_kaldi_available
(),
reason
=
'Kaldi not available'
)
skipIfRocm
=
unittest
.
skipIf
(
os
.
getenv
(
'TORCHAUDIO_TEST_WITH_ROCM'
,
'0'
)
==
'1'
,
reason
=
"test doesn't currently work on the ROCm stack"
)
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
View file @
a6cdd6c7
...
...
@@ -5,6 +5,9 @@ import torch
import
torchaudio.functional
as
F
from
torchaudio_unittest
import
common_utils
from
torchaudio_unittest.common_utils
import
(
skipIfRocm
,
)
class
Functional
(
common_utils
.
TestBaseMixin
):
...
...
@@ -34,6 +37,7 @@ class Functional(common_utils.TestBaseMixin):
tensor
=
common_utils
.
get_whitenoise
()
self
.
_assert_consistency
(
func
,
tensor
)
@
skipIfRocm
def
test_griffinlim
(
self
):
def
func
(
tensor
):
n_fft
=
400
...
...
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
View file @
a6cdd6c7
...
...
@@ -4,6 +4,9 @@ import torch
import
torchaudio.transforms
as
T
from
torchaudio_unittest
import
common_utils
from
torchaudio_unittest.common_utils
import
(
skipIfRocm
,
)
class
Transforms
(
common_utils
.
TestBaseMixin
):
...
...
@@ -21,6 +24,7 @@ class Transforms(common_utils.TestBaseMixin):
tensor
=
torch
.
rand
((
1
,
1000
))
self
.
_assert_consistency
(
T
.
Spectrogram
(),
tensor
)
@
skipIfRocm
def
test_GriffinLim
(
self
):
tensor
=
torch
.
rand
((
1
,
201
,
6
))
self
.
_assert_consistency
(
T
.
GriffinLim
(
length
=
1000
,
rand_init
=
False
),
tensor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment