llama.go 19.6 KB
Newer Older
1
2
3
package llama

/*
Michael Yang's avatar
Michael Yang committed
4
#cgo CFLAGS: -std=c11
5
#cgo windows CFLAGS: -Wno-dll-attribute-on-redeclaration
Michael Yang's avatar
Michael Yang committed
6
7
8
9
10
11
#cgo CXXFLAGS: -std=c++17
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/include
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/common
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/examples/llava
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/src
#cgo CPPFLAGS: -I${SRCDIR}/../ml/backend/ggml/ggml/include
12
13

#include <stdlib.h>
Michael Yang's avatar
Michael Yang committed
14
#include "ggml.h"
15
16
17
#include "llama.h"
#include "clip.h"
#include "llava.h"
18
#include "gguf.h"
Michael Yang's avatar
Michael Yang committed
19

20
#include "mllama.h"
21
22
#include "sampling_ext.h"

23
24
extern bool llamaProgressCallback(float progress, void *user_data);
extern void llamaLog(int level, char* text, void* user_data);
25
26
27
28
*/
import "C"

import (
29
	"context"
30
31
32
	_ "embed"
	"errors"
	"fmt"
33
	"log/slog"
34
	"os"
35
36
	"runtime"
	"runtime/cgo"
Jesse Gross's avatar
Jesse Gross committed
37
	"slices"
38
	"strings"
39
	"sync"
40
	"unsafe"
Michael Yang's avatar
Michael Yang committed
41
42
43
44

	_ "github.com/ollama/ollama/llama/llama.cpp/common"
	_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
	_ "github.com/ollama/ollama/llama/llama.cpp/src"
45
	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
46
47
)

48
49
50
51
52
53
54
55
56
57
58
59
func init() {
	C.llama_log_set(C.ggml_log_callback(C.llamaLog), nil)
}

//export llamaLog
func llamaLog(level C.int, text *C.char, _ unsafe.Pointer) {
	// slog levels zeros INFO and are multiples of 4
	if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
		fmt.Fprint(os.Stderr, C.GoString(text))
	}
}

60
func BackendInit() {
Michael Yang's avatar
Michael Yang committed
61
	ggml.OnceLoad()
62
63
64
	C.llama_backend_init()
}

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
func GetModelArch(modelPath string) (string, error) {
	mp := C.CString(modelPath)
	defer C.free(unsafe.Pointer(mp))

	gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
	if gguf_ctx == nil {
		return "", errors.New("unable to load model file")
	}
	defer C.gguf_free(gguf_ctx)

	key := C.CString("general.architecture")
	defer C.free(unsafe.Pointer(key))
	arch_index := C.gguf_find_key(gguf_ctx, key)
	if int(arch_index) < 0 {
		return "", errors.New("unknown model architecture")
	}

	arch := C.gguf_get_val_str(gguf_ctx, arch_index)

	return C.GoString(arch), nil
}

87
88
89
90
type ContextParams struct {
	c C.struct_llama_context_params
}

91
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
92
93
94
95
96
97
98
99
	params := C.llama_context_default_params()
	params.n_ctx = C.uint(numCtx)
	params.n_batch = C.uint(batchSize)
	params.n_seq_max = C.uint(numSeqMax)
	params.n_threads = C.int(threads)
	params.n_threads_batch = params.n_threads
	params.embeddings = C.bool(true)
	params.flash_attn = C.bool(flashAttention)
100
101
102
	params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
	params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))

103
104
105
	return ContextParams{c: params}
}

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
func kvCacheTypeFromStr(s string) C.enum_ggml_type {
	if s == "" {
		return C.GGML_TYPE_F16
	}

	switch s {
	case "q8_0":
		return C.GGML_TYPE_Q8_0
	case "q4_0":
		return C.GGML_TYPE_Q4_0
	default:
		return C.GGML_TYPE_F16
	}
}

122
123
124
125
126
type Context struct {
	c          *C.struct_llama_context
	numThreads int
}

127
var ErrKvCacheFull = errors.New("could not find a kv cache slot")
128
129
130
131
132
133
134
135
136
137
138
139
140

func (c *Context) Decode(batch *Batch) error {
	// Positive return values does not mean a fatal error, but rather a warning.
	//   0 - success
	//   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
	// < 0 - error
	code := int(C.llama_decode(c.c, batch.c))

	if code < 0 {
		return fmt.Errorf("llama_decode failed with code %d", code)
	}

	if code > 0 {
141
		return ErrKvCacheFull
142
143
144
145
146
147
148
149
150
151
	}

	return nil
}

func (c *Context) Model() *Model {
	return &Model{c: C.llama_get_model(c.c)}
}

func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
152
	C.llama_kv_self_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
153
154
155
}

func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
156
	return bool(C.llama_kv_self_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
157
158
159
}

func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
160
	C.llama_kv_self_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
161
162
}

163
func (c *Context) KvCacheClear() {
164
	C.llama_kv_self_clear(c.c)
165
166
167
}

func (c *Context) KvCacheDefrag() {
168
	C.llama_kv_self_defrag(c.c)
169
170
}

171
func (c *Context) KvCacheCanShift() bool {
172
	return bool(C.llama_kv_self_can_shift(c.c))
173
174
}

175
176
// Get the embeddings for a sequence id
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
177
178
	e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
	if e == nil {
179
180
181
		return nil
	}

182
183
184
	embeddings := make([]float32, c.Model().NEmbd())
	_ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
	return embeddings
185
186
187
}

func (c *Context) GetEmbeddingsIth(i int) []float32 {
188
189
	e := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
	if e == nil {
190
191
192
		return nil
	}

193
194
195
	embeddings := make([]float32, c.Model().NEmbd())
	_ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
	return embeddings
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
}

type ModelParams struct {
	NumGpuLayers int
	MainGpu      int
	UseMmap      bool
	UseMlock     bool
	TensorSplit  []float32
	Progress     func(float32)
	VocabOnly    bool
}

//export llamaProgressCallback
func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
	handle := *(*cgo.Handle)(userData)
	callback := handle.Value().(func(float32))
	callback(float32(progress))
	return true
}

216
func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
	cparams := C.llama_model_default_params()
	cparams.n_gpu_layers = C.int(params.NumGpuLayers)
	cparams.main_gpu = C.int32_t(params.MainGpu)
	cparams.use_mmap = C.bool(params.UseMmap)
	cparams.use_mlock = C.bool(params.UseMlock)
	cparams.vocab_only = C.bool(params.VocabOnly)

	if len(params.TensorSplit) > 0 {
		tensorSplitData := &params.TensorSplit[0]

		var tensorSplitPin runtime.Pinner
		tensorSplitPin.Pin(tensorSplitData)
		defer tensorSplitPin.Unpin()

		cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
	}

	if params.Progress != nil {
		handle := cgo.NewHandle(params.Progress)
		defer handle.Delete()

		var handlePin runtime.Pinner
		handlePin.Pin(&handle)
		defer handlePin.Unpin()

		cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
		cparams.progress_callback_user_data = unsafe.Pointer(&handle)
	}

246
	m := Model{c: C.llama_model_load_from_file(C.CString(modelPath), cparams)}
Jesse Gross's avatar
Jesse Gross committed
247
	if m.c == nil {
248
249
250
251
		return nil, fmt.Errorf("unable to load model: %s", modelPath)
	}

	return &m, nil
252
253
254
}

func FreeModel(model *Model) {
255
	C.llama_model_free(model.c)
256
257
}

258
259
func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
	c := Context{
260
		c:          C.llama_init_from_model(model.c, params.c),
261
262
		numThreads: int(params.c.n_threads),
	}
Jesse Gross's avatar
Jesse Gross committed
263
	if c.c == nil {
264
265
266
267
		return nil, errors.New("unable to create llama context")
	}

	return &c, nil
268
269
270
}

func (m *Model) NumVocab() int {
271
	return int(C.llama_vocab_n_tokens(m.Vocab()))
272
273
274
}

func (m *Model) TokenIsEog(token int) bool {
275
	return bool(C.llama_vocab_is_eog(m.Vocab(), C.llama_token(token)))
276
277
278
}

func (m *Model) AddBOSToken() bool {
279
	return bool(C.llama_vocab_get_add_bos(m.Vocab()))
280
281
282
283
284
285
}

func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
	cLoraPath := C.CString(loraPath)
	defer C.free(unsafe.Pointer(cLoraPath))

286
	loraAdapter := C.llama_adapter_lora_init(m.c, cLoraPath)
Jesse Gross's avatar
Jesse Gross committed
287
288
289
	if loraAdapter == nil {
		return errors.New("unable to load lora")
	}
290
291
292

	err := -1
	if loraAdapter != nil {
293
		err = int(C.llama_set_adapter_lora(context.c, loraAdapter, C.float(scale)))
294
295
296
297
298
299
300
301
	}
	if err != 0 {
		return errors.New("error applying lora from file")
	}

	return nil
}

302
303
304
305
func (m *Model) Vocab() *C.struct_llama_vocab {
	return C.llama_model_get_vocab(m.c)
}

306
307
308
type Batch struct {
	c         C.struct_llama_batch
	batchSize int
309
	maxSeq    int
310
311
312
	embedSize int
}

313
314
315
// Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
// Batches cannot contain both types at the same time. batchSize is the maximum number of entries
// that can be added per sequence
Jesse Gross's avatar
Jesse Gross committed
316
317
func NewBatch(batchSize int, maxSeq int, embedSize int) (*Batch, error) {
	b := Batch{
318
319
320
321
		c:         C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
		batchSize: batchSize,
		maxSeq:    maxSeq,
		embedSize: embedSize,
322
	}
Jesse Gross's avatar
Jesse Gross committed
323
324
325
326
327
328
329
330
331
332
333
334

	// Check to see if any of the allocations in llama_batch_init() failed
	nilPointer := (embedSize == 0 && b.c.token == nil) || (embedSize != 0 && b.c.embd == nil) ||
		b.c.pos == nil || b.c.n_seq_id == nil || b.c.seq_id == nil || b.c.logits == nil ||
		slices.Contains(unsafe.Slice(b.c.seq_id, b.allocSize()), nil)

	if nilPointer {
		C.llama_batch_free(b.c)
		return nil, fmt.Errorf("unable to allocate batch (batchSize=%v maxSeq=%v embedSize=%v)", batchSize, maxSeq, embedSize)
	}

	return &b, nil
335
336
}

337
338
339
340
341
342
343
344
func (b *Batch) Size() int {
	return b.batchSize
}

func (b *Batch) allocSize() int {
	return b.batchSize * b.maxSeq
}

345
346
347
348
349
350
351
352
353
354
355
356
func (b *Batch) NumTokens() int {
	return int(b.c.n_tokens)
}

func (b *Batch) IsEmbedding() bool {
	return b.embedSize != 0
}

// Add adds either a token or an image embedding to the batch depending on the type
// when the batch was initialized. The other argument will be ignored. Adds to the
// batch with the given position for the given sequence ids, and optionally instructs
// to include logits.
357
func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
358
	if !b.IsEmbedding() {
359
		unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
360
	} else {
361
		copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
362
	}
363
364
	unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
	unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
365
366

	for i, s := range seqIds {
367
		unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
368
369
370
	}

	if logits {
371
		unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
372
373
	} else {
		unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 0
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
	}

	b.c.n_tokens += 1
}

func (b *Batch) Clear() {
	b.c.n_tokens = 0
}

func (b *Batch) Free() {
	b.batchSize = 0
	C.llama_batch_free(b.c)
}

type Model struct {
	c *C.struct_llama_model
}

func (m *Model) TokenToPiece(token int) string {
	tokenLen := 12
	buf := make([]byte, tokenLen)
	tokenLen = int(C.llama_token_to_piece(
396
		m.Vocab(),
397
398
399
400
401
402
403
404
405
406
407
		C.int32_t(token),
		(*C.char)(unsafe.Pointer(&buf[0])),
		C.int32_t(tokenLen),
		C.int32_t(0),
		C.bool(true),
	))
	if tokenLen < 0 {
		tokenLen = -tokenLen

		buf = make([]byte, tokenLen)
		C.llama_token_to_piece(
408
			m.Vocab(),
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
			C.int32_t(token),
			(*C.char)(unsafe.Pointer(&buf[0])),
			C.int32_t(tokenLen),
			C.int32_t(0),
			C.bool(true),
		)
	}
	return strings.TrimRight(string(buf), "\x00")
}

func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
	maxTokens := len(text) + 2
	cTokens := make([]C.llama_token, maxTokens)
	cText := C.CString(text)
	defer C.free(unsafe.Pointer(cText))

	result := C.llama_tokenize(
426
		m.Vocab(),
427
428
429
430
431
432
433
434
435
436
437
438
439
		cText,
		C.int32_t(len(text)),
		&cTokens[0],
		C.int32_t(maxTokens),
		C.bool(addSpecial),
		C.bool(parseSpecial),
	)

	// if the result is negative, reallocate and retry with the correct buffer size
	if result < 0 {
		maxTokens = int(-result)
		cTokens = make([]C.llama_token, maxTokens)
		result = C.llama_tokenize(
440
			m.Vocab(),
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
			cText,
			C.int32_t(len(text)),
			&cTokens[0],
			C.int32_t(maxTokens),
			C.bool(addSpecial),
			C.bool(parseSpecial),
		)
		if result < 0 {
			return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
		}
	}

	tokens := make([]int, result)
	for i := range result {
		tokens[i] = int(cTokens[i])
	}

	return tokens, nil
}

func (m *Model) NEmbd() int {
462
	return int(C.llama_model_n_embd(m.c))
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
}

func Quantize(infile, outfile string, ftype uint32) error {
	cinfile := C.CString(infile)
	defer C.free(unsafe.Pointer(cinfile))

	coutfile := C.CString(outfile)
	defer C.free(unsafe.Pointer(coutfile))

	params := C.llama_model_quantize_default_params()
	params.nthread = -1
	params.ftype = ftype

	if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
		return fmt.Errorf("llama_model_quantize: %d", rc)
	}

	return nil
}

483
// vision processing
484
type ClipContext struct {
485
	c *C.struct_clip_ctx
486
487
}

488
func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
489
490
	mp := C.CString(modelPath)
	defer C.free(unsafe.Pointer(mp))
491
	c := C.clip_model_load(mp, 1)
Jesse Gross's avatar
Jesse Gross committed
492
493
494
	if c == nil {
		return nil, fmt.Errorf("unable to load clip model: %v", modelPath)
	}
495

496
497
498
499
	projEmbedSize := int(C.clip_n_mmproj_embd(c))
	modelEmbedSize := llamaContext.Model().NEmbd()
	if projEmbedSize != modelEmbedSize {
		return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
500
501
	}

502
	return &ClipContext{c: c}, nil
503
504
505
}

func (c *ClipContext) Free() {
506
	C.clip_free(c.c)
507
508
}

Jesse Gross's avatar
Jesse Gross committed
509
func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) {
510
	l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
Jesse Gross's avatar
Jesse Gross committed
511
512
513
	if l == nil {
		return nil, errors.New("unable to make llava embedding from image")
	}
514

515
	numTokens := int(l.n_image_pos)
516
517
	numEmbed := llamaContext.Model().NEmbd()

518
	s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
519
520
521
522
523
524
525
526
527

	embed := make([][]float32, numTokens)
	rows := make([]float32, len(s))
	copy(rows, s)

	for i := range embed {
		embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
	}

528
	C.llava_image_embed_free(l)
529

Jesse Gross's avatar
Jesse Gross committed
530
	return embed, nil
531
532
}

533
534
535
536
537
538
539
540
type MllamaContext struct {
	c *C.struct_mllama_ctx
}

func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
	mp := C.CString(modelPath)
	defer C.free(unsafe.Pointer(mp))
	c := C.mllama_model_load(mp, 1)
Jesse Gross's avatar
Jesse Gross committed
541
542
543
	if c == nil {
		return nil, fmt.Errorf("unable to load mllama model: %v", modelPath)
	}
544
545
546
547
548
549
550
551
552
553
554
555
556
557

	projEmbedSize := int(C.mllama_n_embd(c))
	modelEmbedSize := llamaContext.Model().NEmbd()
	if projEmbedSize != modelEmbedSize {
		return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
	}

	return &MllamaContext{c: c}, nil
}

func (m *MllamaContext) Free() {
	C.mllama_free(m.c)
}

Jesse Gross's avatar
Jesse Gross committed
558
func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) ([][]float32, error) {
559
560
561
	img := C.mllama_image_init()
	defer C.mllama_image_free(img)

Jesse Gross's avatar
Jesse Gross committed
562
563
564
565
	ok := bool(C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img))
	if !ok {
		return nil, errors.New("unable to load mllama image data")
	}
566

567
	rows := make([]float32, m.EmbedSize(llamaContext))
Jesse Gross's avatar
Jesse Gross committed
568
569
570
571
	ok = bool(C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0]))))
	if !ok {
		return nil, errors.New("unable to make mllama embedding from image")
	}
572

573
574
	embed := make([][]float32, 1)
	embed[0] = rows
575

Jesse Gross's avatar
Jesse Gross committed
576
	return embed, nil
577
578
}

579
580
581
func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
	numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
	numEmbed := llamaContext.Model().NEmbd()
582

583
584
	return numTokens * numEmbed
}
585

586
587
func (c *Context) SetCrossAttention(state bool) {
	C.llama_set_cross_attention(c.c, C.bool(state))
588
589
}

590
591
592
593
func (c *Context) Synchronize() {
	C.llama_synchronize(c.c)
}

594
595
596
// sampling
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
type SamplingContext struct {
597
	c *C.struct_common_sampler
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
}

type SamplingParams struct {
	TopK           int
	TopP           float32
	MinP           float32
	TypicalP       float32
	Temp           float32
	RepeatLastN    int
	PenaltyRepeat  float32
	PenaltyFreq    float32
	PenaltyPresent float32
	Mirostat       int
	MirostatTau    float32
	MirostatEta    float32
	PenalizeNl     bool
	Seed           uint32
	Grammar        string
}

Jesse Gross's avatar
Jesse Gross committed
618
func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) {
619
	var cparams C.struct_common_sampler_cparams
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
	cparams.top_k = C.int32_t(params.TopK)
	cparams.top_p = C.float(params.TopP)
	cparams.min_p = C.float(params.MinP)
	cparams.typical_p = C.float(params.TypicalP)
	cparams.temp = C.float(params.Temp)
	cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
	cparams.penalty_repeat = C.float(params.PenaltyRepeat)
	cparams.penalty_freq = C.float(params.PenaltyFreq)
	cparams.penalty_present = C.float(params.PenaltyFreq)
	cparams.mirostat = C.int32_t(params.Mirostat)
	cparams.mirostat_tau = C.float(params.MirostatTau)
	cparams.mirostat_eta = C.float(params.MirostatEta)
	cparams.seed = C.uint32_t(params.Seed)

	grammar := C.CString(params.Grammar)
	defer C.free(unsafe.Pointer(grammar))

	cparams.grammar = grammar
638
	context := &SamplingContext{c: C.common_sampler_cinit(model.c, &cparams)}
Jesse Gross's avatar
Jesse Gross committed
639
640
641
642
	if context.c == nil {
		return nil, errors.New("unable to create sampling context")
	}

643
	runtime.SetFinalizer(context, func(s *SamplingContext) { C.common_sampler_cfree(s.c) })
644

Jesse Gross's avatar
Jesse Gross committed
645
	return context, nil
646
647
648
}

func (s *SamplingContext) Reset() {
649
	C.common_sampler_creset(s.c)
650
651
}

652
func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
653
	return int(C.common_sampler_csample(s.c, llamaContext.c, C.int(idx)))
654
655
}

656
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
657
	C.common_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
658
}
659

660
661
662
663
// SchemaToGrammar converts the provided JSON schema to a grammar. It returns
// nil if the provided schema is invalid JSON or an invalid JSON schema.
func SchemaToGrammar(schema []byte) []byte {
	cStr := C.CString(string(schema))
664
665
666
667
668
669
670
	defer C.free(unsafe.Pointer(cStr))

	// Allocate buffer for grammar output with reasonable size
	const maxLen = 32768 // 32KB
	buf := make([]byte, maxLen)

	// Call C function to convert schema to grammar
671
672
673
674
	n := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
	if n == 0 {
		// preserve nil
		return nil
675
	}
676
	return buf[:n]
677
}
678

679
680
681
682
683
684
685
686
type TokenData struct {
	ID    int32
	Logit float32
}

type Grammar struct {
	c  *C.struct_llama_grammar
	mu sync.Mutex
687
688
}

689
func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []uint32) *Grammar {
690
691
692
	cGrammar := C.CString(grammar)
	defer C.free(unsafe.Pointer(cGrammar))

693
694
695
696
	cTokens := make([]C.uint32_t, len(vocabIds))
	for i, token := range vocabIds {
		cTokens[i] = C.uint32_t(token)
	}
697

698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
	cPieces := make([]*C.char, len(vocabValues))
	for i, piece := range vocabValues {
		cPieces[i] = C.CString(piece)
		defer C.free(unsafe.Pointer(cPieces[i]))
	}

	cEogTokens := make([]C.uint32_t, len(eogTokens))
	for i, token := range eogTokens {
		cEogTokens[i] = C.uint32_t(token)
	}

	g := C.grammar_init(cGrammar, (*C.uint32_t)(unsafe.Pointer(&cTokens[0])), C.size_t(len(cTokens)), (**C.char)(unsafe.Pointer(&cPieces[0])), (*C.uint32_t)(unsafe.Pointer(&cEogTokens[0])), C.size_t(len(cEogTokens)))
	if g == nil {
		return nil
	}
713

714
	return &Grammar{c: g}
715
716
}

717
718
719
720
721
722
723
func (g *Grammar) Free() {
	g.mu.Lock()
	defer g.mu.Unlock()
	if g.c != nil {
		C.grammar_free(g.c)
		g.c = nil
	}
724
725
}

726
727
728
729
730
731
732
733
func (g *Grammar) Apply(tokens []TokenData) {
	g.mu.Lock()
	defer g.mu.Unlock()

	if g.c == nil {
		return
	}

734
735
736
	tds := make([]C.struct_llama_token_data, len(tokens))
	for i, token := range tokens {
		tds[i] = C.struct_llama_token_data{
737
			id:    C.int32_t(token.ID),
738
739
740
741
742
743
744
745
746
747
748
749
750
751
			logit: C.float(token.Logit),
			p:     C.float(0.0),
		}
	}
	tda := &C.llama_token_data_array{
		data:     (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])),
		size:     C.size_t(len(tokens)),
		selected: C.int64_t(-1),
		sorted:   C.bool(false),
	}
	var pinner runtime.Pinner
	pinner.Pin(&tds[0])
	defer pinner.Unpin()

752
	C.grammar_apply(g.c, tda)
753
754
755
756
	for i := range tokens {
		tokens[i].Logit = float32(tds[i].logit)
	}
}
757
758
759
760
761
762
763
764
765
766
767
768

func (g *Grammar) Accept(token int32) {
	g.mu.Lock()
	defer g.mu.Unlock()

	// Check if grammar was freed
	if g.c == nil {
		return
	}

	C.grammar_accept(g.c, C.llama_token(token))
}