Commit 97f19956 authored by mykg's avatar mykg
Browse files

Fix some build error for ROCM



1. Fix terrible logic in CMakeLists.txt
2. using the correct typedef for hip
Signed-off-by: mykg's avataronepick <jiajuku12@163.com>
parent 3efb6621
......@@ -248,9 +248,9 @@ if (WIN32)
include_directories("$ENV{CUDA_PATH}/include")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif (UNIX)
if (NOT KTRANSFORMERS_USE_MUSA)
# find_package(CUDA REQUIRED)
# include_directories("${CUDA_INCLUDE_DIRS}")
if (KTRANSFORMERS_USE_CUDA)
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
......@@ -324,7 +324,7 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX)
if(NOT KTRANSFORMERS_USE_MUSA)
if(KTRANSFORMERS_USE_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif()
if (KTRANSFORMERS_USE_ROCM)
......
......@@ -17,7 +17,7 @@
#include <c10/cuda/CUDAGuard.h>
#ifdef __HIP_PLATFORM_AMD__
typedef hip_bfloat16 nv_bfloat16;
typedef __hip_bfloat16 nv_bfloat16;
#endif
__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
......@@ -879,4 +879,4 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i
}
cudaDeviceSynchronize();
return output;
}
\ No newline at end of file
}
......@@ -9,7 +9,9 @@
**/
// Python bindings
#include "cpu_backend/cpuinfer.h"
#ifndef KTRANSFORMERS_USE_ROCM
#include "device_launch_parameters.h"
#endif
#include "llamafile/flags.h"
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment