Commit 4a8fc3f9 authored by Michael Yang's avatar Michael Yang
Browse files

bf16

parent 4183bb05
......@@ -31,8 +31,9 @@ func (t tensorBase) Shape() []uint64 {
}
const (
tensorKindF32 uint32 = iota
tensorKindF16
tensorKindFP32 uint32 = iota
tensorKindFP16
tensorKindBF16 = 30
)
func (t tensorBase) Kind() uint32 {
......@@ -43,16 +44,16 @@ func (t tensorBase) Kind() uint32 {
t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" {
// these tensors are always F32
return 0
return tensorKindFP32
}
switch len(t.shape) {
case 0:
panic("invalid tensor shape")
case 1:
return tensorKindF32
return tensorKindFP32
default:
return tensorKindF16
return tensorKindBF16
}
}
......
......@@ -162,15 +162,18 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
}
switch st.Kind() {
case tensorKindF32:
case tensorKindFP32:
return 0, binary.Write(w, binary.LittleEndian, f32s)
case tensorKindF16:
case tensorKindFP16:
f16s := make([]uint16, len(f32s))
for i := range f32s {
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
}
return 0, binary.Write(w, binary.LittleEndian, f16s)
case tensorKindBF16:
u8s := bfloat16.EncodeFloat32(f32s)
return 0, binary.Write(w, binary.LittleEndian, u8s)
default:
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment