Commit 9484fd1c authored by xiabo's avatar xiabo
Browse files

Adapt to 0.1.0

parent 477f2db8
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
#elif (CUDART_VERSION >= 11000) #elif (CUDART_VERSION >= 11000)
#include <cub/cub.cuh> #include <cub/cub.cuh>
#else #else
#include "3rdparty/cub/cub.cuh" // #include "3rdparty/cub/cub.cuh"
#include <cub/cub.cuh>
#endif #endif
#include "src/turbomind/kernels/reduce_kernel_utils.cuh" #include "src/turbomind/kernels/reduce_kernel_utils.cuh"
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
#elif (CUDART_VERSION >= 11000) #elif (CUDART_VERSION >= 11000)
#include <cub/cub.cuh> #include <cub/cub.cuh>
#else #else
#include "3rdparty/cub/cub.cuh" // #include "3rdparty/cub/cub.cuh"
#include <cub/cub.cuh>
#endif #endif
#include "src/turbomind/kernels/reduce_kernel_utils.cuh" #include "src/turbomind/kernels/reduce_kernel_utils.cuh"
......
...@@ -145,7 +145,8 @@ void invokeLengthCriterion(bool* finished, ...@@ -145,7 +145,8 @@ void invokeLengthCriterion(bool* finished,
// Check if we have attained the sequence length limit. If so, stop the sequence. // Check if we have attained the sequence length limit. If so, stop the sequence.
// In addition, check if all sequences are stopped and return the result in should_stop // In addition, check if all sequences are stopped and return the result in should_stop
TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
dim3 block{min(512, uint32_t(batch_size * beam_width))}; // dim3 block{min(512, uint32_t(batch_size * beam_width))};
dim3 block{static_cast<unsigned int>(min(512, uint32_t(batch_size * beam_width)))};
dim3 grid{1}; dim3 grid{1};
h_pinned_finished_sum_[0] = -1; h_pinned_finished_sum_[0] = -1;
......
...@@ -178,7 +178,11 @@ __global__ void softmax_kernel_h2(T* attn_score, ...@@ -178,7 +178,11 @@ __global__ void softmax_kernel_h2(T* attn_score,
qk_bias = hadd2<T2>(qk_bias, hmul2<T2>(hsub2<T2>(ONE, mask_val), NEG_INFTY)); qk_bias = hadd2<T2>(qk_bias, hmul2<T2>(hsub2<T2>(ONE, mask_val), NEG_INFTY));
data[i] = hadd2<T2>(hmul2<T2>(qk, qk_scale_h2), qk_bias); data[i] = hadd2<T2>(hmul2<T2>(qk, qk_scale_h2), qk_bias);
local_max = fmax(local_max, fmax((float)data[i].x, (float)data[i].y)); // if (std::is_same<T2, half2>::value) {
local_max = fmax(local_max, fmax((float)data[i].data[0], (float)data[i].data[1]));
// } else {
// local_max = fmax(local_max, fmax((float)data[i].x, (float)data[i].y));
// }
} }
float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax<float>(local_max); float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax<float>(local_max);
...@@ -190,7 +194,11 @@ __global__ void softmax_kernel_h2(T* attn_score, ...@@ -190,7 +194,11 @@ __global__ void softmax_kernel_h2(T* attn_score,
float local_sum = 0.0f; float local_sum = 0.0f;
for (int i = 0; blockDim.x * i + threadIdx.x < (k_length / 2) && i < ITEMS_PER_THREAD; i++) { for (int i = 0; blockDim.x * i + threadIdx.x < (k_length / 2) && i < ITEMS_PER_THREAD; i++) {
data[i] = hexp2<T2>(hsub2<T2>(data[i], cuda_cast<T2>(s_max))); data[i] = hexp2<T2>(hsub2<T2>(data[i], cuda_cast<T2>(s_max)));
local_sum += (float)(data[i].x + data[i].y); // if (std::is_same<T2, half2>::value) {
local_sum += (float)(data[i].data[0] + data[i].data[1]);
// } else {
// local_sum += (float)(data[i].x + data[i].y);
// }
} }
float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum<float>(local_sum); float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum<float>(local_sum);
...@@ -310,7 +318,11 @@ __global__ void softmax_kernel_h2_v2(T* attn_score, ...@@ -310,7 +318,11 @@ __global__ void softmax_kernel_h2_v2(T* attn_score,
val = hadd2<T2>(val, pos_bias[j]); val = hadd2<T2>(val, pos_bias[j]);
} }
data[j][i] = val; data[j][i] = val;
local_max[j] = fmax(local_max[j], fmax((float)data[j][i].x, (float)data[j][i].y)); // if (std::is_same<T2, half2>::value) {
local_max[j] = fmax(local_max[j], fmax((float)data[j][i].data[0], (float)data[j][i].data[1]));
// } else {
// local_max[j] = fmax(local_max[j], fmax((float)data[j][i].x, (float)data[j][i].y));
// }
} }
} }
...@@ -343,7 +355,11 @@ __global__ void softmax_kernel_h2_v2(T* attn_score, ...@@ -343,7 +355,11 @@ __global__ void softmax_kernel_h2_v2(T* attn_score,
#pragma unroll #pragma unroll
for (int j = 0; j < Q_ITEMS; j++) { for (int j = 0; j < Q_ITEMS; j++) {
local_sum[j] += (float)(data[j][i].x + data[j][i].y); // if (std::is_same<T2, half2>::value) {
local_sum[j] += (float)(data[j][i].data[0] + data[j][i].data[1]);
// } else {
// local_sum[j] += (float)(data[j][i].x + data[j][i].y);
// }
} }
} }
...@@ -1885,6 +1901,7 @@ void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf, ...@@ -1885,6 +1901,7 @@ void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf,
qk_scale); qk_scale);
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
printf("============xiabo_test %s:%d\n", __FILE__,__LINE__);
softmax_withRelPosBias_element2_kernel<half2, half> softmax_withRelPosBias_element2_kernel<half2, half>
<<<grid, block, 0, stream>>>((half2*)qk_buf, <<<grid, block, 0, stream>>>((half2*)qk_buf,
(const half2*)attn_mask, (const half2*)attn_mask,
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
cmake_minimum_required(VERSION 3.8) cmake_minimum_required(VERSION 3.8)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
add_subdirectory(sampling_layers) add_subdirectory(sampling_layers)
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
add_library(DynamicDecodeLayer STATIC DynamicDecodeLayer.cc) add_library(DynamicDecodeLayer STATIC DynamicDecodeLayer.cc)
set_property(TARGET DynamicDecodeLayer PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET DynamicDecodeLayer PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET DynamicDecodeLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET DynamicDecodeLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(DynamicDecodeLayer PUBLIC CUDA::cudart TopKSamplingLayer target_link_libraries(DynamicDecodeLayer PUBLIC cudart TopKSamplingLayer
TopPSamplingLayer ban_bad_words stop_criteria gpt_kernels tensor nvtx_utils) TopPSamplingLayer ban_bad_words stop_criteria gpt_kernels tensor nvtx_utils)
...@@ -14,19 +14,23 @@ ...@@ -14,19 +14,23 @@
cmake_minimum_required(VERSION 3.8) cmake_minimum_required(VERSION 3.8)
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
add_library(BaseSamplingLayer STATIC BaseSamplingLayer.cc) add_library(BaseSamplingLayer STATIC BaseSamplingLayer.cc)
set_property(TARGET BaseSamplingLayer PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET BaseSamplingLayer PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET BaseSamplingLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET BaseSamplingLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(BaseSamplingLayer PUBLIC CUDA::cudart sampling_penalty_kernels memory_utils) target_link_libraries(BaseSamplingLayer PUBLIC cudart sampling_penalty_kernels memory_utils)
add_library(TopKSamplingLayer STATIC TopKSamplingLayer.cu) add_library(TopKSamplingLayer STATIC TopKSamplingLayer.cu)
set_property(TARGET TopKSamplingLayer PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET TopKSamplingLayer PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET TopKSamplingLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET TopKSamplingLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(TopKSamplingLayer PUBLIC CUDA::cudart BaseSamplingLayer sampling_topk_kernels) target_link_libraries(TopKSamplingLayer PUBLIC cudart BaseSamplingLayer sampling_topk_kernels)
add_library(TopPSamplingLayer STATIC TopPSamplingLayer.cu) add_library(TopPSamplingLayer STATIC TopPSamplingLayer.cu)
set_property(TARGET TopPSamplingLayer PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET TopPSamplingLayer PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET TopPSamplingLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET TopPSamplingLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(TopPSamplingLayer PUBLIC CUDA::cudart BaseSamplingLayer sampling_topk_kernels sampling_topp_kernels) target_link_libraries(TopPSamplingLayer PUBLIC cudart BaseSamplingLayer sampling_topk_kernels sampling_topp_kernels)
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
cmake_minimum_required(VERSION 3.8) cmake_minimum_required(VERSION 3.8)
add_subdirectory(fused_multi_head_attention) #add_subdirectory(fused_multi_head_attention)
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
add_library(Llama STATIC add_library(Llama STATIC
LlamaV2.cc LlamaV2.cc
...@@ -19,10 +20,12 @@ add_library(Llama STATIC ...@@ -19,10 +20,12 @@ add_library(Llama STATIC
llama_kernels.cu llama_kernels.cu
llama_decoder_kernels.cu llama_decoder_kernels.cu
llama_utils.cu) llama_utils.cu)
set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
target_link_libraries(Llama PUBLIC CUDA::cudart #set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
gemm_s4_f16 #set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(Llama PUBLIC cudart
# gemm_s4_f16
cublasMMWrapper cublasMMWrapper
DynamicDecodeLayer DynamicDecodeLayer
activation_kernels activation_kernels
...@@ -38,17 +41,16 @@ target_link_libraries(Llama PUBLIC CUDA::cudart ...@@ -38,17 +41,16 @@ target_link_libraries(Llama PUBLIC CUDA::cudart
memory_utils memory_utils
nccl_utils nccl_utils
cuda_utils cuda_utils
logger logger)
llama_fmha) # llama_fmha)
if (NOT MSVC) if (NOT MSVC)
add_subdirectory(flash_attention2) # add_subdirectory(flash_attention2)
target_link_libraries(Llama PUBLIC flash_attention2) # target_link_libraries(Llama PUBLIC flash_attention2)
endif() endif()
add_executable(llama_gemm llama_gemm.cc) add_executable(llama_gemm llama_gemm.cc)
target_link_libraries(llama_gemm PUBLIC CUDA::cudart gpt_gemm_func memory_utils cuda_utils logger) target_link_libraries(llama_gemm PUBLIC cudart gpt_gemm_func memory_utils cuda_utils logger)
install(TARGETS llama_gemm DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/bin) install(TARGETS llama_gemm DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/bin)
find_package(Catch2 3 QUIET) find_package(Catch2 3 QUIET)
......
...@@ -22,10 +22,18 @@ ...@@ -22,10 +22,18 @@
#include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/memory_utils.h" #include "src/turbomind/utils/memory_utils.h"
#include <filesystem> // #include <filesystem>
#include <experimental/filesystem>
#include <sys/stat.h>
#include <string>
namespace turbomind { namespace turbomind {
bool fileExists(const std::string& path) {
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0);
}
template<typename T> template<typename T>
LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num, LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
size_t kv_head_num, size_t kv_head_num,
...@@ -170,7 +178,8 @@ void loadWeights(LlamaDenseWeight<T>& w, ...@@ -170,7 +178,8 @@ void loadWeights(LlamaDenseWeight<T>& w,
} }
else { else {
// Disable slice if weight has already been sliced // Disable slice if weight has already been sliced
if (std::filesystem::exists(max_prefix + ".weight") || std::filesystem::exists(max_prefix + ".qweight")) { // if (std::filesystem::exists(max_prefix + ".weight") || std::filesystem::exists(max_prefix + ".qweight")) {
if (fileExists(max_prefix + ".weight") || fileExists(max_prefix + ".qweight")) {
TM_LOG_DEBUG("TP weight exists. Disable runtime TP."); TM_LOG_DEBUG("TP weight exists. Disable runtime TP.");
enable_slice = false; enable_slice = false;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#pragma once #pragma once
#include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h" // #include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cublasMMWrapper.h"
...@@ -62,29 +62,29 @@ private: ...@@ -62,29 +62,29 @@ private:
void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type) void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
{ {
if constexpr (std::is_same_v<T, half>) { // if constexpr (std::is_same_v<T, half>) {
gemm_s4_f16_.Run(output_data, // gemm_s4_f16_.Run(output_data,
(const uint*)weight.kernel, // (const uint*)weight.kernel,
input_data, // input_data,
(const half2*)weight.scales_and_zeros, // (const half2*)weight.scales_and_zeros,
weight.output_dims, // weight.output_dims,
batch_size, // batch_size,
weight.input_dims, // weight.input_dims,
weight.group_size, // weight.group_size,
type == kFusedSiluFfn ? GemmS4F16::kFusedSiluFfn : GemmS4F16::kGemm, // type == kFusedSiluFfn ? GemmS4F16::kFusedSiluFfn : GemmS4F16::kGemm,
-1, // -1,
stream_); // stream_);
sync_check_cuda_error(); // sync_check_cuda_error();
} // }
else { // else {
FT_CHECK_WITH_INFO(0, "Not implemented"); FT_CHECK_WITH_INFO(0, "Not implemented");
} // }
} }
private: private:
cublasMMWrapper* cublas_wrapper_; cublasMMWrapper* cublas_wrapper_;
cudaStream_t stream_{}; cudaStream_t stream_{};
GemmS4F16 gemm_s4_f16_; // GemmS4F16 gemm_s4_f16_;
}; };
} // namespace turbomind } // namespace turbomind
...@@ -110,24 +110,58 @@ struct res_norm_ops_t<__nv_bfloat16> { ...@@ -110,24 +110,58 @@ struct res_norm_ops_t<__nv_bfloat16> {
#endif #endif
template<typename T> // template<typename T>
__device__ T blockReduceSum(const cg::thread_block& block, T value) // __device__ T blockReduceSum(const cg::thread_block& block, T value)
{ // {
__shared__ float partial[32]; // __shared__ float partial[32];
auto tile = cg::tiled_partition<32>(block); // auto tile = cg::tiled_partition<32>(block);
value = cg::reduce(tile, value, cg::plus<float>{}); // value = cg::reduce(tile, value, cg::plus<float>{});
if (tile.thread_rank() == 0) { // if (tile.thread_rank() == 0) {
partial[tile.meta_group_rank()] = value; // partial[tile.meta_group_rank()] = value;
} // }
block.sync(); // block.sync();
value = tile.thread_rank() < tile.meta_group_size() ? partial[tile.thread_rank()] : T{}; // value = tile.thread_rank() < tile.meta_group_size() ? partial[tile.thread_rank()] : T{};
return cg::reduce(tile, value, cg::plus<float>{}); // return cg::reduce(tile, value, cg::plus<float>{});
// }
#define WARPSIZE 64
template<typename T>
__inline__ __device__ T warpReduceSum_xiabo(T value)
{
#pragma unroll
for (int offset = WARPSIZE / 2; offset > 0; offset >>= 1)
value += __shfl_down_sync(0xffffffff, value, offset);
return value;
} }
template<typename T>
__inline__ __device__ T blockReduceSum_xiabo(T val)
{
T sum = (T)(0.0f);
__shared__ T shared[WARPSIZE];
sum = warpReduceSum_xiabo(val);
__syncthreads();
int tid = threadIdx.x + threadIdx.y * blockDim.x;
if (tid % WARPSIZE == 0) {
shared[tid / WARPSIZE] = sum;
}
if (tid >= blockDim.x * blockDim.y / WARPSIZE && tid < WARPSIZE) {
shared[tid] = (T)(0.0f);
}
__syncthreads();
if (tid / WARPSIZE == 0) {
sum = warpReduceSum_xiabo(shared[tid]);
if (tid == 0) {
shared[0] = sum;
}
}
__syncthreads();
return shared[0];
}
// r' = r + x // r' = r + x
// x' = norm(r') * scales // x' = norm(r') * scales
template<typename T> template<typename T>
...@@ -140,7 +174,7 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data, ...@@ -140,7 +174,7 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
int n_dims) int n_dims)
{ {
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto grid = cg::this_grid(); // auto grid = cg::this_grid();
constexpr int PACK_DIM = sizeof(uint4) / sizeof(T); constexpr int PACK_DIM = sizeof(uint4) / sizeof(T);
...@@ -160,7 +194,8 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data, ...@@ -160,7 +194,8 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data,
r_ptr[i] = r; r_ptr[i] = r;
} }
auto total_sum = blockReduceSum(block, thread_sum); // auto total_sum = blockReduceSum(block, thread_sum);
auto total_sum = blockReduceSum_xiabo(thread_sum);
float s_inv_mean = rsqrt(total_sum / n_dims + eps); float s_inv_mean = rsqrt(total_sum / n_dims + eps);
......
...@@ -918,50 +918,50 @@ void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cud ...@@ -918,50 +918,50 @@ void invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cud
} \ } \
}() }()
template<typename T> // template<typename T>
FlashAttentionOp<T>::FlashAttentionOp(int batch_size, int head_num, int key_len, int seq_len, int size_per_head): // FlashAttentionOp<T>::FlashAttentionOp(int batch_size, int head_num, int key_len, int seq_len, int size_per_head):
batch_size_(batch_size), head_num_(head_num), key_len_(key_len), seq_len_(seq_len), size_per_head_(size_per_head) // batch_size_(batch_size), head_num_(head_num), key_len_(key_len), seq_len_(seq_len), size_per_head_(size_per_head)
{ // {
#ifdef _MSC_VER // #ifdef _MSC_VER
op_version_ = 1; // op_version_ = 1;
#else // #else
op_version_ = std::is_same<float, typename std::decay<T>::type>::value ? 1 : 2; // op_version_ = std::is_same<float, typename std::decay<T>::type>::value ? 1 : 2;
if (op_version_ == 2 && getSMVersion() < 80) { // if (op_version_ == 2 && getSMVersion() < 80) {
op_version_ = 1; // op_version_ = 1;
} // }
#endif // #endif
} // }
template<typename T> // template<typename T>
int FlashAttentionOp<T>::get_workspace_size() const // int FlashAttentionOp<T>::get_workspace_size() const
{ // {
#ifdef _MSC_VER // #ifdef _MSC_VER
FlashAttentionOpImpl<T, 1> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_); // FlashAttentionOpImpl<T, 1> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_);
return attention_op.get_workspace_size(); // return attention_op.get_workspace_size();
#else // #else
return VERSION_SWITCH(op_version_, OP_VERSION, [&]() { // return VERSION_SWITCH(op_version_, OP_VERSION, [&]() {
FlashAttentionOpImpl<T, OP_VERSION> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_); // FlashAttentionOpImpl<T, OP_VERSION> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_);
return attention_op.get_workspace_size(); // return attention_op.get_workspace_size();
}); // });
#endif // #endif
} // }
template<typename T> // template<typename T>
void FlashAttentionOp<T>::operator()(Params& params, cudaStream_t st) const // void FlashAttentionOp<T>::operator()(Params& params, cudaStream_t st) const
{ // {
#ifdef _MSC_VER // #ifdef _MSC_VER
FlashAttentionOpImpl<T, 1> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_); // FlashAttentionOpImpl<T, 1> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_);
return attention_op(params, st); // return attention_op(params, st);
#else // #else
return VERSION_SWITCH(op_version_, OP_VERSION, [&]() { // return VERSION_SWITCH(op_version_, OP_VERSION, [&]() {
FlashAttentionOpImpl<T, OP_VERSION> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_); // FlashAttentionOpImpl<T, OP_VERSION> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_);
return attention_op(params, st); // return attention_op(params, st);
}); // });
#endif // #endif
} // }
template class FlashAttentionOp<float>; // template class FlashAttentionOp<float>;
template class FlashAttentionOp<half>; // template class FlashAttentionOp<half>;
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template class FlashAttentionOp<__nv_bfloat16>; template class FlashAttentionOp<__nv_bfloat16>;
#endif #endif
......
#include "src/turbomind/kernels/gemm_s_f16/format.h" // #include "src/turbomind/kernels/gemm_s_f16/format.h"
#include "src/turbomind/python/dlpack.h" #include "src/turbomind/python/dlpack.h"
#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h" #include "src/turbomind/triton_backend/llama/LlamaTritonModel.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp" #include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
...@@ -43,7 +43,8 @@ DLDevice getDLDevice(triton::Tensor& tensor) ...@@ -43,7 +43,8 @@ DLDevice getDLDevice(triton::Tensor& tensor)
device.device_type = DLDeviceType::kDLCUDAHost; device.device_type = DLDeviceType::kDLCUDAHost;
break; break;
case triton::MEMORY_GPU: case triton::MEMORY_GPU:
device.device_type = DLDeviceType::kDLCUDA; // device.device_type = DLDeviceType::kDLCUDA;
device.device_type = DLDeviceType::kDLROCM;
break; break;
default: default:
break; break;
...@@ -456,15 +457,15 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -456,15 +457,15 @@ PYBIND11_MODULE(_turbomind, m)
auto src_tensor = GetDLTensor(src); auto src_tensor = GetDLTensor(src);
auto dst_tensor = GetDLTensor(dst); auto dst_tensor = GetDLTensor(dst);
turbomind::transpose_qk_s4_k_m8_hf( // turbomind::transpose_qk_s4_k_m8_hf(
(uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, size_per_head, nullptr); // (uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, size_per_head, nullptr);
}); });
m.def("fuse_w1_w3_s4_k_m8", [](py::object src, py::object dst, int m, int k) { m.def("fuse_w1_w3_s4_k_m8", [](py::object src, py::object dst, int m, int k) {
auto src_tensor = GetDLTensor(src); auto src_tensor = GetDLTensor(src);
auto dst_tensor = GetDLTensor(dst); auto dst_tensor = GetDLTensor(dst);
turbomind::fuse_w1_w3_s4_k_m8((uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, nullptr); // turbomind::fuse_w1_w3_s4_k_m8((uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, nullptr);
}); });
m.def("convert_s4_k_m8", m.def("convert_s4_k_m8",
...@@ -484,16 +485,16 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -484,16 +485,16 @@ PYBIND11_MODULE(_turbomind, m)
auto s = GetDLTensor(scales); auto s = GetDLTensor(scales);
auto qz = GetDLTensor(qzeros); auto qz = GetDLTensor(qzeros);
turbomind::convert_s4_k_m8((uint32_t*)a_dst.data, // turbomind::convert_s4_k_m8((uint32_t*)a_dst.data,
(half2*)q_dst.data, // (half2*)q_dst.data,
(half*)w.data, // (half*)w.data,
(const uint32_t*)a_src.data, // (const uint32_t*)a_src.data,
(const half*)s.data, // (const half*)s.data,
(const uint32_t*)qz.data, // (const uint32_t*)qz.data,
m, // m,
k, // k,
group_size, // group_size,
nullptr); // nullptr);
}); });
m.def("dequantize_s4", [](py::object src, py::object dst) { m.def("dequantize_s4", [](py::object src, py::object dst) {
...@@ -502,6 +503,6 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -502,6 +503,6 @@ PYBIND11_MODULE(_turbomind, m)
auto src_count = std::accumulate(src_tensor.shape, src_tensor.shape + src_tensor.ndim, size_t{1}); auto src_count = std::accumulate(src_tensor.shape, src_tensor.shape + src_tensor.ndim, size_t{1});
auto dst_count = std::accumulate(dst_tensor.shape, dst_tensor.shape + dst_tensor.ndim, size_t{1}); auto dst_count = std::accumulate(dst_tensor.shape, dst_tensor.shape + dst_tensor.ndim, size_t{1});
turbomind::FT_CHECK(src_count * 8 == dst_count); turbomind::FT_CHECK(src_count * 8 == dst_count);
turbomind::dequantize_s4((uint4*)dst_tensor.data, (uint32_t*)src_tensor.data, src_count, nullptr); // turbomind::dequantize_s4((uint4*)dst_tensor.data, (uint32_t*)src_tensor.data, src_count, nullptr);
}); });
} }
...@@ -24,13 +24,17 @@ ...@@ -24,13 +24,17 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required (VERSION 3.18) #cmake_minimum_required (VERSION 3.18)
cmake_minimum_required (VERSION 3.16)
project(tritonturbomindbackend LANGUAGES C CXX) project(tritonturbomindbackend LANGUAGES C CXX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
add_library(TransformerTritonBackend STATIC transformer_triton_backend.cpp) add_library(TransformerTritonBackend STATIC transformer_triton_backend.cpp)
target_link_libraries(TransformerTritonBackend PUBLIC nccl_utils) target_link_libraries(TransformerTritonBackend PUBLIC nccl_utils)
set_property(TARGET TransformerTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET TransformerTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON)
install(TARGETS TransformerTritonBackend DESTINATION ${CMAKE_INSTALL_LIBDIR}) install(TARGETS TransformerTritonBackend DESTINATION ${CMAKE_INSTALL_LIBDIR})
add_subdirectory(llama) add_subdirectory(llama)
...@@ -70,21 +74,24 @@ include(FetchContent) ...@@ -70,21 +74,24 @@ include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
repo-common repo-common
GIT_REPOSITORY https://github.com/triton-inference-server/common.git URL ../../../3rdparty/common-r22.12
GIT_TAG ${TRITON_COMMON_REPO_TAG} #GIT_REPOSITORY https://github.com/triton-inference-server/common.git
GIT_SHALLOW ON #GIT_TAG ${TRITON_COMMON_REPO_TAG}
#GIT_SHALLOW ON
) )
FetchContent_Declare( FetchContent_Declare(
repo-core repo-core
GIT_REPOSITORY https://github.com/triton-inference-server/core.git URL ../../../3rdparty/core-r22.12
GIT_TAG ${TRITON_CORE_REPO_TAG} #GIT_REPOSITORY https://github.com/triton-inference-server/core.git
GIT_SHALLOW ON #GIT_TAG ${TRITON_CORE_REPO_TAG}
#GIT_SHALLOW ON
) )
FetchContent_Declare( FetchContent_Declare(
repo-backend repo-backend
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git URL ../../../3rdparty/backend-r22.12
GIT_TAG ${TRITON_BACKEND_REPO_TAG} #GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
GIT_SHALLOW ON #GIT_TAG ${TRITON_BACKEND_REPO_TAG}
#GIT_SHALLOW ON
) )
FetchContent_MakeAvailable(repo-common repo-core repo-backend) FetchContent_MakeAvailable(repo-common repo-core repo-backend)
...@@ -92,7 +99,8 @@ FetchContent_MakeAvailable(repo-common repo-core repo-backend) ...@@ -92,7 +99,8 @@ FetchContent_MakeAvailable(repo-common repo-core repo-backend)
# CUDA # CUDA
# #
if(${TRITON_ENABLE_GPU}) if(${TRITON_ENABLE_GPU})
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
endif() # TRITON_ENABLE_GPU endif() # TRITON_ENABLE_GPU
# #
...@@ -109,7 +117,8 @@ add_library( ...@@ -109,7 +117,8 @@ add_library(
TritonTurboMindBackend::triton-turbomind-backend ALIAS triton-turbomind-backend TritonTurboMindBackend::triton-turbomind-backend ALIAS triton-turbomind-backend
) )
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
find_package(CUDA 10.1 REQUIRED) find_package(CUDA 10.1 REQUIRED)
if (${CUDA_VERSION} GREATER_EQUAL 11.0) if (${CUDA_VERSION} GREATER_EQUAL 11.0)
message(STATUS "Add DCUDA11_MODE") message(STATUS "Add DCUDA11_MODE")
...@@ -158,10 +167,14 @@ if(${TRITON_ENABLE_GPU}) ...@@ -158,10 +167,14 @@ if(${TRITON_ENABLE_GPU})
) )
endif() # TRITON_ENABLE_GPU endif() # TRITON_ENABLE_GPU
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
set_target_properties( set_target_properties(
triton-turbomind-backend triton-turbomind-backend
PROPERTIES PROPERTIES
POSITION_INDEPENDENT_CODE ON # POSITION_INDEPENDENT_CODE ON
POSITION_INDEPENDENT_CODE OFF
OUTPUT_NAME triton_turbomind OUTPUT_NAME triton_turbomind
SKIP_BUILD_RPATH TRUE SKIP_BUILD_RPATH TRUE
BUILD_WITH_INSTALL_RPATH TRUE BUILD_WITH_INSTALL_RPATH TRUE
...@@ -194,7 +207,7 @@ target_link_libraries( ...@@ -194,7 +207,7 @@ target_link_libraries(
transformer-shared # from repo-ft transformer-shared # from repo-ft
${TRITON_PYTORCH_LDFLAGS} ${TRITON_PYTORCH_LDFLAGS}
-lcublas -lcublas
-lcublasLt # -lcublasLt
-lcudart -lcudart
-lcurand -lcurand
) )
...@@ -228,7 +241,8 @@ if(${TRITON_ENABLE_GPU}) ...@@ -228,7 +241,8 @@ if(${TRITON_ENABLE_GPU})
target_link_libraries( target_link_libraries(
triton-turbomind-backend triton-turbomind-backend
PRIVATE PRIVATE
CUDA::cudart # CUDA::cudart
cudart
) )
endif() # TRITON_ENABLE_GPU endif() # TRITON_ENABLE_GPU
......
...@@ -22,8 +22,10 @@ set(llama_triton_backend_files ...@@ -22,8 +22,10 @@ set(llama_triton_backend_files
LlamaTritonModelInstance.cc LlamaTritonModelInstance.cc
) )
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
add_library(LlamaTritonBackend STATIC ${llama_triton_backend_files}) add_library(LlamaTritonBackend STATIC ${llama_triton_backend_files})
set_property(TARGET LlamaTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET LlamaTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(LlamaTritonBackend PUBLIC TransformerTritonBackend Llama tensor memory_utils CUDA::cublasLt) #target_link_libraries(LlamaTritonBackend PUBLIC TransformerTritonBackend Llama tensor memory_utils CUDA::cublasLt)
target_link_libraries(LlamaTritonBackend PUBLIC TransformerTritonBackend Llama tensor memory_utils)
target_compile_features(LlamaTritonBackend PRIVATE cxx_std_14) target_compile_features(LlamaTritonBackend PRIVATE cxx_std_14)
...@@ -258,7 +258,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh ...@@ -258,7 +258,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
cublasLtHandle_t cublaslt_handle; cublasLtHandle_t cublaslt_handle;
cublasCreate(&cublas_handle); cublasCreate(&cublas_handle);
cublasLtCreate(&cublaslt_handle); // cublasLtCreate(&cublaslt_handle);
cublasSetStream(cublas_handle, stream); cublasSetStream(cublas_handle, stream);
std::unique_ptr<ft::cublasAlgoMap> cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); std::unique_ptr<ft::cublasAlgoMap> cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in"));
...@@ -270,7 +270,8 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh ...@@ -270,7 +270,8 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id)); ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id));
if (std::is_same<T, half>::value) { if (std::is_same<T, half>::value) {
cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); // cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F);
cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F);
} }
else if (std::is_same<T, float>::value) { else if (std::is_same<T, float>::value) {
cublas_wrapper->setFP32GemmConfig(); cublas_wrapper->setFP32GemmConfig();
......
...@@ -14,98 +14,104 @@ ...@@ -14,98 +14,104 @@
cmake_minimum_required(VERSION 3.8) cmake_minimum_required(VERSION 3.8)
find_package(CUDAToolkit REQUIRED) #find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
add_subdirectory(gemm_test) add_subdirectory(gemm_test)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
add_library(cuda_utils STATIC cuda_utils.cc) add_library(cuda_utils STATIC cuda_utils.cc)
set_property(TARGET cuda_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cuda_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cuda_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cuda_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(cuda_utils PUBLIC CUDA::cudart) target_link_libraries(cuda_utils PUBLIC cudart)
add_library(logger STATIC logger.cc) add_library(logger STATIC logger.cc)
set_property(TARGET logger PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET logger PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET logger PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET logger PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(logger PUBLIC CUDA::cudart) target_link_libraries(logger PUBLIC cudart)
add_library(cublasAlgoMap STATIC cublasAlgoMap.cc) add_library(cublasAlgoMap STATIC cublasAlgoMap.cc)
set_property(TARGET cublasAlgoMap PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cublasAlgoMap PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cublasAlgoMap PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cublasAlgoMap PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(cublasAlgoMap PUBLIC CUDA::cublas CUDA::cudart CUDA::curand cuda_utils logger) target_link_libraries(cublasAlgoMap PUBLIC cublas cudart curand cuda_utils logger)
add_library(cublasMMWrapper STATIC cublasMMWrapper.cc) add_library(cublasMMWrapper STATIC cublasMMWrapper.cc)
set_property(TARGET cublasMMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cublasMMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cublasMMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cublasMMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(cublasMMWrapper PUBLIC CUDA::cublas CUDA::cudart CUDA::curand cublasAlgoMap cuda_utils logger) target_link_libraries(cublasMMWrapper PUBLIC cublas cudart curand cublasAlgoMap cuda_utils logger)
if (SPARSITY_SUPPORT) if (SPARSITY_SUPPORT)
target_link_libraries(cublasMMWrapper PUBLIC CUDA::cusparse -lcusparseLt) target_link_libraries(cublasMMWrapper PUBLIC cusparse -lcusparseLt)
endif() endif()
add_library(word_list STATIC word_list.cc) add_library(word_list STATIC word_list.cc)
set_property(TARGET word_list PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET word_list PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET word_list PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET word_list PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(nvtx_utils STATIC nvtx_utils.cc) add_library(nvtx_utils STATIC nvtx_utils.cc)
set_property(TARGET nvtx_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET nvtx_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET nvtx_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET nvtx_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
if(${CMAKE_VERSION} VERSION_LESS "3.25") if(${CMAKE_VERSION} VERSION_LESS "3.25")
target_link_libraries(nvtx_utils PUBLIC CUDA::nvToolsExt -ldl) # target_link_libraries(nvtx_utils PUBLIC nvToolsExt -ldl)
else() else()
target_link_libraries(nvtx_utils PUBLIC CUDA::nvtx3 -ldl) # target_link_libraries(nvtx_utils PUBLIC nvtx3 -ldl)
endif() endif()
add_library(memory_utils STATIC memory_utils.cu) add_library(memory_utils STATIC memory_utils.cu)
set_property(TARGET memory_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET memory_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET memory_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET memory_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(memory_utils PUBLIC cuda_utils logger tensor) target_link_libraries(memory_utils PUBLIC cuda_utils logger tensor)
add_library(mpi_utils STATIC mpi_utils.cc) add_library(mpi_utils STATIC mpi_utils.cc)
set_property(TARGET mpi_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET mpi_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET mpi_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET mpi_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
if (BUILD_MULTI_GPU) if (BUILD_MULTI_GPU)
target_link_libraries(mpi_utils PUBLIC mpi logger) target_link_libraries(mpi_utils PUBLIC mpi logger)
endif() endif()
add_library(nccl_utils STATIC nccl_utils.cc) add_library(nccl_utils STATIC nccl_utils.cc)
set_property(TARGET nccl_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET nccl_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET nccl_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET nccl_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
if (BUILD_MULTI_GPU) if (BUILD_MULTI_GPU)
target_link_libraries(nccl_utils PUBLIC ${NCCL_LIBRARIES} logger) target_link_libraries(nccl_utils PUBLIC ${NCCL_LIBRARIES} logger)
endif() endif()
add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc) # add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc)
set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cublasINT8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cublasINT8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(cublasINT8MMWrapper PUBLIC CUDA::cublasLt CUDA::cudart CUDA::curand cublasAlgoMap cublasMMWrapper cuda_utils logger) #target_link_libraries(cublasINT8MMWrapper PUBLIC cublasLt cudart curand cublasAlgoMap cublasMMWrapper cuda_utils logger)
if(ENABLE_FP8) if(ENABLE_FP8)
add_library(cublasFP8MMWrapper STATIC cublasFP8MMWrapper.cu) add_library(cublasFP8MMWrapper STATIC cublasFP8MMWrapper.cu)
set_property(TARGET cublasFP8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cublasFP8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cublasFP8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cublasFP8MMWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(cublasFP8MMWrapper PUBLIC CUDA::cublasLt CUDA::cudart CUDA::curand #target_link_libraries(cublasFP8MMWrapper PUBLIC cublasLt cudart curand
target_link_libraries(cublasFP8MMWrapper PUBLIC cudart curand
cublasAlgoMap cublasMMWrapper nvtx_utils fp8_qgmma_1x1_utils) cublasAlgoMap cublasMMWrapper nvtx_utils fp8_qgmma_1x1_utils)
endif() endif()
add_library(custom_ar_comm STATIC custom_ar_comm.cc) add_library(custom_ar_comm STATIC custom_ar_comm.cc)
set_property(TARGET custom_ar_comm PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET custom_ar_comm PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET custom_ar_comm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET custom_ar_comm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(custom_ar_comm PUBLIC custom_ar_kernels memory_utils cuda_utils logger) target_link_libraries(custom_ar_comm PUBLIC custom_ar_kernels memory_utils cuda_utils logger)
add_library(gemm STATIC gemm.cc) add_library(gemm STATIC gemm.cc)
set_property(TARGET gemm PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET gemm PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET gemm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET gemm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(gemm PUBLIC target_link_libraries(gemm PUBLIC
CUDA::cublas CUDA::cublasLt CUDA::cudart CUDA::curand # cublas cublasLt cudart curand
cublas cudart curand
cublasAlgoMap memory_utils cuda_utils logger) cublasAlgoMap memory_utils cuda_utils logger)
if (SPARSITY_SUPPORT) if (SPARSITY_SUPPORT)
target_link_libraries(gemm PUBLIC CUDA::cusparse -lcusparseLt) target_link_libraries(gemm PUBLIC cusparse -lcusparseLt)
endif() endif()
add_library(cuda_fp8_utils STATIC cuda_fp8_utils.cu) # add_library(cuda_fp8_utils STATIC cuda_fp8_utils.cu)
set_property(TARGET cuda_fp8_utils PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET cuda_fp8_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cuda_fp8_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET cuda_fp8_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(tensor STATIC Tensor.cc) add_library(tensor STATIC Tensor.cc)
set_property(TARGET tensor PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET tensor PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET tensor PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET tensor PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(tensor PUBLIC cuda_utils logger) target_link_libraries(tensor PUBLIC cuda_utils logger)
...@@ -44,9 +44,9 @@ ...@@ -44,9 +44,9 @@
#include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/logger.h"
#if defined(CUDART_VERSION) && CUDART_VERSION < 11020 // #if defined(CUDART_VERSION) && CUDART_VERSION < 11020
#define CUDA_MEMORY_POOL_DISABLED #define CUDA_MEMORY_POOL_DISABLED
#endif // #endif
namespace turbomind { namespace turbomind {
...@@ -158,36 +158,36 @@ public: ...@@ -158,36 +158,36 @@ public:
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
pointer_mapping_ = new std::unordered_map<void*, std::pair<size_t, MemoryType>>(); pointer_mapping_ = new std::unordered_map<void*, std::pair<size_t, MemoryType>>();
#if defined(CUDA_MEMORY_POOL_DISABLED) // #if defined(CUDA_MEMORY_POOL_DISABLED)
TM_LOG_WARNING( // TM_LOG_WARNING(
"Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free." // "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free."
"Note this may lead to hang with NCCL kernels launched in parallel; if so, try NCCL_LAUNCH_MODE=GROUP"); // "Note this may lead to hang with NCCL kernels launched in parallel; if so, try NCCL_LAUNCH_MODE=GROUP");
#else // #else
int device_count = 1; // int device_count = 1;
check_cuda_error(cudaGetDeviceCount(&device_count)); // check_cuda_error(cudaGetDeviceCount(&device_count));
cudaMemPool_t mempool; // cudaMemPool_t mempool;
check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id)); // check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id));
cudaMemAccessDesc desc = {}; // cudaMemAccessDesc desc = {};
int peer_access_available = 0; // int peer_access_available = 0;
for (int i = 0; i < device_count; i++) { // for (int i = 0; i < device_count; i++) {
if (i == device_id) { // if (i == device_id) {
continue; // continue;
} // }
check_cuda_error(cudaDeviceCanAccessPeer(&peer_access_available, device_id, i)); // check_cuda_error(cudaDeviceCanAccessPeer(&peer_access_available, device_id, i));
if (!peer_access_available) { // if (!peer_access_available) {
TM_LOG_WARNING("Device " + std::to_string(device_id) + " peer access Device " + std::to_string(i) // TM_LOG_WARNING("Device " + std::to_string(device_id) + " peer access Device " + std::to_string(i)
+ " is not available."); // + " is not available.");
continue; // continue;
} // }
desc.location.type = cudaMemLocationTypeDevice; // desc.location.type = cudaMemLocationTypeDevice;
desc.location.id = i; // desc.location.id = i;
desc.flags = cudaMemAccessFlagsProtReadWrite; // desc.flags = cudaMemAccessFlagsProtReadWrite;
check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1)); // check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1));
} // }
// set memory pool threshold to avoid shrinking the pool // // set memory pool threshold to avoid shrinking the pool
uint64_t setVal = UINT64_MAX; // uint64_t setVal = UINT64_MAX;
check_cuda_error(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &setVal)); // check_cuda_error(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &setVal));
#endif // #endif
} }
virtual ~Allocator() virtual ~Allocator()
......
...@@ -139,7 +139,8 @@ cublasAlgoMap::getAlgo(const int batch_count, const int m, const int n, const in ...@@ -139,7 +139,8 @@ cublasAlgoMap::getAlgo(const int batch_count, const int m, const int n, const in
else { else {
cublasLtMatmulAlgo_info tmp_algo; cublasLtMatmulAlgo_info tmp_algo;
tmp_algo.algoId = tmp_algo.algoId =
static_cast<int>(data_type == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP); // static_cast<int>(data_type == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP);
static_cast<int>(data_type == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT);
tmp_algo.customOption = -1; tmp_algo.customOption = -1;
tmp_algo.tile = -1; tmp_algo.tile = -1;
tmp_algo.splitK_val = -1; tmp_algo.splitK_val = -1;
......
...@@ -237,10 +237,10 @@ void cublasFP8MMWrapper::Gemm(__nv_bfloat16* res, ...@@ -237,10 +237,10 @@ void cublasFP8MMWrapper::Gemm(__nv_bfloat16* res,
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
#endif // #endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
...@@ -462,10 +462,10 @@ void cublasFP8MMWrapper::Gemm(__nv_fp8_e4m3* res, ...@@ -462,10 +462,10 @@ void cublasFP8MMWrapper::Gemm(__nv_fp8_e4m3* res,
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages));
#endif // #endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
......
...@@ -94,11 +94,11 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -94,11 +94,11 @@ void cublasINT8MMWrapper::Gemm(int* res,
{ {
mu_->lock(); mu_->lock();
cublasOperation_t opTranspose = CUBLAS_OP_T; cublasOperation_t opTranspose = CUBLAS_OP_T;
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; // cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
#else // #else
cudaDataType_t computeType = CUDA_R_32I; cudaDataType_t computeType = CUDA_R_32I;
#endif // #endif
cublasLtMatmulDesc_t matmulDesc; cublasLtMatmulDesc_t matmulDesc;
cublasLtMatrixLayout_t AtransformDesc = NULL; cublasLtMatrixLayout_t AtransformDesc = NULL;
cublasLtMatrixLayout_t BtransformDesc = NULL; cublasLtMatrixLayout_t BtransformDesc = NULL;
...@@ -106,16 +106,16 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -106,16 +106,16 @@ void cublasINT8MMWrapper::Gemm(int* res,
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_matrixB; cublasLtOrder_t order_matrixB;
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
if (use_ORDER_COL32_2R_4R4_) { // if (use_ORDER_COL32_2R_4R4_) {
order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; // order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4;
} // }
else { // else {
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; // order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
} // }
#else // #else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif // #endif
int ldaTransform = 32 * m; int ldaTransform = 32 * m;
int ldbTransform; int ldbTransform;
...@@ -128,11 +128,11 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -128,11 +128,11 @@ void cublasINT8MMWrapper::Gemm(int* res,
int ldcTransform = 32 * m; int ldcTransform = 32 * m;
// create matmulDesc // create matmulDesc
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I); // cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I);
#else // #else
cublasLtMatmulDescCreate(&matmulDesc, computeType); cublasLtMatmulDescCreate(&matmulDesc, computeType);
#endif // #endif
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t));
cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform);
cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)); cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32));
...@@ -187,10 +187,10 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -187,10 +187,10 @@ void cublasINT8MMWrapper::Gemm(int* res,
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), sizeof(tmp_info.swizzle)); &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), sizeof(tmp_info.swizzle));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(tmp_info.reductionScheme), sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(tmp_info.reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(tmp_info.stages), sizeof(tmp_info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(tmp_info.stages), sizeof(tmp_info.stages));
#endif // #endif
} }
else { else {
findAlgo = 1; findAlgo = 1;
...@@ -215,16 +215,16 @@ void cublasINT8MMWrapper::Gemm(int* res, ...@@ -215,16 +215,16 @@ void cublasINT8MMWrapper::Gemm(int* res,
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
int stages; // int stages;
if (use_ORDER_COL32_2R_4R4_) { // if (use_ORDER_COL32_2R_4R4_) {
stages = 15; // stages = 15;
} // }
else { // else {
stages = 13; // stages = 13;
} // }
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); // cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
#endif // #endif
} }
cublasLtMatmul(cublaslt_handle_, cublasLtMatmul(cublaslt_handle_,
...@@ -273,11 +273,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -273,11 +273,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
// int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE
// cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; // cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
cudaDataType_t scaleType = CUDA_R_32F; cudaDataType_t scaleType = CUDA_R_32F;
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; // cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
#else // #else
cudaDataType_t computeType = CUDA_R_32I; cudaDataType_t computeType = CUDA_R_32I;
#endif // #endif
cublasLtMatmulDesc_t matmulDesc; cublasLtMatmulDesc_t matmulDesc;
cublasLtMatrixLayout_t AtransformDesc = NULL; cublasLtMatrixLayout_t AtransformDesc = NULL;
cublasLtMatrixLayout_t BtransformDesc = NULL; cublasLtMatrixLayout_t BtransformDesc = NULL;
...@@ -285,16 +285,16 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -285,16 +285,16 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_matrixB; cublasLtOrder_t order_matrixB;
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
if (use_ORDER_COL32_2R_4R4_) { // if (use_ORDER_COL32_2R_4R4_) {
order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; // order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4;
} // }
else { // else {
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; // order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
} // }
#else // #else
order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C;
#endif // #endif
int ldaTransform = 32 * m; int ldaTransform = 32 * m;
...@@ -309,11 +309,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -309,11 +309,11 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
int ldcTransform = 32 * m; int ldcTransform = 32 * m;
// create matmulDesc // create matmulDesc
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); // cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType);
#else // #else
cublasLtMatmulDescCreate(&matmulDesc, computeType); cublasLtMatmulDescCreate(&matmulDesc, computeType);
#endif // #endif
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scaleType, sizeof(scaleType)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scaleType, sizeof(scaleType));
// cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, // cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode,
...@@ -367,10 +367,10 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -367,10 +367,10 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), sizeof(tmp_info.swizzle)); &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), sizeof(tmp_info.swizzle));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(tmp_info.reductionScheme), sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(tmp_info.reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
cublasLtMatmulAlgoConfigSetAttribute( // cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(tmp_info.stages), sizeof(tmp_info.stages)); // &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(tmp_info.stages), sizeof(tmp_info.stages));
#endif // #endif
} }
else { else {
findAlgo = 1; findAlgo = 1;
...@@ -395,16 +395,16 @@ void cublasINT8MMWrapper::Gemm(int8_t* res, ...@@ -395,16 +395,16 @@ void cublasINT8MMWrapper::Gemm(int8_t* res,
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle));
cublasLtMatmulAlgoConfigSetAttribute( cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int)); &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int));
#if (CUDART_VERSION >= 11000) // #if (CUDART_VERSION >= 11000)
int stages; // int stages;
if (use_ORDER_COL32_2R_4R4_) { // if (use_ORDER_COL32_2R_4R4_) {
stages = 15; // stages = 15;
} // }
else { // else {
stages = 13; // stages = 13;
} // }
cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); // cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
#endif // #endif
} }
float beta = 0.0f; float beta = 0.0f;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment