Commit 8d9c0ce9 authored by Karthik Kashyap Thatipamula's avatar Karthik Kashyap Thatipamula
Browse files

Add rocm support. Fix issues in upstream repo.

parents f2a32f9d 08ce574a
...@@ -285,6 +285,9 @@ if(USE_ROCM) ...@@ -285,6 +285,9 @@ if(USE_ROCM)
endif() endif()
message(STATUS "CMAKE_HIP_FLAGS: ${CMAKE_HIP_FLAGS}") message(STATUS "CMAKE_HIP_FLAGS: ${CMAKE_HIP_FLAGS}")
# Building for ROCm almost always means USE_CUDA.
# Exceptions to this will be guarded by USE_ROCM.
add_definitions(-DUSE_CUDA)
add_definitions(-DUSE_ROCM) add_definitions(-DUSE_ROCM)
endif() endif()
...@@ -471,10 +474,21 @@ set( ...@@ -471,10 +474,21 @@ set(
src/cuda/cuda_algorithms.cu src/cuda/cuda_algorithms.cu
) )
if(USE_CUDA) if(USE_CUDA OR USE_ROCM)
list(APPEND LGBM_SOURCES ${LGBM_CUDA_SOURCES}) list(APPEND LGBM_SOURCES ${LGBM_CUDA_SOURCES})
endif() endif()
if(USE_ROCM)
set(CU_FILES "")
foreach(file IN LISTS LGBM_CUDA_SOURCES)
string(REGEX MATCH "\\.cu$" is_cu_file "${file}")
if(is_cu_file)
list(APPEND CU_FILES "${file}")
endif()
endforeach()
set_source_files_properties(${CU_FILES} PROPERTIES LANGUAGE HIP)
endif()
add_library(lightgbm_objs OBJECT ${LGBM_SOURCES}) add_library(lightgbm_objs OBJECT ${LGBM_SOURCES})
if(BUILD_CLI) if(BUILD_CLI)
...@@ -623,6 +637,10 @@ if(USE_CUDA) ...@@ -623,6 +637,10 @@ if(USE_CUDA)
endif() endif()
endif() endif()
if(USE_ROCM)
target_link_libraries(lightgbm_objs PUBLIC hip::host)
endif()
if(WIN32) if(WIN32)
if(MINGW OR CYGWIN) if(MINGW OR CYGWIN)
target_link_libraries(lightgbm_objs PUBLIC ws2_32 iphlpapi) target_link_libraries(lightgbm_objs PUBLIC ws2_32 iphlpapi)
......
...@@ -54,6 +54,8 @@ ...@@ -54,6 +54,8 @@
# --precompile # --precompile
# Use precompiled library. # Use precompiled library.
# Only used with 'install' command. # Only used with 'install' command.
# --rocm
# Compile ROCm version.
# --time-costs # --time-costs
# Compile version that outputs time costs for different internal routines. # Compile version that outputs time costs for different internal routines.
# --user # --user
...@@ -142,6 +144,9 @@ while [ $# -gt 0 ]; do ...@@ -142,6 +144,9 @@ while [ $# -gt 0 ]; do
--cuda) --cuda)
BUILD_ARGS="${BUILD_ARGS} --config-setting=cmake.define.USE_CUDA=ON" BUILD_ARGS="${BUILD_ARGS} --config-setting=cmake.define.USE_CUDA=ON"
;; ;;
--rocm)
BUILD_ARGS="${BUILD_ARGS} --config-setting=cmake.define.USE_ROCM=ON"
;;
--gpu) --gpu)
BUILD_ARGS="${BUILD_ARGS} --config-setting=cmake.define.USE_GPU=ON" BUILD_ARGS="${BUILD_ARGS} --config-setting=cmake.define.USE_GPU=ON"
;; ;;
......
...@@ -742,6 +742,65 @@ macOS ...@@ -742,6 +742,65 @@ macOS
The CUDA version is not supported on macOS. The CUDA version is not supported on macOS.
Build ROCm Version
~~~~~~~~~~~~~~~~~~
The `original GPU version <#build-gpu-version>`__ of LightGBM (``device_type=gpu``) is based on OpenCL.
The ROCm-based version (``device_type=cuda``) is a separate implementation. Yes, the ROCm version reuses the ``device_type=cuda`` as a convenience for users. Use this version in Linux environments with an AMD GPU.
Windows
^^^^^^^
The ROCm version is not supported on Windows.
Use the `GPU version <#build-gpu-version>`__ (``device_type=gpu``) for GPU acceleration on Windows.
Linux
^^^^^
On Linux, a ROCm version of LightGBM can be built using
- **CMake**, **gcc** and **ROCm**;
- **CMake**, **Clang** and **ROCm**.
Please refer to `the ROCm docs`_ for **ROCm** libraries installation.
After compilation the executable and ``.so`` files will be in ``LightGBM/`` folder.
gcc
***
1. Install `CMake`_, **gcc** and **ROCm**.
2. Run the following commands:
.. code:: sh
git clone --recursive https://github.com/microsoft/LightGBM
cd LightGBM
cmake -B build -S . -DUSE_ROCM=ON
cmake --build build -j4
Clang
*****
1. Install `CMake`_, **Clang**, **OpenMP** and **ROCm**.
2. Run the following commands:
.. code:: sh
git clone --recursive https://github.com/microsoft/LightGBM
cd LightGBM
export CXX=clang++-14 CC=clang-14 # replace "14" with version of Clang installed on your machine
cmake -B build -S . -DUSE_ROCM=ON
cmake --build build -j4
macOS
^^^^^
The ROCm version is not supported on macOS.
Build Java Wrapper Build Java Wrapper
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
...@@ -1051,6 +1110,8 @@ gcc ...@@ -1051,6 +1110,8 @@ gcc
.. _this detailed guide: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html .. _this detailed guide: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html
.. _the ROCm docs: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/
.. _following docs: https://github.com/google/sanitizers/wiki .. _following docs: https://github.com/google/sanitizers/wiki
.. _Ninja: https://ninja-build.org .. _Ninja: https://ninja-build.org
...@@ -264,7 +264,7 @@ Core Parameters ...@@ -264,7 +264,7 @@ Core Parameters
- ``cpu`` supports all LightGBM functionality and is portable across the widest range of operating systems and hardware - ``cpu`` supports all LightGBM functionality and is portable across the widest range of operating systems and hardware
- ``cuda`` offers faster training than ``gpu`` or ``cpu``, but only works on GPUs supporting CUDA - ``cuda`` offers faster training than ``gpu`` or ``cpu``, but only works on GPUs supporting CUDA or ROCm
- ``gpu`` can be faster than ``cpu`` and works on a wider range of GPUs than CUDA - ``gpu`` can be faster than ``cpu`` and works on a wider range of GPUs than CUDA
...@@ -272,7 +272,7 @@ Core Parameters ...@@ -272,7 +272,7 @@ Core Parameters
- **Note**: for the faster speed, GPU uses 32-bit float point to sum up by default, so this may affect the accuracy for some tasks. You can set ``gpu_use_dp=true`` to enable 64-bit float point, but it will slow down the training - **Note**: for the faster speed, GPU uses 32-bit float point to sum up by default, so this may affect the accuracy for some tasks. You can set ``gpu_use_dp=true`` to enable 64-bit float point, but it will slow down the training
- **Note**: refer to `Installation Guide <./Installation-Guide.rst>`__ to build LightGBM with GPU or CUDA support - **Note**: refer to `Installation Guide <./Installation-Guide.rst>`__ to build LightGBM with GPU, CUDA, or ROCm support
- ``seed`` :raw-html:`<a id="seed" title="Permalink to this parameter" href="#seed">&#x1F517;&#xFE0E;</a>`, default = ``None``, type = int, aliases: ``random_seed``, ``random_state`` - ``seed`` :raw-html:`<a id="seed" title="Permalink to this parameter" href="#seed">&#x1F517;&#xFE0E;</a>`, default = ``None``, type = int, aliases: ``random_seed``, ``random_state``
......
...@@ -22,6 +22,7 @@ $(() => { ...@@ -22,6 +22,7 @@ $(() => {
"#build-mpi-version", "#build-mpi-version",
"#build-gpu-version", "#build-gpu-version",
"#build-cuda-version", "#build-cuda-version",
"#build-rocm-version",
"#build-java-wrapper", "#build-java-wrapper",
"#build-python-package", "#build-python-package",
"#build-r-package", "#build-r-package",
......
...@@ -245,11 +245,11 @@ struct Config { ...@@ -245,11 +245,11 @@ struct Config {
// alias = device // alias = device
// desc = device for the tree learning // desc = device for the tree learning
// desc = ``cpu`` supports all LightGBM functionality and is portable across the widest range of operating systems and hardware // desc = ``cpu`` supports all LightGBM functionality and is portable across the widest range of operating systems and hardware
// desc = ``cuda`` offers faster training than ``gpu`` or ``cpu``, but only works on GPUs supporting CUDA // desc = ``cuda`` offers faster training than ``gpu`` or ``cpu``, but only works on GPUs supporting CUDA or ROCm
// desc = ``gpu`` can be faster than ``cpu`` and works on a wider range of GPUs than CUDA // desc = ``gpu`` can be faster than ``cpu`` and works on a wider range of GPUs than CUDA
// desc = **Note**: it is recommended to use the smaller ``max_bin`` (e.g. 63) to get the better speed up // desc = **Note**: it is recommended to use the smaller ``max_bin`` (e.g. 63) to get the better speed up
// desc = **Note**: for the faster speed, GPU uses 32-bit float point to sum up by default, so this may affect the accuracy for some tasks. You can set ``gpu_use_dp=true`` to enable 64-bit float point, but it will slow down the training // desc = **Note**: for the faster speed, GPU uses 32-bit float point to sum up by default, so this may affect the accuracy for some tasks. You can set ``gpu_use_dp=true`` to enable 64-bit float point, but it will slow down the training
// desc = **Note**: refer to `Installation Guide <./Installation-Guide.rst>`__ to build LightGBM with GPU or CUDA support // desc = **Note**: refer to `Installation Guide <./Installation-Guide.rst>`__ to build LightGBM with GPU, CUDA, or ROCm support
std::string device_type = "cpu"; std::string device_type = "cpu";
// [no-automatically-extract] // [no-automatically-extract]
......
...@@ -9,8 +9,10 @@ ...@@ -9,8 +9,10 @@
#ifdef USE_CUDA #ifdef USE_CUDA
#ifndef USE_ROCM
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#endif
#include <stdio.h> #include <stdio.h>
#include <LightGBM/bin.h> #include <LightGBM/bin.h>
......
...@@ -7,8 +7,10 @@ ...@@ -7,8 +7,10 @@
#ifdef USE_CUDA #ifdef USE_CUDA
#ifndef USE_ROCM
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#endif
namespace LightGBM { namespace LightGBM {
......
...@@ -7,16 +7,59 @@ ...@@ -7,16 +7,59 @@
#ifdef USE_CUDA #ifdef USE_CUDA
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP__) #if defined(__HIP_PLATFORM_AMD__)
// ROCm doesn't have __shfl_down_sync, only __shfl_down without mask.
// ROCm doesn't have atomicAdd_block, but it should be semantically the same as atomicAdd
#define atomicAdd_block atomicAdd
// hipify
#include <hip/hip_runtime.h>
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaFree hipFree
#define cudaFreeHost hipFreeHost
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorName hipGetErrorName
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaHostAlloc hipHostAlloc
#define cudaHostAllocPortable hipHostAllocPortable
#define cudaMalloc hipMalloc
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemoryTypeHost hipMemoryTypeHost
#define cudaMemset hipMemset
#define cudaPointerAttributes hipPointerAttribute_t
#define cudaPointerGetAttributes hipPointerGetAttributes
#define cudaSetDevice hipSetDevice
#define cudaStreamCreate hipStreamCreate
#define cudaStreamDestroy hipStreamDestroy
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
// ROCm 7.0 did add __shfl_down_sync et al, but the following hack still works.
// Since mask is full 0xffffffff, we can use __shfl_down instead. // Since mask is full 0xffffffff, we can use __shfl_down instead.
#define __shfl_down_sync(mask, val, offset) __shfl_down(val, offset) #define __shfl_down_sync(mask, val, offset) __shfl_down(val, offset)
#define __shfl_up_sync(mask, val, offset) __shfl_up(val, offset) #define __shfl_up_sync(mask, val, offset) __shfl_up(val, offset)
// ROCm warpSize is constexpr and is either 32 or 64 depending on gfx arch.
#define WARPSIZE warpSize // warpSize is only allowed for device code.
// ROCm doesn't have atomicAdd_block, but it should be semantically the same as atomicAdd // HIP header used to define warpSize as a constexpr that was either 32 or 64
#define atomicAdd_block atomicAdd // depending on the target device, and then always set it to 64 for host code.
#else static inline constexpr int WARP_SIZE_INTERNAL() {
#if defined(__GFX9__)
return 64;
#else // __GFX9__
return 32;
#endif // __GFX9__
}
#define WARPSIZE (WARP_SIZE_INTERNAL())
#else // __HIP_PLATFORM_AMD__
// CUDA warpSize is not a constexpr, but always 32 // CUDA warpSize is not a constexpr, but always 32
#define WARPSIZE 32 #define WARPSIZE 32
#endif // defined(__HIP_PLATFORM_AMD__) || defined(__HIP__) #endif // defined(__HIP_PLATFORM_AMD__) || defined(__HIP__)
......
...@@ -8,8 +8,12 @@ ...@@ -8,8 +8,12 @@
#ifdef USE_CUDA #ifdef USE_CUDA
#if defined(USE_ROCM)
#include <LightGBM/cuda/cuda_rocm_interop.h>
#else
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#endif
#include <stdio.h> #include <stdio.h>
#include <LightGBM/utils/log.h> #include <LightGBM/utils/log.h>
......
...@@ -9,9 +9,12 @@ ...@@ -9,9 +9,12 @@
#include <LightGBM/utils/common.h> #include <LightGBM/utils/common.h>
#ifdef USE_CUDA #ifdef USE_CUDA
#ifndef USE_ROCM
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#endif #endif // USE_ROCM
#include <LightGBM/cuda/cuda_utils.hu>
#endif // USE_CUDA
#include <stdio.h> #include <stdio.h>
enum LGBM_Device { enum LGBM_Device {
...@@ -66,14 +69,14 @@ struct CHAllocator { ...@@ -66,14 +69,14 @@ struct CHAllocator {
#ifdef USE_CUDA #ifdef USE_CUDA
if (LGBM_config_::current_device == lgbm_device_cuda) { if (LGBM_config_::current_device == lgbm_device_cuda) {
cudaPointerAttributes attributes; cudaPointerAttributes attributes;
cudaPointerGetAttributes(&attributes, p); CUDASUCCESS_OR_FATAL(cudaPointerGetAttributes(&attributes, p));
#if CUDA_VERSION >= 10000 #if CUDA_VERSION >= 10000 || defined(USE_ROCM)
if ((attributes.type == cudaMemoryTypeHost) && (attributes.devicePointer != NULL)) { if ((attributes.type == cudaMemoryTypeHost) && (attributes.devicePointer != NULL)) {
cudaFreeHost(p); CUDASUCCESS_OR_FATAL(cudaFreeHost(p));
} }
#else #else
if ((attributes.memoryType == cudaMemoryTypeHost) && (attributes.devicePointer != NULL)) { if ((attributes.memoryType == cudaMemoryTypeHost) && (attributes.devicePointer != NULL)) {
cudaFreeHost(p); CUDASUCCESS_OR_FATAL(cudaFreeHost(p));
} }
#endif #endif
} else { } else {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#ifdef USE_CUDA #ifdef USE_CUDA
#include <LightGBM/cuda/cuda_rocm_interop.h>
#include <LightGBM/cuda/cuda_utils.hu> #include <LightGBM/cuda/cuda_utils.hu>
namespace LightGBM { namespace LightGBM {
......
...@@ -934,7 +934,11 @@ __global__ void FindBestSplitsDiscretizedForLeafKernel( ...@@ -934,7 +934,11 @@ __global__ void FindBestSplitsDiscretizedForLeafKernel(
if (is_feature_used_bytree[inner_feature_index]) { if (is_feature_used_bytree[inner_feature_index]) {
if (task->is_categorical) { if (task->is_categorical) {
__threadfence(); // ensure store issued before trap __threadfence(); // ensure store issued before trap
#if defined(USE_ROCM)
__builtin_trap();
#else
asm("trap;"); asm("trap;");
#endif
} else { } else {
if (!task->reverse) { if (!task->reverse) {
if (use_16bit_bin) { if (use_16bit_bin) {
......
...@@ -155,7 +155,7 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { ...@@ -155,7 +155,7 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
#pragma warning(disable : 4702) #pragma warning(disable : 4702)
explicit CUDASingleGPUTreeLearner(const Config* tree_config, const bool /*boosting_on_cuda*/) : SerialTreeLearner(tree_config) { explicit CUDASingleGPUTreeLearner(const Config* tree_config, const bool /*boosting_on_cuda*/) : SerialTreeLearner(tree_config) {
Log::Fatal("CUDA Tree Learner was not enabled in this build.\n" Log::Fatal("CUDA Tree Learner was not enabled in this build.\n"
"Please recompile with CMake option -DUSE_CUDA=1"); "Please recompile with CMake option -DUSE_CUDA=1 (NVIDIA GPUs) or -DUSE_ROCM=1 (AMD GPUs)");
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment