Unverified Commit cc93136e authored by tpoisonooo's avatar tpoisonooo Committed by GitHub
Browse files

feat(src): add kv cache int8 quantization (#22)

* feat(src): add int8 and compile passed

* feat(kernels): fix

* feat(llama): update kernel

* feat(src): add debug

* fix(kernel): k_cache use int8_t pointer

* style(llama): clean code

* feat(deploy.py): revert to enable fmha

* style(LlamaV2): clean code

* feat(deploy.py): add default quant policy
parent 4d42a781
......@@ -9,6 +9,15 @@
namespace fastertransformer {
enum QuantPolicy {
kNone = 0x00,
// reserve 0x01 and 0x02 for backward compatibility
kReserve1 = 0x01,
kReserve2 = 0x02,
// quantize cache kv
kCacheKVInt8 = 0x04,
};
enum CmpMode
{
kCmpNone,
......
......@@ -129,6 +129,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0);
prefix_cache_len_ = reader.GetInteger("llama", "prefix_cache_len", 0);
attn_bias_ = reader.GetInteger("llama", "attn_bias", 0);
quant_policy_ = reader.GetInteger("llama", "quant_policy", 0);
handleMissingParams();
......@@ -224,6 +225,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
end_id_,
cache_max_entry_count_,
cache_chunk_size_,
quant_policy_,
use_context_fmha_,
shared_state_,
shared_weights_[device_id].get(),
......@@ -307,7 +309,7 @@ std::string LlamaTritonModel<T>::toString()
<< "\nuse_context_fmha: " << use_context_fmha_ << "\nstart_id: " << start_id_
<< "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_
<< "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\nmodel_name: " << model_name_
<< "\nprefix_cache_len: " << prefix_cache_len_ << "\nmodel_dir: " << model_dir_ << std::endl;
<< "\nprefix_cache_len: " << prefix_cache_len_ << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_ << std::endl;
return ss.str();
}
......
......@@ -93,6 +93,7 @@ private:
size_t pipeline_para_size_;
ft::WeightType weight_type_;
bool attn_bias_;
int quant_policy_;
size_t prefix_cache_len_{};
......
......@@ -344,6 +344,11 @@ std::vector<T> loadWeightFromBinHelper(std::vector<size_t> shape, std::string fi
return host_array;
}
std::vector<float> loadArrayFromBin(std::vector<size_t> shape, std::string filename)
{
return loadWeightFromBinHelper<float>(shape, filename);
}
template<typename T, typename T_IN>
int loadWeightFromBinFunc(T* ptr, std::vector<size_t> shape, std::string filename)
{
......@@ -523,7 +528,7 @@ void saveToBinary(const T* ptr, const size_t size, std::string filename)
std::vector<T> h_ptr(size);
cudaD2Hcpy(h_ptr.data(), ptr, size);
std::vector<float> float_ptr(size);
for (int i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++) {
float_ptr[i] = (float)h_ptr[i];
}
......
......@@ -55,6 +55,8 @@ int loadWeightFromBin(T* ptr,
std::string filename,
FtCudaDataType model_file_type = FtCudaDataType::FP32);
std::vector<float> loadArrayFromBin(std::vector<size_t> shape, std::string filename);
// template<typename T>
// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr,
// T* scale_ptr,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment