Unverified Commit 8770b6d5 authored by ZiWei Yuan's avatar ZiWei Yuan Committed by GitHub
Browse files

Merge pull request #1159 from onepick/fix-rocm-build-error

Fix some build error for ROCM
parents 6e4da83d 6a7624fe
...@@ -248,30 +248,13 @@ if (WIN32) ...@@ -248,30 +248,13 @@ if (WIN32)
include_directories("$ENV{CUDA_PATH}/include") include_directories("$ENV{CUDA_PATH}/include")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1) add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif (UNIX) elseif (UNIX)
if (NOT KTRANSFORMERS_USE_MUSA)
# find_package(CUDA REQUIRED)
# include_directories("${CUDA_INCLUDE_DIRS}")
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA detected")
find_package(CUDAToolkit REQUIRED)
include_directories(${CUDAToolkit_INCLUDE_DIRS})
endif()
message(STATUS "enabling CUDA")
enable_language(CUDA)
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
endif()
if (KTRANSFORMERS_USE_ROCM) if (KTRANSFORMERS_USE_ROCM)
find_package(HIP REQUIRED) find_package(HIP REQUIRED)
if(HIP_FOUND) if(HIP_FOUND)
include_directories("${HIP_INCLUDE_DIRS}") include_directories("${HIP_INCLUDE_DIRS}")
add_compile_definitions(KTRANSFORMERS_USE_ROCM=1) add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
endif() endif()
endif() elseif (KTRANSFORMERS_USE_MUSA)
if (KTRANSFORMERS_USE_MUSA)
if (NOT EXISTS $ENV{MUSA_PATH}) if (NOT EXISTS $ENV{MUSA_PATH})
if (NOT EXISTS /opt/musa) if (NOT EXISTS /opt/musa)
set(MUSA_PATH /usr/local/musa) set(MUSA_PATH /usr/local/musa)
...@@ -289,7 +272,19 @@ elseif (UNIX) ...@@ -289,7 +272,19 @@ elseif (UNIX)
message(STATUS "MUSA Toolkit found") message(STATUS "MUSA Toolkit found")
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1) add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
endif() endif()
else()
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA detected")
find_package(CUDAToolkit REQUIRED)
include_directories(${CUDAToolkit_INCLUDE_DIRS})
endif() endif()
message(STATUS "enabling CUDA")
enable_language(CUDA)
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
endif() endif()
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
...@@ -324,17 +319,14 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama) ...@@ -324,17 +319,14 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
if(WIN32) if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX) elseif(UNIX)
if(NOT KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif()
if (KTRANSFORMERS_USE_ROCM) if (KTRANSFORMERS_USE_ROCM)
add_compile_definitions(USE_HIP=1) add_compile_definitions(USE_HIP=1)
target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so") target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
message(STATUS "Building for HIP") message(STATUS "Building for HIP")
endif() elseif(KTRANSFORMERS_USE_MUSA)
if(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart) target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
endif() else()
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif() endif()
# Define the USE_NUMA option # Define the USE_NUMA option
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
typedef hip_bfloat16 nv_bfloat16; typedef __hip_bfloat16 nv_bfloat16;
#endif #endif
__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { __global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
......
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
**/ **/
// Python bindings // Python bindings
#include "cpu_backend/cpuinfer.h" #include "cpu_backend/cpuinfer.h"
#ifndef KTRANSFORMERS_USE_ROCM
#include "device_launch_parameters.h" #include "device_launch_parameters.h"
#endif
#include "llamafile/flags.h" #include "llamafile/flags.h"
#include "operators/kvcache/kvcache.h" #include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h" #include "operators/llamafile/linear.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment