Commit c00fa9cc authored by Michael Yang's avatar Michael Yang Committed by Michael Yang
Browse files

convert: split gate_up bias

parent df411c4b
...@@ -85,6 +85,17 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { ...@@ -85,6 +85,17 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
case "scales": case "scales":
mxfp4s[name].scales = t mxfp4s[name].scales = t
} }
} else if strings.HasSuffix(t.Name(), "gate_up_exps.bias") {
out = append(out, slices.Collect(splitDim(t, 1,
split{
Replacer: strings.NewReplacer("gate_up_exps", "gate_exps"),
slices: []tensor.Slice{nil, tensor.S(0, int(t.Shape()[1]), 2)},
},
split{
Replacer: strings.NewReplacer("gate_up_exps", "up_exps"),
slices: []tensor.Slice{nil, tensor.S(1, int(t.Shape()[1]), 2)},
},
))...)
} else { } else {
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
Name: t.Name(), Name: t.Name(),
......
...@@ -17,6 +17,7 @@ import ( ...@@ -17,6 +17,7 @@ import (
type split struct { type split struct {
*strings.Replacer *strings.Replacer
dim int dim int
slices []tensor.Slice
// fn is an optional function to apply to the tensor after slicing // fn is an optional function to apply to the tensor after slicing
fn func(tensor.Tensor) (tensor.Tensor, error) fn func(tensor.Tensor) (tensor.Tensor, error)
...@@ -32,9 +33,12 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] { ...@@ -32,9 +33,12 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
shape := slices.Clone(t.Shape()) shape := slices.Clone(t.Shape())
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits))) shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
slice := split.slices
if len(slice) == 0 {
slice := slices.Repeat([]tensor.Slice{nil}, len(shape)) slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(offset, offset+int(shape[dim])) slice[dim] = tensor.S(offset, offset+int(shape[dim]))
offset += int(shape[dim]) offset += int(shape[dim])
}
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) { t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape)) dims := make([]int, len(shape))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment