Commit fccf8d17 authored by Michael Yang's avatar Michael Yang
Browse files

partial decode ggml bin for more info

parent 5b5cc9c9
package llm
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
type ModelFamily string
const ModelFamilyLlama ModelFamily = "llama"
type ModelType uint32
const (
ModelType3B ModelType = 26
ModelType7B ModelType = 32
ModelType13B ModelType = 40
ModelType30B ModelType = 60
ModelType65B ModelType = 80
)
type FileType uint32
const (
FileTypeF32 FileType = iota
FileTypeF16
FileTypeQ4_0
FileTypeQ4_1
FileTypeQ4_1_F16
FileTypeQ8_0 = iota + 3
FileTypeQ5_0
FileTypeQ5_1
FileTypeQ2_K
FileTypeQ3_K
FileTypeQ4_K
FileTypeQ5_K
FileTypeQ6_K
FileTypeUnknown = -1
)
type GGML struct {
ModelFamily
ModelType
magic uint32
container
llamaHyperparameters
}
type container interface {
Name() string
Decode(io.Reader) error
}
type containerGGML struct {
}
func (c *containerGGML) Name() string {
return "ggml"
}
func (c *containerGGML) Decode(r io.Reader) error {
return nil
}
type containerGGMF struct {
version uint32
}
func (c *containerGGMF) Name() string {
return "ggmf"
}
func (c *containerGGMF) Decode(r io.Reader) error {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
switch version {
case 1:
default:
return errors.New("invalid version")
}
c.version = version
return nil
}
type containerGGJT struct {
version uint32
}
func (c *containerGGJT) Name() string {
return "ggjt"
}
func (c *containerGGJT) Decode(r io.Reader) error {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
switch version {
case 1, 2, 3:
default:
return errors.New("invalid version")
}
c.version = version
return nil
}
type containerLORA struct {
version uint32
}
func (c *containerLORA) Name() string {
return "ggla"
}
func (c *containerLORA) Decode(r io.Reader) error {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
switch version {
case 1:
default:
return errors.New("invalid version")
}
c.version = version
return nil
}
const (
// / Magic constant for `ggml` files (unversioned).
FILE_MAGIC_GGML = 0x67676d6c
// / Magic constant for `ggml` files (versioned, ggmf).
FILE_MAGIC_GGMF = 0x67676d66
// / Magic constant for `ggml` files (versioned, ggjt).
FILE_MAGIC_GGJT = 0x67676a74
// / Magic constant for `ggla` files (LoRA adapter).
FILE_MAGIC_GGLA = 0x67676C61
)
func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
var ggml GGML
binary.Read(r, binary.LittleEndian, &ggml.magic)
switch ggml.magic {
case FILE_MAGIC_GGML:
ggml.container = &containerGGML{}
case FILE_MAGIC_GGMF:
ggml.container = &containerGGMF{}
case FILE_MAGIC_GGJT:
ggml.container = &containerGGJT{}
case FILE_MAGIC_GGLA:
ggml.container = &containerLORA{}
default:
return nil, errors.New("invalid file magic")
}
if err := ggml.Decode(r); err != nil {
return nil, err
}
// different model types may have different layouts for hyperparameters
switch hint {
case ModelFamilyLlama:
binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters)
// TODO: sanity check hyperparameters
default:
return nil, fmt.Errorf("unsupported model type: %s", hint)
}
// final model type
ggml.ModelFamily = hint
ggml.ModelType = ModelType(ggml.NumLayer)
return &ggml, nil
}
package llama
package llm
/*
#cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
......@@ -105,7 +105,7 @@ import (
//go:embed ggml-metal.metal
var fs embed.FS
type LLM struct {
type llama struct {
params *C.struct_llama_context_params
model *C.struct_llama_model
ctx *C.struct_llama_context
......@@ -120,12 +120,28 @@ type LLM struct {
api.Options
}
func New(model string, opts api.Options) (*LLM, error) {
type llamaHyperparameters struct {
// NumVocab is the size of the model's vocabulary.
NumVocab uint32
// NumEmbd is the size of the model's embedding layer.
NumEmbd uint32
NumMult uint32
NumHead uint32
// NumLayer is the number of layers in the model.
NumLayer uint32
NumRot uint32
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
FileType
}
func newLlama(model string, opts api.Options) (*llama, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
llm := LLM{Options: opts}
llm := llama{Options: opts}
C.llama_backend_init(C.bool(llm.UseNUMA))
......@@ -168,7 +184,7 @@ func New(model string, opts api.Options) (*LLM, error) {
return &llm, nil
}
func (llm *LLM) Close() {
func (llm *llama) Close() {
llm.gc = true
llm.mu.Lock()
......@@ -180,17 +196,16 @@ func (llm *LLM) Close() {
C.llama_print_timings(llm.ctx)
}
func (llm *llama) SetOptions(opts api.Options) {
llm.Options = opts
}
var errNeedMoreData = errors.New("need more data")
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
C.llama_reset_timings(llm.ctx)
tokens := make([]C.llama_token, len(ctx))
for i := range tokens {
tokens[i] = C.llama_token(ctx[i])
}
llm.marshalPrompt(tokens, prompt)
llm.marshalPrompt(ctx, prompt)
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
......@@ -205,7 +220,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
return err
}
b.WriteString(llm.Decode(token))
b.WriteString(llm.Decode(int(token)))
if err := llm.checkStopConditions(b); err != nil {
if errors.Is(err, io.EOF) {
......@@ -243,7 +258,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
return nil
}
func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
func (llm *llama) checkStopConditions(b bytes.Buffer) error {
for _, stopCondition := range llm.Stop {
if stopCondition == strings.TrimSpace(b.String()) {
return io.EOF
......@@ -255,12 +270,17 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
return nil
}
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token {
tokens := append(ctx, llm.Encode(prompt)...)
if llm.NumKeep < 0 {
llm.NumKeep = len(tokens)
}
cTokens := make([]C.llama_token, len(tokens))
for i := range tokens {
cTokens[i] = C.llama_token(tokens[i])
}
// min(llm.NumCtx - 4, llm.NumKeep)
if llm.NumCtx-4 < llm.NumKeep {
llm.NumKeep = llm.NumCtx - 4
......@@ -269,25 +289,25 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
if len(tokens) >= llm.NumCtx {
// truncate input
numLeft := (llm.NumCtx - llm.NumKeep) / 2
truncated := tokens[:llm.NumKeep]
erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
copy(llm.last, tokens[len(tokens)-llm.NumCtx:])
truncated := cTokens[:llm.NumKeep]
erasedBlocks := (len(cTokens) - llm.NumKeep - numLeft - 1) / numLeft
truncated = append(truncated, cTokens[llm.NumKeep+erasedBlocks*numLeft:]...)
copy(llm.last, cTokens[len(cTokens)-llm.NumCtx:])
tokens = truncated
cTokens = truncated
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
} else {
llm.last = make([]C.llama_token, llm.NumCtx-len(tokens))
llm.last = append(llm.last, tokens...)
llm.last = make([]C.llama_token, llm.NumCtx-len(cTokens))
llm.last = append(llm.last, cTokens...)
}
var i int
for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ {
for i = 0; i < len(llm.embd) && i < len(cTokens) && llm.embd[i] == cTokens[i]; i++ {
// noop
}
llm.embd = tokens
if i == len(tokens) {
llm.embd = cTokens
if i == len(cTokens) {
// evaluate at least one token to generate logits
i--
}
......@@ -295,31 +315,36 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
llm.cursor = i
log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
return tokens
return cTokens
}
func (llm *LLM) Encode(prompt string) []C.llama_token {
func (llm *llama) Encode(prompt string) []int {
cPrompt := C.CString(prompt)
defer C.free(unsafe.Pointer(cPrompt))
tokens := make([]C.llama_token, len(prompt)+1)
if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(tokens), C.int(len(tokens)), true); n > 0 {
return tokens[:n]
cTokens := make([]C.llama_token, len(prompt)+1)
if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(cTokens), C.int(len(cTokens)), true); n > 0 {
tokens := make([]int, n)
for i := range cTokens[:n] {
tokens[i] = int(cTokens[i])
}
return tokens
}
return nil
}
func (llm *LLM) Decode(tokens ...C.llama_token) string {
func (llm *llama) Decode(tokens ...int) string {
var sb strings.Builder
for _, token := range tokens {
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, C.llama_token(token))))
}
return sb.String()
}
func (llm *LLM) next() (C.llama_token, error) {
func (llm *llama) next() (C.llama_token, error) {
llm.mu.Lock()
defer llm.mu.Unlock()
......@@ -410,7 +435,7 @@ func (llm *LLM) next() (C.llama_token, error) {
return token, nil
}
func (llm *LLM) Embedding(input string) ([]float64, error) {
func (llm *llama) Embedding(input string) ([]float64, error) {
if !llm.EmbeddingOnly {
return nil, errors.New("llama: embedding not enabled")
}
......@@ -420,7 +445,12 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
return nil, errors.New("llama: tokenize embedding")
}
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
cTokens := make([]C.llama_token, len(tokens))
for i := range tokens {
cTokens[i] = C.llama_token(tokens[i])
}
retval := C.llama_eval(llm.ctx, unsafe.SliceData(cTokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
if retval != 0 {
return nil, errors.New("llama: eval")
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment