Unverified Commit 8770b6d5 authored by ZiWei Yuan's avatar ZiWei Yuan Committed by GitHub
Browse files

Merge pull request #1159 from onepick/fix-rocm-build-error

Fix some build error for ROCM
parents 6e4da83d 6a7624fe
......@@ -248,30 +248,13 @@ if (WIN32)
include_directories("$ENV{CUDA_PATH}/include")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif (UNIX)
if (NOT KTRANSFORMERS_USE_MUSA)
# find_package(CUDA REQUIRED)
# include_directories("${CUDA_INCLUDE_DIRS}")
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA detected")
find_package(CUDAToolkit REQUIRED)
include_directories(${CUDAToolkit_INCLUDE_DIRS})
endif()
message(STATUS "enabling CUDA")
enable_language(CUDA)
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
endif()
if (KTRANSFORMERS_USE_ROCM)
find_package(HIP REQUIRED)
if(HIP_FOUND)
include_directories("${HIP_INCLUDE_DIRS}")
add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
endif()
endif()
if (KTRANSFORMERS_USE_MUSA)
elseif (KTRANSFORMERS_USE_MUSA)
if (NOT EXISTS $ENV{MUSA_PATH})
if (NOT EXISTS /opt/musa)
set(MUSA_PATH /usr/local/musa)
......@@ -289,7 +272,19 @@ elseif (UNIX)
message(STATUS "MUSA Toolkit found")
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
endif()
endif()
else()
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA detected")
find_package(CUDAToolkit REQUIRED)
include_directories(${CUDAToolkit_INCLUDE_DIRS})
endif()
message(STATUS "enabling CUDA")
enable_language(CUDA)
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
endif()
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
......@@ -324,17 +319,14 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX)
if(NOT KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif()
if (KTRANSFORMERS_USE_ROCM)
add_compile_definitions(USE_HIP=1)
target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
message(STATUS "Building for HIP")
endif()
if(KTRANSFORMERS_USE_MUSA)
elseif(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
endif()
else()
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif()
# Define the USE_NUMA option
......
......@@ -17,7 +17,7 @@
#include <c10/cuda/CUDAGuard.h>
#ifdef __HIP_PLATFORM_AMD__
typedef hip_bfloat16 nv_bfloat16;
typedef __hip_bfloat16 nv_bfloat16;
#endif
__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
......@@ -879,4 +879,4 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i
}
cudaDeviceSynchronize();
return output;
}
\ No newline at end of file
}
......@@ -9,7 +9,9 @@
**/
// Python bindings
#include "cpu_backend/cpuinfer.h"
#ifndef KTRANSFORMERS_USE_ROCM
#include "device_launch_parameters.h"
#endif
#include "llamafile/flags.h"
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment