Commit 6a7624fe authored by mykg's avatar mykg
Browse files

Change the logic to build device since cuda is as default


Signed-off-by: mykg's avataronepick <jiajuku12@163.com>
parent 97f19956
...@@ -248,30 +248,13 @@ if (WIN32) ...@@ -248,30 +248,13 @@ if (WIN32)
include_directories("$ENV{CUDA_PATH}/include") include_directories("$ENV{CUDA_PATH}/include")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1) add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif (UNIX) elseif (UNIX)
if (KTRANSFORMERS_USE_CUDA)
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA detected")
find_package(CUDAToolkit REQUIRED)
include_directories(${CUDAToolkit_INCLUDE_DIRS})
endif()
message(STATUS "enabling CUDA")
enable_language(CUDA)
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
endif()
if (KTRANSFORMERS_USE_ROCM) if (KTRANSFORMERS_USE_ROCM)
find_package(HIP REQUIRED) find_package(HIP REQUIRED)
if(HIP_FOUND) if(HIP_FOUND)
include_directories("${HIP_INCLUDE_DIRS}") include_directories("${HIP_INCLUDE_DIRS}")
add_compile_definitions(KTRANSFORMERS_USE_ROCM=1) add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
endif() endif()
endif() elseif (KTRANSFORMERS_USE_MUSA)
if (KTRANSFORMERS_USE_MUSA)
if (NOT EXISTS $ENV{MUSA_PATH}) if (NOT EXISTS $ENV{MUSA_PATH})
if (NOT EXISTS /opt/musa) if (NOT EXISTS /opt/musa)
set(MUSA_PATH /usr/local/musa) set(MUSA_PATH /usr/local/musa)
...@@ -289,7 +272,19 @@ elseif (UNIX) ...@@ -289,7 +272,19 @@ elseif (UNIX)
message(STATUS "MUSA Toolkit found") message(STATUS "MUSA Toolkit found")
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1) add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
endif() endif()
endif() else()
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA detected")
find_package(CUDAToolkit REQUIRED)
include_directories(${CUDAToolkit_INCLUDE_DIRS})
endif()
message(STATUS "enabling CUDA")
enable_language(CUDA)
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
endif() endif()
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
...@@ -324,17 +319,14 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama) ...@@ -324,17 +319,14 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
if(WIN32) if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX) elseif(UNIX)
if(KTRANSFORMERS_USE_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif()
if (KTRANSFORMERS_USE_ROCM) if (KTRANSFORMERS_USE_ROCM)
add_compile_definitions(USE_HIP=1) add_compile_definitions(USE_HIP=1)
target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so") target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
message(STATUS "Building for HIP") message(STATUS "Building for HIP")
endif() elseif(KTRANSFORMERS_USE_MUSA)
if(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart) target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
endif() else()
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif() endif()
# Define the USE_NUMA option # Define the USE_NUMA option
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment