Unverified Commit c3290cad authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Feature] Blazing fast W4A16 inference (#202)

* add w4a16

* fix `deploy.py`

* add doc

* add w4a16 kernels

* fuse w1/w3 & bugfixes

* fix typo

* python

* guard sm75/80 features

* add missing header

* refactor

* qkvo bias

* update cost model

* fix lint

* update `deploy.py`
parent d3dbe179
...@@ -2,29 +2,39 @@ ...@@ -2,29 +2,39 @@
#pragma once #pragma once
#include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include <type_traits>
namespace turbomind { namespace turbomind {
template<typename T> template<typename T>
class LlamaLinear { class LlamaLinear {
public: public:
enum Type
{
kGemm,
kFusedSiluFfn
};
LlamaLinear(cublasMMWrapper* cublas_wrapper, cudaStream_t stream): cublas_wrapper_(cublas_wrapper), stream_(stream) LlamaLinear(cublasMMWrapper* cublas_wrapper, cudaStream_t stream): cublas_wrapper_(cublas_wrapper), stream_(stream)
{ {
} }
void forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight) void
forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type = kGemm)
{ {
switch (weight.type) { switch (weight.type) {
case WeightType::kFP16: case WeightType::kFP16:
case WeightType::kFP32: case WeightType::kFP32:
forwardFp(output_data, input_data, batch_size, weight); forwardFp(output_data, input_data, batch_size, weight, type);
break; break;
case WeightType::kINT4: case WeightType::kINT4:
forwardInt4(output_data, input_data, batch_size, weight); forwardInt4(output_data, input_data, batch_size, weight, type);
break; break;
default: default:
FT_CHECK(0); FT_CHECK(0);
...@@ -32,8 +42,9 @@ public: ...@@ -32,8 +42,9 @@ public:
} }
private: private:
void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight) void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
{ {
FT_CHECK(type == kGemm);
cublas_wrapper_->Gemm(CUBLAS_OP_N, cublas_wrapper_->Gemm(CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
weight.output_dims, weight.output_dims,
...@@ -48,14 +59,31 @@ private: ...@@ -48,14 +59,31 @@ private:
sync_check_cuda_error(); sync_check_cuda_error();
} }
void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight) void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
{ {
FT_CHECK_WITH_INFO(0, "Not implemented"); if constexpr (std::is_same_v<T, half>) {
gemm_s4_f16_.Run(output_data,
(const uint*)weight.kernel,
input_data,
(const half2*)weight.scales_and_zeros,
weight.output_dims,
batch_size,
weight.input_dims,
weight.group_size,
type == kFusedSiluFfn ? GemmS4F16::kFusedSiluFfn : GemmS4F16::kGemm,
-1,
stream_);
sync_check_cuda_error();
}
else {
FT_CHECK_WITH_INFO(0, "Not implemented");
}
} }
private: private:
cublasMMWrapper* cublas_wrapper_; cublasMMWrapper* cublas_wrapper_;
cudaStream_t stream_{}; cudaStream_t stream_{};
GemmS4F16 gemm_s4_f16_;
}; };
} // namespace turbomind } // namespace turbomind
...@@ -29,19 +29,18 @@ LlamaWeight<T>::LlamaWeight(size_t head_num, ...@@ -29,19 +29,18 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
size_t inter_size, size_t inter_size,
size_t vocab_size, size_t vocab_size,
size_t num_layer, size_t num_layer,
WeightType weight_type,
bool attn_bias, bool attn_bias,
WeightType weight_type,
int group_size,
size_t tensor_para_size, size_t tensor_para_size,
size_t tensor_para_rank, size_t tensor_para_rank):
int prefix_cache_len):
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
inter_size_(inter_size), inter_size_(inter_size),
vocab_size_(vocab_size), vocab_size_(vocab_size),
num_layer_(num_layer), num_layer_(num_layer),
weight_type_(weight_type), weight_type_(weight_type),
tensor_para_size_(tensor_para_size), tensor_para_size_(tensor_para_size),
tensor_para_rank_(tensor_para_rank), tensor_para_rank_(tensor_para_rank)
prefix_cache_len_(prefix_cache_len)
{ {
decoder_layer_weights.reserve(num_layer_); decoder_layer_weights.reserve(num_layer_);
for (unsigned l = 0; l < num_layer_; ++l) { for (unsigned l = 0; l < num_layer_; ++l) {
...@@ -50,6 +49,7 @@ LlamaWeight<T>::LlamaWeight(size_t head_num, ...@@ -50,6 +49,7 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
size_per_head, size_per_head,
inter_size_, inter_size_,
weight_type_, weight_type_,
group_size,
attn_bias, attn_bias,
tensor_para_size_, tensor_para_size_,
tensor_para_rank_)); tensor_para_rank_));
...@@ -65,17 +65,8 @@ LlamaWeight<T>::~LlamaWeight() ...@@ -65,17 +65,8 @@ LlamaWeight<T>::~LlamaWeight()
cudaFree((void*)output_norm_weight); cudaFree((void*)output_norm_weight);
cudaFree((void*)post_decoder_embedding_kernel); cudaFree((void*)post_decoder_embedding_kernel);
if (prefix_cache_key) {
cudaFree((void*)prefix_cache_key);
cudaFree((void*)prefix_cache_token);
}
pre_decoder_embedding_table = nullptr; pre_decoder_embedding_table = nullptr;
post_decoder_embedding_kernel = nullptr; post_decoder_embedding_kernel = nullptr;
prefix_cache_token = nullptr;
prefix_cache_key = nullptr;
prefix_cache_value = nullptr;
} }
template<typename T> template<typename T>
...@@ -84,13 +75,6 @@ void LlamaWeight<T>::mallocWeights() ...@@ -84,13 +75,6 @@ void LlamaWeight<T>::mallocWeights()
deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_ * hidden_units_); deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_ * hidden_units_);
deviceMalloc((T**)&output_norm_weight, hidden_units_); deviceMalloc((T**)&output_norm_weight, hidden_units_);
deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_); deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_);
if (prefix_cache_len_) {
size_t cache_size = num_layer_ * prefix_cache_len_ * hidden_units_ / tensor_para_size_;
deviceMalloc((T**)&prefix_cache_key, cache_size * 2);
prefix_cache_value = prefix_cache_key + cache_size;
deviceMalloc((int**)&prefix_cache_token, prefix_cache_len_);
}
} }
template<typename T> template<typename T>
...@@ -109,18 +93,6 @@ void LlamaWeight<T>::loadModel(std::string dir_path) ...@@ -109,18 +93,6 @@ void LlamaWeight<T>::loadModel(std::string dir_path)
loadWeightFromBin( loadWeightFromBin(
(T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_}, dir_path + "output.weight", model_file_type); (T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_}, dir_path + "output.weight", model_file_type);
if (prefix_cache_len_) {
loadWeightFromBin((float*)prefix_cache_token, {prefix_cache_len_}, dir_path + "prefix_cache.token");
loadWeightFromBin((T*)prefix_cache_key,
{num_layer_ * prefix_cache_len_, hidden_units_ / tensor_para_size_},
dir_path + "prefix_cache." + std::to_string(tensor_para_rank_) + ".key",
model_file_type);
loadWeightFromBin((T*)prefix_cache_value,
{num_layer_ * prefix_cache_len_, hidden_units_ / tensor_para_size_},
dir_path + "prefix_cache." + std::to_string(tensor_para_rank_) + ".value",
model_file_type);
}
for (unsigned layer = 0; layer < num_layer_; ++layer) { for (unsigned layer = 0; layer < num_layer_; ++layer) {
decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type); decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type);
} }
......
...@@ -34,11 +34,11 @@ struct LlamaWeight { ...@@ -34,11 +34,11 @@ struct LlamaWeight {
size_t inter_size, size_t inter_size,
size_t vocab_size, size_t vocab_size,
size_t num_layer, size_t num_layer,
WeightType weight_type,
bool attn_bias, bool attn_bias,
WeightType weight_type,
int group_size,
size_t tensor_para_size, size_t tensor_para_size,
size_t tensor_para_rank, size_t tensor_para_rank);
int prefix_cache_len);
~LlamaWeight(); ~LlamaWeight();
...@@ -52,11 +52,6 @@ struct LlamaWeight { ...@@ -52,11 +52,6 @@ struct LlamaWeight {
const T* output_norm_weight{}; const T* output_norm_weight{};
const T* post_decoder_embedding_kernel{}; const T* post_decoder_embedding_kernel{};
size_t prefix_cache_len_;
int* prefix_cache_token{};
T* prefix_cache_key{};
T* prefix_cache_value{};
private: private:
void mallocWeights(); void mallocWeights();
......
#include "src/turbomind/kernels/gemm_s_f16/format.h"
#include "src/turbomind/python/dlpack.h" #include "src/turbomind/python/dlpack.h"
#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h" #include "src/turbomind/triton_backend/llama/LlamaTritonModel.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp" #include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/nccl_utils.h" #include "src/turbomind/utils/nccl_utils.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <memory> #include <memory>
#include <pybind11/functional.h> #include <pybind11/functional.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <pybind11/stl_bind.h> #include <pybind11/stl_bind.h>
...@@ -211,6 +214,13 @@ std::shared_ptr<triton::Tensor> DLManagedTensorToTritonTensor(DLManagedTensor* t ...@@ -211,6 +214,13 @@ std::shared_ptr<triton::Tensor> DLManagedTensorToTritonTensor(DLManagedTensor* t
return std::make_shared<triton::Tensor>(where, dtype, shape, data); return std::make_shared<triton::Tensor>(where, dtype, shape, data);
} }
DLTensor GetDLTensor(py::object obj)
{
py::capsule cap = obj.attr("__dlpack__")();
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName));
return dlmt->dl_tensor;
}
PYBIND11_MODULE(_turbomind, m) PYBIND11_MODULE(_turbomind, m)
{ {
// nccl param // nccl param
...@@ -335,7 +345,7 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -335,7 +345,7 @@ PYBIND11_MODULE(_turbomind, m)
size_t pipeline_para_size, size_t pipeline_para_size,
int enable_custom_all_reduce, int enable_custom_all_reduce,
std::string data_type) -> std::shared_ptr<AbstractTransformerModel> { std::string data_type) -> std::shared_ptr<AbstractTransformerModel> {
if (data_type == "half" || data_type == "fp16") { if (data_type == "half" || data_type == "fp16" || data_type == "int4") {
return std::make_shared<LlamaTritonModel<half>>( return std::make_shared<LlamaTritonModel<half>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
} }
...@@ -389,4 +399,57 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -389,4 +399,57 @@ PYBIND11_MODULE(_turbomind, m)
.def("__repr__", &AbstractTransformerModel::toString) .def("__repr__", &AbstractTransformerModel::toString)
.def("get_tensor_para_size", &AbstractTransformerModel::getTensorParaSize) .def("get_tensor_para_size", &AbstractTransformerModel::getTensorParaSize)
.def("get_pipeline_para_size", &AbstractTransformerModel::getPipelineParaSize); .def("get_pipeline_para_size", &AbstractTransformerModel::getPipelineParaSize);
m.def("transpose_qk_s4_k_m8", [](py::object src, py::object dst, int m, int k, int size_per_head) {
auto src_tensor = GetDLTensor(src);
auto dst_tensor = GetDLTensor(dst);
turbomind::transpose_qk_s4_k_m8_hf(
(uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, size_per_head, nullptr);
});
m.def("fuse_w1_w3_s4_k_m8", [](py::object src, py::object dst, int m, int k) {
auto src_tensor = GetDLTensor(src);
auto dst_tensor = GetDLTensor(dst);
turbomind::fuse_w1_w3_s4_k_m8((uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, nullptr);
});
m.def("convert_s4_k_m8",
[](py::object A_dst,
py::object Q_dst,
py::object ws,
py::object A_src,
py::object scales,
py::object qzeros,
int m,
int k,
int group_size) {
auto a_dst = GetDLTensor(A_dst);
auto q_dst = GetDLTensor(Q_dst);
auto w = GetDLTensor(ws);
auto a_src = GetDLTensor(A_src);
auto s = GetDLTensor(scales);
auto qz = GetDLTensor(qzeros);
turbomind::convert_s4_k_m8((uint32_t*)a_dst.data,
(half2*)q_dst.data,
(half*)w.data,
(const uint32_t*)a_src.data,
(const half*)s.data,
(const uint32_t*)qz.data,
m,
k,
group_size,
nullptr);
});
m.def("dequantize_s4", [](py::object src, py::object dst) {
auto src_tensor = GetDLTensor(src);
auto dst_tensor = GetDLTensor(dst);
auto src_count = std::accumulate(src_tensor.shape, src_tensor.shape + src_tensor.ndim, size_t{1});
auto dst_count = std::accumulate(dst_tensor.shape, dst_tensor.shape + dst_tensor.ndim, size_t{1});
turbomind::FT_CHECK(src_count * 8 == dst_count);
turbomind::dequantize_s4((uint4*)dst_tensor.data, (uint32_t*)src_tensor.data, src_count, nullptr);
});
} }
...@@ -133,9 +133,9 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size, ...@@ -133,9 +133,9 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
cache_max_entry_count_ = reader.GetInteger("llama", "cache_max_entry_count", 0); cache_max_entry_count_ = reader.GetInteger("llama", "cache_max_entry_count", 0);
use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1); use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1);
cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0); cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0);
prefix_cache_len_ = reader.GetInteger("llama", "prefix_cache_len", 0);
attn_bias_ = reader.GetInteger("llama", "attn_bias", 0); attn_bias_ = reader.GetInteger("llama", "attn_bias", 0);
quant_policy_ = reader.GetInteger("llama", "quant_policy", 0); quant_policy_ = reader.GetInteger("llama", "quant_policy", 0);
group_size_ = reader.GetInteger("llama", "group_size", 0);
handleMissingParams(); handleMissingParams();
...@@ -296,11 +296,11 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank) ...@@ -296,11 +296,11 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank)
inter_size_, inter_size_,
vocab_size_, vocab_size_,
num_layer_, num_layer_,
weight_type_,
attn_bias_, attn_bias_,
weight_type_,
group_size_,
tensor_para_size_, tensor_para_size_,
tensor_para_rank, tensor_para_rank);
prefix_cache_len_);
shared_weights_[device_id]->loadModel(model_dir_); shared_weights_[device_id]->loadModel(model_dir_);
return; return;
} }
...@@ -318,8 +318,8 @@ std::string LlamaTritonModel<T>::toString() ...@@ -318,8 +318,8 @@ std::string LlamaTritonModel<T>::toString()
<< "\ncache_chunk_size: " << cache_chunk_size_ << "\nuse_context_fmha: " << use_context_fmha_ << "\ncache_chunk_size: " << cache_chunk_size_ << "\nuse_context_fmha: " << use_context_fmha_
<< "\nstart_id: " << start_id_ << "\ntensor_para_size: " << tensor_para_size_ << "\nstart_id: " << start_id_ << "\ntensor_para_size: " << tensor_para_size_
<< "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_
<< "\nmodel_name: " << model_name_ << "\nprefix_cache_len: " << prefix_cache_len_ << "\nmodel_name: " << model_name_ << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_
<< "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_ << std::endl; << "\ngroup_size: " << group_size_ << std::endl;
return ss.str(); return ss.str();
} }
......
...@@ -95,8 +95,7 @@ private: ...@@ -95,8 +95,7 @@ private:
ft::WeightType weight_type_; ft::WeightType weight_type_;
bool attn_bias_; bool attn_bias_;
int quant_policy_; int quant_policy_;
int group_size_;
size_t prefix_cache_len_{};
// shared weights for each device // shared weights for each device
std::vector<std::shared_ptr<ft::LlamaWeight<T>>> shared_weights_; std::vector<std::shared_ptr<ft::LlamaWeight<T>>> shared_weights_;
...@@ -107,15 +106,6 @@ private: ...@@ -107,15 +106,6 @@ private:
std::vector<std::weak_ptr<LlamaTritonSharedModelInstance<T>>> shared_instances_; std::vector<std::weak_ptr<LlamaTritonSharedModelInstance<T>>> shared_instances_;
std::deque<std::mutex> shared_mutexes_; // is locking really needed? std::deque<std::mutex> shared_mutexes_; // is locking really needed?
// // residual type
// bool use_gptj_residual_ = true;
// // number of tasks (for prefix-prompt, p/prompt-tuning)
// size_t num_tasks_ = 0;
// int prompt_learning_start_id_ = 0;
// ft::PromptLearningType prompt_learning_type_ = ft::PromptLearningType::no_prompt;
// std::map<std::string, std::pair<int, int>> prompt_learning_table_pair_ = {};
bool is_fp16_; bool is_fp16_;
int enable_custom_all_reduce_ = 0; int enable_custom_all_reduce_ = 0;
......
...@@ -330,16 +330,16 @@ loadWeightFromBinHelper(std::vector<size_t> shape, std::string filename, std::ve ...@@ -330,16 +330,16 @@ loadWeightFromBinHelper(std::vector<size_t> shape, std::string filename, std::ve
size_t loaded_data_size = sizeof(T) * size; size_t loaded_data_size = sizeof(T) * size;
in.seekg(0, in.end); in.seekg(0, in.end);
const auto file_size_in_bytes = (size_t)in.tellg();
in.seekg(0, in.beg); in.seekg(0, in.beg);
TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename); TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename);
in.read((char*)host_array.data(), loaded_data_size); in.read((char*)host_array.data(), loaded_data_size);
size_t in_get_size = in.gcount(); if (file_size_in_bytes != loaded_data_size) {
if (in_get_size != loaded_data_size) { TM_LOG_WARNING("file %s has %ld, but request %ld, loading model fails!",
TM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n",
filename.c_str(), filename.c_str(),
in_get_size, file_size_in_bytes,
loaded_data_size); loaded_data_size);
return std::vector<T>(); return std::vector<T>();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment