llama.go 19.9 KB
Newer Older
1
2
3
package llama

/*
Michael Yang's avatar
Michael Yang committed
4
#cgo CFLAGS: -std=c11
5
#cgo windows CFLAGS: -Wno-dll-attribute-on-redeclaration
Michael Yang's avatar
Michael Yang committed
6
7
8
#cgo CXXFLAGS: -std=c++17
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/include
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/common
9
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/vendor
10
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/tools/mtmd
Michael Yang's avatar
Michael Yang committed
11
12
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/src
#cgo CPPFLAGS: -I${SRCDIR}/../ml/backend/ggml/ggml/include
13
14

#include <stdlib.h>
Michael Yang's avatar
Michael Yang committed
15
#include "ggml.h"
16
#include "llama.h"
17
18
#include "mtmd.h"
#include "mtmd-helper.h"
19
#include "gguf.h"
Michael Yang's avatar
Michael Yang committed
20

21
22
#include "sampling_ext.h"

23
24
extern bool llamaProgressCallback(float progress, void *user_data);
extern void llamaLog(int level, char* text, void* user_data);
25
26
27
28
*/
import "C"

import (
29
	"context"
30
31
32
	_ "embed"
	"errors"
	"fmt"
33
	"log/slog"
34
	"os"
35
36
	"runtime"
	"runtime/cgo"
Jesse Gross's avatar
Jesse Gross committed
37
	"slices"
38
	"strings"
39
	"sync"
40
	"unsafe"
Michael Yang's avatar
Michael Yang committed
41
42
43

	_ "github.com/ollama/ollama/llama/llama.cpp/common"
	_ "github.com/ollama/ollama/llama/llama.cpp/src"
44
	_ "github.com/ollama/ollama/llama/llama.cpp/tools/mtmd"
45
	"github.com/ollama/ollama/ml"
46
	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
47
48
)

49
50
51
52
53
54
55
56
57
58
59
60
func init() {
	C.llama_log_set(C.ggml_log_callback(C.llamaLog), nil)
}

//export llamaLog
func llamaLog(level C.int, text *C.char, _ unsafe.Pointer) {
	// slog levels zeros INFO and are multiples of 4
	if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
		fmt.Fprint(os.Stderr, C.GoString(text))
	}
}

61
func BackendInit() {
Michael Yang's avatar
Michael Yang committed
62
	ggml.OnceLoad()
63
64
65
	C.llama_backend_init()
}

66
67
func EnumerateGPUs() []ml.DeviceID {
	var ids []ml.DeviceID
Jesse Gross's avatar
Jesse Gross committed
68
69
70
71

	for i := range C.ggml_backend_dev_count() {
		device := C.ggml_backend_dev_get(i)

72
73
74
		switch C.ggml_backend_dev_type(device) {
		case C.GGML_BACKEND_DEVICE_TYPE_GPU,
			C.GGML_BACKEND_DEVICE_TYPE_IGPU:
Jesse Gross's avatar
Jesse Gross committed
75
76
			var props C.struct_ggml_backend_dev_props
			C.ggml_backend_dev_get_props(device, &props)
77
78
79
80
			ids = append(ids, ml.DeviceID{
				ID:      C.GoString(props.id),
				Library: C.GoString(props.library),
			})
Jesse Gross's avatar
Jesse Gross committed
81
82
83
84
85
86
		}
	}

	return ids
}

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
func GetModelArch(modelPath string) (string, error) {
	mp := C.CString(modelPath)
	defer C.free(unsafe.Pointer(mp))

	gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
	if gguf_ctx == nil {
		return "", errors.New("unable to load model file")
	}
	defer C.gguf_free(gguf_ctx)

	key := C.CString("general.architecture")
	defer C.free(unsafe.Pointer(key))
	arch_index := C.gguf_find_key(gguf_ctx, key)
	if int(arch_index) < 0 {
		return "", errors.New("unknown model architecture")
	}

	arch := C.gguf_get_val_str(gguf_ctx, arch_index)

	return C.GoString(arch), nil
}

109
110
111
112
type ContextParams struct {
	c C.struct_llama_context_params
}

113
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
114
115
116
117
118
119
120
	params := C.llama_context_default_params()
	params.n_ctx = C.uint(numCtx)
	params.n_batch = C.uint(batchSize)
	params.n_seq_max = C.uint(numSeqMax)
	params.n_threads = C.int(threads)
	params.n_threads_batch = params.n_threads
	params.embeddings = C.bool(true)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
121
122
123
124
125
	if flashAttention {
		params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_ENABLED
	} else {
		params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_DISABLED
	}
126
127
128
	params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
	params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))

129
130
131
	return ContextParams{c: params}
}

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
// kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
func kvCacheTypeFromStr(s string) C.enum_ggml_type {
	if s == "" {
		return C.GGML_TYPE_F16
	}

	switch s {
	case "q8_0":
		return C.GGML_TYPE_Q8_0
	case "q4_0":
		return C.GGML_TYPE_Q4_0
	default:
		return C.GGML_TYPE_F16
	}
}

148
149
150
151
152
type Context struct {
	c          *C.struct_llama_context
	numThreads int
}

153
var ErrKvCacheFull = errors.New("could not find a kv cache slot")
154
155
156
157
158
159
160
161
162
163
164
165
166

func (c *Context) Decode(batch *Batch) error {
	// Positive return values does not mean a fatal error, but rather a warning.
	//   0 - success
	//   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
	// < 0 - error
	code := int(C.llama_decode(c.c, batch.c))

	if code < 0 {
		return fmt.Errorf("llama_decode failed with code %d", code)
	}

	if code > 0 {
167
		return ErrKvCacheFull
168
169
170
171
172
173
174
175
176
177
	}

	return nil
}

func (c *Context) Model() *Model {
	return &Model{c: C.llama_get_model(c.c)}
}

func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
178
	C.llama_memory_seq_add(C.llama_get_memory(c.c), C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
179
180
181
}

func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
182
	return bool(C.llama_memory_seq_rm(C.llama_get_memory(c.c), C.int(seqId), C.int(p0), C.int(p1)))
183
184
185
}

func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
186
	C.llama_memory_seq_cp(C.llama_get_memory(c.c), C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
187
188
}

189
func (c *Context) KvCacheClear() {
190
	C.llama_memory_clear(C.llama_get_memory(c.c), true)
191
192
}

193
func (c *Context) KvCacheCanShift() bool {
194
	return bool(C.llama_memory_can_shift(C.llama_get_memory(c.c)))
195
196
}

197
198
// Get the embeddings for a sequence id
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
199
200
	e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
	if e == nil {
201
202
203
		return nil
	}

204
205
206
	embeddings := make([]float32, c.Model().NEmbd())
	_ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
	return embeddings
207
208
209
}

func (c *Context) GetEmbeddingsIth(i int) []float32 {
210
211
	e := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
	if e == nil {
212
213
214
		return nil
	}

215
216
217
	embeddings := make([]float32, c.Model().NEmbd())
	_ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
	return embeddings
218
219
}

220
221
222
223
224
225
226
227
228
229
230
231
232
// GetLogitsIth gets the logits for the ith token
func (c *Context) GetLogitsIth(i int) []float32 {
	logits := unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int32_t(i)))
	if logits == nil {
		return nil
	}

	vocabSize := c.Model().NumVocab()
	result := make([]float32, vocabSize)
	_ = copy(result, unsafe.Slice((*float32)(logits), vocabSize))
	return result
}

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
type ModelParams struct {
	NumGpuLayers int
	MainGpu      int
	UseMmap      bool
	TensorSplit  []float32
	Progress     func(float32)
	VocabOnly    bool
}

//export llamaProgressCallback
func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
	handle := *(*cgo.Handle)(userData)
	callback := handle.Value().(func(float32))
	callback(float32(progress))
	return true
}

250
func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
	cparams := C.llama_model_default_params()
	cparams.n_gpu_layers = C.int(params.NumGpuLayers)
	cparams.main_gpu = C.int32_t(params.MainGpu)
	cparams.use_mmap = C.bool(params.UseMmap)
	cparams.vocab_only = C.bool(params.VocabOnly)

	if len(params.TensorSplit) > 0 {
		tensorSplitData := &params.TensorSplit[0]

		var tensorSplitPin runtime.Pinner
		tensorSplitPin.Pin(tensorSplitData)
		defer tensorSplitPin.Unpin()

		cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
	}

	if params.Progress != nil {
		handle := cgo.NewHandle(params.Progress)
		defer handle.Delete()

		var handlePin runtime.Pinner
		handlePin.Pin(&handle)
		defer handlePin.Unpin()

		cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
		cparams.progress_callback_user_data = unsafe.Pointer(&handle)
	}

279
	m := Model{c: C.llama_model_load_from_file(C.CString(modelPath), cparams)}
Jesse Gross's avatar
Jesse Gross committed
280
	if m.c == nil {
281
282
283
284
		return nil, fmt.Errorf("unable to load model: %s", modelPath)
	}

	return &m, nil
285
286
287
}

func FreeModel(model *Model) {
288
	C.llama_model_free(model.c)
289
290
}

291
292
func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
	c := Context{
293
		c:          C.llama_init_from_model(model.c, params.c),
294
295
		numThreads: int(params.c.n_threads),
	}
Jesse Gross's avatar
Jesse Gross committed
296
	if c.c == nil {
297
298
299
300
		return nil, errors.New("unable to create llama context")
	}

	return &c, nil
301
302
303
}

func (m *Model) NumVocab() int {
304
	return int(C.llama_vocab_n_tokens(m.Vocab()))
305
306
307
}

func (m *Model) TokenIsEog(token int) bool {
308
	return bool(C.llama_vocab_is_eog(m.Vocab(), C.llama_token(token)))
309
310
311
}

func (m *Model) AddBOSToken() bool {
312
	return bool(C.llama_vocab_get_add_bos(m.Vocab()))
313
314
315
316
317
318
}

func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
	cLoraPath := C.CString(loraPath)
	defer C.free(unsafe.Pointer(cLoraPath))

319
	loraAdapter := C.llama_adapter_lora_init(m.c, cLoraPath)
Jesse Gross's avatar
Jesse Gross committed
320
321
322
	if loraAdapter == nil {
		return errors.New("unable to load lora")
	}
323
324
325

	err := -1
	if loraAdapter != nil {
326
		err = int(C.llama_set_adapter_lora(context.c, loraAdapter, C.float(scale)))
327
328
329
330
331
332
333
334
	}
	if err != 0 {
		return errors.New("error applying lora from file")
	}

	return nil
}

335
336
337
338
func (m *Model) Vocab() *C.struct_llama_vocab {
	return C.llama_model_get_vocab(m.c)
}

339
340
341
type Batch struct {
	c         C.struct_llama_batch
	batchSize int
342
	maxSeq    int
343
344
345
	embedSize int
}

346
347
348
// Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
// Batches cannot contain both types at the same time. batchSize is the maximum number of entries
// that can be added per sequence
Jesse Gross's avatar
Jesse Gross committed
349
350
func NewBatch(batchSize int, maxSeq int, embedSize int) (*Batch, error) {
	b := Batch{
351
352
353
354
		c:         C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
		batchSize: batchSize,
		maxSeq:    maxSeq,
		embedSize: embedSize,
355
	}
Jesse Gross's avatar
Jesse Gross committed
356
357
358
359
360
361
362
363
364
365
366
367

	// Check to see if any of the allocations in llama_batch_init() failed
	nilPointer := (embedSize == 0 && b.c.token == nil) || (embedSize != 0 && b.c.embd == nil) ||
		b.c.pos == nil || b.c.n_seq_id == nil || b.c.seq_id == nil || b.c.logits == nil ||
		slices.Contains(unsafe.Slice(b.c.seq_id, b.allocSize()), nil)

	if nilPointer {
		C.llama_batch_free(b.c)
		return nil, fmt.Errorf("unable to allocate batch (batchSize=%v maxSeq=%v embedSize=%v)", batchSize, maxSeq, embedSize)
	}

	return &b, nil
368
369
}

370
371
372
373
374
375
376
377
func (b *Batch) Size() int {
	return b.batchSize
}

func (b *Batch) allocSize() int {
	return b.batchSize * b.maxSeq
}

378
379
380
381
382
383
384
385
386
387
388
389
func (b *Batch) NumTokens() int {
	return int(b.c.n_tokens)
}

func (b *Batch) IsEmbedding() bool {
	return b.embedSize != 0
}

// Add adds either a token or an image embedding to the batch depending on the type
// when the batch was initialized. The other argument will be ignored. Adds to the
// batch with the given position for the given sequence ids, and optionally instructs
// to include logits.
390
func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
391
	if !b.IsEmbedding() {
392
		unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
393
	} else {
394
		copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
395
	}
396
397
	unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
	unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
398
399

	for i, s := range seqIds {
400
		unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
401
402
403
	}

	if logits {
404
		unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
405
406
	} else {
		unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 0
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
	}

	b.c.n_tokens += 1
}

func (b *Batch) Clear() {
	b.c.n_tokens = 0
}

func (b *Batch) Free() {
	b.batchSize = 0
	C.llama_batch_free(b.c)
}

type Model struct {
	c *C.struct_llama_model
}

func (m *Model) TokenToPiece(token int) string {
	tokenLen := 12
	buf := make([]byte, tokenLen)
	tokenLen = int(C.llama_token_to_piece(
429
		m.Vocab(),
430
431
432
433
434
435
436
437
438
439
440
		C.int32_t(token),
		(*C.char)(unsafe.Pointer(&buf[0])),
		C.int32_t(tokenLen),
		C.int32_t(0),
		C.bool(true),
	))
	if tokenLen < 0 {
		tokenLen = -tokenLen

		buf = make([]byte, tokenLen)
		C.llama_token_to_piece(
441
			m.Vocab(),
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
			C.int32_t(token),
			(*C.char)(unsafe.Pointer(&buf[0])),
			C.int32_t(tokenLen),
			C.int32_t(0),
			C.bool(true),
		)
	}
	return strings.TrimRight(string(buf), "\x00")
}

func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
	maxTokens := len(text) + 2
	cTokens := make([]C.llama_token, maxTokens)
	cText := C.CString(text)
	defer C.free(unsafe.Pointer(cText))

	result := C.llama_tokenize(
459
		m.Vocab(),
460
461
462
463
464
465
466
467
468
469
470
471
472
		cText,
		C.int32_t(len(text)),
		&cTokens[0],
		C.int32_t(maxTokens),
		C.bool(addSpecial),
		C.bool(parseSpecial),
	)

	// if the result is negative, reallocate and retry with the correct buffer size
	if result < 0 {
		maxTokens = int(-result)
		cTokens = make([]C.llama_token, maxTokens)
		result = C.llama_tokenize(
473
			m.Vocab(),
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
			cText,
			C.int32_t(len(text)),
			&cTokens[0],
			C.int32_t(maxTokens),
			C.bool(addSpecial),
			C.bool(parseSpecial),
		)
		if result < 0 {
			return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
		}
	}

	tokens := make([]int, result)
	for i := range result {
		tokens[i] = int(cTokens[i])
	}

	return tokens, nil
}

func (m *Model) NEmbd() int {
495
	return int(C.llama_model_n_embd(m.c))
496
497
}

498
// vision processing
499
500
type MtmdContext struct {
	c *C.struct_mtmd_context
501
502
}

503
func NewMtmdContext(llamaContext *Context, modelPath string) (*MtmdContext, error) {
504
505
	mp := C.CString(modelPath)
	defer C.free(unsafe.Pointer(mp))
506
507
	// TODO: Support non-default params
	cp := C.mtmd_context_params_default()
508

509
510
511
512
	// NOTE: The model and projector embedding lengths are checked during init
	c := C.mtmd_init_from_file(mp, C.llama_get_model(llamaContext.c), cp)
	if c == nil {
		return nil, fmt.Errorf("unable to load mmtd model: %v", modelPath)
513
514
	}

515
	return &MtmdContext{c: c}, nil
516
517
}

518
519
func (c *MtmdContext) Free() {
	C.mtmd_free(c.c)
520
521
}

522
523
524
525
526
527
type MtmdChunk struct {
	Embed  []float32
	Tokens []int
}

func (c *MtmdContext) MultimodalTokenize(llamaContext *Context, data []byte) ([]MtmdChunk, error) {
528
529
530
531
532
533
534
	// Initialize the input chunks pointer
	ic := C.mtmd_input_chunks_init()
	defer C.mtmd_input_chunks_free(ic)

	// Initialize an empty text prompt so we can tokenize
	it := C.mtmd_input_text_init(C.mtmd_default_marker(), true, true)
	defer C.mtmd_input_text_free(it)
535

536
537
538
539
540
541
542
543
544
	// Initialize a bitmap with the image data
	bm := C.mtmd_helper_bitmap_init_from_buf(c.c, (*C.uchar)(unsafe.Pointer(&data[0])), C.size_t(len(data)))
	defer C.mtmd_bitmap_free(bm)

	// Tokenize the image
	if C.int32_t(0) != C.mtmd_tokenize(c.c, ic, it, &bm, 1) {
		return nil, errors.New("unable to tokenize mtmd embedding from image")
	}
	nChunks := C.mtmd_input_chunks_size(ic)
545
	numEmbed := llamaContext.Model().NEmbd()
546
	outChunks := make([]MtmdChunk, 0)
547
548
549
	for i := range int(nChunks) {
		chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
		numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
550
		slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
551

552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
		if C.mtmd_input_chunk_get_type(chunk) == C.MTMD_INPUT_CHUNK_TYPE_TEXT {
			// If this is a text chunk, add the tokens
			cNumTokens := C.size_t(0)
			cTokens := C.mtmd_input_chunk_get_tokens_text(chunk, &cNumTokens)
			cTokensArr := unsafe.Slice(cTokens, int(cNumTokens))
			tokens := make([]int, int(cNumTokens))
			for j := range int(cNumTokens) {
				tokens[j] = int(cTokensArr[j])
			}
			outChunks = append(outChunks, MtmdChunk{Tokens: tokens})
		} else {
			// Otherwise, encode the image chunk to embeddings

			// Encode the chunk
			if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
				return nil, errors.New("unable to encode mtmd image chunk")
			}

			// Get the embeddings for this chunk
			chunkEmbed := make([][]float32, numTokens)
			chunkEmbd := C.mtmd_get_output_embd(c.c)
			if nil == chunkEmbd {
				return nil, errors.New("no mtmd image embedding")
			}

			// Extend the embedding array for each token
			s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
			rows := make([]float32, len(s))
			copy(rows, s)
			for i := range numTokens {
				chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
			}
			for _, e := range chunkEmbed {
				outChunks = append(outChunks, MtmdChunk{Embed: e})
			}
587
		}
588
	}
589
590
	slog.Debug("image tokenization chunks", "totalChunks", len(outChunks))
	return outChunks, nil
591
592
}

593
594
595
596
func (c *Context) Synchronize() {
	C.llama_synchronize(c.c)
}

597
598
599
// sampling
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
type SamplingContext struct {
600
	c *C.struct_common_sampler
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
}

type SamplingParams struct {
	TopK           int
	TopP           float32
	MinP           float32
	TypicalP       float32
	Temp           float32
	RepeatLastN    int
	PenaltyRepeat  float32
	PenaltyFreq    float32
	PenaltyPresent float32
	PenalizeNl     bool
	Seed           uint32
	Grammar        string
}

Jesse Gross's avatar
Jesse Gross committed
618
func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) {
619
	var cparams C.struct_common_sampler_cparams
620
621
622
623
624
625
626
627
	cparams.top_k = C.int32_t(params.TopK)
	cparams.top_p = C.float(params.TopP)
	cparams.min_p = C.float(params.MinP)
	cparams.typical_p = C.float(params.TypicalP)
	cparams.temp = C.float(params.Temp)
	cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
	cparams.penalty_repeat = C.float(params.PenaltyRepeat)
	cparams.penalty_freq = C.float(params.PenaltyFreq)
628
	cparams.penalty_present = C.float(params.PenaltyPresent)
629
630
631
632
633
634
	cparams.seed = C.uint32_t(params.Seed)

	grammar := C.CString(params.Grammar)
	defer C.free(unsafe.Pointer(grammar))

	cparams.grammar = grammar
635
	context := &SamplingContext{c: C.common_sampler_cinit(model.c, &cparams)}
Jesse Gross's avatar
Jesse Gross committed
636
637
638
639
	if context.c == nil {
		return nil, errors.New("unable to create sampling context")
	}

640
	runtime.SetFinalizer(context, func(s *SamplingContext) { C.common_sampler_cfree(s.c) })
641

Jesse Gross's avatar
Jesse Gross committed
642
	return context, nil
643
644
645
}

func (s *SamplingContext) Reset() {
646
	C.common_sampler_creset(s.c)
647
648
}

649
func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
650
	return int(C.common_sampler_csample(s.c, llamaContext.c, C.int(idx)))
651
652
}

653
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
654
	C.common_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
655
}
656

657
658
659
660
// SchemaToGrammar converts the provided JSON schema to a grammar. It returns
// nil if the provided schema is invalid JSON or an invalid JSON schema.
func SchemaToGrammar(schema []byte) []byte {
	cStr := C.CString(string(schema))
661
662
	defer C.free(unsafe.Pointer(cStr))

663
	// Allocate buffer for grammar based on schema length but with upper bound
664
	maxLen := max(32768, min(1024*1024, len(schema)*4))
665
666
667
	buf := make([]byte, maxLen)

	// Call C function to convert schema to grammar
668
669
670
671
	n := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
	if n == 0 {
		// preserve nil
		return nil
672
	}
673
	return buf[:n]
674
}
675

676
677
678
679
680
681
682
683
type TokenData struct {
	ID    int32
	Logit float32
}

type Grammar struct {
	c  *C.struct_llama_grammar
	mu sync.Mutex
684
685
}

686
func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []int32) *Grammar {
687
688
689
	cGrammar := C.CString(grammar)
	defer C.free(unsafe.Pointer(cGrammar))

690
691
692
693
	cTokens := make([]C.uint32_t, len(vocabIds))
	for i, token := range vocabIds {
		cTokens[i] = C.uint32_t(token)
	}
694

695
696
697
698
699
700
701
702
703
704
705
	cPieces := make([]*C.char, len(vocabValues))
	for i, piece := range vocabValues {
		cPieces[i] = C.CString(piece)
		defer C.free(unsafe.Pointer(cPieces[i]))
	}

	cEogTokens := make([]C.uint32_t, len(eogTokens))
	for i, token := range eogTokens {
		cEogTokens[i] = C.uint32_t(token)
	}

706
	g := C.grammar_init(cGrammar, unsafe.SliceData(cTokens), C.size_t(len(cTokens)), unsafe.SliceData(cPieces), unsafe.SliceData(cEogTokens), C.size_t(len(cEogTokens)))
707
708
709
	if g == nil {
		return nil
	}
710

711
	return &Grammar{c: g}
712
713
}

714
715
716
717
718
719
720
func (g *Grammar) Free() {
	g.mu.Lock()
	defer g.mu.Unlock()
	if g.c != nil {
		C.grammar_free(g.c)
		g.c = nil
	}
721
722
}

723
724
725
726
727
728
729
730
func (g *Grammar) Apply(tokens []TokenData) {
	g.mu.Lock()
	defer g.mu.Unlock()

	if g.c == nil {
		return
	}

731
732
733
	tds := make([]C.struct_llama_token_data, len(tokens))
	for i, token := range tokens {
		tds[i] = C.struct_llama_token_data{
734
			id:    C.int32_t(token.ID),
735
736
737
738
739
740
741
742
743
744
745
746
747
748
			logit: C.float(token.Logit),
			p:     C.float(0.0),
		}
	}
	tda := &C.llama_token_data_array{
		data:     (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])),
		size:     C.size_t(len(tokens)),
		selected: C.int64_t(-1),
		sorted:   C.bool(false),
	}
	var pinner runtime.Pinner
	pinner.Pin(&tds[0])
	defer pinner.Unpin()

749
	C.grammar_apply(g.c, tda)
750
751
752
753
	for i := range tokens {
		tokens[i].Logit = float32(tds[i].logit)
	}
}
754
755
756
757
758
759
760
761
762
763
764
765

func (g *Grammar) Accept(token int32) {
	g.mu.Lock()
	defer g.mu.Unlock()

	// Check if grammar was freed
	if g.c == nil {
		return
	}

	C.grammar_accept(g.c, C.llama_token(token))
}