llama.go 19.1 KB
Newer Older
1
2
3
package llama

/*
Michael Yang's avatar
Michael Yang committed
4
5
6
7
8
9
10
#cgo CFLAGS: -std=c11
#cgo CXXFLAGS: -std=c++17
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/include
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/common
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/examples/llava
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/src
#cgo CPPFLAGS: -I${SRCDIR}/../ml/backend/ggml/ggml/include
11
12

#include <stdlib.h>
Michael Yang's avatar
Michael Yang committed
13
#include "ggml.h"
14
15
16
#include "llama.h"
#include "clip.h"
#include "llava.h"
17
#include "gguf.h"
Michael Yang's avatar
Michael Yang committed
18

19
#include "mllama.h"
20
21
#include "sampling_ext.h"

22
23
extern bool llamaProgressCallback(float progress, void *user_data);
extern void llamaLog(int level, char* text, void* user_data);
24
25
26
27
*/
import "C"

import (
28
	"context"
29
30
31
	_ "embed"
	"errors"
	"fmt"
32
	"log/slog"
33
	"os"
34
35
	"runtime"
	"runtime/cgo"
Jesse Gross's avatar
Jesse Gross committed
36
	"slices"
37
38
	"strings"
	"unsafe"
Michael Yang's avatar
Michael Yang committed
39
40
41
42

	_ "github.com/ollama/ollama/llama/llama.cpp/common"
	_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
	_ "github.com/ollama/ollama/llama/llama.cpp/src"
43
	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
44
45
)

46
47
48
49
50
51
52
53
54
55
56
57
func init() {
	C.llama_log_set(C.ggml_log_callback(C.llamaLog), nil)
}

//export llamaLog
func llamaLog(level C.int, text *C.char, _ unsafe.Pointer) {
	// slog levels zeros INFO and are multiples of 4
	if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
		fmt.Fprint(os.Stderr, C.GoString(text))
	}
}

58
func BackendInit() {
Michael Yang's avatar
Michael Yang committed
59
	ggml.OnceLoad()
60
61
62
	C.llama_backend_init()
}

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
func GetModelArch(modelPath string) (string, error) {
	mp := C.CString(modelPath)
	defer C.free(unsafe.Pointer(mp))

	gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
	if gguf_ctx == nil {
		return "", errors.New("unable to load model file")
	}
	defer C.gguf_free(gguf_ctx)

	key := C.CString("general.architecture")
	defer C.free(unsafe.Pointer(key))
	arch_index := C.gguf_find_key(gguf_ctx, key)
	if int(arch_index) < 0 {
		return "", errors.New("unknown model architecture")
	}

	arch := C.gguf_get_val_str(gguf_ctx, arch_index)

	return C.GoString(arch), nil
}

85
86
87
88
type ContextParams struct {
	c C.struct_llama_context_params
}

89
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
90
91
92
93
94
95
96
97
	params := C.llama_context_default_params()
	params.n_ctx = C.uint(numCtx)
	params.n_batch = C.uint(batchSize)
	params.n_seq_max = C.uint(numSeqMax)
	params.n_threads = C.int(threads)
	params.n_threads_batch = params.n_threads
	params.embeddings = C.bool(true)
	params.flash_attn = C.bool(flashAttention)
98
99
100
	params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
	params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))

101
102
103
	return ContextParams{c: params}
}

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
// kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
func kvCacheTypeFromStr(s string) C.enum_ggml_type {
	if s == "" {
		return C.GGML_TYPE_F16
	}

	switch s {
	case "q8_0":
		return C.GGML_TYPE_Q8_0
	case "q4_0":
		return C.GGML_TYPE_Q4_0
	default:
		return C.GGML_TYPE_F16
	}
}

120
121
122
123
124
type Context struct {
	c          *C.struct_llama_context
	numThreads int
}

125
var ErrKvCacheFull = errors.New("could not find a kv cache slot")
126
127
128
129
130
131
132
133
134
135
136
137
138

func (c *Context) Decode(batch *Batch) error {
	// Positive return values does not mean a fatal error, but rather a warning.
	//   0 - success
	//   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
	// < 0 - error
	code := int(C.llama_decode(c.c, batch.c))

	if code < 0 {
		return fmt.Errorf("llama_decode failed with code %d", code)
	}

	if code > 0 {
139
		return ErrKvCacheFull
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
	}

	return nil
}

func (c *Context) Model() *Model {
	return &Model{c: C.llama_get_model(c.c)}
}

func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
	C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
}

func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
	return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
}

func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
	C.llama_kv_cache_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
}

161
162
163
164
165
166
167
168
func (c *Context) KvCacheClear() {
	C.llama_kv_cache_clear(c.c)
}

func (c *Context) KvCacheDefrag() {
	C.llama_kv_cache_defrag(c.c)
}

169
170
171
172
func (c *Context) KvCacheCanShift() bool {
	return bool(C.llama_kv_cache_can_shift(c.c))
}

173
174
// Get the embeddings for a sequence id
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
175
176
	e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
	if e == nil {
177
178
179
		return nil
	}

180
181
182
	embeddings := make([]float32, c.Model().NEmbd())
	_ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
	return embeddings
183
184
185
}

func (c *Context) GetEmbeddingsIth(i int) []float32 {
186
187
	e := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
	if e == nil {
188
189
190
		return nil
	}

191
192
193
	embeddings := make([]float32, c.Model().NEmbd())
	_ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
	return embeddings
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
}

type ModelParams struct {
	NumGpuLayers int
	MainGpu      int
	UseMmap      bool
	UseMlock     bool
	TensorSplit  []float32
	Progress     func(float32)
	VocabOnly    bool
}

//export llamaProgressCallback
func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
	handle := *(*cgo.Handle)(userData)
	callback := handle.Value().(func(float32))
	callback(float32(progress))
	return true
}

214
func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
	cparams := C.llama_model_default_params()
	cparams.n_gpu_layers = C.int(params.NumGpuLayers)
	cparams.main_gpu = C.int32_t(params.MainGpu)
	cparams.use_mmap = C.bool(params.UseMmap)
	cparams.use_mlock = C.bool(params.UseMlock)
	cparams.vocab_only = C.bool(params.VocabOnly)

	if len(params.TensorSplit) > 0 {
		tensorSplitData := &params.TensorSplit[0]

		var tensorSplitPin runtime.Pinner
		tensorSplitPin.Pin(tensorSplitData)
		defer tensorSplitPin.Unpin()

		cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
	}

	if params.Progress != nil {
		handle := cgo.NewHandle(params.Progress)
		defer handle.Delete()

		var handlePin runtime.Pinner
		handlePin.Pin(&handle)
		defer handlePin.Unpin()

		cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
		cparams.progress_callback_user_data = unsafe.Pointer(&handle)
	}

244
	m := Model{c: C.llama_model_load_from_file(C.CString(modelPath), cparams)}
Jesse Gross's avatar
Jesse Gross committed
245
	if m.c == nil {
246
247
248
249
		return nil, fmt.Errorf("unable to load model: %s", modelPath)
	}

	return &m, nil
250
251
}

252
253
254
255
256
257
258
259
260
261
262
263
264
265
func LoadVocabFromFile(path string) (*Vocab, error) {
	mp := C.CString(path)
	defer C.free(unsafe.Pointer(mp))
	v := Vocab{c: C.llama_load_vocab_from_file(mp)}
	if v.c == nil {
		return nil, fmt.Errorf("unable to load vocab: %s", path)
	}
	return &v, nil
}

func FreeVocab(vocab *Vocab) {
	C.llama_free_vocab(vocab.c)
}

266
func FreeModel(model *Model) {
267
	C.llama_model_free(model.c)
268
269
}

270
271
func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
	c := Context{
272
		c:          C.llama_init_from_model(model.c, params.c),
273
274
		numThreads: int(params.c.n_threads),
	}
Jesse Gross's avatar
Jesse Gross committed
275
	if c.c == nil {
276
277
278
279
		return nil, errors.New("unable to create llama context")
	}

	return &c, nil
280
281
282
}

func (m *Model) NumVocab() int {
283
	return int(C.llama_vocab_n_tokens(m.Vocab()))
284
285
286
}

func (m *Model) TokenIsEog(token int) bool {
287
	return bool(C.llama_vocab_is_eog(m.Vocab(), C.llama_token(token)))
288
289
290
}

func (m *Model) AddBOSToken() bool {
291
	return bool(C.llama_vocab_get_add_bos(m.Vocab()))
292
293
294
295
296
297
}

func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
	cLoraPath := C.CString(loraPath)
	defer C.free(unsafe.Pointer(cLoraPath))

298
	loraAdapter := C.llama_adapter_lora_init(m.c, cLoraPath)
Jesse Gross's avatar
Jesse Gross committed
299
300
301
	if loraAdapter == nil {
		return errors.New("unable to load lora")
	}
302
303
304

	err := -1
	if loraAdapter != nil {
305
		err = int(C.llama_set_adapter_lora(context.c, loraAdapter, C.float(scale)))
306
307
308
309
310
311
312
313
	}
	if err != 0 {
		return errors.New("error applying lora from file")
	}

	return nil
}

314
315
316
317
type Vocab struct {
	c *C.struct_llama_vocab
}

318
319
320
321
func (m *Model) Vocab() *C.struct_llama_vocab {
	return C.llama_model_get_vocab(m.c)
}

322
323
324
type Batch struct {
	c         C.struct_llama_batch
	batchSize int
325
	maxSeq    int
326
327
328
	embedSize int
}

329
330
331
// Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
// Batches cannot contain both types at the same time. batchSize is the maximum number of entries
// that can be added per sequence
Jesse Gross's avatar
Jesse Gross committed
332
333
func NewBatch(batchSize int, maxSeq int, embedSize int) (*Batch, error) {
	b := Batch{
334
335
336
337
		c:         C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
		batchSize: batchSize,
		maxSeq:    maxSeq,
		embedSize: embedSize,
338
	}
Jesse Gross's avatar
Jesse Gross committed
339
340
341
342
343
344
345
346
347
348
349
350

	// Check to see if any of the allocations in llama_batch_init() failed
	nilPointer := (embedSize == 0 && b.c.token == nil) || (embedSize != 0 && b.c.embd == nil) ||
		b.c.pos == nil || b.c.n_seq_id == nil || b.c.seq_id == nil || b.c.logits == nil ||
		slices.Contains(unsafe.Slice(b.c.seq_id, b.allocSize()), nil)

	if nilPointer {
		C.llama_batch_free(b.c)
		return nil, fmt.Errorf("unable to allocate batch (batchSize=%v maxSeq=%v embedSize=%v)", batchSize, maxSeq, embedSize)
	}

	return &b, nil
351
352
}

353
354
355
356
357
358
359
360
func (b *Batch) Size() int {
	return b.batchSize
}

func (b *Batch) allocSize() int {
	return b.batchSize * b.maxSeq
}

361
362
363
364
365
366
367
368
369
370
371
372
func (b *Batch) NumTokens() int {
	return int(b.c.n_tokens)
}

func (b *Batch) IsEmbedding() bool {
	return b.embedSize != 0
}

// Add adds either a token or an image embedding to the batch depending on the type
// when the batch was initialized. The other argument will be ignored. Adds to the
// batch with the given position for the given sequence ids, and optionally instructs
// to include logits.
373
func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
374
	if !b.IsEmbedding() {
375
		unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
376
	} else {
377
		copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
378
	}
379
380
	unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
	unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
381
382

	for i, s := range seqIds {
383
		unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
384
385
386
	}

	if logits {
387
		unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
388
389
	} else {
		unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 0
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
	}

	b.c.n_tokens += 1
}

func (b *Batch) Clear() {
	b.c.n_tokens = 0
}

func (b *Batch) Free() {
	b.batchSize = 0
	C.llama_batch_free(b.c)
}

type Model struct {
	c *C.struct_llama_model
}

func (m *Model) TokenToPiece(token int) string {
	tokenLen := 12
	buf := make([]byte, tokenLen)
	tokenLen = int(C.llama_token_to_piece(
412
		m.Vocab(),
413
414
415
416
417
418
419
420
421
422
423
		C.int32_t(token),
		(*C.char)(unsafe.Pointer(&buf[0])),
		C.int32_t(tokenLen),
		C.int32_t(0),
		C.bool(true),
	))
	if tokenLen < 0 {
		tokenLen = -tokenLen

		buf = make([]byte, tokenLen)
		C.llama_token_to_piece(
424
			m.Vocab(),
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
			C.int32_t(token),
			(*C.char)(unsafe.Pointer(&buf[0])),
			C.int32_t(tokenLen),
			C.int32_t(0),
			C.bool(true),
		)
	}
	return strings.TrimRight(string(buf), "\x00")
}

func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
	maxTokens := len(text) + 2
	cTokens := make([]C.llama_token, maxTokens)
	cText := C.CString(text)
	defer C.free(unsafe.Pointer(cText))

	result := C.llama_tokenize(
442
		m.Vocab(),
443
444
445
446
447
448
449
450
451
452
453
454
455
		cText,
		C.int32_t(len(text)),
		&cTokens[0],
		C.int32_t(maxTokens),
		C.bool(addSpecial),
		C.bool(parseSpecial),
	)

	// if the result is negative, reallocate and retry with the correct buffer size
	if result < 0 {
		maxTokens = int(-result)
		cTokens = make([]C.llama_token, maxTokens)
		result = C.llama_tokenize(
456
			m.Vocab(),
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
			cText,
			C.int32_t(len(text)),
			&cTokens[0],
			C.int32_t(maxTokens),
			C.bool(addSpecial),
			C.bool(parseSpecial),
		)
		if result < 0 {
			return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
		}
	}

	tokens := make([]int, result)
	for i := range result {
		tokens[i] = int(cTokens[i])
	}

	return tokens, nil
}

func (m *Model) NEmbd() int {
478
	return int(C.llama_model_n_embd(m.c))
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
}

func Quantize(infile, outfile string, ftype uint32) error {
	cinfile := C.CString(infile)
	defer C.free(unsafe.Pointer(cinfile))

	coutfile := C.CString(outfile)
	defer C.free(unsafe.Pointer(coutfile))

	params := C.llama_model_quantize_default_params()
	params.nthread = -1
	params.ftype = ftype

	if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
		return fmt.Errorf("llama_model_quantize: %d", rc)
	}

	return nil
}

499
// vision processing
500
type ClipContext struct {
501
	c *C.struct_clip_ctx
502
503
}

504
func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
505
506
	mp := C.CString(modelPath)
	defer C.free(unsafe.Pointer(mp))
507
	c := C.clip_model_load(mp, 1)
Jesse Gross's avatar
Jesse Gross committed
508
509
510
	if c == nil {
		return nil, fmt.Errorf("unable to load clip model: %v", modelPath)
	}
511

512
513
514
515
	projEmbedSize := int(C.clip_n_mmproj_embd(c))
	modelEmbedSize := llamaContext.Model().NEmbd()
	if projEmbedSize != modelEmbedSize {
		return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
516
517
	}

518
	return &ClipContext{c: c}, nil
519
520
521
}

func (c *ClipContext) Free() {
522
	C.clip_free(c.c)
523
524
}

Jesse Gross's avatar
Jesse Gross committed
525
func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) {
526
	l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
Jesse Gross's avatar
Jesse Gross committed
527
528
529
	if l == nil {
		return nil, errors.New("unable to make llava embedding from image")
	}
530

531
	numTokens := int(l.n_image_pos)
532
533
	numEmbed := llamaContext.Model().NEmbd()

534
	s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
535
536
537
538
539
540
541
542
543

	embed := make([][]float32, numTokens)
	rows := make([]float32, len(s))
	copy(rows, s)

	for i := range embed {
		embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
	}

544
	C.llava_image_embed_free(l)
545

Jesse Gross's avatar
Jesse Gross committed
546
	return embed, nil
547
548
}

549
550
551
552
553
554
555
556
type MllamaContext struct {
	c *C.struct_mllama_ctx
}

func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
	mp := C.CString(modelPath)
	defer C.free(unsafe.Pointer(mp))
	c := C.mllama_model_load(mp, 1)
Jesse Gross's avatar
Jesse Gross committed
557
558
559
	if c == nil {
		return nil, fmt.Errorf("unable to load mllama model: %v", modelPath)
	}
560
561
562
563
564
565
566
567
568
569
570
571
572
573

	projEmbedSize := int(C.mllama_n_embd(c))
	modelEmbedSize := llamaContext.Model().NEmbd()
	if projEmbedSize != modelEmbedSize {
		return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
	}

	return &MllamaContext{c: c}, nil
}

func (m *MllamaContext) Free() {
	C.mllama_free(m.c)
}

Jesse Gross's avatar
Jesse Gross committed
574
func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) ([][]float32, error) {
575
576
577
	img := C.mllama_image_init()
	defer C.mllama_image_free(img)

Jesse Gross's avatar
Jesse Gross committed
578
579
580
581
	ok := bool(C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img))
	if !ok {
		return nil, errors.New("unable to load mllama image data")
	}
582

583
	rows := make([]float32, m.EmbedSize(llamaContext))
Jesse Gross's avatar
Jesse Gross committed
584
585
586
587
	ok = bool(C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0]))))
	if !ok {
		return nil, errors.New("unable to make mllama embedding from image")
	}
588

589
590
	embed := make([][]float32, 1)
	embed[0] = rows
591

Jesse Gross's avatar
Jesse Gross committed
592
	return embed, nil
593
594
}

595
596
597
func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
	numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
	numEmbed := llamaContext.Model().NEmbd()
598

599
600
	return numTokens * numEmbed
}
601

602
603
func (c *Context) SetCrossAttention(state bool) {
	C.llama_set_cross_attention(c.c, C.bool(state))
604
605
}

606
607
608
609
func (c *Context) Synchronize() {
	C.llama_synchronize(c.c)
}

610
611
612
// sampling
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
type SamplingContext struct {
613
	c *C.struct_common_sampler
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
}

type SamplingParams struct {
	TopK           int
	TopP           float32
	MinP           float32
	TypicalP       float32
	Temp           float32
	RepeatLastN    int
	PenaltyRepeat  float32
	PenaltyFreq    float32
	PenaltyPresent float32
	Mirostat       int
	MirostatTau    float32
	MirostatEta    float32
	PenalizeNl     bool
	Seed           uint32
	Grammar        string
}

Jesse Gross's avatar
Jesse Gross committed
634
func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) {
635
	var cparams C.struct_common_sampler_cparams
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
	cparams.top_k = C.int32_t(params.TopK)
	cparams.top_p = C.float(params.TopP)
	cparams.min_p = C.float(params.MinP)
	cparams.typical_p = C.float(params.TypicalP)
	cparams.temp = C.float(params.Temp)
	cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
	cparams.penalty_repeat = C.float(params.PenaltyRepeat)
	cparams.penalty_freq = C.float(params.PenaltyFreq)
	cparams.penalty_present = C.float(params.PenaltyFreq)
	cparams.mirostat = C.int32_t(params.Mirostat)
	cparams.mirostat_tau = C.float(params.MirostatTau)
	cparams.mirostat_eta = C.float(params.MirostatEta)
	cparams.seed = C.uint32_t(params.Seed)

	grammar := C.CString(params.Grammar)
	defer C.free(unsafe.Pointer(grammar))

	cparams.grammar = grammar
654
	context := &SamplingContext{c: C.common_sampler_cinit(model.c, &cparams)}
Jesse Gross's avatar
Jesse Gross committed
655
656
657
658
	if context.c == nil {
		return nil, errors.New("unable to create sampling context")
	}

659
	runtime.SetFinalizer(context, func(s *SamplingContext) { C.common_sampler_cfree(s.c) })
660

Jesse Gross's avatar
Jesse Gross committed
661
	return context, nil
662
663
664
}

func (s *SamplingContext) Reset() {
665
	C.common_sampler_creset(s.c)
666
667
}

668
func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
669
	return int(C.common_sampler_csample(s.c, llamaContext.c, C.int(idx)))
670
671
}

672
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
673
	C.common_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
674
}
675

676
677
678
679
// SchemaToGrammar converts the provided JSON schema to a grammar. It returns
// nil if the provided schema is invalid JSON or an invalid JSON schema.
func SchemaToGrammar(schema []byte) []byte {
	cStr := C.CString(string(schema))
680
681
682
683
684
685
686
	defer C.free(unsafe.Pointer(cStr))

	// Allocate buffer for grammar output with reasonable size
	const maxLen = 32768 // 32KB
	buf := make([]byte, maxLen)

	// Call C function to convert schema to grammar
687
688
689
690
	n := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
	if n == 0 {
		// preserve nil
		return nil
691
	}
692
	return buf[:n]
693
}
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743

type Sampler struct {
	c *C.struct_llama_sampler
}

func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler {
	cGrammar := C.CString(grammar)
	cRoot := C.CString("root")
	defer C.free(unsafe.Pointer(cGrammar))
	defer C.free(unsafe.Pointer(cRoot))

	sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)}

	return sampler
}

func (s *Sampler) Accept(token int32) {
	C.llama_sampler_accept(s.c, C.llama_token(token))
}

type TokenData struct {
	Id    int32
	Logit float32
}

func (s *Sampler) Apply(tokens []TokenData) {
	tds := make([]C.struct_llama_token_data, len(tokens))
	for i, token := range tokens {
		tds[i] = C.struct_llama_token_data{
			id:    C.int32_t(token.Id),
			logit: C.float(token.Logit),
			p:     C.float(0.0),
		}
	}
	tda := &C.llama_token_data_array{
		data:     (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])),
		size:     C.size_t(len(tokens)),
		selected: C.int64_t(-1),
		sorted:   C.bool(false),
	}

	var pinner runtime.Pinner
	pinner.Pin(&tds[0])
	defer pinner.Unpin()

	C.llama_sampler_apply(s.c, tda)
	for i := range tokens {
		tokens[i].Logit = float32(tds[i].logit)
	}
}