Unverified Commit 3295eac3 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Support turbomind bf16 (#803)

* Add bf16 template sp

* prepare merge

* add enable bf

* add bf16 decode attention support

* fix python lint

* fix yapf

* fix c format

* c format11

* fix cast

* fix on sm<80

* fix linux bf162 cast

* fix type cast

* fix lint

* support from hf pretrained

* fix pybind

* fix converter

* add trust remote code

* fix comment

* fix convert qwen

* fix lint

* fix baichuan

* update weight map
parent b190521b
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
#ifdef ENABLE_BF16
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
}
#endif
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
#ifdef ENABLE_BF16
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
}
#endif
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
#ifdef ENABLE_BF16
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
}
#endif
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
#ifdef ENABLE_BF16
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
}
#endif
...@@ -13,6 +13,27 @@ ...@@ -13,6 +13,27 @@
namespace turbomind { namespace turbomind {
template<typename T>
struct ToCutlassType_ {
};
template<>
struct ToCutlassType_<float> {
using Type = float;
};
template<>
struct ToCutlassType_<half> {
using Type = cutlass::half_t;
};
#ifdef ENABLE_BF16
template<>
struct ToCutlassType_<__nv_bfloat16> {
using Type = cutlass::bfloat16_t;
};
#endif
template< template<
// dtype of Q/K/V/M // dtype of Q/K/V/M
typename Element_, typename Element_,
...@@ -655,8 +676,7 @@ void invokeFlashAttention_impl(int batc ...@@ -655,8 +676,7 @@ void invokeFlashAttention_impl(int batc
auto layout_o = attention_params.layout_o; auto layout_o = attention_params.layout_o;
auto group_size = attention_params.group_size; auto group_size = attention_params.group_size;
using scalar_t = using scalar_t = typename ToCutlassType_<T>::Type;
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
const float qk_scale = static_cast<float>(1.f / sqrtf(size_per_head * 1.f)); const float qk_scale = static_cast<float>(1.f / sqrtf(size_per_head * 1.f));
...@@ -742,8 +762,7 @@ void invokeFlashAttention_impl(int batc ...@@ -742,8 +762,7 @@ void invokeFlashAttention_impl(int batc
template<typename T, int kQueriesPerBlock, int kKeysPerBlock> template<typename T, int kQueriesPerBlock, int kKeysPerBlock>
bool get_needs_accum_buffer() bool get_needs_accum_buffer()
{ {
using scalar_t = using scalar_t = typename ToCutlassType_<T>::Type;
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
#define GET_NEED_ACCUM_BUFFER(sm) \ #define GET_NEED_ACCUM_BUFFER(sm) \
ATTENTION_KERNEL(scalar_t, sm, kQueriesPerBlock, kKeysPerBlock, false)::kNeedsOutputAccumulatorBuffer ATTENTION_KERNEL(scalar_t, sm, kQueriesPerBlock, kKeysPerBlock, false)::kNeedsOutputAccumulatorBuffer
...@@ -774,8 +793,7 @@ void invoke_attention_impl(bool single_v ...@@ -774,8 +793,7 @@ void invoke_attention_impl(bool single_v
typename FlashAttentionOpImpl<T, 1>::Params& params, typename FlashAttentionOpImpl<T, 1>::Params& params,
cudaStream_t st) cudaStream_t st)
{ {
using scalar_t = using scalar_t = typename ToCutlassType_<T>::Type;
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
#define INVOKE_ATTEN_IMPL(sm, single_value) \ #define INVOKE_ATTEN_IMPL(sm, single_value) \
{ \ { \
...@@ -836,9 +854,8 @@ class FlashAttentionOpImpl<T, 1>::impl { ...@@ -836,9 +854,8 @@ class FlashAttentionOpImpl<T, 1>::impl {
private: private:
static constexpr int kQueriesPerBlock = 32; static constexpr int kQueriesPerBlock = 32;
static constexpr int kKeysPerBlock = 128; static constexpr int kKeysPerBlock = 128;
using scalar_t = using scalar_t = typename ToCutlassType_<T>::Type;
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>; using Params = typename FlashAttentionOpImpl<T, 1>::Params;
using Params = typename FlashAttentionOpImpl<T, 1>::Params;
int batch_size_; int batch_size_;
int head_num_; int head_num_;
...@@ -909,5 +926,8 @@ void FlashAttentionOpImpl<T, 1>::operator()(Params& params, cudaStream_t st) con ...@@ -909,5 +926,8 @@ void FlashAttentionOpImpl<T, 1>::operator()(Params& params, cudaStream_t st) con
template class FlashAttentionOpImpl<float, 1>; template class FlashAttentionOpImpl<float, 1>;
template class FlashAttentionOpImpl<half, 1>; template class FlashAttentionOpImpl<half, 1>;
#ifdef ENABLE_BF16
template class FlashAttentionOpImpl<__nv_bfloat16, 1>;
#endif // ENABLE_BF16
} // namespace turbomind } // namespace turbomind
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "src/turbomind/macro.h" #include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/llama_decoder_kernels.h" #include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> #include <cooperative_groups/reduce.h>
...@@ -83,6 +84,32 @@ struct res_norm_ops_t<float> { ...@@ -83,6 +84,32 @@ struct res_norm_ops_t<float> {
} }
}; };
#ifdef ENABLE_BF16
template<>
struct res_norm_ops_t<__nv_bfloat16> {
__device__ float2 cast(const uint& x) const
{
return cuda_cast<float2, __nv_bfloat162>(reinterpret_cast<const __nv_bfloat162&>(x));
}
__device__ uint cast(const float2& x) const
{
auto y = cuda_cast<__nv_bfloat162, float2>(x);
return reinterpret_cast<uint&>(y);
}
__device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const
{
float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y};
accum += c.x * c.x + c.y * c.y;
return c;
}
__device__ float2 norm(const float2& a, const float2& s, float factor) const
{
return {a.x * s.x * factor, a.y * s.y * factor};
}
};
#endif
template<typename T> template<typename T>
__device__ T blockReduceSum(const cg::thread_block& block, T value) __device__ T blockReduceSum(const cg::thread_block& block, T value)
{ {
...@@ -164,5 +191,8 @@ void invokeFusedAddBiasResidualRMSNorm( ...@@ -164,5 +191,8 @@ void invokeFusedAddBiasResidualRMSNorm(
template void template void
invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t); invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t);
template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t); template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t);
#ifdef ENABLE_BF16
template void invokeFusedAddBiasResidualRMSNorm(
__nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t);
#endif
} // namespace turbomind } // namespace turbomind
...@@ -90,6 +90,10 @@ void invokeRootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, ...@@ -90,6 +90,10 @@ void invokeRootMeanSquareNorm(T* out, const T* input, const T* scale, float eps,
template void invokeRootMeanSquareNorm(float*, const float*, const float*, float, int, int, cudaStream_t); template void invokeRootMeanSquareNorm(float*, const float*, const float*, float, int, int, cudaStream_t);
template void invokeRootMeanSquareNorm(half*, const half*, const half*, float, int, int, cudaStream_t); template void invokeRootMeanSquareNorm(half*, const half*, const half*, float, int, int, cudaStream_t);
#ifdef ENABLE_BF16
template void
invokeRootMeanSquareNorm(__nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t);
#endif
// #ifdef ENABLE_BF16 // #ifdef ENABLE_BF16
...@@ -208,6 +212,23 @@ void invokeCreateCausalMasks( ...@@ -208,6 +212,23 @@ void invokeCreateCausalMasks(
template void invokeCreateCausalMasks(float* mask, const int*, const int*, int, int, int, cudaStream_t); template void invokeCreateCausalMasks(float* mask, const int*, const int*, int, int, int, cudaStream_t);
template void invokeCreateCausalMasks(half* mask, const int*, const int*, int, int, int, cudaStream_t); template void invokeCreateCausalMasks(half* mask, const int*, const int*, int, int, int, cudaStream_t);
#ifdef ENABLE_BF16
template<>
__global__ void createCausalMasks<__nv_bfloat16>(
__nv_bfloat16* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len)
{
const auto q_len = q_lens[blockIdx.x];
const auto k_len = k_lens[blockIdx.x];
mask += blockIdx.x * max_q_len * max_k_len;
for (int i = threadIdx.x; i < max_q_len * max_k_len; i += blockDim.x) {
const int q = i / max_k_len; // [0, max_q_len)
const int k = i % max_k_len; // [0, max_k_len)
bool is_valid = q < q_len && k < k_len && k <= q + (k_len - q_len);
mask[i] = static_cast<__nv_bfloat16>(float(is_valid));
}
}
template void invokeCreateCausalMasks(__nv_bfloat16* mask, const int*, const int*, int, int, int, cudaStream_t);
#endif
template<typename Ti, typename To> template<typename Ti, typename To>
struct ExtendKvCache { struct ExtendKvCache {
...@@ -377,6 +398,24 @@ template void invokeExtendKVCache(void** k_dst_ptrs, ...@@ -377,6 +398,24 @@ template void invokeExtendKVCache(void** k_dst_ptrs,
int quant, int quant,
const float* kv_scale, const float* kv_scale,
cudaStream_t stream); cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeExtendKVCache(void** k_dst_ptrs,
void** v_dst_ptrs,
const __nv_bfloat16* k_src,
const __nv_bfloat16* v_src,
const int* cu_block_counts,
const int* query_length,
const int* history_length,
int batch_size,
int block_length,
size_t dst_layer_offset,
int max_q_len,
int head_dim,
int head_num,
int quant,
const float* kv_scale,
cudaStream_t stream);
#endif
template<typename Ti, typename To> template<typename Ti, typename To>
struct TransposeKvCache { struct TransposeKvCache {
...@@ -527,6 +566,23 @@ template void invokeTransposeKVCache(half*, ...@@ -527,6 +566,23 @@ template void invokeTransposeKVCache(half*,
cudaStream_t stream, cudaStream_t stream,
int, int,
const float*); const float*);
#ifdef ENABLE_BF16
template void invokeTransposeKVCache(__nv_bfloat16*,
__nv_bfloat16*,
const __nv_bfloat16**,
const __nv_bfloat16**,
size_t,
int,
const int*,
int,
int,
int,
int,
int,
cudaStream_t stream,
int,
const float*);
#endif
__global__ void gatherOutput(int* output_ids, __global__ void gatherOutput(int* output_ids,
const int* ids, const int* ids,
...@@ -776,6 +832,9 @@ void invokeGetFeatureOfLastToken( ...@@ -776,6 +832,9 @@ void invokeGetFeatureOfLastToken(
template void invokeGetFeatureOfLastToken(half*, const half*, const int*, int, int, cudaStream_t); template void invokeGetFeatureOfLastToken(half*, const half*, const int*, int, int, cudaStream_t);
template void invokeGetFeatureOfLastToken(float*, const float*, const int*, int, int, cudaStream_t); template void invokeGetFeatureOfLastToken(float*, const float*, const int*, int, int, cudaStream_t);
#ifdef ENABLE_BF16
template void invokeGetFeatureOfLastToken(__nv_bfloat16*, const __nv_bfloat16*, const int*, int, int, cudaStream_t);
#endif // ENABLE_BF16
template<class T, int C> template<class T, int C>
struct BatchedCopyParam { struct BatchedCopyParam {
...@@ -866,7 +925,7 @@ FlashAttentionOp<T>::FlashAttentionOp(int batch_size, int head_num, int key_len, ...@@ -866,7 +925,7 @@ FlashAttentionOp<T>::FlashAttentionOp(int batch_size, int head_num, int key_len,
#ifdef _MSC_VER #ifdef _MSC_VER
op_version_ = 1; op_version_ = 1;
#else #else
op_version_ = std::is_same<half, typename std::decay<T>::type>::value ? 2 : 1; op_version_ = std::is_same<float, typename std::decay<T>::type>::value ? 1 : 2;
if (op_version_ == 2 && getSMVersion() < 80) { if (op_version_ == 2 && getSMVersion() < 80) {
op_version_ = 1; op_version_ = 1;
} }
...@@ -903,5 +962,8 @@ void FlashAttentionOp<T>::operator()(Params& params, cudaStream_t st) const ...@@ -903,5 +962,8 @@ void FlashAttentionOp<T>::operator()(Params& params, cudaStream_t st) const
template class FlashAttentionOp<float>; template class FlashAttentionOp<float>;
template class FlashAttentionOp<half>; template class FlashAttentionOp<half>;
#ifdef ENABLE_BF16
template class FlashAttentionOp<__nv_bfloat16>;
#endif
} // namespace turbomind } // namespace turbomind
...@@ -626,5 +626,8 @@ void UnifiedAttentionLayer<T>::unfusedMultiHeadAttention(T* output, ...@@ -626,5 +626,8 @@ void UnifiedAttentionLayer<T>::unfusedMultiHeadAttention(T* output,
template class UnifiedAttentionLayer<float>; template class UnifiedAttentionLayer<float>;
template class UnifiedAttentionLayer<half>; template class UnifiedAttentionLayer<half>;
#ifdef ENABLE_BF16
template class UnifiedAttentionLayer<__nv_bfloat16>;
#endif // ENABLE_BF16
} // namespace turbomind } // namespace turbomind
...@@ -261,5 +261,8 @@ void UnifiedDecoder<T>::forward(TensorMap* outputs, const TensorMap* inputs, con ...@@ -261,5 +261,8 @@ void UnifiedDecoder<T>::forward(TensorMap* outputs, const TensorMap* inputs, con
template class UnifiedDecoder<float>; template class UnifiedDecoder<float>;
template class UnifiedDecoder<half>; template class UnifiedDecoder<half>;
#ifdef ENABLE_BF16
template class UnifiedDecoder<__nv_bfloat16>;
#endif // ENABLE_BF16
} // namespace turbomind } // namespace turbomind
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <pybind11/pytypes.h> #include <pybind11/pytypes.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <pybind11/stl_bind.h> #include <pybind11/stl_bind.h>
#include <stdexcept>
namespace py = pybind11; namespace py = pybind11;
namespace ft = turbomind; namespace ft = turbomind;
...@@ -23,12 +24,6 @@ using TensorMap = std::unordered_map<std::string, triton::Tensor>; ...@@ -23,12 +24,6 @@ using TensorMap = std::unordered_map<std::string, triton::Tensor>;
PYBIND11_MAKE_OPAQUE(TensorMap); PYBIND11_MAKE_OPAQUE(TensorMap);
static const char kDlTensorCapsuleName[] = "dltensor"; static const char kDlTensorCapsuleName[] = "dltensor";
template<typename T>
std::shared_ptr<T> make_shared_nodel(T data)
{
return std::shared_ptr<T>(&data, [](T*) {});
}
DLDevice getDLDevice(triton::Tensor& tensor) DLDevice getDLDevice(triton::Tensor& tensor)
{ {
int device_id = 0; int device_id = 0;
...@@ -46,6 +41,7 @@ DLDevice getDLDevice(triton::Tensor& tensor) ...@@ -46,6 +41,7 @@ DLDevice getDLDevice(triton::Tensor& tensor)
break; break;
case triton::MEMORY_CPU_PINNED: case triton::MEMORY_CPU_PINNED:
device.device_type = DLDeviceType::kDLCUDAHost; device.device_type = DLDeviceType::kDLCUDAHost;
break;
case triton::MEMORY_GPU: case triton::MEMORY_GPU:
device.device_type = DLDeviceType::kDLCUDA; device.device_type = DLDeviceType::kDLCUDA;
break; break;
...@@ -132,12 +128,11 @@ std::unique_ptr<DLManagedTensor> TritonTensorToDLManagedTensor(triton::Tensor& t ...@@ -132,12 +128,11 @@ std::unique_ptr<DLManagedTensor> TritonTensorToDLManagedTensor(triton::Tensor& t
triton::MemoryType getMemoryType(DLDevice device) triton::MemoryType getMemoryType(DLDevice device)
{ {
switch (device.device_type) { switch (device.device_type) {
case DLDeviceType::kDLCPU:
return triton::MemoryType::MEMORY_CPU;
case DLDeviceType::kDLCUDAHost: case DLDeviceType::kDLCUDAHost:
return triton::MemoryType::MEMORY_CPU_PINNED; return triton::MemoryType::MEMORY_CPU_PINNED;
case DLDeviceType::kDLCUDA: case DLDeviceType::kDLCUDA:
return triton::MemoryType::MEMORY_GPU; return triton::MemoryType::MEMORY_GPU;
case DLDeviceType::kDLCPU:
default: default:
return triton::MemoryType::MEMORY_CPU; return triton::MemoryType::MEMORY_CPU;
} }
...@@ -289,17 +284,21 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -289,17 +284,21 @@ PYBIND11_MODULE(_turbomind, m)
DLManagedTensor* dlmt = DLManagedTensor* dlmt =
static_cast<DLManagedTensor*>(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName)); static_cast<DLManagedTensor*>(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName));
auto src = DLManagedTensorToTritonTensor(dlmt); auto src = DLManagedTensorToTritonTensor(dlmt);
if (self->type == triton::TYPE_FP16 || self->type == triton::TYPE_FP32 switch (self->type) {
|| self->type == triton::TYPE_INT32) { case triton::TYPE_FP16:
auto num_element = case triton::TYPE_FP32:
std::accumulate(src->shape.begin(), src->shape.end(), 1LL, std::multiplies<int64_t>()); case triton::TYPE_INT32:
auto num_bytes = num_element * dlmt->dl_tensor.dtype.bits / 8; case triton::TYPE_BF16: {
ft::FT_CHECK(self->shape.size() == 1 && num_bytes == self->shape[0]); auto num_element =
cudaMemcpy( std::accumulate(src->shape.begin(), src->shape.end(), 1LL, std::multiplies<int64_t>());
const_cast<void*>(self->data), const_cast<void*>(src->data), num_bytes, cudaMemcpyDefault); auto num_bytes = num_element * dlmt->dl_tensor.dtype.bits / 8;
} ft::FT_CHECK(self->shape.size() == 1 && num_bytes == self->shape[0]);
else { cudaMemcpy(
ft::FT_CHECK(0); const_cast<void*>(self->data), const_cast<void*>(src->data), num_bytes, cudaMemcpyDefault);
break;
}
default:
ft::FT_CHECK(0);
} }
}, },
"tensor"_a) "tensor"_a)
...@@ -380,6 +379,16 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -380,6 +379,16 @@ PYBIND11_MODULE(_turbomind, m)
model->setFfiLock(gil_control); model->setFfiLock(gil_control);
return model; return model;
} }
else if (data_type == "bf16") {
#ifdef ENABLE_BF16
auto model = std::make_shared<LlamaTritonModel<__nv_bfloat16>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config);
model->setFfiLock(gil_control);
return model;
#else
throw std::runtime_error("Error: turbomind has not been built with bf16 support.");
#endif
}
else { else {
auto model = std::make_shared<LlamaTritonModel<float>>( auto model = std::make_shared<LlamaTritonModel<float>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config); tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir, config);
......
...@@ -47,6 +47,18 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM ...@@ -47,6 +47,18 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM
reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0),
model_dir); model_dir);
} }
else if (data_type == "bf16") {
#ifdef ENABLE_BF16
return std::make_shared<LlamaTritonModel<__nv_bfloat16>>(
reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"),
reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"),
reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0),
model_dir);
#else
TM_LOG_ERROR("[ERROR] Turbomind is not built with ENABLE_BF16");
ft::FT_CHECK(false);
#endif
}
else { else {
return std::make_shared<LlamaTritonModel<float>>( return std::make_shared<LlamaTritonModel<float>>(
reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"),
...@@ -205,6 +217,9 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size, ...@@ -205,6 +217,9 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
if (weight_type_str == "fp16") { if (weight_type_str == "fp16") {
weight_type_ = ft::WeightType::kFP16; weight_type_ = ft::WeightType::kFP16;
} }
else if (weight_type_str == "bf16") {
weight_type_ = ft::WeightType::kBF16;
}
else if (weight_type_str == "fp32") { else if (weight_type_str == "fp32") {
weight_type_ = ft::WeightType::kFP32; weight_type_ = ft::WeightType::kFP32;
} }
...@@ -260,6 +275,11 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh ...@@ -260,6 +275,11 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
else if (std::is_same<T, float>::value) { else if (std::is_same<T, float>::value) {
cublas_wrapper->setFP32GemmConfig(); cublas_wrapper->setFP32GemmConfig();
} }
#ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) {
cublas_wrapper->setBF16GemmConfig();
}
#endif
ft::NcclParam tensor_para = nccl_params.first[comms_rank]; ft::NcclParam tensor_para = nccl_params.first[comms_rank];
ft::NcclParam pipeline_para = nccl_params.second[comms_rank]; ft::NcclParam pipeline_para = nccl_params.second[comms_rank];
...@@ -449,3 +469,6 @@ int LlamaTritonModel<T>::getPipelineParaSize() ...@@ -449,3 +469,6 @@ int LlamaTritonModel<T>::getPipelineParaSize()
template struct LlamaTritonModel<float>; template struct LlamaTritonModel<float>;
template struct LlamaTritonModel<half>; template struct LlamaTritonModel<half>;
#ifdef ENABLE_BF16
template struct LlamaTritonModel<__nv_bfloat16>;
#endif
...@@ -244,3 +244,6 @@ void LlamaTritonModelInstance<T>::freeBuffer() ...@@ -244,3 +244,6 @@ void LlamaTritonModelInstance<T>::freeBuffer()
template struct LlamaTritonModelInstance<float>; template struct LlamaTritonModelInstance<float>;
template struct LlamaTritonModelInstance<half>; template struct LlamaTritonModelInstance<half>;
#ifdef ENABLE_BF16
template struct LlamaTritonModelInstance<__nv_bfloat16>;
#endif
...@@ -507,12 +507,12 @@ __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) ...@@ -507,12 +507,12 @@ __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
template<> template<>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) __device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{ {
return fabs(val); return fabs(cuda_cast<float>(val));
} }
template<> template<>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{ {
return make_bfloat162(fabs(val.x), fabs(val.y)); return make_bfloat162(fabs(cuda_cast<float>(val.x)), fabs(cuda_cast<float>(val.y)));
} }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment