Commit c8ac4cc5 authored by Michael Yang's avatar Michael Yang
Browse files

convert to mxfp4

parent 6ca094a3
package convert package convert
import ( import (
"bytes"
"cmp" "cmp"
"encoding/binary"
"io"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
) )
type gptossModel struct { type gptossModel struct {
...@@ -58,12 +65,39 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV { ...@@ -58,12 +65,39 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor var out []*ggml.Tensor
mxfp4s := make(map[string]*mxfp4)
for _, t := range ts { for _, t := range ts {
if strings.HasSuffix(t.Name(), ".blocks") || strings.HasSuffix(t.Name(), ".scales") {
dot := strings.LastIndex(t.Name(), ".")
name, suffix := t.Name()[:dot], t.Name()[dot+1:]
if _, ok := mxfp4s[name]; !ok {
mxfp4s[name] = &mxfp4{}
}
switch suffix {
case "blocks":
mxfp4s[name].blocks = t
case "scales":
mxfp4s[name].scales = t
}
} else {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape()
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
Name: t.Name(), Name: name,
Kind: t.Kind(), Kind: uint32(ggml.TensorTypeMXFP4),
Shape: t.Shape(), Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: t, WriterTo: mxfp4,
}) })
} }
...@@ -72,6 +106,10 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { ...@@ -72,6 +106,10 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
func (m *gptossModel) Replacements() []string { func (m *gptossModel) Replacements() []string {
return []string{ return []string{
// noop replacements so other replacements will not be applied
".blocks", ".blocks",
".scales", ".scales",
// real replacements
"block", "blk", "block", "blk",
"attn.norm", "attn_norm", "attn.norm", "attn_norm",
"attn.qkv", "attn_qkv", "attn.qkv", "attn_qkv",
...@@ -87,3 +125,55 @@ func (m *gptossModel) Replacements() []string { ...@@ -87,3 +125,55 @@ func (m *gptossModel) Replacements() []string {
"scale", "weight", "scale", "weight",
} }
} }
type mxfp4 struct {
blocks, scales Tensor
}
func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
var b bytes.Buffer
if _, err := m.blocks.WriteTo(&b); err != nil {
return 0, err
}
blocksDims := make([]int, len(m.blocks.Shape()))
for i, d := range m.blocks.Shape() {
blocksDims[i] = int(d)
}
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes()))
var s bytes.Buffer
if _, err := m.scales.WriteTo(&s); err != nil {
return 0, err
}
scalesDims := slices.Repeat([]int{1}, len(m.blocks.Shape()))
for i, d := range m.scales.Shape() {
scalesDims[i] = int(d)
}
var scales tensor.Tensor = tensor.New(tensor.WithShape(scalesDims...), tensor.WithBacking(s.Bytes()))
out, err := tensor.Concat(3, scales, blocks)
if err != nil {
return 0, err
}
out = tensor.Materialize(out)
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
return 0, err
}
u8s, err := native.VectorU8(out.(*tensor.Dense))
if err != nil {
return 0, err
}
if err := binary.Write(w, binary.LittleEndian, u8s); err != nil {
return 0, err
}
return 0, nil
}
...@@ -33,7 +33,8 @@ func (t tensorBase) Shape() []uint64 { ...@@ -33,7 +33,8 @@ func (t tensorBase) Shape() []uint64 {
const ( const (
tensorKindFP32 uint32 = iota tensorKindFP32 uint32 = iota
tensorKindFP16 tensorKindFP16
tensorKindBF16 = 30 tensorKindMXFP4 = 4
tensorKindBF16 = 30
) )
func (t tensorBase) Kind() uint32 { func (t tensorBase) Kind() uint32 {
......
...@@ -282,7 +282,7 @@ func (t Tensor) block() (n int) { ...@@ -282,7 +282,7 @@ func (t Tensor) block() (n int) {
} }
func (t Tensor) blockSize() uint64 { func (t Tensor) blockSize() uint64 {
return (TensorType)(t.Kind).BlockSize() return TensorType(t.Kind).BlockSize()
} }
func (t TensorType) BlockSize() uint64 { func (t TensorType) BlockSize() uint64 {
...@@ -300,6 +300,7 @@ func (t TensorType) BlockSize() uint64 { ...@@ -300,6 +300,7 @@ func (t TensorType) BlockSize() uint64 {
case case
2, // Q4_0 2, // Q4_0
3, // Q4_1 3, // Q4_1
4, // MXFP4
6, // Q5_0 6, // Q5_0
7, // Q5_1 7, // Q5_1
8, // Q8_0 8, // Q8_0
...@@ -327,6 +328,8 @@ func (t TensorType) TypeSize() uint64 { ...@@ -327,6 +328,8 @@ func (t TensorType) TypeSize() uint64 {
return 2 + blockSize/2 return 2 + blockSize/2
case TensorTypeQ4_1: case TensorTypeQ4_1:
return 2 + 2 + blockSize/2 return 2 + 2 + blockSize/2
case TensorTypeMXFP4:
return 1 + blockSize/2
case TensorTypeQ5_0: case TensorTypeQ5_0:
return 2 + 4 + blockSize/2 return 2 + 4 + blockSize/2
case TensorTypeQ5_1: case TensorTypeQ5_1:
......
...@@ -14,9 +14,9 @@ const ( ...@@ -14,9 +14,9 @@ const (
FileTypeF16 FileTypeF16
fileTypeQ4_0 fileTypeQ4_0
fileTypeQ4_1 fileTypeQ4_1
fileTypeQ4_1_F16 // unused by GGML fileTypeMXFP4 // originally fileTypeQ4_1_F16 // unused by GGML
fileTypeQ4_2 // unused by GGML fileTypeQ4_2 // unused by GGML
fileTypeQ4_3 // unused by GGML fileTypeQ4_3 // unused by GGML
FileTypeQ8_0 FileTypeQ8_0
fileTypeQ5_0 fileTypeQ5_0
fileTypeQ5_1 fileTypeQ5_1
...@@ -97,6 +97,8 @@ func (t FileType) String() string { ...@@ -97,6 +97,8 @@ func (t FileType) String() string {
return "Q4_0" return "Q4_0"
case fileTypeQ4_1: case fileTypeQ4_1:
return "Q4_1" return "Q4_1"
case fileTypeMXFP4:
return "MXFP4"
case FileTypeQ8_0: case FileTypeQ8_0:
return "Q8_0" return "Q8_0"
case fileTypeQ5_0: case fileTypeQ5_0:
...@@ -144,6 +146,8 @@ func (ftype FileType) ToTensorType() TensorType { ...@@ -144,6 +146,8 @@ func (ftype FileType) ToTensorType() TensorType {
return TensorTypeQ4_0 return TensorTypeQ4_0
case fileTypeQ4_1: case fileTypeQ4_1:
return TensorTypeQ4_1 return TensorTypeQ4_1
case fileTypeMXFP4:
return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
case FileTypeQ8_0: case FileTypeQ8_0:
return TensorTypeQ8_0 return TensorTypeQ8_0
case fileTypeQ5_0: case fileTypeQ5_0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment