Unverified Commit c5e1bbab authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

instead of static number of parameters for each model family, get the real...

instead of static number of parameters for each model family, get the real number from the tensors (#1022)

* parse tensor info

* refactor decoder

* return actual parameter count

* explicit rounding

* s/Human/HumanNumber/
parent a49d6acc
package format
import (
"fmt"
"math"
)
const (
Thousand = 1000
Million = Thousand * 1000
Billion = Million * 1000
)
func HumanNumber(b uint64) string {
switch {
case b > Billion:
return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion))
case b > Million:
return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million))
case b > Thousand:
return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand))
default:
return fmt.Sprintf("%d", b)
}
}
...@@ -5,6 +5,8 @@ import ( ...@@ -5,6 +5,8 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"github.com/jmorganca/ollama/format"
) )
type containerGGUF struct { type containerGGUF struct {
...@@ -21,6 +23,8 @@ type containerGGUF struct { ...@@ -21,6 +23,8 @@ type containerGGUF struct {
NumTensor uint64 NumTensor uint64
NumKV uint64 NumKV uint64
} }
parameters uint64
} }
func (c *containerGGUF) Name() string { func (c *containerGGUF) Name() string {
...@@ -75,6 +79,14 @@ func newGGUFModel(container *containerGGUF) *ggufModel { ...@@ -75,6 +79,14 @@ func newGGUFModel(container *containerGGUF) *ggufModel {
} }
} }
func (llm *ggufModel) NumTensor() uint64 {
if llm.Version == 1 {
return uint64(llm.V1.NumTensor)
}
return llm.V2.NumTensor
}
func (llm *ggufModel) NumKV() uint64 { func (llm *ggufModel) NumKV() uint64 {
if llm.Version == 1 { if llm.Version == 1 {
return uint64(llm.V1.NumKV) return uint64(llm.V1.NumKV)
...@@ -93,6 +105,10 @@ func (llm *ggufModel) ModelFamily() string { ...@@ -93,6 +105,10 @@ func (llm *ggufModel) ModelFamily() string {
} }
func (llm *ggufModel) ModelType() string { func (llm *ggufModel) ModelType() string {
if llm.parameters > 0 {
return format.HumanNumber(llm.parameters)
}
switch llm.ModelFamily() { switch llm.ModelFamily() {
case "llama": case "llama":
if blocks, ok := llm.kv["llama.block_count"].(uint32); ok { if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
...@@ -127,13 +143,9 @@ func (llm *ggufModel) FileType() string { ...@@ -127,13 +143,9 @@ func (llm *ggufModel) FileType() string {
} }
func (llm *ggufModel) Decode(r io.Reader) error { func (llm *ggufModel) Decode(r io.Reader) error {
read := llm.readString // decode key-values
if llm.Version == 1 {
read = llm.readStringV1
}
for i := 0; uint64(i) < llm.NumKV(); i++ { for i := 0; uint64(i) < llm.NumKV(); i++ {
k, err := read(r) k, err := llm.readString(r)
if err != nil { if err != nil {
return err return err
} }
...@@ -165,24 +177,14 @@ func (llm *ggufModel) Decode(r io.Reader) error { ...@@ -165,24 +177,14 @@ func (llm *ggufModel) Decode(r io.Reader) error {
case ggufTypeBool: case ggufTypeBool:
v = llm.readBool(r) v = llm.readBool(r)
case ggufTypeString: case ggufTypeString:
fn := llm.readString s, err := llm.readString(r)
if llm.Version == 1 {
fn = llm.readStringV1
}
s, err := fn(r)
if err != nil { if err != nil {
return err return err
} }
v = s v = s
case ggufTypeArray: case ggufTypeArray:
fn := llm.readArray a, err := llm.readArray(r)
if llm.Version == 1 {
fn = llm.readArrayV1
}
a, err := fn(r)
if err != nil { if err != nil {
return err return err
} }
...@@ -195,6 +197,25 @@ func (llm *ggufModel) Decode(r io.Reader) error { ...@@ -195,6 +197,25 @@ func (llm *ggufModel) Decode(r io.Reader) error {
llm.kv[k] = v llm.kv[k] = v
} }
// decode tensors
for i := 0; uint64(i) < llm.NumTensor(); i++ {
if _, err := llm.readString(r); err != nil {
return err
}
dimensions := llm.readU32(r)
var elements uint64 = 1
for i := 0; uint32(i) < dimensions; i++ {
elements *= llm.readU64(r)
}
llm.readU32(r) // type
llm.readU64(r) // offset
llm.parameters += elements
}
return nil return nil
} }
...@@ -290,6 +311,10 @@ func (llm ggufModel) readStringV1(r io.Reader) (string, error) { ...@@ -290,6 +311,10 @@ func (llm ggufModel) readStringV1(r io.Reader) (string, error) {
} }
func (llm ggufModel) readString(r io.Reader) (string, error) { func (llm ggufModel) readString(r io.Reader) (string, error) {
if llm.Version == 1 {
return llm.readStringV1(r)
}
var nameLength uint64 var nameLength uint64
binary.Read(r, llm.bo, &nameLength) binary.Read(r, llm.bo, &nameLength)
...@@ -339,6 +364,10 @@ func (llm *ggufModel) readArrayV1(r io.Reader) (arr []any, err error) { ...@@ -339,6 +364,10 @@ func (llm *ggufModel) readArrayV1(r io.Reader) (arr []any, err error) {
} }
func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) { func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) {
if llm.Version == 1 {
return llm.readArrayV1(r)
}
atype := llm.readU32(r) atype := llm.readU32(r)
n := llm.readU64(r) n := llm.readU64(r)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment