Unverified Commit 12b174b1 authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

fix tensor merge (#13053)

parent 333203d8
...@@ -2,10 +2,12 @@ package convert ...@@ -2,10 +2,12 @@ package convert
import ( import (
"cmp" "cmp"
"errors"
"io" "io"
"iter" "iter"
"path" "path"
"slices" "slices"
"strconv"
"strings" "strings"
"github.com/pdevine/tensor" "github.com/pdevine/tensor"
...@@ -94,6 +96,26 @@ func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ [] ...@@ -94,6 +96,26 @@ func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []
return matched return matched
}) })
slices.SortStableFunc(matched, func(a, b Tensor) int {
x := strings.Split(a.Name(), ".")
y := strings.Split(b.Name(), ".")
if len(x) != len(y) {
return cmp.Compare(len(x), len(y))
}
vals := make([]int, len(x))
for i := range x {
vals[i] = strings.Compare(x[i], y[i])
m, err := strconv.ParseInt(x[i], 0, 0)
n, err2 := strconv.ParseInt(y[i], 0, 0)
if errors.Join(err, err2) == nil {
vals[i] = cmp.Compare(m, n)
}
}
return cmp.Or(vals...)
})
if len(matched) > 0 { if len(matched) > 0 {
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
Name: merges[i].name, Name: merges[i].name,
......
...@@ -3,8 +3,10 @@ package convert ...@@ -3,8 +3,10 @@ package convert
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"iter" "iter"
"math/rand/v2"
"slices" "slices"
"strings" "strings"
"testing" "testing"
...@@ -951,3 +953,45 @@ func TestMerge(t *testing.T) { ...@@ -951,3 +953,45 @@ func TestMerge(t *testing.T) {
} }
}) })
} }
func TestMergeOrder(t *testing.T) {
for range 8 {
t.Run("", func(t *testing.T) {
tensors := make([]Tensor, 16)
for i := range tensors {
tensors[i] = &fakeTensor{
name: fmt.Sprintf("layer.%d.weight", i),
shape: []uint64{1},
data: []float32{float32(i)},
}
}
rand.Shuffle(len(tensors), func(i, j int) {
tensors[i], tensors[j] = tensors[j], tensors[i]
})
matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"})
if len(unmatched) != 0 {
t.Error("expected no remaining tensors, got", len(unmatched))
}
if len(matched) != 1 {
t.Error("expected 1 merged tensor, got", len(matched))
}
var b bytes.Buffer
if _, err := matched[0].WriteTo(&b); err != nil {
t.Fatal(err)
}
var f32s [16]float32
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.IsSorted(f32s[:]) {
t.Errorf("merged tensor data is not in order: %+v", f32s)
}
})
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment