Unverified Commit 208b6841 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

fix clang-format (#68)

parent b2393467
...@@ -26,11 +26,18 @@ ...@@ -26,11 +26,18 @@
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, QUANT_POLICY, stream) \ #define MMHA_LAUNCH_KERNEL( \
size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, QUANT_POLICY, stream) \
dim3 grid(params.num_heads, params.batch_size); \ size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, QUANT_POLICY> \ dim3 grid(params.num_heads, params.batch_size); \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params) mmha::masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK, \
HAS_BEAMS, \
QUANT_POLICY><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -54,7 +61,8 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st ...@@ -54,7 +61,8 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st
else { else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, 4, stream); MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, 4, stream);
} }
} else { }
else {
if (tlength < 32) { if (tlength < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, 0, stream); MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, 0, stream);
} }
......
...@@ -1272,7 +1272,7 @@ template<typename T, // The type of the inputs. Supported types: float and half ...@@ -1272,7 +1272,7 @@ template<typename T, // The type of the inputs. Supported types: float and half
int THREADS_PER_VALUE, // The number of threads per value. int THREADS_PER_VALUE, // The number of threads per value.
int THREADS_PER_BLOCK, // The number of threads in a threadblock. int THREADS_PER_BLOCK, // The number of threads in a threadblock.
bool HAS_BEAMS, bool HAS_BEAMS,
int QUANT_POLICY> // quantization method int QUANT_POLICY> // quantization method
__global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> params) __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> params)
{ {
...@@ -1464,7 +1464,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1464,7 +1464,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (not QUANT_POLICY) { if (not QUANT_POLICY) {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k); *reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
} else if (QUANT_POLICY == 4) { }
else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale); Packed_Int8_t k_int8 = quant(k, k_scale);
...@@ -1486,7 +1487,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1486,7 +1487,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (not QUANT_POLICY) { if (not QUANT_POLICY) {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) = *reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
vec_conversion<Qk_vec_m, Qk_vec_k>(k); vec_conversion<Qk_vec_m, Qk_vec_k>(k);
} else if (QUANT_POLICY == 4) { }
else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale); Packed_Int8_t k_int8 = quant(k, k_scale);
...@@ -1575,11 +1577,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1575,11 +1577,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (not QUANT_POLICY) { if (not QUANT_POLICY) {
k_cache_batch = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset k_cache_batch = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + ki) : + hi * params.memory_max_len * Dh + ki) :
&params.k_cache[bhi * params.memory_max_len * Dh + ki]; &params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer // Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki]; // T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
} else if (QUANT_POLICY == 4) { }
else if (QUANT_POLICY == 4) {
// convert k_cache_per_sample to int8 // convert k_cache_per_sample to int8
if (params.k_cache_per_sample) { if (params.k_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]); int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
...@@ -1628,7 +1631,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1628,7 +1631,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (not QUANT_POLICY) { if (not QUANT_POLICY) {
k[ii] = vec_conversion<K_vec_k, K_vec_m>( k[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); (*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
} else if (QUANT_POLICY == 4) { }
else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<K_vec_m>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<K_vec_m>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<K_vec_m>::value>::type; using Packed_Float_t = typename packed_type<float, num_elems<K_vec_m>::value>::type;
...@@ -1766,7 +1770,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1766,7 +1770,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Base pointer for the beam's batch, before offsetting with indirection buffer // Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi]; // T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
v_cache_batch = v_cache; v_cache_batch = v_cache;
} else if (QUANT_POLICY == 4) { }
else if (QUANT_POLICY == 4) {
if (params.v_cache_per_sample) { if (params.v_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]); int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi; v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi;
...@@ -1831,7 +1836,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1831,7 +1836,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (not QUANT_POLICY) { if (not QUANT_POLICY) {
v = vec_conversion<V_vec_k, V_vec_m>( v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh])); *reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
} else if (QUANT_POLICY == 4) { }
else if (QUANT_POLICY == 4) {
Packed_Int8_t v_vec_m_int8 = Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]); *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale); Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
...@@ -1877,7 +1883,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1877,7 +1883,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (not QUANT_POLICY) { if (not QUANT_POLICY) {
v = vec_conversion<V_vec_k, V_vec_m>( v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh])); *reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
} else if (QUANT_POLICY == 4) { }
else if (QUANT_POLICY == 4) {
Packed_Int8_t v_vec_m_int8 = Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]); *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale); Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
...@@ -1931,7 +1938,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1931,7 +1938,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (not QUANT_POLICY) { if (not QUANT_POLICY) {
*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v); *reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
} else if (QUANT_POLICY == 4) { }
else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
Packed_Int8_t v_int8 = quant(v, v_scale); Packed_Int8_t v_int8 = quant(v, v_scale);
*reinterpret_cast<Packed_Int8_t*>(&v_cache_int8[tlength_circ * Dh]) = v_int8; *reinterpret_cast<Packed_Int8_t*>(&v_cache_int8[tlength_circ * Dh]) = v_int8;
......
#include "src/turbomind/python/dlpack.h" #include "src/turbomind/python/dlpack.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h" #include "src/turbomind/triton_backend/llama/LlamaTritonModel.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
#include <memory> #include <memory>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
...@@ -302,37 +302,37 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -302,37 +302,37 @@ PYBIND11_MODULE(_turbomind, m)
py::class_<AbstractTransformerModelInstance>(m, "AbstractTransformerModelInstance") py::class_<AbstractTransformerModelInstance>(m, "AbstractTransformerModelInstance")
.def( .def(
"forward", "forward",
[](AbstractTransformerModelInstance* model, std::shared_ptr<TensorMap> input_tensors, ft::AbstractInstanceComm* inst_comm) { [](AbstractTransformerModelInstance* model,
return model->forward(input_tensors, inst_comm); std::shared_ptr<TensorMap> input_tensors,
}, py::call_guard<py::gil_scoped_release>(), ft::AbstractInstanceComm* inst_comm) { return model->forward(input_tensors, inst_comm); },
py::call_guard<py::gil_scoped_release>(),
"input_tensors"_a, "input_tensors"_a,
"inst_comm"_a = nullptr); "inst_comm"_a = nullptr);
// transformer model // transformer model
py::class_<AbstractTransformerModel, std::shared_ptr<AbstractTransformerModel>>(m, "AbstractTransformerModel") py::class_<AbstractTransformerModel, std::shared_ptr<AbstractTransformerModel>>(m, "AbstractTransformerModel")
// .def_static("create_llama_model", &AbstractTransformerModel::createLlamaModel, "model_dir"_a) // .def_static("create_llama_model", &AbstractTransformerModel::createLlamaModel, "model_dir"_a)
.def_static("create_llama_model", [](std::string model_dir, .def_static(
size_t tensor_para_size, "create_llama_model",
size_t pipeline_para_size, [](std::string model_dir,
int enable_custom_all_reduce, size_t tensor_para_size,
std::string data_type) -> std::shared_ptr<AbstractTransformerModel> { size_t pipeline_para_size,
int enable_custom_all_reduce,
if (data_type == "half" || data_type == "fp16") { std::string data_type) -> std::shared_ptr<AbstractTransformerModel> {
return std::make_shared<LlamaTritonModel<half>>(tensor_para_size, if (data_type == "half" || data_type == "fp16") {
pipeline_para_size, return std::make_shared<LlamaTritonModel<half>>(
enable_custom_all_reduce, tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
model_dir); }
}else { else {
return std::make_shared<LlamaTritonModel<float>>(tensor_para_size, return std::make_shared<LlamaTritonModel<float>>(
pipeline_para_size, tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
enable_custom_all_reduce, }
model_dir); },
} "model_dir"_a,
}, "model_dir"_a, "tensor_para_size"_a = 1,
"tensor_para_size"_a=1, "pipeline_para_size"_a = 1,
"pipeline_para_size"_a=1, "enable_custom_all_reduce"_a = 0,
"enable_custom_all_reduce"_a=0, "data_type"_a = "half")
"data_type"_a="half")
.def("create_nccl_params", .def("create_nccl_params",
&AbstractTransformerModel::createNcclParams, &AbstractTransformerModel::createNcclParams,
"node_id"_a, "node_id"_a,
......
...@@ -69,9 +69,11 @@ typedef struct { ...@@ -69,9 +69,11 @@ typedef struct {
* \brief The device type in DLDevice. * \brief The device type in DLDevice.
*/ */
#ifdef __cplusplus #ifdef __cplusplus
typedef enum: int32_t { typedef enum: int32_t
{
#else #else
typedef enum { typedef enum
{
#endif #endif
/*! \brief CPU device */ /*! \brief CPU device */
kDLCPU = 1, kDLCPU = 1,
...@@ -134,7 +136,8 @@ typedef struct { ...@@ -134,7 +136,8 @@ typedef struct {
/*! /*!
* \brief The type code options DLDataType. * \brief The type code options DLDataType.
*/ */
typedef enum { typedef enum
{
/*! \brief signed integer */ /*! \brief signed integer */
kDLInt = 0U, kDLInt = 0U,
/*! \brief unsigned integer */ /*! \brief unsigned integer */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment