Unverified Commit c3290cad authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Feature] Blazing fast W4A16 inference (#202)

* add w4a16

* fix `deploy.py`

* add doc

* add w4a16 kernels

* fuse w1/w3 & bugfixes

* fix typo

* python

* guard sm75/80 features

* add missing header

* refactor

* qkvo bias

* update cost model

* fix lint

* update `deploy.py`
parent d3dbe179
......@@ -2,29 +2,39 @@
#pragma once
#include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include <type_traits>
namespace turbomind {
template<typename T>
class LlamaLinear {
public:
enum Type
{
kGemm,
kFusedSiluFfn
};
LlamaLinear(cublasMMWrapper* cublas_wrapper, cudaStream_t stream): cublas_wrapper_(cublas_wrapper), stream_(stream)
{
}
void forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight)
void
forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type = kGemm)
{
switch (weight.type) {
case WeightType::kFP16:
case WeightType::kFP32:
forwardFp(output_data, input_data, batch_size, weight);
forwardFp(output_data, input_data, batch_size, weight, type);
break;
case WeightType::kINT4:
forwardInt4(output_data, input_data, batch_size, weight);
forwardInt4(output_data, input_data, batch_size, weight, type);
break;
default:
FT_CHECK(0);
......@@ -32,8 +42,9 @@ public:
}
private:
void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight)
void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
{
FT_CHECK(type == kGemm);
cublas_wrapper_->Gemm(CUBLAS_OP_N,
CUBLAS_OP_N,
weight.output_dims,
......@@ -48,14 +59,31 @@ private:
sync_check_cuda_error();
}
void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight)
void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
{
FT_CHECK_WITH_INFO(0, "Not implemented");
if constexpr (std::is_same_v<T, half>) {
gemm_s4_f16_.Run(output_data,
(const uint*)weight.kernel,
input_data,
(const half2*)weight.scales_and_zeros,
weight.output_dims,
batch_size,
weight.input_dims,
weight.group_size,
type == kFusedSiluFfn ? GemmS4F16::kFusedSiluFfn : GemmS4F16::kGemm,
-1,
stream_);
sync_check_cuda_error();
}
else {
FT_CHECK_WITH_INFO(0, "Not implemented");
}
}
private:
cublasMMWrapper* cublas_wrapper_;
cudaStream_t stream_{};
GemmS4F16 gemm_s4_f16_;
};
} // namespace turbomind
......@@ -29,19 +29,18 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
size_t inter_size,
size_t vocab_size,
size_t num_layer,
WeightType weight_type,
bool attn_bias,
WeightType weight_type,
int group_size,
size_t tensor_para_size,
size_t tensor_para_rank,
int prefix_cache_len):
size_t tensor_para_rank):
hidden_units_(head_num * size_per_head),
inter_size_(inter_size),
vocab_size_(vocab_size),
num_layer_(num_layer),
weight_type_(weight_type),
tensor_para_size_(tensor_para_size),
tensor_para_rank_(tensor_para_rank),
prefix_cache_len_(prefix_cache_len)
tensor_para_rank_(tensor_para_rank)
{
decoder_layer_weights.reserve(num_layer_);
for (unsigned l = 0; l < num_layer_; ++l) {
......@@ -50,6 +49,7 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
size_per_head,
inter_size_,
weight_type_,
group_size,
attn_bias,
tensor_para_size_,
tensor_para_rank_));
......@@ -65,17 +65,8 @@ LlamaWeight<T>::~LlamaWeight()
cudaFree((void*)output_norm_weight);
cudaFree((void*)post_decoder_embedding_kernel);
if (prefix_cache_key) {
cudaFree((void*)prefix_cache_key);
cudaFree((void*)prefix_cache_token);
}
pre_decoder_embedding_table = nullptr;
post_decoder_embedding_kernel = nullptr;
prefix_cache_token = nullptr;
prefix_cache_key = nullptr;
prefix_cache_value = nullptr;
}
template<typename T>
......@@ -84,13 +75,6 @@ void LlamaWeight<T>::mallocWeights()
deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_ * hidden_units_);
deviceMalloc((T**)&output_norm_weight, hidden_units_);
deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_);
if (prefix_cache_len_) {
size_t cache_size = num_layer_ * prefix_cache_len_ * hidden_units_ / tensor_para_size_;
deviceMalloc((T**)&prefix_cache_key, cache_size * 2);
prefix_cache_value = prefix_cache_key + cache_size;
deviceMalloc((int**)&prefix_cache_token, prefix_cache_len_);
}
}
template<typename T>
......@@ -109,18 +93,6 @@ void LlamaWeight<T>::loadModel(std::string dir_path)
loadWeightFromBin(
(T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_}, dir_path + "output.weight", model_file_type);
if (prefix_cache_len_) {
loadWeightFromBin((float*)prefix_cache_token, {prefix_cache_len_}, dir_path + "prefix_cache.token");
loadWeightFromBin((T*)prefix_cache_key,
{num_layer_ * prefix_cache_len_, hidden_units_ / tensor_para_size_},
dir_path + "prefix_cache." + std::to_string(tensor_para_rank_) + ".key",
model_file_type);
loadWeightFromBin((T*)prefix_cache_value,
{num_layer_ * prefix_cache_len_, hidden_units_ / tensor_para_size_},
dir_path + "prefix_cache." + std::to_string(tensor_para_rank_) + ".value",
model_file_type);
}
for (unsigned layer = 0; layer < num_layer_; ++layer) {
decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type);
}
......
......@@ -34,11 +34,11 @@ struct LlamaWeight {
size_t inter_size,
size_t vocab_size,
size_t num_layer,
WeightType weight_type,
bool attn_bias,
WeightType weight_type,
int group_size,
size_t tensor_para_size,
size_t tensor_para_rank,
int prefix_cache_len);
size_t tensor_para_rank);
~LlamaWeight();
......@@ -52,11 +52,6 @@ struct LlamaWeight {
const T* output_norm_weight{};
const T* post_decoder_embedding_kernel{};
size_t prefix_cache_len_;
int* prefix_cache_token{};
T* prefix_cache_key{};
T* prefix_cache_value{};
private:
void mallocWeights();
......
#include "src/turbomind/kernels/gemm_s_f16/format.h"
#include "src/turbomind/python/dlpack.h"
#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/nccl_utils.h"
#include <cuda_runtime.h>
#include <memory>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
......@@ -211,6 +214,13 @@ std::shared_ptr<triton::Tensor> DLManagedTensorToTritonTensor(DLManagedTensor* t
return std::make_shared<triton::Tensor>(where, dtype, shape, data);
}
DLTensor GetDLTensor(py::object obj)
{
py::capsule cap = obj.attr("__dlpack__")();
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName));
return dlmt->dl_tensor;
}
PYBIND11_MODULE(_turbomind, m)
{
// nccl param
......@@ -335,7 +345,7 @@ PYBIND11_MODULE(_turbomind, m)
size_t pipeline_para_size,
int enable_custom_all_reduce,
std::string data_type) -> std::shared_ptr<AbstractTransformerModel> {
if (data_type == "half" || data_type == "fp16") {
if (data_type == "half" || data_type == "fp16" || data_type == "int4") {
return std::make_shared<LlamaTritonModel<half>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
}
......@@ -389,4 +399,57 @@ PYBIND11_MODULE(_turbomind, m)
.def("__repr__", &AbstractTransformerModel::toString)
.def("get_tensor_para_size", &AbstractTransformerModel::getTensorParaSize)
.def("get_pipeline_para_size", &AbstractTransformerModel::getPipelineParaSize);
m.def("transpose_qk_s4_k_m8", [](py::object src, py::object dst, int m, int k, int size_per_head) {
auto src_tensor = GetDLTensor(src);
auto dst_tensor = GetDLTensor(dst);
turbomind::transpose_qk_s4_k_m8_hf(
(uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, size_per_head, nullptr);
});
m.def("fuse_w1_w3_s4_k_m8", [](py::object src, py::object dst, int m, int k) {
auto src_tensor = GetDLTensor(src);
auto dst_tensor = GetDLTensor(dst);
turbomind::fuse_w1_w3_s4_k_m8((uint32_t*)dst_tensor.data, (const uint32_t*)src_tensor.data, m, k, nullptr);
});
m.def("convert_s4_k_m8",
[](py::object A_dst,
py::object Q_dst,
py::object ws,
py::object A_src,
py::object scales,
py::object qzeros,
int m,
int k,
int group_size) {
auto a_dst = GetDLTensor(A_dst);
auto q_dst = GetDLTensor(Q_dst);
auto w = GetDLTensor(ws);
auto a_src = GetDLTensor(A_src);
auto s = GetDLTensor(scales);
auto qz = GetDLTensor(qzeros);
turbomind::convert_s4_k_m8((uint32_t*)a_dst.data,
(half2*)q_dst.data,
(half*)w.data,
(const uint32_t*)a_src.data,
(const half*)s.data,
(const uint32_t*)qz.data,
m,
k,
group_size,
nullptr);
});
m.def("dequantize_s4", [](py::object src, py::object dst) {
auto src_tensor = GetDLTensor(src);
auto dst_tensor = GetDLTensor(dst);
auto src_count = std::accumulate(src_tensor.shape, src_tensor.shape + src_tensor.ndim, size_t{1});
auto dst_count = std::accumulate(dst_tensor.shape, dst_tensor.shape + dst_tensor.ndim, size_t{1});
turbomind::FT_CHECK(src_count * 8 == dst_count);
turbomind::dequantize_s4((uint4*)dst_tensor.data, (uint32_t*)src_tensor.data, src_count, nullptr);
});
}
......@@ -133,9 +133,9 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
cache_max_entry_count_ = reader.GetInteger("llama", "cache_max_entry_count", 0);
use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1);
cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0);
prefix_cache_len_ = reader.GetInteger("llama", "prefix_cache_len", 0);
attn_bias_ = reader.GetInteger("llama", "attn_bias", 0);
quant_policy_ = reader.GetInteger("llama", "quant_policy", 0);
group_size_ = reader.GetInteger("llama", "group_size", 0);
handleMissingParams();
......@@ -296,11 +296,11 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank)
inter_size_,
vocab_size_,
num_layer_,
weight_type_,
attn_bias_,
weight_type_,
group_size_,
tensor_para_size_,
tensor_para_rank,
prefix_cache_len_);
tensor_para_rank);
shared_weights_[device_id]->loadModel(model_dir_);
return;
}
......@@ -318,8 +318,8 @@ std::string LlamaTritonModel<T>::toString()
<< "\ncache_chunk_size: " << cache_chunk_size_ << "\nuse_context_fmha: " << use_context_fmha_
<< "\nstart_id: " << start_id_ << "\ntensor_para_size: " << tensor_para_size_
<< "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_
<< "\nmodel_name: " << model_name_ << "\nprefix_cache_len: " << prefix_cache_len_
<< "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_ << std::endl;
<< "\nmodel_name: " << model_name_ << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_
<< "\ngroup_size: " << group_size_ << std::endl;
return ss.str();
}
......
......@@ -95,8 +95,7 @@ private:
ft::WeightType weight_type_;
bool attn_bias_;
int quant_policy_;
size_t prefix_cache_len_{};
int group_size_;
// shared weights for each device
std::vector<std::shared_ptr<ft::LlamaWeight<T>>> shared_weights_;
......@@ -107,15 +106,6 @@ private:
std::vector<std::weak_ptr<LlamaTritonSharedModelInstance<T>>> shared_instances_;
std::deque<std::mutex> shared_mutexes_; // is locking really needed?
// // residual type
// bool use_gptj_residual_ = true;
// // number of tasks (for prefix-prompt, p/prompt-tuning)
// size_t num_tasks_ = 0;
// int prompt_learning_start_id_ = 0;
// ft::PromptLearningType prompt_learning_type_ = ft::PromptLearningType::no_prompt;
// std::map<std::string, std::pair<int, int>> prompt_learning_table_pair_ = {};
bool is_fp16_;
int enable_custom_all_reduce_ = 0;
......
......@@ -330,16 +330,16 @@ loadWeightFromBinHelper(std::vector<size_t> shape, std::string filename, std::ve
size_t loaded_data_size = sizeof(T) * size;
in.seekg(0, in.end);
const auto file_size_in_bytes = (size_t)in.tellg();
in.seekg(0, in.beg);
TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename);
in.read((char*)host_array.data(), loaded_data_size);
size_t in_get_size = in.gcount();
if (in_get_size != loaded_data_size) {
TM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n",
if (file_size_in_bytes != loaded_data_size) {
TM_LOG_WARNING("file %s has %ld, but request %ld, loading model fails!",
filename.c_str(),
in_get_size,
file_size_in_bytes,
loaded_data_size);
return std::vector<T>();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment