Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
a5deda33
Commit
a5deda33
authored
Jul 09, 2025
by
PanZezhong
Browse files
support bf16
parent
4837543a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
0 deletions
+50
-0
scripts/jiuge.py
scripts/jiuge.py
+6
-0
src/models/jiuge/jiuge_weight.hpp
src/models/jiuge/jiuge_weight.hpp
+4
-0
src/tensor/tensor.cpp
src/tensor/tensor.cpp
+18
-0
src/utils.hpp
src/utils.hpp
+22
-0
No files found.
scripts/jiuge.py
View file @
a5deda33
...
...
@@ -85,6 +85,8 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
dt_
=
DataType
.
INFINI_DTYPE_F16
elif
dtype
==
torch
.
float32
:
dt_
=
DataType
.
INFINI_DTYPE_F32
elif
dtype
==
torch
.
bfloat16
:
dt_
=
DataType
.
INFINI_DTYPE_BF16
else
:
dt_
=
DataType
.
INFINI_DTYPE_F16
super
().
__init__
(
...
...
@@ -134,12 +136,16 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
self
.
dt_mat
=
DataType
.
INFINI_DTYPE_F16
elif
torch_dt_mat
==
torch
.
float32
:
self
.
dt_mat
=
DataType
.
INFINI_DTYPE_F32
elif
torch_dt_mat
==
torch
.
bfloat16
:
self
.
dt_mat
=
DataType
.
INFINI_DTYPE_BF16
else
:
raise
ValueError
(
"Unsupported proj weight data type"
)
if
torch_dt_norm
==
torch
.
float16
:
self
.
dt_norm
=
DataType
.
INFINI_DTYPE_F16
elif
torch_dt_norm
==
torch
.
float32
:
self
.
dt_norm
=
DataType
.
INFINI_DTYPE_F32
elif
torch_dt_norm
==
torch
.
bfloat16
:
self
.
dt_norm
=
DataType
.
INFINI_DTYPE_BF16
else
:
raise
ValueError
(
"Unsupported norm weight data type"
)
...
...
src/models/jiuge/jiuge_weight.hpp
View file @
a5deda33
...
...
@@ -142,6 +142,8 @@ inline std::shared_ptr<Tensor> getSinTable(JiugeMeta const *meta) {
static_cast
<
float
>
(
i
)
/
std
::
pow
(
meta
->
theta
,
static_cast
<
float
>
(
j
)
/
half_dh
));
if
(
meta
->
dt_logits
==
INFINI_DTYPE_F16
)
{
((
uint16_t
*
)
table
)[
i
*
half_dh
+
j
]
=
f32_to_f16
(
_sin
);
}
else
if
(
meta
->
dt_logits
==
INFINI_DTYPE_BF16
)
{
((
uint16_t
*
)
table
)[
i
*
half_dh
+
j
]
=
f32_to_bf16
(
_sin
);
}
else
if
(
meta
->
dt_logits
==
INFINI_DTYPE_F32
)
{
((
float
*
)
table
)[
i
*
half_dh
+
j
]
=
_sin
;
}
else
{
...
...
@@ -167,6 +169,8 @@ inline std::shared_ptr<Tensor> getCosTable(JiugeMeta const *meta) {
static_cast
<
float
>
(
i
)
/
std
::
pow
(
meta
->
theta
,
static_cast
<
float
>
(
j
)
/
half_dh
));
if
(
meta
->
dt_logits
==
INFINI_DTYPE_F16
)
{
((
uint16_t
*
)
table
)[
i
*
half_dh
+
j
]
=
f32_to_f16
(
_cos
);
}
else
if
(
meta
->
dt_logits
==
INFINI_DTYPE_BF16
)
{
((
uint16_t
*
)
table
)[
i
*
half_dh
+
j
]
=
f32_to_bf16
(
_cos
);
}
else
if
(
meta
->
dt_logits
==
INFINI_DTYPE_F32
)
{
((
float
*
)
table
)[
i
*
half_dh
+
j
]
=
_cos
;
}
else
{
...
...
src/tensor/tensor.cpp
View file @
a5deda33
...
...
@@ -234,6 +234,20 @@ void print_data(uint16_t const *data, const std::vector<size_t> &shape,
}
}
void
print_data_bf16
(
uint16_t
const
*
data
,
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
vector
<
ptrdiff_t
>
&
strides
,
size_t
dim
)
{
if
(
dim
==
shape
.
size
()
-
1
)
{
for
(
size_t
i
=
0
;
i
<
shape
[
dim
];
i
++
)
{
std
::
cout
<<
bf16_to_f32
(
data
[
i
*
strides
[
dim
]])
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
else
if
(
dim
<
shape
.
size
()
-
1
)
{
for
(
size_t
i
=
0
;
i
<
shape
[
dim
];
i
++
)
{
print_data
(
data
+
i
*
strides
[
dim
],
shape
,
strides
,
dim
+
1
);
}
}
}
std
::
string
Tensor
::
info
()
const
{
std
::
stringstream
ss
;
...
...
@@ -296,6 +310,10 @@ void Tensor::debug(const std::string &filename) const {
print_data
((
int32_t
const
*
)((
char
const
*
)
cpu_data
+
dataOffset
()),
this
->
shape
(),
this
->
strides
(),
0
);
break
;
case
INFINI_DTYPE_BF16
:
print_data_bf16
((
uint16_t
const
*
)((
char
const
*
)
cpu_data
+
dataOffset
()),
this
->
shape
(),
this
->
strides
(),
0
);
break
;
default:
PANIC
(
"Unsupported data type"
);
}
...
...
src/utils.hpp
View file @
a5deda33
...
...
@@ -97,4 +97,26 @@ inline uint16_t f32_to_f16(float val) {
}
}
inline
float
bf16_to_f32
(
uint16_t
val
)
{
// 只需把 bf16 放到 float32 高 16 bit,其余 16 位置 0。
uint32_t
bits32
=
static_cast
<
uint32_t
>
(
val
)
<<
16
;
float
out
;
std
::
memcpy
(
&
out
,
&
bits32
,
sizeof
(
out
));
return
out
;
}
inline
uint16_t
f32_to_bf16
(
float
val
)
{
uint32_t
bits32
;
std
::
memcpy
(
&
bits32
,
&
val
,
sizeof
(
bits32
));
// 截断前先加 0x7FFF,再根据第 16 位(有效位的最低位)的奇偶做 round-to-nearest-even
const
uint32_t
rounding_bias
=
0x00007FFF
+
// 0111 1111 1111 1111
((
bits32
>>
16
)
&
1
);
// 尾数的有效位的最低位奇数时 +1,即实现舍入偶数
uint16_t
bf16_bits
=
static_cast
<
uint16_t
>
((
bits32
+
rounding_bias
)
>>
16
);
return
bf16_bits
;
}
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment