"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f1f38ffbeed793d684684e00e6e1213bcaf494d6"
Unverified Commit a6fbfc88 authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

gguf: fix write order (#11068)

* ggml: test write gguf order
* ggml: fix write tensor order
parent 50202896
...@@ -527,23 +527,17 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error { ...@@ -527,23 +527,17 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
return err return err
} }
keys := slices.Collect(maps.Keys(kv)) for _, key := range slices.Sorted(maps.Keys(kv)) {
slices.Sort(keys)
for _, key := range keys {
if err := ggufWriteKV(f, key, kv[key]); err != nil { if err := ggufWriteKV(f, key, kv[key]); err != nil {
return err return err
} }
} }
slices.SortStableFunc(ts, func(a, b *Tensor) int { slices.SortStableFunc(ts, func(a, b *Tensor) int {
if i, j := a.block(), b.block(); i < 0 && j > 0 { if i, j := a.block(), b.block(); i > 0 && j > 0 {
return 1
} else if i > 0 && j < 0 {
return -1
} else {
return cmp.Compare(i, j) return cmp.Compare(i, j)
} }
return cmp.Compare(a.Name, b.Name)
}) })
var s uint64 var s uint64
......
...@@ -2,62 +2,82 @@ package ggml ...@@ -2,62 +2,82 @@ package ggml
import ( import (
"bytes" "bytes"
"math/rand/v2"
"os" "os"
"slices" "strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
) )
func TestWriteGGUF(t *testing.T) { func TestWriteGGUF(t *testing.T) {
w, err := os.CreateTemp(t.TempDir(), "*.bin") r := rand.New(rand.NewPCG(0, 0))
if err != nil { for range 8 {
t.Fatal(err) t.Run("shuffle", func(t *testing.T) {
} t.Parallel()
defer w.Close()
if err := WriteGGUF(w, KV{
"general.alignment": uint32(16),
}, []*Tensor{
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
}); err != nil {
t.Fatal(err)
}
r, err := os.Open(w.Name()) ts := []*Tensor{
if err != nil { {Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
t.Fatal(err) {Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
} {Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
defer r.Close() {Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
}
ff, err := Decode(r, 0) r.Shuffle(len(ts), func(i, j int) {
if err != nil { ts[i], ts[j] = ts[j], ts[i]
t.Fatal(err) })
}
if diff := cmp.Diff(ff.KV(), KV{ w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
"general.alignment": uint32(16), if err != nil {
"general.parameter_count": uint64(36), t.Fatal(err)
}); diff != "" { }
t.Errorf("Mismatch (-want +got):\n%s", diff) defer w.Close()
}
if err := WriteGGUF(w, KV{
"general.alignment": uint32(16),
}, ts); err != nil {
t.Fatal(err)
}
r, err := os.Open(w.Name())
if err != nil {
t.Fatal(err)
}
defer r.Close()
ff, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(KV{
"general.alignment": uint32(16),
"general.parameter_count": uint64(54),
}, ff.KV()); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(ff.Tensors(), Tensors{ if diff := cmp.Diff(Tensors{
Offset: 336, Offset: 608,
items: []*Tensor{ items: []*Tensor{
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}}, {Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}}, {Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}}, {Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}}, {Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}}, {Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}}, {Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
}, {Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
}, cmp.AllowUnexported(Tensors{})); diff != "" { {Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
t.Errorf("Mismatch (-want +got):\n%s", diff) {Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
},
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
})
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment