Commit fccf8d17 authored by Michael Yang's avatar Michael Yang
Browse files

partial decode ggml bin for more info

parent 5b5cc9c9
package llm
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
type ModelFamily string
const ModelFamilyLlama ModelFamily = "llama"
type ModelType uint32
const (
ModelType3B ModelType = 26
ModelType7B ModelType = 32
ModelType13B ModelType = 40
ModelType30B ModelType = 60
ModelType65B ModelType = 80
)
type FileType uint32
const (
FileTypeF32 FileType = iota
FileTypeF16
FileTypeQ4_0
FileTypeQ4_1
FileTypeQ4_1_F16
FileTypeQ8_0 = iota + 3
FileTypeQ5_0
FileTypeQ5_1
FileTypeQ2_K
FileTypeQ3_K
FileTypeQ4_K
FileTypeQ5_K
FileTypeQ6_K
FileTypeUnknown = -1
)
type GGML struct {
ModelFamily
ModelType
magic uint32
container
llamaHyperparameters
}
type container interface {
Name() string
Decode(io.Reader) error
}
type containerGGML struct {
}
func (c *containerGGML) Name() string {
return "ggml"
}
func (c *containerGGML) Decode(r io.Reader) error {
return nil
}
type containerGGMF struct {
version uint32
}
func (c *containerGGMF) Name() string {
return "ggmf"
}
func (c *containerGGMF) Decode(r io.Reader) error {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
switch version {
case 1:
default:
return errors.New("invalid version")
}
c.version = version
return nil
}
type containerGGJT struct {
version uint32
}
func (c *containerGGJT) Name() string {
return "ggjt"
}
func (c *containerGGJT) Decode(r io.Reader) error {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
switch version {
case 1, 2, 3:
default:
return errors.New("invalid version")
}
c.version = version
return nil
}
type containerLORA struct {
version uint32
}
func (c *containerLORA) Name() string {
return "ggla"
}
func (c *containerLORA) Decode(r io.Reader) error {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
switch version {
case 1:
default:
return errors.New("invalid version")
}
c.version = version
return nil
}
const (
// / Magic constant for `ggml` files (unversioned).
FILE_MAGIC_GGML = 0x67676d6c
// / Magic constant for `ggml` files (versioned, ggmf).
FILE_MAGIC_GGMF = 0x67676d66
// / Magic constant for `ggml` files (versioned, ggjt).
FILE_MAGIC_GGJT = 0x67676a74
// / Magic constant for `ggla` files (LoRA adapter).
FILE_MAGIC_GGLA = 0x67676C61
)
func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
var ggml GGML
binary.Read(r, binary.LittleEndian, &ggml.magic)
switch ggml.magic {
case FILE_MAGIC_GGML:
ggml.container = &containerGGML{}
case FILE_MAGIC_GGMF:
ggml.container = &containerGGMF{}
case FILE_MAGIC_GGJT:
ggml.container = &containerGGJT{}
case FILE_MAGIC_GGLA:
ggml.container = &containerLORA{}
default:
return nil, errors.New("invalid file magic")
}
if err := ggml.Decode(r); err != nil {
return nil, err
}
// different model types may have different layouts for hyperparameters
switch hint {
case ModelFamilyLlama:
binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters)
// TODO: sanity check hyperparameters
default:
return nil, fmt.Errorf("unsupported model type: %s", hint)
}
// final model type
ggml.ModelFamily = hint
ggml.ModelType = ModelType(ggml.NumLayer)
return &ggml, nil
}
package llama package llm
/* /*
#cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS #cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
...@@ -105,7 +105,7 @@ import ( ...@@ -105,7 +105,7 @@ import (
//go:embed ggml-metal.metal //go:embed ggml-metal.metal
var fs embed.FS var fs embed.FS
type LLM struct { type llama struct {
params *C.struct_llama_context_params params *C.struct_llama_context_params
model *C.struct_llama_model model *C.struct_llama_model
ctx *C.struct_llama_context ctx *C.struct_llama_context
...@@ -120,12 +120,28 @@ type LLM struct { ...@@ -120,12 +120,28 @@ type LLM struct {
api.Options api.Options
} }
func New(model string, opts api.Options) (*LLM, error) { type llamaHyperparameters struct {
// NumVocab is the size of the model's vocabulary.
NumVocab uint32
// NumEmbd is the size of the model's embedding layer.
NumEmbd uint32
NumMult uint32
NumHead uint32
// NumLayer is the number of layers in the model.
NumLayer uint32
NumRot uint32
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
FileType
}
func newLlama(model string, opts api.Options) (*llama, error) {
if _, err := os.Stat(model); err != nil { if _, err := os.Stat(model); err != nil {
return nil, err return nil, err
} }
llm := LLM{Options: opts} llm := llama{Options: opts}
C.llama_backend_init(C.bool(llm.UseNUMA)) C.llama_backend_init(C.bool(llm.UseNUMA))
...@@ -168,7 +184,7 @@ func New(model string, opts api.Options) (*LLM, error) { ...@@ -168,7 +184,7 @@ func New(model string, opts api.Options) (*LLM, error) {
return &llm, nil return &llm, nil
} }
func (llm *LLM) Close() { func (llm *llama) Close() {
llm.gc = true llm.gc = true
llm.mu.Lock() llm.mu.Lock()
...@@ -180,17 +196,16 @@ func (llm *LLM) Close() { ...@@ -180,17 +196,16 @@ func (llm *LLM) Close() {
C.llama_print_timings(llm.ctx) C.llama_print_timings(llm.ctx)
} }
func (llm *llama) SetOptions(opts api.Options) {
llm.Options = opts
}
var errNeedMoreData = errors.New("need more data") var errNeedMoreData = errors.New("need more data")
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
C.llama_reset_timings(llm.ctx) C.llama_reset_timings(llm.ctx)
tokens := make([]C.llama_token, len(ctx)) llm.marshalPrompt(ctx, prompt)
for i := range tokens {
tokens[i] = C.llama_token(ctx[i])
}
llm.marshalPrompt(tokens, prompt)
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed)) C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
...@@ -205,7 +220,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) ...@@ -205,7 +220,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
return err return err
} }
b.WriteString(llm.Decode(token)) b.WriteString(llm.Decode(int(token)))
if err := llm.checkStopConditions(b); err != nil { if err := llm.checkStopConditions(b); err != nil {
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
...@@ -243,7 +258,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) ...@@ -243,7 +258,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
return nil return nil
} }
func (llm *LLM) checkStopConditions(b bytes.Buffer) error { func (llm *llama) checkStopConditions(b bytes.Buffer) error {
for _, stopCondition := range llm.Stop { for _, stopCondition := range llm.Stop {
if stopCondition == strings.TrimSpace(b.String()) { if stopCondition == strings.TrimSpace(b.String()) {
return io.EOF return io.EOF
...@@ -255,12 +270,17 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error { ...@@ -255,12 +270,17 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
return nil return nil
} }
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token { func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token {
tokens := append(ctx, llm.Encode(prompt)...) tokens := append(ctx, llm.Encode(prompt)...)
if llm.NumKeep < 0 { if llm.NumKeep < 0 {
llm.NumKeep = len(tokens) llm.NumKeep = len(tokens)
} }
cTokens := make([]C.llama_token, len(tokens))
for i := range tokens {
cTokens[i] = C.llama_token(tokens[i])
}
// min(llm.NumCtx - 4, llm.NumKeep) // min(llm.NumCtx - 4, llm.NumKeep)
if llm.NumCtx-4 < llm.NumKeep { if llm.NumCtx-4 < llm.NumKeep {
llm.NumKeep = llm.NumCtx - 4 llm.NumKeep = llm.NumCtx - 4
...@@ -269,25 +289,25 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke ...@@ -269,25 +289,25 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
if len(tokens) >= llm.NumCtx { if len(tokens) >= llm.NumCtx {
// truncate input // truncate input
numLeft := (llm.NumCtx - llm.NumKeep) / 2 numLeft := (llm.NumCtx - llm.NumKeep) / 2
truncated := tokens[:llm.NumKeep] truncated := cTokens[:llm.NumKeep]
erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft erasedBlocks := (len(cTokens) - llm.NumKeep - numLeft - 1) / numLeft
truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...) truncated = append(truncated, cTokens[llm.NumKeep+erasedBlocks*numLeft:]...)
copy(llm.last, tokens[len(tokens)-llm.NumCtx:]) copy(llm.last, cTokens[len(cTokens)-llm.NumCtx:])
tokens = truncated cTokens = truncated
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated)) log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
} else { } else {
llm.last = make([]C.llama_token, llm.NumCtx-len(tokens)) llm.last = make([]C.llama_token, llm.NumCtx-len(cTokens))
llm.last = append(llm.last, tokens...) llm.last = append(llm.last, cTokens...)
} }
var i int var i int
for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ { for i = 0; i < len(llm.embd) && i < len(cTokens) && llm.embd[i] == cTokens[i]; i++ {
// noop // noop
} }
llm.embd = tokens llm.embd = cTokens
if i == len(tokens) { if i == len(cTokens) {
// evaluate at least one token to generate logits // evaluate at least one token to generate logits
i-- i--
} }
...@@ -295,31 +315,36 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke ...@@ -295,31 +315,36 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
llm.cursor = i llm.cursor = i
log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:])) log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
return tokens return cTokens
} }
func (llm *LLM) Encode(prompt string) []C.llama_token { func (llm *llama) Encode(prompt string) []int {
cPrompt := C.CString(prompt) cPrompt := C.CString(prompt)
defer C.free(unsafe.Pointer(cPrompt)) defer C.free(unsafe.Pointer(cPrompt))
tokens := make([]C.llama_token, len(prompt)+1) cTokens := make([]C.llama_token, len(prompt)+1)
if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(tokens), C.int(len(tokens)), true); n > 0 { if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(cTokens), C.int(len(cTokens)), true); n > 0 {
return tokens[:n] tokens := make([]int, n)
for i := range cTokens[:n] {
tokens[i] = int(cTokens[i])
}
return tokens
} }
return nil return nil
} }
func (llm *LLM) Decode(tokens ...C.llama_token) string { func (llm *llama) Decode(tokens ...int) string {
var sb strings.Builder var sb strings.Builder
for _, token := range tokens { for _, token := range tokens {
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token))) sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, C.llama_token(token))))
} }
return sb.String() return sb.String()
} }
func (llm *LLM) next() (C.llama_token, error) { func (llm *llama) next() (C.llama_token, error) {
llm.mu.Lock() llm.mu.Lock()
defer llm.mu.Unlock() defer llm.mu.Unlock()
...@@ -410,7 +435,7 @@ func (llm *LLM) next() (C.llama_token, error) { ...@@ -410,7 +435,7 @@ func (llm *LLM) next() (C.llama_token, error) {
return token, nil return token, nil
} }
func (llm *LLM) Embedding(input string) ([]float64, error) { func (llm *llama) Embedding(input string) ([]float64, error) {
if !llm.EmbeddingOnly { if !llm.EmbeddingOnly {
return nil, errors.New("llama: embedding not enabled") return nil, errors.New("llama: embedding not enabled")
} }
...@@ -420,7 +445,12 @@ func (llm *LLM) Embedding(input string) ([]float64, error) { ...@@ -420,7 +445,12 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
return nil, errors.New("llama: tokenize embedding") return nil, errors.New("llama: tokenize embedding")
} }
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread)) cTokens := make([]C.llama_token, len(tokens))
for i := range tokens {
cTokens[i] = C.llama_token(tokens[i])
}
retval := C.llama_eval(llm.ctx, unsafe.SliceData(cTokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
if retval != 0 { if retval != 0 {
return nil, errors.New("llama: eval") return nil, errors.New("llama: eval")
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment