Commit 97f19956 authored by mykg's avatar mykg
Browse files

Fix some build error for ROCM



1. Fix terrible logic in CMakeLists.txt
2. using the correct typedef for hip
Signed-off-by: mykg's avataronepick <jiajuku12@163.com>
parent 3efb6621
...@@ -248,9 +248,9 @@ if (WIN32) ...@@ -248,9 +248,9 @@ if (WIN32)
include_directories("$ENV{CUDA_PATH}/include") include_directories("$ENV{CUDA_PATH}/include")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1) add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif (UNIX) elseif (UNIX)
if (NOT KTRANSFORMERS_USE_MUSA) if (KTRANSFORMERS_USE_CUDA)
# find_package(CUDA REQUIRED) find_package(CUDA REQUIRED)
# include_directories("${CUDA_INCLUDE_DIRS}") include_directories("${CUDA_INCLUDE_DIRS}")
include(CheckLanguage) include(CheckLanguage)
check_language(CUDA) check_language(CUDA)
if(CMAKE_CUDA_COMPILER) if(CMAKE_CUDA_COMPILER)
...@@ -324,7 +324,7 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama) ...@@ -324,7 +324,7 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
if(WIN32) if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX) elseif(UNIX)
if(NOT KTRANSFORMERS_USE_MUSA) if(KTRANSFORMERS_USE_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so") target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif() endif()
if (KTRANSFORMERS_USE_ROCM) if (KTRANSFORMERS_USE_ROCM)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
typedef hip_bfloat16 nv_bfloat16; typedef __hip_bfloat16 nv_bfloat16;
#endif #endif
__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { __global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
......
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
**/ **/
// Python bindings // Python bindings
#include "cpu_backend/cpuinfer.h" #include "cpu_backend/cpuinfer.h"
#ifndef KTRANSFORMERS_USE_ROCM
#include "device_launch_parameters.h" #include "device_launch_parameters.h"
#endif
#include "llamafile/flags.h" #include "llamafile/flags.h"
#include "operators/kvcache/kvcache.h" #include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h" #include "operators/llamafile/linear.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment