Commit a5deda33 authored by PanZezhong's avatar PanZezhong
Browse files

support bf16

parent 4837543a
......@@ -85,6 +85,8 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
dt_ = DataType.INFINI_DTYPE_F16
elif dtype == torch.float32:
dt_ = DataType.INFINI_DTYPE_F32
elif dtype == torch.bfloat16:
dt_ = DataType.INFINI_DTYPE_BF16
else:
dt_ = DataType.INFINI_DTYPE_F16
super().__init__(
......@@ -134,12 +136,16 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
self.dt_mat = DataType.INFINI_DTYPE_F16
elif torch_dt_mat == torch.float32:
self.dt_mat = DataType.INFINI_DTYPE_F32
elif torch_dt_mat == torch.bfloat16:
self.dt_mat = DataType.INFINI_DTYPE_BF16
else:
raise ValueError("Unsupported proj weight data type")
if torch_dt_norm == torch.float16:
self.dt_norm = DataType.INFINI_DTYPE_F16
elif torch_dt_norm == torch.float32:
self.dt_norm = DataType.INFINI_DTYPE_F32
elif torch_dt_norm == torch.bfloat16:
self.dt_norm = DataType.INFINI_DTYPE_BF16
else:
raise ValueError("Unsupported norm weight data type")
......
......@@ -142,6 +142,8 @@ inline std::shared_ptr<Tensor> getSinTable(JiugeMeta const *meta) {
static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh));
if (meta->dt_logits == INFINI_DTYPE_F16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin);
} else if (meta->dt_logits == INFINI_DTYPE_BF16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_sin);
} else if (meta->dt_logits == INFINI_DTYPE_F32) {
((float *)table)[i * half_dh + j] = _sin;
} else {
......@@ -167,6 +169,8 @@ inline std::shared_ptr<Tensor> getCosTable(JiugeMeta const *meta) {
static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh));
if (meta->dt_logits == INFINI_DTYPE_F16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos);
} else if (meta->dt_logits == INFINI_DTYPE_BF16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_cos);
} else if (meta->dt_logits == INFINI_DTYPE_F32) {
((float *)table)[i * half_dh + j] = _cos;
} else {
......
......@@ -234,6 +234,20 @@ void print_data(uint16_t const *data, const std::vector<size_t> &shape,
}
}
void print_data_bf16(uint16_t const *data, const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides, size_t dim) {
if (dim == shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) {
std::cout << bf16_to_f32(data[i * strides[dim]]) << " ";
}
std::cout << std::endl;
} else if (dim < shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) {
print_data(data + i * strides[dim], shape, strides, dim + 1);
}
}
}
std::string Tensor::info() const {
std::stringstream ss;
......@@ -296,6 +310,10 @@ void Tensor::debug(const std::string &filename) const {
print_data((int32_t const *)((char const *)cpu_data + dataOffset()),
this->shape(), this->strides(), 0);
break;
case INFINI_DTYPE_BF16:
print_data_bf16((uint16_t const *)((char const *)cpu_data + dataOffset()),
this->shape(), this->strides(), 0);
break;
default:
PANIC("Unsupported data type");
}
......
......@@ -97,4 +97,26 @@ inline uint16_t f32_to_f16(float val) {
}
}
inline float bf16_to_f32(uint16_t val) {
// 只需把 bf16 放到 float32 高 16 bit,其余 16 位置 0。
uint32_t bits32 = static_cast<uint32_t>(val) << 16;
float out;
std::memcpy(&out, &bits32, sizeof(out));
return out;
}
inline uint16_t f32_to_bf16(float val) {
uint32_t bits32;
std::memcpy(&bits32, &val, sizeof(bits32));
// 截断前先加 0x7FFF,再根据第 16 位(有效位的最低位)的奇偶做 round-to-nearest-even
const uint32_t rounding_bias = 0x00007FFF + // 0111 1111 1111 1111
((bits32 >> 16) & 1); // 尾数的有效位的最低位奇数时 +1,即实现舍入偶数
uint16_t bf16_bits = static_cast<uint16_t>((bits32 + rounding_bias) >> 16);
return bf16_bits;
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment