// Copyright (c) OpenMMLab. All rights reserved. #pragma once #include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/logger.h" #include namespace turbomind { template class LlamaLinear { public: enum Type { kGemm, kFusedSiluFfn }; LlamaLinear(cublasMMWrapper* cublas_wrapper, cudaStream_t stream): cublas_wrapper_(cublas_wrapper), stream_(stream) { } void forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight, Type type = kGemm) { switch (weight.type) { case WeightType::kFP16: case WeightType::kFP32: case WeightType::kBF16: forwardFp(output_data, input_data, batch_size, weight, type); break; case WeightType::kINT4: forwardInt4(output_data, input_data, batch_size, weight, type); break; default: FT_CHECK(0); } } void forward_ffn(T* output_data,T* output_tmp, const T* input_data, int batch_size, const LlamaDenseWeight& weight, Type type = kGemm) { switch (weight.type) { case WeightType::kFP16: case WeightType::kFP32: case WeightType::kBF16: forwardFp(output_data, input_data, batch_size, weight, type); break; case WeightType::kINT4: { if (type == kFusedSiluFfn) forwardInt4_ffn(output_data, output_tmp,input_data, batch_size, weight, type); else forwardInt4(output_data, input_data, batch_size, weight, type); break; } default: FT_CHECK(0); } } private: void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight, Type type) { FT_CHECK(type == kGemm); cublas_wrapper_->Gemm(CUBLAS_OP_N, CUBLAS_OP_N, weight.output_dims, batch_size, weight.input_dims, (const T*)weight.kernel, weight.output_dims, input_data, weight.input_dims, output_data, weight.output_dims); sync_check_cuda_error(); } void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight, Type type) { if constexpr (std::is_same_v) { if(weight.w4_weight_layout==0) //普通NN模式 rocblas { //检查DQweight的空间是否足够 if(batch_size*weight.output_dims>M_max*N_max) { FT_CHECK_WITH_INFO(0, "error! batch_size>N_max ||weight.output_dims>N_max"); } dequant_w4_gemm(stream_, reinterpret_cast(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims,weight.output_dims,weight.group_size); cublas_wrapper_->Gemm(CUBLAS_OP_N, CUBLAS_OP_N, weight.output_dims,//m batch_size,//n weight.input_dims,//k (const T*) cublas_wrapper_->deweight_workspace_, //[] weight.output_dims,//m input_data, weight.input_dims, //k output_data, weight.output_dims); //m } else if(weight.w4_weight_layout==1)//TN模式 padding rocblas { //检查DQweight的空间是否足够 if(batch_size*weight.output_dims>M_max*N_max) { FT_CHECK_WITH_INFO(0, "error! batch_size>N_max ||weight.output_dims>N_max"); } //检查xpad空间是否足够 if(weight.input_dims%4096==0) //需要进行pad { int pad_group_count=2; input_padding(stream_,reinterpret_cast(cublas_wrapper_->xpading_workspace_),(const T*)input_data,batch_size,weight.input_dims,weight.group_size,pad_group_count); dequant_w4_gemm_colmajor(stream_,reinterpret_cast(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims+pad_group_count*weight.group_size ,weight.output_dims,weight.group_size); cublas_wrapper_->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, weight.output_dims,//m batch_size,//n weight.input_dims+pad_group_count*weight.group_size,//k (const T*) reinterpret_cast(cublas_wrapper_->deweight_workspace_), //[] weight.input_dims+pad_group_count*weight.group_size, //k (const T*) cublas_wrapper_->xpading_workspace_, weight.input_dims+pad_group_count*weight.group_size, //k output_data, weight.output_dims); //m } else //不需要进行pad { dequant_w4_gemm_colmajor(stream_,reinterpret_cast(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims,weight.output_dims,weight.group_size); cublas_wrapper_->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, weight.output_dims,//m batch_size,//n weight.input_dims,//k (const T*) reinterpret_cast(cublas_wrapper_->deweight_workspace_), //[] weight.input_dims, //k input_data, weight.input_dims, //k output_data, weight.output_dims); //m } } else if(weight.w4_weight_layout==2) //TN 模式padding ck { //检查ck workspace 的空间是否足够 if(weight.input_dims%4096==0) { int pad_groupcount=2; run_weight_only_gemm(reinterpret_cast(input_data), reinterpret_cast(weight.kernel), reinterpret_cast(weight.scales_and_zeros), reinterpret_cast (output_data), batch_size, weight.output_dims, (weight.input_dims), (weight.input_dims),(weight.input_dims), (weight.input_dims+pad_groupcount*weight.group_size), weight.output_dims, weight.group_size,reinterpret_cast(cublas_wrapper_->ck_workspace_),CK_WORKSPACE_SIZE,(hipStream_t)stream_); } // A B0 B1 C M N K strideA strideB strideBpad strideC group_size else{ run_weight_only_gemm(reinterpret_cast(input_data), reinterpret_cast(weight.kernel), reinterpret_cast(weight.scales_and_zeros), reinterpret_cast (output_data), batch_size, weight.output_dims, (weight.input_dims), (weight.input_dims),(weight.input_dims), (weight.input_dims), weight.output_dims, weight.group_size,reinterpret_cast(cublas_wrapper_->ck_workspace_),CK_WORKSPACE_SIZE,(hipStream_t)stream_); } } sync_check_cuda_error(); } else { FT_CHECK_WITH_INFO(0, "Not implemented"); } } void forwardInt4_ffn(T* output_data,T* output_tmp, const T* input_data, int batch_size, const LlamaDenseWeight& weight, Type type) { if constexpr (std::is_same_v) { if(weight.w4_weight_layout==0) //普通NN模式 rocblas { //检查DQweight的空间是否足够 if(batch_size*weight.output_dims>M_max*N_max) { FT_CHECK_WITH_INFO(0, "error! batch_size>N_max ||weight.output_dims>N_max"); } dequant_w4_gemm(stream_, reinterpret_cast(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims,weight.output_dims,weight.group_size); cublas_wrapper_->Gemm(CUBLAS_OP_N, CUBLAS_OP_N, weight.output_dims,//m batch_size,//n weight.input_dims,//k (const T*) cublas_wrapper_->deweight_workspace_, //[] weight.output_dims,//m input_data, weight.input_dims, //k output_tmp, weight.output_dims); //m } else if(weight.w4_weight_layout==1)//TN模式 padding rocblas { //检查DQweight的空间是否足够 if(batch_size*weight.output_dims>M_max*N_max) { FT_CHECK_WITH_INFO(0, "error! batch_size>N_max ||weight.output_dims>N_max"); } //检查xpad空间是否足够 if(weight.input_dims%4096==0) //需要进行pad { int pad_group_count=2; input_padding(stream_,reinterpret_cast(cublas_wrapper_->xpading_workspace_),(const T*)input_data,batch_size,weight.input_dims,weight.group_size,pad_group_count); dequant_w4_gemm_colmajor(stream_,reinterpret_cast(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims+pad_group_count*weight.group_size,weight.output_dims,weight.group_size); cublas_wrapper_->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, weight.output_dims,//m batch_size,//n weight.input_dims+pad_group_count*weight.group_size,//k (const T*) reinterpret_cast(cublas_wrapper_->deweight_workspace_), //[] weight.input_dims+pad_group_count*weight.group_size, //k (const T*) cublas_wrapper_->xpading_workspace_, weight.input_dims+pad_group_count*weight.group_size, //k output_tmp, weight.output_dims); //m } else //不需要进行pad { dequant_w4_gemm_colmajor(stream_,reinterpret_cast(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims,weight.output_dims,weight.group_size); cublas_wrapper_->Gemm(CUBLAS_OP_T, CUBLAS_OP_N, weight.output_dims,//m batch_size,//n weight.input_dims,//k (const T*) reinterpret_cast(cublas_wrapper_->deweight_workspace_), //[] weight.input_dims, //k input_data, weight.input_dims, //k output_tmp, weight.output_dims); //m } } else if(weight.w4_weight_layout==2) //TN 模式padding ck { //检查ck workspace 的空间是否足够 if(weight.input_dims%4096==0) { int pad_groupcount=2; run_weight_only_gemm(reinterpret_cast(input_data), reinterpret_cast(weight.kernel), reinterpret_cast(weight.scales_and_zeros), reinterpret_cast (output_tmp), batch_size, weight.output_dims, (weight.input_dims), (weight.input_dims),(weight.input_dims), (weight.input_dims+pad_groupcount*weight.group_size), weight.output_dims, weight.group_size,reinterpret_cast(cublas_wrapper_->ck_workspace_),CK_WORKSPACE_SIZE,(hipStream_t)stream_); } // A B0 B1 C M N K strideA strideB strideBpad strideC group_size else{ run_weight_only_gemm(reinterpret_cast(input_data), reinterpret_cast(weight.kernel), reinterpret_cast(weight.scales_and_zeros), reinterpret_cast (output_tmp), batch_size, weight.output_dims, (weight.input_dims), (weight.input_dims),(weight.input_dims), (weight.input_dims), weight.output_dims, weight.group_size,reinterpret_cast(cublas_wrapper_->ck_workspace_),CK_WORKSPACE_SIZE,(hipStream_t)stream_); } } addFusedSiluActivation(stream_,output_data,output_tmp,batch_size,weight.output_dims,1); sync_check_cuda_error(); } else { FT_CHECK_WITH_INFO(0, "Not implemented"); } } private: cublasMMWrapper* cublas_wrapper_; cudaStream_t stream_{}; // GemmS4F16 gemm_s4_f16_; }; } // namespace turbomind