Commit 8efb9210 authored by zhouxiang's avatar zhouxiang
Browse files

Merge branch 'dtk24.04-v0.2.6_awq' into 'dtk24.04-v0.2.6'

合入dtk2404-v0.2.6版本int4量化推理部分

See merge request dcutoolkit/deeplearing/lmdeploy!2
parents 2326380c 175eaedb
......@@ -35,6 +35,7 @@ public:
size_t inter_size,
WeightType weight_type,
int group_size,
int w4_weight_layout,
bool attn_bias,
size_t tensor_para_size,
size_t tensor_para_rank);
......
......@@ -63,6 +63,7 @@ struct LlamaDenseWeight {
T* bias;
T* scales_and_zeros;
int group_size;
int w4_weight_layout;
};
template<typename T>
......
......@@ -29,7 +29,7 @@ namespace turbomind {
template<typename T>
void LlamaFfnLayer<T>::allocateBuffer(size_t token_num)
{
inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, sizeof(T) * token_num * inter_size_, false);
inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, 2*sizeof(T) * token_num * inter_size_, false);
gating_buf_ = (T*)allocator_->reMalloc(gating_buf_, sizeof(T) * token_num * inter_size_, false);
is_allocate_buffer_ = true;
}
......@@ -90,8 +90,11 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
if (weights->fused_gating_intermediate.kernel) {
NvtxScope scope("fused_silu_ffn");
linear_.forward(
gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn);
// linear_.forward(
// gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn);
linear_.forward_ffn(
gating_buf_,inter_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn);
}
else {
{ // w1(x)
......
......@@ -2,7 +2,7 @@
#pragma once
// #include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h"
#include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
......@@ -41,7 +41,27 @@ public:
FT_CHECK(0);
}
}
void forward_ffn(T* output_data,T* output_tmp, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type = kGemm)
{
switch (weight.type) {
case WeightType::kFP16:
case WeightType::kFP32:
case WeightType::kBF16:
forwardFp(output_data, input_data, batch_size, weight, type);
break;
case WeightType::kINT4:
{
if (type == kFusedSiluFfn)
forwardInt4_ffn(output_data, output_tmp,input_data, batch_size, weight, type);
else
forwardInt4(output_data, input_data, batch_size, weight, type);
break;
}
default:
FT_CHECK(0);
}
}
private:
void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
{
......@@ -62,23 +82,184 @@ private:
void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
{
// if constexpr (std::is_same_v<T, half>) {
// gemm_s4_f16_.Run(output_data,
// (const uint*)weight.kernel,
// input_data,
// (const half2*)weight.scales_and_zeros,
// weight.output_dims,
// batch_size,
// weight.input_dims,
// weight.group_size,
// type == kFusedSiluFfn ? GemmS4F16::kFusedSiluFfn : GemmS4F16::kGemm,
// -1,
// stream_);
// sync_check_cuda_error();
// }
// else {
if constexpr (std::is_same_v<T, half>) {
if(weight.w4_weight_layout==0) //普通NN模式 rocblas
{
//检查DQweight的空间是否足够
if(batch_size*weight.output_dims>M_max*N_max)
{
FT_CHECK_WITH_INFO(0, "error! batch_size>N_max ||weight.output_dims>N_max");
}
dequant_w4_gemm(stream_, reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims,weight.output_dims,weight.group_size);
cublas_wrapper_->Gemm(CUBLAS_OP_N,
CUBLAS_OP_N,
weight.output_dims,//m
batch_size,//n
weight.input_dims,//k
(const T*) cublas_wrapper_->deweight_workspace_, //[]
weight.output_dims,//m
input_data,
weight.input_dims, //k
output_data,
weight.output_dims); //m
}
else if(weight.w4_weight_layout==1)//TN模式 padding rocblas
{
//检查DQweight的空间是否足够
if(batch_size*weight.output_dims>M_max*N_max)
{
FT_CHECK_WITH_INFO(0, "error! batch_size>N_max ||weight.output_dims>N_max");
}
//检查xpad空间是否足够
if(weight.input_dims%4096==0) //需要进行pad
{
int pad_group_count=2;
input_padding(stream_,reinterpret_cast<half*>(cublas_wrapper_->xpading_workspace_),(const T*)input_data,batch_size,weight.input_dims,weight.group_size,pad_group_count);
dequant_w4_gemm_colmajor(stream_,reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims+pad_group_count*weight.group_size ,weight.output_dims,weight.group_size);
cublas_wrapper_->Gemm(CUBLAS_OP_T,
CUBLAS_OP_N,
weight.output_dims,//m
batch_size,//n
weight.input_dims+pad_group_count*weight.group_size,//k
(const T*) reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_), //[]
weight.input_dims+pad_group_count*weight.group_size, //k
(const T*) cublas_wrapper_->xpading_workspace_,
weight.input_dims+pad_group_count*weight.group_size, //k
output_data,
weight.output_dims); //m
}
else //不需要进行pad
{
dequant_w4_gemm_colmajor(stream_,reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims,weight.output_dims,weight.group_size);
cublas_wrapper_->Gemm(CUBLAS_OP_T,
CUBLAS_OP_N,
weight.output_dims,//m
batch_size,//n
weight.input_dims,//k
(const T*) reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_), //[]
weight.input_dims, //k
input_data,
weight.input_dims, //k
output_data,
weight.output_dims); //m
}
}
else if(weight.w4_weight_layout==2) //TN 模式padding ck
{
//检查ck workspace 的空间是否足够
if(weight.input_dims%4096==0)
{
int pad_groupcount=2;
run_weight_only_gemm(reinterpret_cast<const void*>(input_data), reinterpret_cast<const void*>(weight.kernel), reinterpret_cast<const void*>(weight.scales_and_zeros), reinterpret_cast<void*> (output_data), batch_size, weight.output_dims, (weight.input_dims), (weight.input_dims),(weight.input_dims), (weight.input_dims+pad_groupcount*weight.group_size), weight.output_dims, weight.group_size,reinterpret_cast<void*>(cublas_wrapper_->ck_workspace_),CK_WORKSPACE_SIZE,(hipStream_t)stream_);
}
// A B0 B1 C M N K strideA strideB strideBpad strideC group_size
else{
run_weight_only_gemm(reinterpret_cast<const void*>(input_data), reinterpret_cast<const void*>(weight.kernel), reinterpret_cast<const void*>(weight.scales_and_zeros), reinterpret_cast<void*> (output_data), batch_size, weight.output_dims, (weight.input_dims), (weight.input_dims),(weight.input_dims), (weight.input_dims), weight.output_dims, weight.group_size,reinterpret_cast<void*>(cublas_wrapper_->ck_workspace_),CK_WORKSPACE_SIZE,(hipStream_t)stream_);
}
}
sync_check_cuda_error();
}
else {
FT_CHECK_WITH_INFO(0, "Not implemented");
}
}
void forwardInt4_ffn(T* output_data,T* output_tmp, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
{
if constexpr (std::is_same_v<T, half>) {
if(weight.w4_weight_layout==0) //普通NN模式 rocblas
{
//检查DQweight的空间是否足够
if(batch_size*weight.output_dims>M_max*N_max)
{
FT_CHECK_WITH_INFO(0, "error! batch_size>N_max ||weight.output_dims>N_max");
}
dequant_w4_gemm(stream_, reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims,weight.output_dims,weight.group_size);
cublas_wrapper_->Gemm(CUBLAS_OP_N,
CUBLAS_OP_N,
weight.output_dims,//m
batch_size,//n
weight.input_dims,//k
(const T*) cublas_wrapper_->deweight_workspace_, //[]
weight.output_dims,//m
input_data,
weight.input_dims, //k
output_tmp,
weight.output_dims); //m
}
else if(weight.w4_weight_layout==1)//TN模式 padding rocblas
{
//检查DQweight的空间是否足够
if(batch_size*weight.output_dims>M_max*N_max)
{
FT_CHECK_WITH_INFO(0, "error! batch_size>N_max ||weight.output_dims>N_max");
}
//检查xpad空间是否足够
if(weight.input_dims%4096==0) //需要进行pad
{
int pad_group_count=2;
input_padding<T>(stream_,reinterpret_cast<half*>(cublas_wrapper_->xpading_workspace_),(const T*)input_data,batch_size,weight.input_dims,weight.group_size,pad_group_count);
dequant_w4_gemm_colmajor(stream_,reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims+pad_group_count*weight.group_size,weight.output_dims,weight.group_size);
cublas_wrapper_->Gemm(CUBLAS_OP_T,
CUBLAS_OP_N,
weight.output_dims,//m
batch_size,//n
weight.input_dims+pad_group_count*weight.group_size,//k
(const T*) reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_), //[]
weight.input_dims+pad_group_count*weight.group_size, //k
(const T*) cublas_wrapper_->xpading_workspace_,
weight.input_dims+pad_group_count*weight.group_size, //k
output_tmp,
weight.output_dims); //m
}
else //不需要进行pad
{
dequant_w4_gemm_colmajor(stream_,reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims,weight.output_dims,weight.group_size);
cublas_wrapper_->Gemm(CUBLAS_OP_T,
CUBLAS_OP_N,
weight.output_dims,//m
batch_size,//n
weight.input_dims,//k
(const T*) reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_), //[]
weight.input_dims, //k
input_data,
weight.input_dims, //k
output_tmp,
weight.output_dims); //m
}
}
else if(weight.w4_weight_layout==2) //TN 模式padding ck
{
//检查ck workspace 的空间是否足够
if(weight.input_dims%4096==0)
{
int pad_groupcount=2;
run_weight_only_gemm(reinterpret_cast<const void*>(input_data), reinterpret_cast<const void*>(weight.kernel), reinterpret_cast<const void*>(weight.scales_and_zeros), reinterpret_cast<void*> (output_tmp), batch_size, weight.output_dims, (weight.input_dims), (weight.input_dims),(weight.input_dims), (weight.input_dims+pad_groupcount*weight.group_size), weight.output_dims, weight.group_size,reinterpret_cast<void*>(cublas_wrapper_->ck_workspace_),CK_WORKSPACE_SIZE,(hipStream_t)stream_);
}
// A B0 B1 C M N K strideA strideB strideBpad strideC group_size
else{
run_weight_only_gemm(reinterpret_cast<const void*>(input_data), reinterpret_cast<const void*>(weight.kernel), reinterpret_cast<const void*>(weight.scales_and_zeros), reinterpret_cast<void*> (output_tmp), batch_size, weight.output_dims, (weight.input_dims), (weight.input_dims),(weight.input_dims), (weight.input_dims), weight.output_dims, weight.group_size,reinterpret_cast<void*>(cublas_wrapper_->ck_workspace_),CK_WORKSPACE_SIZE,(hipStream_t)stream_);
}
}
addFusedSiluActivation(stream_,output_data,output_tmp,batch_size,weight.output_dims,1);
sync_check_cuda_error();
}
else {
FT_CHECK_WITH_INFO(0, "Not implemented");
// }
}
}
private:
......
......@@ -32,6 +32,7 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
bool attn_bias,
WeightType weight_type,
int group_size,
int w4_weight_layout,
size_t tensor_para_size,
size_t tensor_para_rank):
hidden_units_(head_num * size_per_head),
......@@ -55,11 +56,28 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
inter_size_,
weight_type_,
group_size,
w4_weight_layout,
attn_bias,
tensor_para_size_,
tensor_para_rank_));
}
// 这同样会将环境变量 MY_VARIABLE 设置为 my_value,并且最后一个参数 1 表示如果变量已经存在,是否覆盖。如果为 1,则会覆盖原有的值;如果为 0,则不会覆盖,保持原有的值不变。
char* env_name ="LMDEPLOY_WEIGHTLAYOUT_SWITCH";
if(weight_type_ ==WeightType::kINT4){
std::string str_w4_weight_layout = std::to_string(w4_weight_layout);
const char* env_value = str_w4_weight_layout.c_str();
setenv(env_name,env_value , 1);
//printf("set LMDEPLOY_WEIGHTLAYOUT_SWITCH env: %d \n",w4_weight_layout);
}
else
{
std::string str_w4_weight_layout = std::to_string(-1);
const char* env_value = str_w4_weight_layout.c_str();
setenv(env_name,env_value , 1);
//printf("set LMDEPLOY_WEIGHTLAYOUT_SWITCH env: %d \n",-1);
}
mallocWeights();
}
......
......@@ -37,6 +37,7 @@ struct LlamaWeight {
bool attn_bias,
WeightType weight_type,
int group_size,
int w4_weight_layout,
size_t tensor_para_size,
size_t tensor_para_rank);
......
// #include "src/turbomind/kernels/gemm_s_f16/format.h"
#include "src/turbomind/kernels/gemm_s_f16/format.h"
#include "src/turbomind/python/dlpack.h"
#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
......@@ -457,15 +457,15 @@ PYBIND11_MODULE(_turbomind, m)
auto src_tensor = GetDLTensor(src);
auto dst_tensor = GetDLTensor(dst);
// turbomind::transpose_qk_s4_k_m8_hf(
// (uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, size_per_head, nullptr);
turbomind::transpose_qk_s4_k_m8_hf(
(uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, size_per_head, nullptr);
});
m.def("fuse_w1_w3_s4_k_m8", [](py::object src, py::object dst, int m, int k) {
auto src_tensor = GetDLTensor(src);
auto dst_tensor = GetDLTensor(dst);
// turbomind::fuse_w1_w3_s4_k_m8((uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, nullptr);
turbomind::fuse_w1_w3_s4_k_m8((uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, nullptr);
});
m.def("convert_s4_k_m8",
......@@ -485,16 +485,45 @@ PYBIND11_MODULE(_turbomind, m)
auto s = GetDLTensor(scales);
auto qz = GetDLTensor(qzeros);
// turbomind::convert_s4_k_m8((uint32_t*)a_dst.data,
// (half2*)q_dst.data,
// (half*)w.data,
// (const uint32_t*)a_src.data,
// (const half*)s.data,
// (const uint32_t*)qz.data,
// m,
// k,
// group_size,
// nullptr);
turbomind::convert_s4_k_m8((uint32_t*)a_dst.data,
(half2*)q_dst.data,
(half*)w.data,
(const uint32_t*)a_src.data,
(const half*)s.data,
(const uint32_t*)qz.data,
m,
k,
group_size,
nullptr);
});
m.def("convert_s4_k_m8_",
[](py::object A_dst,
py::object Q_dst,
py::object ws,
py::object A_src,
py::object scales,
py::object qzeros,
int m,
int k,
int group_size) {
auto a_dst = GetDLTensor(A_dst);
auto q_dst = GetDLTensor(Q_dst);
auto w = GetDLTensor(ws);
auto a_src = GetDLTensor(A_src);
auto s = GetDLTensor(scales);
auto qz = GetDLTensor(qzeros);
turbomind::convert_s4_k_m8_((uint32_t*)a_dst.data,
(half2*)q_dst.data,
(half*)w.data,
(const uint32_t*)a_src.data,
(const half*)s.data,
(const uint32_t*)qz.data,
m,
k,
group_size,
nullptr);
});
m.def("dequantize_s4", [](py::object src, py::object dst) {
......
......@@ -186,6 +186,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
attn_bias_ = reader.GetInteger("llama", "attn_bias", 0);
quant_policy_ = reader.GetInteger("llama", "quant_policy", 0);
group_size_ = reader.GetInteger("llama", "group_size", 0);
w4_weight_layout_ = reader.GetInteger("llama", "w4_weight_layout", 2);
// rotary embedding parameters
attn_params_.rotary_embedding_dim = reader.GetInteger("llama", "rotary_embedding");
......@@ -381,6 +382,7 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank)
attn_bias_,
weight_type_,
group_size_,
w4_weight_layout_,
tensor_para_size_,
tensor_para_rank);
// model inited with model_dir
......
......@@ -101,7 +101,8 @@ private:
bool attn_bias_;
int quant_policy_;
int group_size_;
int w4_weight_layout_;
// shared weights for each device
std::vector<std::shared_ptr<ft::LlamaWeight<T>>> shared_weights_;
......
......@@ -36,9 +36,39 @@ cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle,
mu_(mu),
allocator_(allocator)
{
//申请内存前读取环境变量确定weight_alyout格式
//m_weightlayout_switch = 0 -->nn 形式的rocblas
//m_weightlayout_switch = 1 -->tn pad 形式的rocblas
//m_weightlayout_switch = 2 -->tn pad 形式的ck
const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
if (env_weightlayout_str != nullptr) {
m_weightlayout_switch = std::stoi(env_weightlayout_str);
}
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (allocator_ != nullptr) {
cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false);
//当采用rocblas的时候或者采用ck并开启dump功能的时候需要申请反量化模块
if(m_weightlayout_switch ==1||m_weightlayout_switch==0)
{
//需要反量化后weight临时存储的空间
printf("alloc space for deqeight\n");
deweight_workspace_=allocator_->reMalloc(deweight_workspace_, DEQ_WORKSPACE_SIZE, false);
if(m_weightlayout_switch ==1)
{
printf("alloc space for xpading\n");
printf("weight layout is tn pading rocblas\n");
xpading_workspace_=allocator_->reMalloc(xpading_workspace_, XPAD_WORKSPACE_SIZE, false);
}
}
else if(m_weightlayout_switch ==2)
{
printf("alloc space for ck workspace\n");
printf("weight layout is tn pading ck\n");
ck_workspace_ = allocator_->reMalloc(ck_workspace_, CK_WORKSPACE_SIZE, false);
}
}
// hgemm-switch 0:fp32r,1:fp16r-fp32r,2:fp16r ----xzhou 20240427
m_ihgemm_switch = 0;
......@@ -70,9 +100,34 @@ cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle,
mu_(mu),
allocator_(allocator)
{
const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
if (env_weightlayout_str != nullptr) {
m_weightlayout_switch = std::stoi(env_weightlayout_str);
}
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (allocator_ != nullptr) {
cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false);
//当采用rocblas的时候或者采用ck并开启dump功能的时候需要申请反量化模块
if(m_weightlayout_switch ==1||m_weightlayout_switch==0)
{
//需要反量化后weight临时存储的空间
printf("alloc space for deqeight\n");
deweight_workspace_=allocator_->reMalloc(deweight_workspace_, DEQ_WORKSPACE_SIZE, false);
if(m_weightlayout_switch ==1)
{
printf("alloc space for xpading\n");
printf("weight layout is tn pading rocblas\n");
xpading_workspace_=allocator_->reMalloc(xpading_workspace_, XPAD_WORKSPACE_SIZE, false);
}
}
else if(m_weightlayout_switch ==2)
{
printf("alloc space for ck workspace\n");
printf("weight layout is tn pading ck\n");
ck_workspace_ = allocator_->reMalloc(ck_workspace_, CK_WORKSPACE_SIZE, false);
}
}
}
#endif
......@@ -83,6 +138,22 @@ cublasMMWrapper::~cublasMMWrapper()
mu_ = nullptr;
if (allocator_ != nullptr) {
allocator_->free((void**)(&cublas_workspace_));
if(m_weightlayout_switch ==1||m_weightlayout_switch==0)
{
//需要反量化后weight临时存储的空间
printf("free space for deqeight\n");
allocator_->free((void**)(&deweight_workspace_));
if(m_weightlayout_switch ==1)
{
printf("free space for xpading\n");
allocator_->free((void**)(&xpading_workspace_));
}
}
else if(m_weightlayout_switch ==2)
{
printf("free space for ck workspace\n");
allocator_->free((void**)(&ck_workspace_));
}
allocator_ = nullptr;
}
}
......@@ -98,9 +169,34 @@ cublasMMWrapper::cublasMMWrapper(const cublasMMWrapper& wrapper):
mu_(wrapper.mu_),
allocator_(wrapper.allocator_)
{
const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
if (env_weightlayout_str != nullptr) {
m_weightlayout_switch = std::stoi(env_weightlayout_str);
}
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (allocator_ != nullptr) {
cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false);
//当采用rocblas的时候或者采用ck并开启dump功能的时候需要申请反量化模块
if(m_weightlayout_switch ==1||m_weightlayout_switch==0)
{
//需要反量化后weight临时存储的空间
printf("alloc space for deqeight\n");
deweight_workspace_=allocator_->reMalloc(deweight_workspace_, DEQ_WORKSPACE_SIZE, false);
if(m_weightlayout_switch ==1)
{
printf("alloc space for xpading\n");
printf("weight layout is tn pading rocblas\n");
xpading_workspace_=allocator_->reMalloc(xpading_workspace_, XPAD_WORKSPACE_SIZE, false);
}
}
else if(m_weightlayout_switch ==2)
{
printf("alloc space for ck workspace\n");
printf("weight layout is tn pading ck\n");
ck_workspace_ = allocator_->reMalloc(ck_workspace_, CK_WORKSPACE_SIZE, false);
}
}
}
......
......@@ -70,6 +70,12 @@ protected:
const bool per_column_scaling);
public:
void* ck_workspace_ = nullptr;
//x的pad
void* xpading_workspace_ = nullptr;
void* deweight_workspace_ = nullptr;
int m_weightlayout_switch = 0;
cublasMMWrapper(cublasHandle_t cublas_handle_,
cublasLtHandle_t cublaslt_handle_,
cudaStream_t stream,
......
......@@ -38,7 +38,14 @@ namespace turbomind {
#define COL32_ 32
// workspace for cublas gemm : 32MB
#define CUBLAS_WORKSPACE_SIZE 33554432
#define CK_WORKSPACE_SIZE 1056768000
#define N_max 3000
#define M_max 22016
#define XPAD_WORKSPACE_SIZE 132096000
#define DEQ_WORKSPACE_SIZE 232096000
// workspace for ck gemm : 3000*22016*8*2= 1,056,768,000
// XPAD_WORKSPACE_SIZE :3000*22016*2 = 132,096,000
// DEQ_WORKSPACE_SIZE :4096*22016*2 = 180,355,072 < 232,096,000
typedef struct __align__(4)
{
half x, y, z, w;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment