Commit e38ee081 authored by xiabo's avatar xiabo
Browse files

Adapt to rocm

parent 56942c43
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "src/turbomind/models/llama/llama_decoder_kernels.h" #include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> // #include <cooperative_groups/reduce.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
...@@ -83,22 +83,57 @@ struct res_norm_ops_t<float> { ...@@ -83,22 +83,57 @@ struct res_norm_ops_t<float> {
} }
}; };
template<typename T> // template<typename T>
__device__ T blockReduceSum(const cg::thread_block& block, T value) // __device__ T blockReduceSum(const cg::thread_block& block, T value)
{ // {
__shared__ float partial[32]; // __shared__ float partial[32];
auto tile = cg::tiled_partition<32>(block); // auto tile = cg::tiled_partition<32>(block);
value = cg::reduce(tile, value, cg::plus<float>{}); // value = cg::reduce(tile, value, cg::plus<float>{});
if (tile.thread_rank() == 0) { // if (tile.thread_rank() == 0) {
partial[tile.meta_group_rank()] = value; // partial[tile.meta_group_rank()] = value;
} // }
// block.sync();
block.sync(); // value = tile.thread_rank() < tile.meta_group_size() ? partial[tile.thread_rank()] : T{};
// return cg::reduce(tile, value, cg::plus<float>{});
// }
#define WARPSIZE 64
value = tile.thread_rank() < tile.meta_group_size() ? partial[tile.thread_rank()] : T{}; template<typename T>
return cg::reduce(tile, value, cg::plus<float>{}); __inline__ __device__ T warpReduceSum_xiabo(T value)
{
#pragma unroll
for (int offset = WARPSIZE / 2; offset > 0; offset >>= 1)
value += __shfl_down_sync(0xffffffff, value, offset);
return value;
}
template<typename T>
__inline__ __device__ T blockReduceSum_xiabo(T val)
{
T sum = (T)(0.0f);
__shared__ T shared[WARPSIZE];
sum = warpReduceSum_xiabo(val);
__syncthreads();
int tid = threadIdx.x + threadIdx.y * blockDim.x;
if (tid % WARPSIZE == 0) {
shared[tid / WARPSIZE] = sum;
}
if (tid >= blockDim.x * blockDim.y / WARPSIZE && tid < WARPSIZE) {
shared[tid] = (T)(0.0f);
}
__syncthreads();
if (tid / WARPSIZE == 0) {
sum = warpReduceSum_xiabo(shared[tid]);
if (tid == 0) {
shared[0] = sum;
}
}
__syncthreads();
return shared[0];
} }
template<typename T> template<typename T>
...@@ -111,7 +146,7 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data, ...@@ -111,7 +146,7 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
int n_dims) int n_dims)
{ {
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto grid = cg::this_grid(); // auto grid = cg::this_grid();
constexpr int PACK_DIM = sizeof(uint4) / sizeof(T); constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);
...@@ -131,7 +166,8 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data, ...@@ -131,7 +166,8 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
r_ptr[i] = r; r_ptr[i] = r;
} }
auto total_sum = blockReduceSum(block, thread_sum); // auto total_sum = blockReduceSum(block, thread_sum);
auto total_sum = blockReduceSum_xiabo(thread_sum);
float s_inv_mean = rsqrt(total_sum / n_dims + eps); float s_inv_mean = rsqrt(total_sum / n_dims + eps);
......
...@@ -315,7 +315,8 @@ static inline __device__ half4 char4_scale_to_half4(char4 value, const float sca ...@@ -315,7 +315,8 @@ static inline __device__ half4 char4_scale_to_half4(char4 value, const float sca
static inline __device__ uint32_t float4_to_char4(float x, float y, float z, float w) static inline __device__ uint32_t float4_to_char4(float x, float y, float z, float w)
{ {
uint32_t dst; uint32_t dst;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 720 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 720
#if 0
uint32_t a; uint32_t a;
asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x)); asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x));
uint32_t b; uint32_t b;
......
#include "src/turbomind/kernels/gemm_s_f16/format.h" // #include "src/turbomind/kernels/gemm_s_f16/format.h"
#include "src/turbomind/python/dlpack.h" #include "src/turbomind/python/dlpack.h"
#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h" #include "src/turbomind/triton_backend/llama/LlamaTritonModel.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp" #include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
...@@ -47,7 +47,8 @@ DLDevice getDLDevice(triton::Tensor& tensor) ...@@ -47,7 +47,8 @@ DLDevice getDLDevice(triton::Tensor& tensor)
case triton::MEMORY_CPU_PINNED: case triton::MEMORY_CPU_PINNED:
device.device_type = DLDeviceType::kDLCUDAHost; device.device_type = DLDeviceType::kDLCUDAHost;
case triton::MEMORY_GPU: case triton::MEMORY_GPU:
device.device_type = DLDeviceType::kDLCUDA; // device.device_type = DLDeviceType::kDLCUDA;
device.device_type = DLDeviceType::kDLROCM;
break; break;
default: default:
break; break;
...@@ -415,15 +416,15 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -415,15 +416,15 @@ PYBIND11_MODULE(_turbomind, m)
auto src_tensor = GetDLTensor(src); auto src_tensor = GetDLTensor(src);
auto dst_tensor = GetDLTensor(dst); auto dst_tensor = GetDLTensor(dst);
turbomind::transpose_qk_s4_k_m8_hf( // turbomind::transpose_qk_s4_k_m8_hf(
(uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, size_per_head, nullptr); // (uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, size_per_head, nullptr);
}); });
m.def("fuse_w1_w3_s4_k_m8", [](py::object src, py::object dst, int m, int k) { m.def("fuse_w1_w3_s4_k_m8", [](py::object src, py::object dst, int m, int k) {
auto src_tensor = GetDLTensor(src); auto src_tensor = GetDLTensor(src);
auto dst_tensor = GetDLTensor(dst); auto dst_tensor = GetDLTensor(dst);
turbomind::fuse_w1_w3_s4_k_m8((uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, nullptr); // turbomind::fuse_w1_w3_s4_k_m8((uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, nullptr);
}); });
m.def("convert_s4_k_m8", m.def("convert_s4_k_m8",
...@@ -443,16 +444,16 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -443,16 +444,16 @@ PYBIND11_MODULE(_turbomind, m)
auto s = GetDLTensor(scales); auto s = GetDLTensor(scales);
auto qz = GetDLTensor(qzeros); auto qz = GetDLTensor(qzeros);
turbomind::convert_s4_k_m8((uint32_t*)a_dst.data, // turbomind::convert_s4_k_m8((uint32_t*)a_dst.data,
(half2*)q_dst.data, // (half2*)q_dst.data,
(half*)w.data, // (half*)w.data,
(const uint32_t*)a_src.data, // (const uint32_t*)a_src.data,
(const half*)s.data, // (const half*)s.data,
(const uint32_t*)qz.data, // (const uint32_t*)qz.data,
m, // m,
k, // k,
group_size, // group_size,
nullptr); // nullptr);
}); });
m.def("dequantize_s4", [](py::object src, py::object dst) { m.def("dequantize_s4", [](py::object src, py::object dst) {
......
...@@ -24,13 +24,17 @@ ...@@ -24,13 +24,17 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required (VERSION 3.18) #cmake_minimum_required (VERSION 3.18)
cmake_minimum_required (VERSION 3.16)
project(tritonturbomindbackend LANGUAGES C CXX) project(tritonturbomindbackend LANGUAGES C CXX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
add_library(TransformerTritonBackend STATIC transformer_triton_backend.cpp) add_library(TransformerTritonBackend STATIC transformer_triton_backend.cpp)
target_link_libraries(TransformerTritonBackend PUBLIC nccl_utils) target_link_libraries(TransformerTritonBackend PUBLIC nccl_utils)
set_property(TARGET TransformerTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET TransformerTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON)
install(TARGETS TransformerTritonBackend DESTINATION ${CMAKE_INSTALL_LIBDIR}) install(TARGETS TransformerTritonBackend DESTINATION ${CMAKE_INSTALL_LIBDIR})
add_subdirectory(llama) add_subdirectory(llama)
...@@ -70,21 +74,24 @@ include(FetchContent) ...@@ -70,21 +74,24 @@ include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
repo-common repo-common
GIT_REPOSITORY https://github.com/triton-inference-server/common.git URL ../../../3rdparty/common-r22.12
GIT_TAG ${TRITON_COMMON_REPO_TAG} #GIT_REPOSITORY https://github.com/triton-inference-server/common.git
GIT_SHALLOW ON #GIT_TAG ${TRITON_COMMON_REPO_TAG}
#GIT_SHALLOW ON
) )
FetchContent_Declare( FetchContent_Declare(
repo-core repo-core
GIT_REPOSITORY https://github.com/triton-inference-server/core.git URL ../../../3rdparty/core-r22.12
GIT_TAG ${TRITON_CORE_REPO_TAG} #GIT_REPOSITORY https://github.com/triton-inference-server/core.git
GIT_SHALLOW ON #GIT_TAG ${TRITON_CORE_REPO_TAG}
#GIT_SHALLOW ON
) )
FetchContent_Declare( FetchContent_Declare(
repo-backend repo-backend
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git URL ../../../3rdparty/backend-r22.12
GIT_TAG ${TRITON_BACKEND_REPO_TAG} #GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
GIT_SHALLOW ON #GIT_TAG ${TRITON_BACKEND_REPO_TAG}
#GIT_SHALLOW ON
) )
FetchContent_MakeAvailable(repo-common repo-core repo-backend) FetchContent_MakeAvailable(repo-common repo-core repo-backend)
...@@ -92,7 +99,8 @@ FetchContent_MakeAvailable(repo-common repo-core repo-backend) ...@@ -92,7 +99,8 @@ FetchContent_MakeAvailable(repo-common repo-core repo-backend)
# CUDA # CUDA
# #
if(${TRITON_ENABLE_GPU}) if(${TRITON_ENABLE_GPU})
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
endif() # TRITON_ENABLE_GPU endif() # TRITON_ENABLE_GPU
# #
...@@ -109,7 +117,8 @@ add_library( ...@@ -109,7 +117,8 @@ add_library(
TritonTurboMindBackend::triton-turbomind-backend ALIAS triton-turbomind-backend TritonTurboMindBackend::triton-turbomind-backend ALIAS triton-turbomind-backend
) )
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
find_package(CUDA 10.1 REQUIRED) find_package(CUDA 10.1 REQUIRED)
if (${CUDA_VERSION} GREATER_EQUAL 11.0) if (${CUDA_VERSION} GREATER_EQUAL 11.0)
message(STATUS "Add DCUDA11_MODE") message(STATUS "Add DCUDA11_MODE")
...@@ -158,10 +167,14 @@ if(${TRITON_ENABLE_GPU}) ...@@ -158,10 +167,14 @@ if(${TRITON_ENABLE_GPU})
) )
endif() # TRITON_ENABLE_GPU endif() # TRITON_ENABLE_GPU
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
set_target_properties( set_target_properties(
triton-turbomind-backend triton-turbomind-backend
PROPERTIES PROPERTIES
POSITION_INDEPENDENT_CODE ON # POSITION_INDEPENDENT_CODE ON
POSITION_INDEPENDENT_CODE OFF
OUTPUT_NAME triton_turbomind OUTPUT_NAME triton_turbomind
SKIP_BUILD_RPATH TRUE SKIP_BUILD_RPATH TRUE
BUILD_WITH_INSTALL_RPATH TRUE BUILD_WITH_INSTALL_RPATH TRUE
...@@ -194,7 +207,7 @@ target_link_libraries( ...@@ -194,7 +207,7 @@ target_link_libraries(
transformer-shared # from repo-ft transformer-shared # from repo-ft
${TRITON_PYTORCH_LDFLAGS} ${TRITON_PYTORCH_LDFLAGS}
-lcublas -lcublas
-lcublasLt # -lcublasLt
-lcudart -lcudart
-lcurand -lcurand
) )
...@@ -228,7 +241,8 @@ if(${TRITON_ENABLE_GPU}) ...@@ -228,7 +241,8 @@ if(${TRITON_ENABLE_GPU})
target_link_libraries( target_link_libraries(
triton-turbomind-backend triton-turbomind-backend
PRIVATE PRIVATE
CUDA::cudart # CUDA::cudart
cudart
) )
endif() # TRITON_ENABLE_GPU endif() # TRITON_ENABLE_GPU
......
...@@ -22,8 +22,10 @@ set(llama_triton_backend_files ...@@ -22,8 +22,10 @@ set(llama_triton_backend_files
LlamaTritonModelInstance.cc LlamaTritonModelInstance.cc
) )
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
add_library(LlamaTritonBackend STATIC ${llama_triton_backend_files}) add_library(LlamaTritonBackend STATIC ${llama_triton_backend_files})
set_property(TARGET LlamaTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET LlamaTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(LlamaTritonBackend PUBLIC TransformerTritonBackend Llama tensor memory_utils CUDA::cublasLt) #target_link_libraries(LlamaTritonBackend PUBLIC TransformerTritonBackend Llama tensor memory_utils CUDA::cublasLt)
target_link_libraries(LlamaTritonBackend PUBLIC TransformerTritonBackend Llama tensor memory_utils)
target_compile_features(LlamaTritonBackend PRIVATE cxx_std_14) target_compile_features(LlamaTritonBackend PRIVATE cxx_std_14)
...@@ -197,7 +197,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh ...@@ -197,7 +197,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
cublasLtHandle_t cublaslt_handle; cublasLtHandle_t cublaslt_handle;
cublasCreate(&cublas_handle); cublasCreate(&cublas_handle);
cublasLtCreate(&cublaslt_handle); // cublasLtCreate(&cublaslt_handle);
cublasSetStream(cublas_handle, stream); cublasSetStream(cublas_handle, stream);
std::unique_ptr<ft::cublasAlgoMap> cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); std::unique_ptr<ft::cublasAlgoMap> cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in"));
......
...@@ -14,98 +14,105 @@ ...@@ -14,98 +14,105 @@
cmake_minimum_required(VERSION 3.8) cmake_minimum_required(VERSION 3.8)
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
add_subdirectory(gemm_test) add_subdirectory(gemm_test)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
add_library(cuda_utils STATIC cuda_utils.cc) add_library(cuda_utils STATIC cuda_utils.cc)
set_property(TARGET cuda_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cuda_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cuda_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cuda_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(cuda_utils PUBLIC CUDA::cudart) target_link_libraries(cuda_utils PUBLIC cudart)
add_library(logger STATIC logger.cc) add_library(logger STATIC logger.cc)
set_property(TARGET logger PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET logger PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET logger PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET logger PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(logger PUBLIC CUDA::cudart) target_link_libraries(logger PUBLIC cudart)
add_library(cublasAlgoMap STATIC cublasAlgoMap.cc) add_library(cublasAlgoMap STATIC cublasAlgoMap.cc)
set_property(TARGET cublasAlgoMap PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cublasAlgoMap PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cublasAlgoMap PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cublasAlgoMap PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(cublasAlgoMap PUBLIC CUDA::cublas CUDA::cudart CUDA::curand cuda_utils logger) target_link_libraries(cublasAlgoMap PUBLIC cublas cudart curand cuda_utils logger)
add_library(cublasMMWrapper STATIC cublasMMWrapper.cc) add_library(cublasMMWrapper STATIC cublasMMWrapper.cc)
set_property(TARGET cublasMMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cublasMMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cublasMMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cublasMMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(cublasMMWrapper PUBLIC CUDA::cublas CUDA::cudart CUDA::curand cublasAlgoMap cuda_utils logger) target_link_libraries(cublasMMWrapper PUBLIC cublas cudart curand cublasAlgoMap cuda_utils logger)
if (SPARSITY_SUPPORT) if (SPARSITY_SUPPORT)
target_link_libraries(cublasMMWrapper PUBLIC CUDA::cusparse -lcusparseLt) target_link_libraries(cublasMMWrapper PUBLIC cusparse -lcusparseLt)
endif() endif()
add_library(word_list STATIC word_list.cc) add_library(word_list STATIC word_list.cc)
set_property(TARGET word_list PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET word_list PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET word_list PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET word_list PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(nvtx_utils STATIC nvtx_utils.cc) add_library(nvtx_utils STATIC nvtx_utils.cc)
set_property(TARGET nvtx_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET nvtx_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET nvtx_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET nvtx_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
if(${CMAKE_VERSION} VERSION_LESS "3.25") if(${CMAKE_VERSION} VERSION_LESS "3.25")
target_link_libraries(nvtx_utils PUBLIC CUDA::nvToolsExt -ldl) # target_link_libraries(nvtx_utils PUBLIC nvToolsExt -ldl)
else() else()
target_link_libraries(nvtx_utils PUBLIC CUDA::nvtx3 -ldl) # target_link_libraries(nvtx_utils PUBLIC nvtx3 -ldl)
endif() endif()
add_library(memory_utils STATIC memory_utils.cu) add_library(memory_utils STATIC memory_utils.cu)
set_property(TARGET memory_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET memory_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET memory_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET memory_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(memory_utils PUBLIC cuda_utils logger tensor) target_link_libraries(memory_utils PUBLIC cuda_utils logger tensor)
add_library(mpi_utils STATIC mpi_utils.cc) add_library(mpi_utils STATIC mpi_utils.cc)
set_property(TARGET mpi_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET mpi_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET mpi_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET mpi_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
if (BUILD_MULTI_GPU) if (BUILD_MULTI_GPU)
target_link_libraries(mpi_utils PUBLIC mpi logger) target_link_libraries(mpi_utils PUBLIC mpi logger)
endif() endif()
add_library(nccl_utils STATIC nccl_utils.cc) add_library(nccl_utils STATIC nccl_utils.cc)
set_property(TARGET nccl_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET nccl_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET nccl_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET nccl_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
if (BUILD_MULTI_GPU) if (BUILD_MULTI_GPU)
target_link_libraries(nccl_utils PUBLIC ${NCCL_LIBRARIES} logger) target_link_libraries(nccl_utils PUBLIC ${NCCL_LIBRARIES} logger)
endif() endif()
add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc) add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc)
set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cublasINT8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cublasINT8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(cublasINT8MMWrapper PUBLIC CUDA::cublasLt CUDA::cudart CUDA::curand cublasAlgoMap cublasMMWrapper cuda_utils logger) #target_link_libraries(cublasINT8MMWrapper PUBLIC cublasLt cudart curand cublasAlgoMap cublasMMWrapper cuda_utils logger)
target_link_libraries(cublasINT8MMWrapper PUBLIC cudart curand cublasAlgoMap cublasMMWrapper cuda_utils logger)
if(ENABLE_FP8) if(ENABLE_FP8)
add_library(cublasFP8MMWrapper STATIC cublasFP8MMWrapper.cu) add_library(cublasFP8MMWrapper STATIC cublasFP8MMWrapper.cu)
set_property(TARGET cublasFP8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cublasFP8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cublasFP8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cublasFP8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(cublasFP8MMWrapper PUBLIC CUDA::cublasLt CUDA::cudart CUDA::curand #target_link_libraries(cublasFP8MMWrapper PUBLIC cublasLt cudart curand
target_link_libraries(cublasFP8MMWrapper PUBLIC cudart curand
cublasAlgoMap cublasMMWrapper nvtx_utils fp8_qgmma_1x1_utils) cublasAlgoMap cublasMMWrapper nvtx_utils fp8_qgmma_1x1_utils)
endif() endif()
add_library(custom_ar_comm STATIC custom_ar_comm.cc) add_library(custom_ar_comm STATIC custom_ar_comm.cc)
set_property(TARGET custom_ar_comm PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET custom_ar_comm PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET custom_ar_comm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET custom_ar_comm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(custom_ar_comm PUBLIC custom_ar_kernels memory_utils cuda_utils logger) target_link_libraries(custom_ar_comm PUBLIC custom_ar_kernels memory_utils cuda_utils logger)
add_library(gemm STATIC gemm.cc) add_library(gemm STATIC gemm.cc)
set_property(TARGET gemm PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET gemm PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET gemm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET gemm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(gemm PUBLIC target_link_libraries(gemm PUBLIC
CUDA::cublas CUDA::cublasLt CUDA::cudart CUDA::curand # cublas cublasLt cudart curand
cublas cudart curand
cublasAlgoMap memory_utils cuda_utils logger) cublasAlgoMap memory_utils cuda_utils logger)
if (SPARSITY_SUPPORT) if (SPARSITY_SUPPORT)
target_link_libraries(gemm PUBLIC CUDA::cusparse -lcusparseLt) target_link_libraries(gemm PUBLIC cusparse -lcusparseLt)
endif() endif()
add_library(cuda_fp8_utils STATIC cuda_fp8_utils.cu) add_library(cuda_fp8_utils STATIC cuda_fp8_utils.cu)
set_property(TARGET cuda_fp8_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cuda_fp8_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cuda_fp8_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cuda_fp8_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(tensor STATIC Tensor.cc) add_library(tensor STATIC Tensor.cc)
set_property(TARGET tensor PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET tensor PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET tensor PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET tensor PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(tensor PUBLIC cuda_utils logger) target_link_libraries(tensor PUBLIC cuda_utils logger)
...@@ -158,36 +158,36 @@ public: ...@@ -158,36 +158,36 @@ public:
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
pointer_mapping_ = new std::unordered_map<void*, std::pair<size_t, MemoryType>>(); pointer_mapping_ = new std::unordered_map<void*, std::pair<size_t, MemoryType>>();
#if defined(CUDA_MEMORY_POOL_DISABLED) // #if defined(CUDA_MEMORY_POOL_DISABLED)
TM_LOG_WARNING( // TM_LOG_WARNING(
"Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free." // "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free."
"Note this may lead to hang with NCCL kernels launched in parallel; if so, try NCCL_LAUNCH_MODE=GROUP"); // "Note this may lead to hang with NCCL kernels launched in parallel; if so, try NCCL_LAUNCH_MODE=GROUP");
#else // #else
int device_count = 1; // int device_count = 1;
check_cuda_error(cudaGetDeviceCount(&device_count)); // check_cuda_error(cudaGetDeviceCount(&device_count));
cudaMemPool_t mempool; // cudaMemPool_t mempool;
check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id)); // check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id));
cudaMemAccessDesc desc = {}; // cudaMemAccessDesc desc = {};
int peer_access_available = 0; // int peer_access_available = 0;
for (int i = 0; i < device_count; i++) { // for (int i = 0; i < device_count; i++) {
if (i == device_id) { // if (i == device_id) {
continue; // continue;
} // }
check_cuda_error(cudaDeviceCanAccessPeer(&peer_access_available, device_id, i)); // check_cuda_error(cudaDeviceCanAccessPeer(&peer_access_available, device_id, i));
if (!peer_access_available) { // if (!peer_access_available) {
TM_LOG_WARNING("Device " + std::to_string(device_id) + " peer access Device " + std::to_string(i) // TM_LOG_WARNING("Device " + std::to_string(device_id) + " peer access Device " + std::to_string(i)
+ " is not available."); // + " is not available.");
continue; // continue;
} // }
desc.location.type = cudaMemLocationTypeDevice; // desc.location.type = cudaMemLocationTypeDevice;
desc.location.id = i; // desc.location.id = i;
desc.flags = cudaMemAccessFlagsProtReadWrite; // desc.flags = cudaMemAccessFlagsProtReadWrite;
check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1)); // check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1));
} // }
// set memory pool threshold to avoid shrinking the pool // // set memory pool threshold to avoid shrinking the pool
uint64_t setVal = UINT64_MAX; // uint64_t setVal = UINT64_MAX;
check_cuda_error(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &setVal)); // check_cuda_error(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &setVal));
#endif // #endif
} }
virtual ~Allocator() virtual ~Allocator()
......
...@@ -139,7 +139,8 @@ cublasAlgoMap::getAlgo(const int batch_count, const int m, const int n, const in ...@@ -139,7 +139,8 @@ cublasAlgoMap::getAlgo(const int batch_count, const int m, const int n, const in
else { else {
cublasLtMatmulAlgo_info tmp_algo; cublasLtMatmulAlgo_info tmp_algo;
tmp_algo.algoId = tmp_algo.algoId =
static_cast<int>(data_type == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP); // static_cast<int>(data_type == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP);
static_cast<int>(data_type == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT);
tmp_algo.customOption = -1; tmp_algo.customOption = -1;
tmp_algo.tile = -1; tmp_algo.tile = -1;
tmp_algo.splitK_val = -1; tmp_algo.splitK_val = -1;
......
...@@ -192,7 +192,8 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -192,7 +192,8 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
} }
} }
if (using_cublasLt) { // if (using_cublasLt) {
if (0) {
cublasLtMatmulDesc_t operationDesc = NULL; cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cudaDataType_t scaleType; cudaDataType_t scaleType;
...@@ -279,22 +280,22 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -279,22 +280,22 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
} }
} }
cublasLtMatmul(cublaslt_handle_, // cublasLtMatmul(cublaslt_handle_,
operationDesc, // operationDesc,
alpha, // alpha,
A, // A,
Adesc, // Adesc,
B, // B,
Bdesc, // Bdesc,
beta, // beta,
C, // C,
Cdesc, // Cdesc,
C, // C,
Cdesc, // Cdesc,
(findAlgo == 1 ? (&algo) : NULL), // (findAlgo == 1 ? (&algo) : NULL),
workSpace, // workSpace,
workspaceSize, // workspaceSize,
stream_); // stream_);
cublasLtMatmulDescDestroy(operationDesc); cublasLtMatmulDescDestroy(operationDesc);
cublasLtMatrixLayoutDestroy(Adesc); cublasLtMatrixLayoutDestroy(Adesc);
...@@ -448,8 +449,8 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -448,8 +449,8 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)); cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t)); cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*)); cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*));
check_cuda_error(cublasLtMatmul( // check_cuda_error(cublasLtMatmul(
cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_)); // cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, NULL, NULL, 0, stream_));
cublasLtMatrixLayoutDestroy(Adesc); cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatrixLayoutDestroy(Bdesc); cublasLtMatrixLayoutDestroy(Bdesc);
cublasLtMatrixLayoutDestroy(Cdesc); cublasLtMatrixLayoutDestroy(Cdesc);
...@@ -985,7 +986,8 @@ void cublasMMWrapper::_Int8Gemm(const int m, ...@@ -985,7 +986,8 @@ void cublasMMWrapper::_Int8Gemm(const int m,
* - 0: int8 * int8 -> int32 -> int8 * - 0: int8 * int8 -> int32 -> int8
* - 1: int8 * int8 -> int32 -> int32 * - 1: int8 * int8 -> int32 -> int32
*/ */
#if (CUBLAS_VERSION) <= 11601 // #if (CUBLAS_VERSION) <= 11601
#if 1
FT_CHECK_WITH_INFO(false, "CUBLAS version too low."); FT_CHECK_WITH_INFO(false, "CUBLAS version too low.");
#else #else
......
...@@ -322,7 +322,7 @@ __device__ inline int8_t cuda_cast<int8_t, half>(half val) ...@@ -322,7 +322,7 @@ __device__ inline int8_t cuda_cast<int8_t, half>(half val)
int16_t int16_in; int16_t int16_in;
}; };
fp16 = val; fp16 = val;
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); // asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0]; return int8[0];
} }
...@@ -333,20 +333,31 @@ __device__ inline int16_t cuda_cast<int16_t, half2>(half2 val) ...@@ -333,20 +333,31 @@ __device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
int8_t int8[2]; int8_t int8[2];
int16_t int16; int16_t int16;
}; };
int8[0] = cuda_cast<int8_t>(val.x); // int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y); // int8[1] = cuda_cast<int8_t>(val.y);
int8[0] = cuda_cast<int8_t>((val.data[0]));
int8[1] = cuda_cast<int8_t>((val.data[1]));
return int16; return int16;
} }
template<> template<>
__device__ inline int8_t cuda_cast<int8_t, float>(float val) __device__ inline int8_t cuda_cast<int8_t, float>(float val)
{ {
union { // union {
int8_t int8[2]; // int8_t int8[2];
int16_t int16; // int16_t int16;
}; // };
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); // asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0]; // return int8[0];
int8_t dst;
if (val >= 128){
dst = 127;
}else if (val < -128){
dst = -128;
}else{
dst = static_cast<int8_t>(val);
}
return dst;
} }
template<> template<>
...@@ -528,13 +539,15 @@ __device__ inline To cuda_max(Ti val) ...@@ -528,13 +539,15 @@ __device__ inline To cuda_max(Ti val)
template<> template<>
__device__ inline half cuda_max(half2 val) __device__ inline half cuda_max(half2 val)
{ {
return (val.x > val.y) ? val.x : val.y; // return (val.x > val.y) ? val.x : val.y;
return (val.data[0] > val.data[1]) ? val.data[0] : val.data[1];
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template<> template<>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) __device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
{ {
return (val.x > val.y) ? val.x : val.y; // return (val.x > val.y) ? val.x : val.y;
return (val.data[0] > val.data[1]) ? val.data[0] : val.data[1];
} }
#endif #endif
......
...@@ -26,7 +26,7 @@ Gemm::Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file) ...@@ -26,7 +26,7 @@ Gemm::Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file)
stream_ = stream; stream_ = stream;
mutex_ = new std::mutex(); // mutex per process mutex_ = new std::mutex(); // mutex per process
check_cuda_error(cublasCreate(&cublas_handle_)); check_cuda_error(cublasCreate(&cublas_handle_));
check_cuda_error(cublasLtCreate(&cublaslt_handle_)); // check_cuda_error(cublasLtCreate(&cublaslt_handle_));
check_cuda_error(cublasSetStream(cublas_handle_, stream)); check_cuda_error(cublasSetStream(cublas_handle_, stream));
if (allocator_ != nullptr) { if (allocator_ != nullptr) {
...@@ -41,7 +41,7 @@ Gemm::~Gemm() ...@@ -41,7 +41,7 @@ Gemm::~Gemm()
allocator_->free((void**)(&workspace_)); allocator_->free((void**)(&workspace_));
allocator_ = nullptr; allocator_ = nullptr;
} }
cublasLtDestroy(cublaslt_handle_); // cublasLtDestroy(cublaslt_handle_);
cublasDestroy(cublas_handle_); cublasDestroy(cublas_handle_);
delete cublas_algo_map_; delete cublas_algo_map_;
delete mutex_; delete mutex_;
...@@ -248,7 +248,8 @@ void Gemm::gemm(const GemmOp transa, ...@@ -248,7 +248,8 @@ void Gemm::gemm(const GemmOp transa,
mutex_->lock(); mutex_->lock();
// Use cublas as default in FP32 and cublasLt as default in FP16 // Use cublas as default in FP32 and cublasLt as default in FP16
bool is_fp16_compute_type = compute_type_ == TYPE_FP16; bool is_fp16_compute_type = compute_type_ == TYPE_FP16;
bool using_cublasLt = Atype == TYPE_FP16; // bool using_cublasLt = Atype == TYPE_FP16;
bool using_cublasLt = (Atype == TYPE_FP16) ? false : false;
int batch_count = 1; int batch_count = 1;
half h_alpha = (half)alpha; half h_alpha = (half)alpha;
...@@ -267,7 +268,8 @@ void Gemm::gemm(const GemmOp transa, ...@@ -267,7 +268,8 @@ void Gemm::gemm(const GemmOp transa,
using_cublasLt = (info.stages != -1); using_cublasLt = (info.stages != -1);
} }
if (using_cublasLt) { // if (using_cublasLt) {
if(0) {
const size_t a_rows = (a_op == getCublasOperation(GEMM_OP_N)) ? _m : k; const size_t a_rows = (a_op == getCublasOperation(GEMM_OP_N)) ? _m : k;
const size_t a_cols = (a_op == getCublasOperation(GEMM_OP_N)) ? k : _m; const size_t a_cols = (a_op == getCublasOperation(GEMM_OP_N)) ? k : _m;
const size_t b_rows = (b_op == getCublasOperation(GEMM_OP_N)) ? k : _n; const size_t b_rows = (b_op == getCublasOperation(GEMM_OP_N)) ? k : _n;
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
cmake_minimum_required(VERSION 3.8) cmake_minimum_required(VERSION 3.8)
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
set(gemm_func_files set(gemm_func_files
gemm_func.cc gemm_func.cc
...@@ -51,59 +52,71 @@ set(swin_gemm_func_files ...@@ -51,59 +52,71 @@ set(swin_gemm_func_files
swin_gemm_func.cc swin_gemm_func.cc
) )
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
add_library(gemm_func STATIC ${gemm_func_files}) add_library(gemm_func STATIC ${gemm_func_files})
target_link_libraries(gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart cuda_utils logger) #target_link_libraries(gemm_func PUBLIC cublas cublasLt cudart cuda_utils logger)
set_property(TARGET gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(gemm_func PUBLIC cublas cudart cuda_utils logger)
set_property(TARGET gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(encoder_gemm_func STATIC ${encoder_gemm_func_files}) add_library(encoder_gemm_func STATIC ${encoder_gemm_func_files})
target_link_libraries(encoder_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger) #target_link_libraries(encoder_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
target_link_libraries(encoder_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger)
if (SPARSITY_SUPPORT) if (SPARSITY_SUPPORT)
target_link_libraries(encoder_gemm_func PUBLIC CUDA::cusparse -lcusparseLt) target_link_libraries(encoder_gemm_func PUBLIC cusparse -lcusparseLt)
endif() endif()
set_property(TARGET encoder_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET encoder_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET encoder_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET encoder_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(encoder_igemm_func STATIC ${encoder_igemm_func_files}) add_library(encoder_igemm_func STATIC ${encoder_igemm_func_files})
target_link_libraries(encoder_igemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart cuda_utils logger) #target_link_libraries(encoder_igemm_func PUBLIC cublas cublasLt cudart cuda_utils logger)
target_link_libraries(encoder_igemm_func PUBLIC cublas cudart cuda_utils logger)
if (SPARSITY_SUPPORT) if (SPARSITY_SUPPORT)
target_link_libraries(encoder_igemm_func PUBLIC CUDA::cusparse -lcusparseLt) target_link_libraries(encoder_igemm_func PUBLIC cusparse -lcusparseLt)
endif() endif()
set_property(TARGET encoder_igemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET encoder_igemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET encoder_igemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET encoder_igemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(decoding_gemm_func STATIC ${decoding_gemm_func_files}) add_library(decoding_gemm_func STATIC ${decoding_gemm_func_files})
target_link_libraries(decoding_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger) #target_link_libraries(decoding_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
set_property(TARGET decoding_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(decoding_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger)
set_property(TARGET decoding_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET decoding_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET decoding_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(gpt_gemm_func STATIC ${gpt_gemm_func_files}) add_library(gpt_gemm_func STATIC ${gpt_gemm_func_files})
target_link_libraries(gpt_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger) #target_link_libraries(gpt_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
target_link_libraries(gpt_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger)
if (SPARSITY_SUPPORT) if (SPARSITY_SUPPORT)
target_link_libraries(gpt_gemm_func PUBLIC CUDA::cusparse -lcusparseLt) target_link_libraries(gpt_gemm_func PUBLIC cusparse -lcusparseLt)
endif() endif()
set_property(TARGET gpt_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET gpt_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET gpt_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET gpt_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(xlnet_gemm_func STATIC ${xlnet_gemm_func_files}) add_library(xlnet_gemm_func STATIC ${xlnet_gemm_func_files})
target_link_libraries(xlnet_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger) #target_link_libraries(xlnet_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
set_property(TARGET xlnet_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(xlnet_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger)
set_property(TARGET xlnet_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET xlnet_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET xlnet_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(t5_gemm_func STATIC ${t5_gemm_func_files}) add_library(t5_gemm_func STATIC ${t5_gemm_func_files})
target_link_libraries(t5_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger) #target_link_libraries(t5_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
target_link_libraries(t5_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger)
if (SPARSITY_SUPPORT) if (SPARSITY_SUPPORT)
target_link_libraries(t5_gemm_func PUBLIC CUDA::cusparse -lcusparseLt) target_link_libraries(t5_gemm_func PUBLIC cusparse -lcusparseLt)
endif() endif()
set_property(TARGET t5_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET t5_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET t5_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET t5_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(swin_igemm_func STATIC ${swin_igemm_func_files}) add_library(swin_igemm_func STATIC ${swin_igemm_func_files})
target_link_libraries(swin_igemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func encoder_igemm_func cuda_utils logger) #target_link_libraries(swin_igemm_func PUBLIC cublas cublasLt cudart gemm_func encoder_igemm_func cuda_utils logger)
set_property(TARGET swin_igemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(swin_igemm_func PUBLIC cublas cudart gemm_func encoder_igemm_func cuda_utils logger)
set_property(TARGET swin_igemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET swin_igemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET swin_igemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(swin_gemm_func STATIC ${swin_gemm_func_files}) add_library(swin_gemm_func STATIC ${swin_gemm_func_files})
target_link_libraries(swin_gemm_func PUBLIC CUDA::cublas CUDA::cublasLt CUDA::cudart gemm_func cuda_utils logger) #target_link_libraries(swin_gemm_func PUBLIC cublas cublasLt cudart gemm_func cuda_utils logger)
set_property(TARGET swin_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(swin_gemm_func PUBLIC cublas cudart gemm_func cuda_utils logger)
set_property(TARGET swin_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET swin_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET swin_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
...@@ -130,8 +130,8 @@ void generate_decoding_gemm_config(int batch_size, ...@@ -130,8 +130,8 @@ void generate_decoding_gemm_config(int batch_size,
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle)); // check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType; cudaDataType_t AType;
cudaDataType_t BType; cudaDataType_t BType;
...@@ -156,8 +156,10 @@ void generate_decoding_gemm_config(int batch_size, ...@@ -156,8 +156,10 @@ void generate_decoding_gemm_config(int batch_size,
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { else if (std::is_same<T, __nv_bfloat16>::value) {
...@@ -166,8 +168,10 @@ void generate_decoding_gemm_config(int batch_size, ...@@ -166,8 +168,10 @@ void generate_decoding_gemm_config(int batch_size,
BType = CUDA_R_16BF; BType = CUDA_R_16BF;
CType = CUDA_R_16BF; CType = CUDA_R_16BF;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
using scaleT = typename ScaleTypeConverter<T>::Type; using scaleT = typename ScaleTypeConverter<T>::Type;
...@@ -241,38 +245,39 @@ void generate_decoding_gemm_config(int batch_size, ...@@ -241,38 +245,39 @@ void generate_decoding_gemm_config(int batch_size,
const int ALGO_COMBINATIONS = 5000; const int ALGO_COMBINATIONS = 5000;
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
LtHgemmCustomFind<T, scaleT>(ltHandle, // LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size * beam_width, // batch_size * beam_width,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&alpha, // &alpha,
d_B, // d_B,
d_A, // d_A,
&beta, // &beta,
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS); // ALGO_COMBINATIONS);
if (perfResults[0].time < exec_time) { // if (perfResults[0].time < exec_time) {
printPerfStructure(batch_size * beam_width, // printPerfStructure(batch_size * beam_width,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
perfResults[0], // perfResults[0],
fd, // fd,
data_type, // data_type,
0); // 0);
} // }
else { // else {
{
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 " "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
...@@ -127,8 +127,8 @@ void generate_encoder_gemm_config( ...@@ -127,8 +127,8 @@ void generate_encoder_gemm_config(
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle)); // check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType; cudaDataType_t AType;
cudaDataType_t BType; cudaDataType_t BType;
...@@ -153,8 +153,10 @@ void generate_encoder_gemm_config( ...@@ -153,8 +153,10 @@ void generate_encoder_gemm_config(
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { else if (std::is_same<T, __nv_bfloat16>::value) {
...@@ -163,8 +165,10 @@ void generate_encoder_gemm_config( ...@@ -163,8 +165,10 @@ void generate_encoder_gemm_config(
BType = CUDA_R_16BF; BType = CUDA_R_16BF;
CType = CUDA_R_16BF; CType = CUDA_R_16BF;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
using scaleT = typename ScaleTypeConverter<T, false>::Type; using scaleT = typename ScaleTypeConverter<T, false>::Type;
...@@ -331,30 +335,31 @@ void generate_encoder_gemm_config( ...@@ -331,30 +335,31 @@ void generate_encoder_gemm_config(
// Let try a fixed number of combinations // Let try a fixed number of combinations
const int ALGO_COMBINATIONS = 5000; const int ALGO_COMBINATIONS = 5000;
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
LtHgemmCustomFind<T, scaleT>(ltHandle, // LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size, // batch_size,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&alpha, // &alpha,
d_B, // d_B,
d_A, // d_A,
&beta, // &beta,
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS); // ALGO_COMBINATIONS);
if (perfResults[0].time < exec_time) { // if (perfResults[0].time < exec_time) {
printPerfStructure( // printPerfStructure(
batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0); // batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0);
exec_time = perfResults[0].time; // exec_time = perfResults[0].time;
} // }
else { // else {
{
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 " "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
...@@ -234,22 +234,22 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, // ...@@ -234,22 +234,22 @@ static cublasStatus_t customMatmulRun(cublasLtHandle_t ltHandle, //
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto start = std::chrono::high_resolution_clock::now(); auto start = std::chrono::high_resolution_clock::now();
for (int loop = 0; loop < repeats; loop++) { for (int loop = 0; loop < repeats; loop++) {
oneRunStatus = cublasLtMatmul(ltHandle, // oneRunStatus = cublasLtMatmul(ltHandle,
operationDesc, // operationDesc,
alpha, // alpha,
A, // A,
Adesc, // Adesc,
B, // B,
Bdesc, // Bdesc,
beta, // beta,
C, // C,
Cdesc, // Cdesc,
D, // D,
Ddesc, // Ddesc,
&algo, // &algo,
workSpace, // workSpace,
workSpaceSizeInBytes, // workSpaceSizeInBytes,
stream); // stream);
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto end = std::chrono::high_resolution_clock::now(); auto end = std::chrono::high_resolution_clock::now();
......
...@@ -223,8 +223,8 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -223,8 +223,8 @@ void generate_gpt_gemm_config(int batch_size,
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle)); // check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType; cudaDataType_t AType;
cudaDataType_t BType; cudaDataType_t BType;
...@@ -253,8 +253,10 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -253,8 +253,10 @@ void generate_gpt_gemm_config(int batch_size,
CType = CUDA_R_16F; CType = CUDA_R_16F;
DType = CUDA_R_16F; DType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { else if (std::is_same<T, __nv_bfloat16>::value) {
...@@ -264,8 +266,10 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -264,8 +266,10 @@ void generate_gpt_gemm_config(int batch_size,
CType = CUDA_R_16BF; CType = CUDA_R_16BF;
DType = CUDA_R_16BF; DType = CUDA_R_16BF;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
...@@ -293,8 +297,10 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -293,8 +297,10 @@ void generate_gpt_gemm_config(int batch_size,
DType_FP8[9] = CUDA_R_16BF; DType_FP8[9] = CUDA_R_16BF;
#endif #endif
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
float alpha = (float)1.0f; float alpha = (float)1.0f;
...@@ -456,44 +462,45 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -456,44 +462,45 @@ void generate_gpt_gemm_config(int batch_size,
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
// for gpt, computeType & scaleType should be FP32 // for gpt, computeType & scaleType should be FP32
LtHgemmCustomFind<T, float>(ltHandle, // LtHgemmCustomFind<T, float>(ltHandle,
batch_size * beam_width, // batch_size * beam_width,
i == 1 || i == 2 ? max_input_len : 1, // i == 1 || i == 2 ? max_input_len : 1,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&alpha, // &alpha,
d_B, // d_B,
d_A, // d_A,
&beta, // &beta,
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS, // ALGO_COMBINATIONS,
DType_FP8[i], // DType_FP8[i],
batchCount[i], // batchCount[i],
strideA[i], // strideA[i],
strideB[i], // strideB[i],
strideD[i]); // strideD[i]);
if (perfResults[0].time < exec_time) { // if (perfResults[0].time < exec_time) {
printPerfStructure(batch_size * beam_width, // printPerfStructure(batch_size * beam_width,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
perfResults[0], // perfResults[0],
fd, // fd,
data_type, // data_type,
0, // 0,
batchCount[i]); // batchCount[i]);
} // }
else { // else {
{
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 " "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
...@@ -133,8 +133,8 @@ void generate_swin_gemm_config( ...@@ -133,8 +133,8 @@ void generate_swin_gemm_config(
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle)); // check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType; cudaDataType_t AType;
cudaDataType_t BType; cudaDataType_t BType;
...@@ -159,8 +159,10 @@ void generate_swin_gemm_config( ...@@ -159,8 +159,10 @@ void generate_swin_gemm_config(
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { else if (std::is_same<T, __nv_bfloat16>::value) {
...@@ -169,8 +171,10 @@ void generate_swin_gemm_config( ...@@ -169,8 +171,10 @@ void generate_swin_gemm_config(
BType = CUDA_R_16BF; BType = CUDA_R_16BF;
CType = CUDA_R_16BF; CType = CUDA_R_16BF;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
using scaleT = typename ScaleTypeConverter<T, false>::Type; using scaleT = typename ScaleTypeConverter<T, false>::Type;
...@@ -309,30 +313,31 @@ void generate_swin_gemm_config( ...@@ -309,30 +313,31 @@ void generate_swin_gemm_config(
const int ALGO_COMBINATIONS = 5000; const int ALGO_COMBINATIONS = 5000;
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
LtHgemmCustomFind<T, scaleT>(ltHandle, // LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size, // batch_size,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&alpha, // &alpha,
d_B, // d_B,
d_A, // d_A,
&beta, // &beta,
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS); // ALGO_COMBINATIONS);
if (perfResults[0].time < exec_time) { // if (perfResults[0].time < exec_time) {
printPerfStructure( // printPerfStructure(
batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0); // batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0);
exec_time = perfResults[0].time; // exec_time = perfResults[0].time;
} // }
else { // else {
{
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 " "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
...@@ -144,23 +144,23 @@ int igemm_config_INT8IO(int m, int n, int k, FILE* fout, void* buffer) ...@@ -144,23 +144,23 @@ int igemm_config_INT8IO(int m, int n, int k, FILE* fout, void* buffer)
int8_t* d_B = d_A + m * k; // k * n, stored in column-major int8_t* d_B = d_A + m * k; // k * n, stored in column-major
int8_t* d_C = (int8_t*)(d_B + k * n); // m * n, stored in column-major int8_t* d_C = (int8_t*)(d_B + k * n); // m * n, stored in column-major
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
cublasLtCreate(&ltHandle); // cublasLtCreate(&ltHandle);
LtIgemmCustomFind(ltHandle, // LtIgemmCustomFind(ltHandle,
m, // m,
n, // n,
k, // k,
&alpha, /* host pointer */ // &alpha, /* host pointer */
d_A, // d_A,
d_B, // d_B,
&beta, /* host pointer */ // &beta, /* host pointer */
d_C, // d_C,
NULL, // NULL,
0, // 0,
fout); // fout);
cublasLtDestroy(ltHandle); // cublasLtDestroy(ltHandle);
return 0; return 0;
} }
......
...@@ -195,8 +195,8 @@ void generate_t5_gemm_config(int batch_size, ...@@ -195,8 +195,8 @@ void generate_t5_gemm_config(int batch_size,
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle)); // check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType; cudaDataType_t AType;
cudaDataType_t BType; cudaDataType_t BType;
...@@ -221,8 +221,10 @@ void generate_t5_gemm_config(int batch_size, ...@@ -221,8 +221,10 @@ void generate_t5_gemm_config(int batch_size,
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { else if (std::is_same<T, __nv_bfloat16>::value) {
...@@ -231,8 +233,10 @@ void generate_t5_gemm_config(int batch_size, ...@@ -231,8 +233,10 @@ void generate_t5_gemm_config(int batch_size,
BType = CUDA_R_16BF; BType = CUDA_R_16BF;
CType = CUDA_R_16BF; CType = CUDA_R_16BF;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
float f_alpha = (float)1.0f; float f_alpha = (float)1.0f;
...@@ -442,60 +446,61 @@ void generate_t5_gemm_config(int batch_size, ...@@ -442,60 +446,61 @@ void generate_t5_gemm_config(int batch_size,
scaleT alpha_scale = (scaleT)1.0f; scaleT alpha_scale = (scaleT)1.0f;
scaleT beta_scale = (scaleT)0.0f; scaleT beta_scale = (scaleT)0.0f;
LtHgemmCustomFind<T, scaleT>(ltHandle, // LtHgemmCustomFind<T, scaleT>(ltHandle,
m, // m,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&(alpha_scale), // &(alpha_scale),
d_B, // d_B,
d_A, // d_A,
&(beta_scale), // &(beta_scale),
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS); // ALGO_COMBINATIONS);
} }
else { else {
LtHgemmCustomFind<T, float>(ltHandle, // LtHgemmCustomFind<T, float>(ltHandle,
m, // m,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&(f_alpha), // &(f_alpha),
d_B, // d_B,
d_A, // d_A,
&(f_beta), // &(f_beta),
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS); // ALGO_COMBINATIONS);
} }
if (perfResults[0].time < exec_time) { // if (perfResults[0].time < exec_time) {
printPerfStructure(batch_size * (i <= 5 || i == 1 ? 1 : beam_width), // printPerfStructure(batch_size * (i <= 5 || i == 1 ? 1 : beam_width),
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
perfResults[0], // perfResults[0],
fd, // fd,
data_type, // data_type,
0); // 0);
} // }
else { // else {
{
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 " "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment