Commit a5deda33 authored by PanZezhong's avatar PanZezhong
Browse files

support bf16

parent 4837543a
...@@ -85,6 +85,8 @@ class JiugeMetaFromLlama(JiugeMetaCStruct): ...@@ -85,6 +85,8 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
dt_ = DataType.INFINI_DTYPE_F16 dt_ = DataType.INFINI_DTYPE_F16
elif dtype == torch.float32: elif dtype == torch.float32:
dt_ = DataType.INFINI_DTYPE_F32 dt_ = DataType.INFINI_DTYPE_F32
elif dtype == torch.bfloat16:
dt_ = DataType.INFINI_DTYPE_BF16
else: else:
dt_ = DataType.INFINI_DTYPE_F16 dt_ = DataType.INFINI_DTYPE_F16
super().__init__( super().__init__(
...@@ -134,12 +136,16 @@ class JiugeWeightsImpl(JiugeWeightsCStruct): ...@@ -134,12 +136,16 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
self.dt_mat = DataType.INFINI_DTYPE_F16 self.dt_mat = DataType.INFINI_DTYPE_F16
elif torch_dt_mat == torch.float32: elif torch_dt_mat == torch.float32:
self.dt_mat = DataType.INFINI_DTYPE_F32 self.dt_mat = DataType.INFINI_DTYPE_F32
elif torch_dt_mat == torch.bfloat16:
self.dt_mat = DataType.INFINI_DTYPE_BF16
else: else:
raise ValueError("Unsupported proj weight data type") raise ValueError("Unsupported proj weight data type")
if torch_dt_norm == torch.float16: if torch_dt_norm == torch.float16:
self.dt_norm = DataType.INFINI_DTYPE_F16 self.dt_norm = DataType.INFINI_DTYPE_F16
elif torch_dt_norm == torch.float32: elif torch_dt_norm == torch.float32:
self.dt_norm = DataType.INFINI_DTYPE_F32 self.dt_norm = DataType.INFINI_DTYPE_F32
elif torch_dt_norm == torch.bfloat16:
self.dt_norm = DataType.INFINI_DTYPE_BF16
else: else:
raise ValueError("Unsupported norm weight data type") raise ValueError("Unsupported norm weight data type")
......
...@@ -142,6 +142,8 @@ inline std::shared_ptr<Tensor> getSinTable(JiugeMeta const *meta) { ...@@ -142,6 +142,8 @@ inline std::shared_ptr<Tensor> getSinTable(JiugeMeta const *meta) {
static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh)); static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh));
if (meta->dt_logits == INFINI_DTYPE_F16) { if (meta->dt_logits == INFINI_DTYPE_F16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin); ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin);
} else if (meta->dt_logits == INFINI_DTYPE_BF16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_sin);
} else if (meta->dt_logits == INFINI_DTYPE_F32) { } else if (meta->dt_logits == INFINI_DTYPE_F32) {
((float *)table)[i * half_dh + j] = _sin; ((float *)table)[i * half_dh + j] = _sin;
} else { } else {
...@@ -167,6 +169,8 @@ inline std::shared_ptr<Tensor> getCosTable(JiugeMeta const *meta) { ...@@ -167,6 +169,8 @@ inline std::shared_ptr<Tensor> getCosTable(JiugeMeta const *meta) {
static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh)); static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh));
if (meta->dt_logits == INFINI_DTYPE_F16) { if (meta->dt_logits == INFINI_DTYPE_F16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos); ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos);
} else if (meta->dt_logits == INFINI_DTYPE_BF16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_cos);
} else if (meta->dt_logits == INFINI_DTYPE_F32) { } else if (meta->dt_logits == INFINI_DTYPE_F32) {
((float *)table)[i * half_dh + j] = _cos; ((float *)table)[i * half_dh + j] = _cos;
} else { } else {
......
...@@ -234,6 +234,20 @@ void print_data(uint16_t const *data, const std::vector<size_t> &shape, ...@@ -234,6 +234,20 @@ void print_data(uint16_t const *data, const std::vector<size_t> &shape,
} }
} }
void print_data_bf16(uint16_t const *data, const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides, size_t dim) {
if (dim == shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) {
std::cout << bf16_to_f32(data[i * strides[dim]]) << " ";
}
std::cout << std::endl;
} else if (dim < shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) {
print_data(data + i * strides[dim], shape, strides, dim + 1);
}
}
}
std::string Tensor::info() const { std::string Tensor::info() const {
std::stringstream ss; std::stringstream ss;
...@@ -296,6 +310,10 @@ void Tensor::debug(const std::string &filename) const { ...@@ -296,6 +310,10 @@ void Tensor::debug(const std::string &filename) const {
print_data((int32_t const *)((char const *)cpu_data + dataOffset()), print_data((int32_t const *)((char const *)cpu_data + dataOffset()),
this->shape(), this->strides(), 0); this->shape(), this->strides(), 0);
break; break;
case INFINI_DTYPE_BF16:
print_data_bf16((uint16_t const *)((char const *)cpu_data + dataOffset()),
this->shape(), this->strides(), 0);
break;
default: default:
PANIC("Unsupported data type"); PANIC("Unsupported data type");
} }
......
...@@ -97,4 +97,26 @@ inline uint16_t f32_to_f16(float val) { ...@@ -97,4 +97,26 @@ inline uint16_t f32_to_f16(float val) {
} }
} }
inline float bf16_to_f32(uint16_t val) {
// 只需把 bf16 放到 float32 高 16 bit,其余 16 位置 0。
uint32_t bits32 = static_cast<uint32_t>(val) << 16;
float out;
std::memcpy(&out, &bits32, sizeof(out));
return out;
}
inline uint16_t f32_to_bf16(float val) {
uint32_t bits32;
std::memcpy(&bits32, &val, sizeof(bits32));
// 截断前先加 0x7FFF,再根据第 16 位(有效位的最低位)的奇偶做 round-to-nearest-even
const uint32_t rounding_bias = 0x00007FFF + // 0111 1111 1111 1111
((bits32 >> 16) & 1); // 尾数的有效位的最低位奇数时 +1,即实现舍入偶数
uint16_t bf16_bits = static_cast<uint16_t>((bits32 + rounding_bias) >> 16);
return bf16_bits;
}
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment