Unverified Commit 3295eac3 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Support turbomind bf16 (#803)

* Add bf16 template sp

* prepare merge

* add enable bf

* add bf16 decode attention support

* fix python lint

* fix yapf

* fix c format

* c format11

* fix cast

* fix on sm<80

* fix linux bf162 cast

* fix type cast

* fix lint

* support from hf pretrained

* fix pybind

* fix converter

* add trust remote code

* fix comment

* fix convert qwen

* fix lint

* fix baichuan

* update weight map
parent b190521b
......@@ -17,10 +17,10 @@ project(TurboMind LANGUAGES CXX CUDA)
find_package(CUDA 10.2 REQUIRED)
# if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11")
# add_definitions("-DENABLE_BF16")
# message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag")
# endif()
if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11")
add_definitions("-DENABLE_BF16")
message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag")
endif()
# if((${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11" AND ${CUDA_VERSION_MINOR} VERSION_GREATER_EQUAL "8") OR (${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "12"))
# add_definitions("-DENABLE_FP8")
......
......@@ -6,6 +6,7 @@ import shutil
from pathlib import Path
import fire
import torch
from huggingface_hub import snapshot_download
from lmdeploy.model import MODELS
......@@ -113,6 +114,55 @@ def copy_tokenizer(model_path: str, tokenizer_path: str,
osp.join(triton_models_path, 'tokenizer'))
def update_output_format(model_name: str, model_format: str, model_path: str,
output_format: str):
"""Update output format according to model info."""
TORCH_DTYPE_MAP = {torch.bfloat16: 'bf16'}
MODEL_NAME_MAP = {'qwen': 'bf16', 'llama': 'half'}
model_name = model_name.split('-')[0]
def _fix_device_support(output_format):
"""fix device support."""
if output_format == 'bf16':
if not torch.cuda.is_bf16_supported():
# device does not support bf16
print('Device does not support bf16.')
output_format = 'fp16'
return output_format
def _infer_output_format(config):
"""_infer_output_format."""
torch_dtype = getattr(config, 'torch_dtype', None)
if torch_dtype:
updated_output_format = TORCH_DTYPE_MAP.get(
torch_dtype, output_format)
else:
# get model name prefix
updated_output_format = MODEL_NAME_MAP.get(model_name,
output_format)
return _fix_device_support(updated_output_format)
if model_format in MODEL_NAME_MAP:
updated_output_format = MODEL_NAME_MAP.get(model_name, output_format)
return _fix_device_support(updated_output_format)
else:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
return _infer_output_format(config)
def update_config_weight_type(output_format: str,
config: TurbomindModelConfig):
WEIGHT_TYPE_MAP = {
'fp32': 'fp32',
'fp16': 'fp16',
'bf16': 'bf16',
'w4': 'int4',
'w8': 'int8'
}
config.weight_type = WEIGHT_TYPE_MAP[output_format]
def pack_model_repository(workspace_path: str):
"""package the model repository.
......@@ -215,6 +265,10 @@ def main(model_name: str,
cfg.weight_type = 'int4'
output_format = 'w4'
assert group_size > 0, f'group_size: {group_size} should > 0'
else:
output_format = update_output_format(model_name, inferred_model_format,
model_path, output_format)
update_config_weight_type(output_format, cfg)
# convert
print('model_name ', model_name)
......
......@@ -83,6 +83,14 @@ class TurbomindModelConfig:
return True
_WEIGHT_DTYPE_MAP = dict(
int4=torch.float16,
fp16=torch.float16,
fp32=torch.float16,
bf16=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
)
class BaseOutputModel(ABC):
"""Base output model."""
......@@ -136,19 +144,26 @@ class BaseOutputModel(ABC):
def export_weight(self, param: torch.Tensor, name: str) -> None:
"""export turbomind weight."""
def _tofile(tensor, path):
"""to file."""
if tensor.dtype == torch.bfloat16:
tensor = tensor.view(torch.half)
tensor.contiguous().cpu().numpy().tofile(path)
if self.to_file:
if param.dtype in [torch.float, torch.bfloat16]:
param = param.half()
torch_type = _WEIGHT_DTYPE_MAP.get(self.cfg.weight_type,
torch.float16)
param = param.to(torch_type)
tprint(name, param.shape)
param.contiguous().cpu().numpy().tofile(
osp.join(self.out_dir, name))
_tofile(param, osp.join(self.out_dir, name))
elif len(self.tm_params) > 0:
tm_params = self.tm_params
weight_type = self.cfg.weight_type
assert weight_type in ['fp16', 'fp32', 'int4']
assert weight_type in ['fp16', 'fp32', 'bf16', 'int4']
# currently, the tensor type should in
# [torch.float, torch.half, torch.int32]
# [torch.float, torch.half, torch.bfloat16, torch.int32]
torch_tensor = param.cuda().contiguous()
assert torch_tensor.dtype in [
torch.int32, torch.float, torch.half, torch.bfloat16
......@@ -156,6 +171,8 @@ class BaseOutputModel(ABC):
if torch_tensor.dtype != torch.int32:
if weight_type in ['fp16', 'int4']:
torch_tensor = torch_tensor.half()
elif weight_type == 'bf16':
torch_tensor = torch_tensor.bfloat16()
else:
torch_tensor = torch_tensor.float()
for tm_tensor in tm_params[name]:
......
......@@ -14,7 +14,7 @@ def transpose_tensor(input: List[torch.Tensor]):
return output
@OUTPUT_MODELS.register_module(name='fp16')
@OUTPUT_MODELS.register_module(name=['fp16', 'bf16'])
class TurbomindModel(BaseOutputModel):
"""Export to turbomind fp16 format."""
......
......@@ -22,7 +22,8 @@ from lmdeploy.model import MODELS, BaseModel
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger
from .deploy.converter import get_model_format, supported_formats
from .deploy.converter import (get_model_format, supported_formats,
update_config_weight_type, update_output_format)
from .deploy.source_model.base import INPUT_MODELS
from .deploy.target_model.base import OUTPUT_MODELS, TurbomindModelConfig
from .utils import (ModelSource, check_tm_model_input, create_hf_download_args,
......@@ -246,6 +247,12 @@ class TurboMind:
output_format = 'w4'
data_type = 'int4'
assert group_size > 0, f'group_size: {group_size} should > 0'
else:
output_format = update_output_format(model_name,
inferred_model_format,
model_path, output_format)
data_type = output_format
update_config_weight_type(output_format, cfg)
self.config = cfg
self.model_name = model_name
......
......@@ -117,7 +117,11 @@ struct ReluActivation<__nv_bfloat162> {
static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val)
{
const __nv_bfloat16 zero_bf16 = static_cast<__nv_bfloat16>(0.0f);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
return turbomind::make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16);
#else
return make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16);
#endif
}
};
#endif
......
......@@ -590,6 +590,39 @@ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
return b;
}
#ifdef ENABLE_BF16
template<>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a)
{
return cuda_cast<__nv_bfloat162, float2>(a);
}
template<>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, float4>(const float4& a)
{
bf16_4_t b;
float2 val;
val.x = a.x;
val.y = a.y;
b.x = vec_conversion<__nv_bfloat162, float2>(val);
val.x = a.z;
val.y = a.w;
b.y = vec_conversion<__nv_bfloat162, float2>(val);
return b;
}
template<>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
{
bf16_8_t b;
b.x = vec_conversion<__nv_bfloat162, float2>(a.x);
b.y = vec_conversion<__nv_bfloat162, float2>(a.y);
b.z = vec_conversion<__nv_bfloat162, float2>(a.z);
b.w = vec_conversion<__nv_bfloat162, float2>(a.w);
return b;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS_PER_KEY, typename K_vec, int N>
......@@ -1053,6 +1086,55 @@ inline __device__ int64_t quant(uint4 a, const float scale, const float zp)
int16[3] = quant(a.w, scale, zp);
return int64;
}
// bfloat16 to int8
inline __device__ int8_t quant(__nv_bfloat16 a, const float scale, const float zp)
{
int8_t int8;
float b = bfloat16_to_float(a);
int8 = round(max(-128.f, min(127.f, (b - zp) / scale)));
return int8;
}
// bfloat16x2 to int8x2
inline __device__ int16_t quant(__nv_bfloat162 a, const float scale, const float zp)
{
union {
int8_t int8[2];
short int16;
};
float2 b = bfloat162_to_float2(a);
int8[0] = round(max(-128.f, min(127.f, (b.x - zp) / scale)));
int8[1] = round(max(-128.f, min(127.f, (b.y - zp) / scale)));
return int16;
}
// bfloat16x4 to int8x4
inline __device__ int32_t quant(bf16_4_t a, const float scale, const float zp)
{
union {
int16_t int16[2];
int32_t int32;
};
int16[0] = quant(a.x, scale, zp);
int16[1] = quant(a.y, scale, zp);
return int32;
}
// bfloat16x8 to int8x8
inline __device__ int64_t quant(bf16_8_t a, const float scale, const float zp)
{
union {
int16_t int16[4];
int64_t int64;
};
int16[0] = quant(a.x, scale, zp);
int16[1] = quant(a.y, scale, zp);
int16[2] = quant(a.z, scale, zp);
int16[3] = quant(a.w, scale, zp);
return int64;
}
// int8 to float32, then `vec_conversion` to target format
inline __device__ float dequant(int8_t a, const float scale, const float zp)
{
......
......@@ -43,6 +43,24 @@ struct Float4_ {
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
struct bf16_4_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct bf16_8_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
__nv_bfloat162 z;
__nv_bfloat162 w;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct num_elems;
template<>
......@@ -79,6 +97,21 @@ struct num_elems<uint4> {
static constexpr int value = 8;
};
#ifdef ENABLE_BF16
template<>
struct num_elems<__nv_bfloat162> {
static constexpr int value = 2;
};
template<>
struct num_elems<bf16_4_t> {
static constexpr int value = 4;
};
template<>
struct num_elems<bf16_8_t> {
static constexpr int value = 8;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int N>
......@@ -144,6 +177,44 @@ inline __device__ float4 add(float4 a, float4 b)
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
{
return a + b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b)
{
bf16_4_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b)
{
bf16_8_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint16_t add(uint16_t a, uint16_t b)
{
uint16_t c;
......@@ -236,6 +307,26 @@ inline __device__ float2 half2_to_float2(uint32_t v)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float bfloat16_to_float(__nv_bfloat16 h)
{
return __bfloat162float(h);
// float f;
// asm volatile("cvt.f32.bf16 %0, %1;\n" : "=f"(f) : "h"(h));
// return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 bfloat162_to_float2(__nv_bfloat162 v)
{
return cuda_cast<float2, __nv_bfloat162>(v);
// __nv_bfloat16 lo, hi;
// asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
// return make_float2(bfloat16_to_float(lo), bfloat16_to_float(hi));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float add(float a, uint16_t b)
{
return a + half_to_float(b);
......@@ -243,6 +334,15 @@ inline __device__ float add(float a, uint16_t b)
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float add(float a, __nv_bfloat16 b)
{
return a + __bfloat162float(b);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 add(uint32_t a, float2 fb)
{
float2 fa = half2_to_float2(a);
......@@ -367,6 +467,37 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c)
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float2 add(__nv_bfloat162 a, float2 fb)
{
float2 fa = bf1622float2(a);
return add(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb)
{
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb)
{
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c)
{
uint32_t d;
......@@ -498,6 +629,134 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc)
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(a, b, c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(bf162bf162(a), b, c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c)
{
bf16_4_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c)
{
bf16_8_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc)
{
return __bfloat162float(a) * __bfloat162float(b) + fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc)
{
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return fma(fa, fb, fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc)
{
return fma(bf162bf162(a), b, fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc)
{
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc)
{
__nv_bfloat162 s = bf162bf162(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc)
{
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc)
{
__nv_bfloat162 s = bf162bf162(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B>
......@@ -746,6 +1005,171 @@ inline __device__ float sum(float v)
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hmul(a, b);
#else
return bf16hmul(a, b);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hmul2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b)
{
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b)
{
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b)
{
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b)
{
float fa = (float)a;
float fb = (float)b;
return fa * fb;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(__nv_bfloat16 a, float b)
{
return __bfloat162float(a) * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b)
{
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b)
{
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b)
{
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b)
{
__nv_bfloat162 s = bf162bf162(a);
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b)
{
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b)
{
__nv_bfloat162 s = bf162bf162(a);
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return fc;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float2 v)
{
return v.x + v.y;
......@@ -758,6 +1182,31 @@ inline __device__ float sum(float4 v)
return v.x + v.y + v.z + v.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float sum(__nv_bfloat162 v)
{
float2 vf = bf1622float2(v);
return vf.x + vf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(bf16_4_t v)
{
return sum(v.x) + sum(v.y);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(bf16_8_t v)
{
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint16_t v)
{
return half_to_float(v);
......@@ -1048,4 +1497,85 @@ inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int r
k.w = rotary_embedding_transform(k.w, coef3);
}
#ifdef ENABLE_BF16
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, float base, int t_step)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, t_step);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void
apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, float base, int t_step)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, t_step);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, float base, int t_step)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
}
inline __device__ void
apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, float base, int t_step)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, float base, int t_step)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, t_step);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, t_step);
q.w = rotary_embedding_transform(q.w, coef3);
}
inline __device__ void
apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, float base, int t_step)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, t_step);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, t_step);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
#endif // ENABLE_BF16
} // namespace mmha
......@@ -3,6 +3,7 @@
#pragma once
#include "src/turbomind/kernels/gemm_s_f16/common.h"
#include "src/turbomind/utils/cuda_bf16_wrapper.h"
#include <cfloat>
#include <limits>
......@@ -487,4 +488,34 @@ struct ConvertKvCache<int8_t, half> {
}
};
#ifdef ENABLE_BF16
template<>
struct ConvertKvCache<int8_t, __nv_bfloat16> {
float scale_;
float zero_;
__device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero)
{
zero_ = zero_ - 32896.f * scale_;
}
template<int N>
inline __device__ auto operator()(const Array<int8_t, N>& vi) const -> Array<__nv_bfloat16, N>
{
Array<__nv_bfloat16, N> vo;
PRAGMA_UNROLL
for (int i = 0; i < N; i += 4) {
auto& vec = (Array<__nv_bfloat16, 4>&)vo[i];
auto tmp = fast_i2f_f32_s8((const Array<int8_t, 4>&)vi[i]);
PRAGMA_UNROLL
for (int j = 0; j < 4; ++j) {
vec[j] = __nv_bfloat16(tmp[j] * scale_ + zero_);
// vec[j] = half(tmp[j] * scale_ + (zero_ - 32896.f * scale_));
}
}
return vo;
}
};
#endif // ENABLE_BF16
} // namespace turbomind
......@@ -107,9 +107,26 @@ void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>&
invokeDecoderMultiheadAttention<T, T, HeadDim, 1>(params);
}
}
#ifdef ENABLE_BF16
if constexpr (std::is_same_v<T, __nv_bfloat16>) {
int group_size = params.num_heads / params.num_kv_heads;
if (group_size % 4 == 0) {
invokeDecoderMultiheadAttention<T, T, HeadDim, 4>(params);
}
else if (group_size % 2 == 0) {
invokeDecoderMultiheadAttention<T, T, HeadDim, 2>(params);
}
else {
invokeDecoderMultiheadAttention<T, T, HeadDim, 1>(params);
}
}
#endif // ENABLE_BF16
}
template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<half>& params);
template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<float>& params);
#ifdef ENABLE_BF16
template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<__nv_bfloat16>& params);
#endif // ENABLE_BF16
} // namespace turbomind
......@@ -477,5 +477,21 @@ template void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
int quant_policy,
const float* kv_params,
cudaStream_t st);
#ifdef ENABLE_BF16
template void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
const void** src_v_block_ptrs,
__nv_bfloat16** dst_k_ptrs,
__nv_bfloat16** dst_v_ptrs,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_block_len,
int head_num,
int head_dim,
int batch_size,
int quant_policy,
const float* kv_params,
cudaStream_t st);
#endif
} // namespace turbomind
......@@ -1631,5 +1631,8 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
template class LlamaBatch<half>;
template class LlamaBatch<float>;
#ifdef ENABLE_BF16
template class LlamaBatch<__nv_bfloat16>;
#endif
} // namespace turbomind
......@@ -377,5 +377,8 @@ TensorMap LlamaDecoderLayerWeight<T>::getParams(std::string prefix)
template struct LlamaDecoderLayerWeight<float>;
template struct LlamaDecoderLayerWeight<half>;
#ifdef ENABLE_BF16
template struct LlamaDecoderLayerWeight<__nv_bfloat16>;
#endif
} // namespace turbomind
......@@ -30,6 +30,7 @@ enum class WeightType : int
kFP32,
kFP16,
kFP8, // not supported yet
kBF16,
kINT8,
kINT4
};
......@@ -43,6 +44,8 @@ inline size_t getBitSize(WeightType type)
return 16;
case WeightType::kFP8:
return 8;
case WeightType::kBF16:
return 16;
case WeightType::kINT8:
return 8;
case WeightType::kINT4:
......
......@@ -125,5 +125,8 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
template class LlamaFfnLayer<float>;
template class LlamaFfnLayer<half>;
#ifdef ENABLE_BF16
template class LlamaFfnLayer<__nv_bfloat16>;
#endif
} // namespace turbomind
......@@ -31,6 +31,7 @@ public:
switch (weight.type) {
case WeightType::kFP16:
case WeightType::kFP32:
case WeightType::kBF16:
forwardFp(output_data, input_data, batch_size, weight, type);
break;
case WeightType::kINT4:
......
......@@ -510,5 +510,8 @@ void LlamaV2<T>::forward(std::unordered_map<std::string, Tensor>* outputs,
template class LlamaV2<half>;
template class LlamaV2<float>;
#ifdef ENABLE_BF16
template class LlamaV2<__nv_bfloat16>;
#endif
} // namespace turbomind
......@@ -90,6 +90,9 @@ template<typename T>
void LlamaWeight<T>::loadModel(std::string dir_path)
{
FtCudaDataType model_file_type = FtCudaDataType::FP16;
if(weight_type_ == WeightType::kBF16){
model_file_type = FtCudaDataType::BF16;
}
dir_path += '/';
loadWeightFromBin((T*)pre_decoder_embedding_table,
......@@ -140,5 +143,8 @@ TensorMap LlamaWeight<T>::getParams()
template struct LlamaWeight<float>;
template struct LlamaWeight<half>;
#ifdef ENABLE_BF16
template struct LlamaWeight<__nv_bfloat16>;
#endif
} // namespace turbomind
......@@ -7,6 +7,7 @@ add_library(${PROJECT_NAME} STATIC
# flash_fwd_hdim32_fp16_sm80.cu
# flash_fwd_hdim64_fp16_sm80.cu
flash_fwd_hdim128_fp16_sm80.cu
flash_fwd_hdim128_bf16_sm80.cu
# flash_fwd_hdim256_fp16_sm80.cu
)
target_include_directories(${PROJECT_NAME} PRIVATE ${CUTLASS_DIR} / include)
......
......@@ -12,7 +12,7 @@
void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream)
{
FP16_SWITCH(true,
FP16_SWITCH(!params.is_bf16,
[&] { FWD_HEADDIM_SWITCH(params.d, [&] { run_mha_fwd_<elem_type, kHeadDim>(params, stream); }); });
}
......@@ -126,7 +126,11 @@ public:
fwd_params.blockmask = reinterpret_cast<void*>(params.mask);
fwd_params.is_bf16 = false;
#ifdef ENABLE_BF16
fwd_params.is_bf16 = std::is_same<T, __nv_bfloat16>::value;
#else
fwd_params.is_bf16 = false;
#endif
fwd_params.is_causal = true;
fwd_params.q_enable_seqlen = params.layout_q.use_seqlens;
......@@ -163,5 +167,8 @@ void FlashAttentionOpImpl<T, FMHA_VERSION>::operator()(Params& params, cudaStrea
template class FlashAttentionOpImpl<float, FMHA_VERSION>;
template class FlashAttentionOpImpl<half, FMHA_VERSION>;
#ifdef ENABLE_BF16
template class FlashAttentionOpImpl<__nv_bfloat16, FMHA_VERSION>;
#endif
} // namespace turbomind
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment