llama.go 19.5 KB
Newer Older
1
2
3
package llama

/*
Michael Yang's avatar
Michael Yang committed
4
#cgo CFLAGS: -std=c11
5
#cgo windows CFLAGS: -Wno-dll-attribute-on-redeclaration
Michael Yang's avatar
Michael Yang committed
6
7
8
#cgo CXXFLAGS: -std=c++17
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/include
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/common
9
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/vendor
10
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/tools/mtmd
Michael Yang's avatar
Michael Yang committed
11
12
#cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/src
#cgo CPPFLAGS: -I${SRCDIR}/../ml/backend/ggml/ggml/include
13
14

#include <stdlib.h>
Michael Yang's avatar
Michael Yang committed
15
#include "ggml.h"
16
#include "llama.h"
17
18
#include "mtmd.h"
#include "mtmd-helper.h"
19
#include "gguf.h"
Michael Yang's avatar
Michael Yang committed
20

21
22
#include "sampling_ext.h"

23
24
extern bool llamaProgressCallback(float progress, void *user_data);
extern void llamaLog(int level, char* text, void* user_data);
25
26
27
28
*/
import "C"

import (
29
	"context"
30
31
32
	_ "embed"
	"errors"
	"fmt"
33
	"log/slog"
34
	"os"
35
36
	"runtime"
	"runtime/cgo"
Jesse Gross's avatar
Jesse Gross committed
37
	"slices"
38
	"strings"
39
	"sync"
40
	"unsafe"
Michael Yang's avatar
Michael Yang committed
41
42
43

	_ "github.com/ollama/ollama/llama/llama.cpp/common"
	_ "github.com/ollama/ollama/llama/llama.cpp/src"
44
	_ "github.com/ollama/ollama/llama/llama.cpp/tools/mtmd"
45
	"github.com/ollama/ollama/ml"
46
	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
47
48
)

49
50
51
52
53
54
55
56
57
58
59
60
func init() {
	C.llama_log_set(C.ggml_log_callback(C.llamaLog), nil)
}

//export llamaLog
func llamaLog(level C.int, text *C.char, _ unsafe.Pointer) {
	// slog levels zeros INFO and are multiples of 4
	if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
		fmt.Fprint(os.Stderr, C.GoString(text))
	}
}

61
func BackendInit() {
Michael Yang's avatar
Michael Yang committed
62
	ggml.OnceLoad()
63
64
65
	C.llama_backend_init()
}

66
67
func EnumerateGPUs() []ml.DeviceID {
	var ids []ml.DeviceID
Jesse Gross's avatar
Jesse Gross committed
68
69
70
71
72
73
74

	for i := range C.ggml_backend_dev_count() {
		device := C.ggml_backend_dev_get(i)

		if C.ggml_backend_dev_type(device) == C.GGML_BACKEND_DEVICE_TYPE_GPU {
			var props C.struct_ggml_backend_dev_props
			C.ggml_backend_dev_get_props(device, &props)
75
76
77
78
			ids = append(ids, ml.DeviceID{
				ID:      C.GoString(props.id),
				Library: C.GoString(props.library),
			})
Jesse Gross's avatar
Jesse Gross committed
79
80
81
82
83
84
		}
	}

	return ids
}

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
func GetModelArch(modelPath string) (string, error) {
	mp := C.CString(modelPath)
	defer C.free(unsafe.Pointer(mp))

	gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
	if gguf_ctx == nil {
		return "", errors.New("unable to load model file")
	}
	defer C.gguf_free(gguf_ctx)

	key := C.CString("general.architecture")
	defer C.free(unsafe.Pointer(key))
	arch_index := C.gguf_find_key(gguf_ctx, key)
	if int(arch_index) < 0 {
		return "", errors.New("unknown model architecture")
	}

	arch := C.gguf_get_val_str(gguf_ctx, arch_index)

	return C.GoString(arch), nil
}

107
108
109
110
type ContextParams struct {
	c C.struct_llama_context_params
}

111
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
112
113
114
115
116
117
118
	params := C.llama_context_default_params()
	params.n_ctx = C.uint(numCtx)
	params.n_batch = C.uint(batchSize)
	params.n_seq_max = C.uint(numSeqMax)
	params.n_threads = C.int(threads)
	params.n_threads_batch = params.n_threads
	params.embeddings = C.bool(true)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
119
120
121
122
123
	if flashAttention {
		params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_ENABLED
	} else {
		params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_DISABLED
	}
124
125
126
	params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
	params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))

127
128
129
	return ContextParams{c: params}
}

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
func kvCacheTypeFromStr(s string) C.enum_ggml_type {
	if s == "" {
		return C.GGML_TYPE_F16
	}

	switch s {
	case "q8_0":
		return C.GGML_TYPE_Q8_0
	case "q4_0":
		return C.GGML_TYPE_Q4_0
	default:
		return C.GGML_TYPE_F16
	}
}

146
147
148
149
150
type Context struct {
	c          *C.struct_llama_context
	numThreads int
}

151
var ErrKvCacheFull = errors.New("could not find a kv cache slot")
152
153
154
155
156
157
158
159
160
161
162
163
164

func (c *Context) Decode(batch *Batch) error {
	// Positive return values does not mean a fatal error, but rather a warning.
	//   0 - success
	//   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
	// < 0 - error
	code := int(C.llama_decode(c.c, batch.c))

	if code < 0 {
		return fmt.Errorf("llama_decode failed with code %d", code)
	}

	if code > 0 {
165
		return ErrKvCacheFull
166
167
168
169
170
171
172
173
174
175
	}

	return nil
}

func (c *Context) Model() *Model {
	return &Model{c: C.llama_get_model(c.c)}
}

func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
176
	C.llama_memory_seq_add(C.llama_get_memory(c.c), C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
177
178
179
}

func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
180
	return bool(C.llama_memory_seq_rm(C.llama_get_memory(c.c), C.int(seqId), C.int(p0), C.int(p1)))
181
182
183
}

func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
184
	C.llama_memory_seq_cp(C.llama_get_memory(c.c), C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
185
186
}

187
func (c *Context) KvCacheClear() {
188
	C.llama_memory_clear(C.llama_get_memory(c.c), true)
189
190
}

191
func (c *Context) KvCacheCanShift() bool {
192
	return bool(C.llama_memory_can_shift(C.llama_get_memory(c.c)))
193
194
}

195
196
// Get the embeddings for a sequence id
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
197
198
	e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
	if e == nil {
199
200
201
		return nil
	}

202
203
204
	embeddings := make([]float32, c.Model().NEmbd())
	_ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
	return embeddings
205
206
207
}

func (c *Context) GetEmbeddingsIth(i int) []float32 {
208
209
	e := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
	if e == nil {
210
211
212
		return nil
	}

213
214
215
	embeddings := make([]float32, c.Model().NEmbd())
	_ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
	return embeddings
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
}

type ModelParams struct {
	NumGpuLayers int
	MainGpu      int
	UseMmap      bool
	TensorSplit  []float32
	Progress     func(float32)
	VocabOnly    bool
}

//export llamaProgressCallback
func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
	handle := *(*cgo.Handle)(userData)
	callback := handle.Value().(func(float32))
	callback(float32(progress))
	return true
}

235
func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
	cparams := C.llama_model_default_params()
	cparams.n_gpu_layers = C.int(params.NumGpuLayers)
	cparams.main_gpu = C.int32_t(params.MainGpu)
	cparams.use_mmap = C.bool(params.UseMmap)
	cparams.vocab_only = C.bool(params.VocabOnly)

	if len(params.TensorSplit) > 0 {
		tensorSplitData := &params.TensorSplit[0]

		var tensorSplitPin runtime.Pinner
		tensorSplitPin.Pin(tensorSplitData)
		defer tensorSplitPin.Unpin()

		cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
	}

	if params.Progress != nil {
		handle := cgo.NewHandle(params.Progress)
		defer handle.Delete()

		var handlePin runtime.Pinner
		handlePin.Pin(&handle)
		defer handlePin.Unpin()

		cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
		cparams.progress_callback_user_data = unsafe.Pointer(&handle)
	}

264
	m := Model{c: C.llama_model_load_from_file(C.CString(modelPath), cparams)}
Jesse Gross's avatar
Jesse Gross committed
265
	if m.c == nil {
266
267
268
269
		return nil, fmt.Errorf("unable to load model: %s", modelPath)
	}

	return &m, nil
270
271
272
}

func FreeModel(model *Model) {
273
	C.llama_model_free(model.c)
274
275
}

276
277
func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
	c := Context{
278
		c:          C.llama_init_from_model(model.c, params.c),
279
280
		numThreads: int(params.c.n_threads),
	}
Jesse Gross's avatar
Jesse Gross committed
281
	if c.c == nil {
282
283
284
285
		return nil, errors.New("unable to create llama context")
	}

	return &c, nil
286
287
288
}

func (m *Model) NumVocab() int {
289
	return int(C.llama_vocab_n_tokens(m.Vocab()))
290
291
292
}

func (m *Model) TokenIsEog(token int) bool {
293
	return bool(C.llama_vocab_is_eog(m.Vocab(), C.llama_token(token)))
294
295
296
}

func (m *Model) AddBOSToken() bool {
297
	return bool(C.llama_vocab_get_add_bos(m.Vocab()))
298
299
300
301
302
303
}

func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
	cLoraPath := C.CString(loraPath)
	defer C.free(unsafe.Pointer(cLoraPath))

304
	loraAdapter := C.llama_adapter_lora_init(m.c, cLoraPath)
Jesse Gross's avatar
Jesse Gross committed
305
306
307
	if loraAdapter == nil {
		return errors.New("unable to load lora")
	}
308
309
310

	err := -1
	if loraAdapter != nil {
311
		err = int(C.llama_set_adapter_lora(context.c, loraAdapter, C.float(scale)))
312
313
314
315
316
317
318
319
	}
	if err != 0 {
		return errors.New("error applying lora from file")
	}

	return nil
}

320
321
322
323
func (m *Model) Vocab() *C.struct_llama_vocab {
	return C.llama_model_get_vocab(m.c)
}

324
325
326
type Batch struct {
	c         C.struct_llama_batch
	batchSize int
327
	maxSeq    int
328
329
330
	embedSize int
}

331
332
333
// Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
// Batches cannot contain both types at the same time. batchSize is the maximum number of entries
// that can be added per sequence
Jesse Gross's avatar
Jesse Gross committed
334
335
func NewBatch(batchSize int, maxSeq int, embedSize int) (*Batch, error) {
	b := Batch{
336
337
338
339
		c:         C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
		batchSize: batchSize,
		maxSeq:    maxSeq,
		embedSize: embedSize,
340
	}
Jesse Gross's avatar
Jesse Gross committed
341
342
343
344
345
346
347
348
349
350
351
352

	// Check to see if any of the allocations in llama_batch_init() failed
	nilPointer := (embedSize == 0 && b.c.token == nil) || (embedSize != 0 && b.c.embd == nil) ||
		b.c.pos == nil || b.c.n_seq_id == nil || b.c.seq_id == nil || b.c.logits == nil ||
		slices.Contains(unsafe.Slice(b.c.seq_id, b.allocSize()), nil)

	if nilPointer {
		C.llama_batch_free(b.c)
		return nil, fmt.Errorf("unable to allocate batch (batchSize=%v maxSeq=%v embedSize=%v)", batchSize, maxSeq, embedSize)
	}

	return &b, nil
353
354
}

355
356
357
358
359
360
361
362
func (b *Batch) Size() int {
	return b.batchSize
}

func (b *Batch) allocSize() int {
	return b.batchSize * b.maxSeq
}

363
364
365
366
367
368
369
370
371
372
373
374
func (b *Batch) NumTokens() int {
	return int(b.c.n_tokens)
}

func (b *Batch) IsEmbedding() bool {
	return b.embedSize != 0
}

// Add adds either a token or an image embedding to the batch depending on the type
// when the batch was initialized. The other argument will be ignored. Adds to the
// batch with the given position for the given sequence ids, and optionally instructs
// to include logits.
375
func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
376
	if !b.IsEmbedding() {
377
		unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
378
	} else {
379
		copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
380
	}
381
382
	unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
	unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
383
384

	for i, s := range seqIds {
385
		unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
386
387
388
	}

	if logits {
389
		unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
390
391
	} else {
		unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 0
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
	}

	b.c.n_tokens += 1
}

func (b *Batch) Clear() {
	b.c.n_tokens = 0
}

func (b *Batch) Free() {
	b.batchSize = 0
	C.llama_batch_free(b.c)
}

type Model struct {
	c *C.struct_llama_model
}

func (m *Model) TokenToPiece(token int) string {
	tokenLen := 12
	buf := make([]byte, tokenLen)
	tokenLen = int(C.llama_token_to_piece(
414
		m.Vocab(),
415
416
417
418
419
420
421
422
423
424
425
		C.int32_t(token),
		(*C.char)(unsafe.Pointer(&buf[0])),
		C.int32_t(tokenLen),
		C.int32_t(0),
		C.bool(true),
	))
	if tokenLen < 0 {
		tokenLen = -tokenLen

		buf = make([]byte, tokenLen)
		C.llama_token_to_piece(
426
			m.Vocab(),
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
			C.int32_t(token),
			(*C.char)(unsafe.Pointer(&buf[0])),
			C.int32_t(tokenLen),
			C.int32_t(0),
			C.bool(true),
		)
	}
	return strings.TrimRight(string(buf), "\x00")
}

func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
	maxTokens := len(text) + 2
	cTokens := make([]C.llama_token, maxTokens)
	cText := C.CString(text)
	defer C.free(unsafe.Pointer(cText))

	result := C.llama_tokenize(
444
		m.Vocab(),
445
446
447
448
449
450
451
452
453
454
455
456
457
		cText,
		C.int32_t(len(text)),
		&cTokens[0],
		C.int32_t(maxTokens),
		C.bool(addSpecial),
		C.bool(parseSpecial),
	)

	// if the result is negative, reallocate and retry with the correct buffer size
	if result < 0 {
		maxTokens = int(-result)
		cTokens = make([]C.llama_token, maxTokens)
		result = C.llama_tokenize(
458
			m.Vocab(),
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
			cText,
			C.int32_t(len(text)),
			&cTokens[0],
			C.int32_t(maxTokens),
			C.bool(addSpecial),
			C.bool(parseSpecial),
		)
		if result < 0 {
			return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
		}
	}

	tokens := make([]int, result)
	for i := range result {
		tokens[i] = int(cTokens[i])
	}

	return tokens, nil
}

func (m *Model) NEmbd() int {
480
	return int(C.llama_model_n_embd(m.c))
481
482
}

483
// vision processing
484
485
type MtmdContext struct {
	c *C.struct_mtmd_context
486
487
}

488
func NewMtmdContext(llamaContext *Context, modelPath string) (*MtmdContext, error) {
489
490
	mp := C.CString(modelPath)
	defer C.free(unsafe.Pointer(mp))
491
492
	// TODO: Support non-default params
	cp := C.mtmd_context_params_default()
493

494
495
496
497
	// NOTE: The model and projector embedding lengths are checked during init
	c := C.mtmd_init_from_file(mp, C.llama_get_model(llamaContext.c), cp)
	if c == nil {
		return nil, fmt.Errorf("unable to load mmtd model: %v", modelPath)
498
499
	}

500
	return &MtmdContext{c: c}, nil
501
502
}

503
504
func (c *MtmdContext) Free() {
	C.mtmd_free(c.c)
505
506
}

507
508
509
510
511
512
type MtmdChunk struct {
	Embed  []float32
	Tokens []int
}

func (c *MtmdContext) MultimodalTokenize(llamaContext *Context, data []byte) ([]MtmdChunk, error) {
513
514
515
516
517
518
519
	// Initialize the input chunks pointer
	ic := C.mtmd_input_chunks_init()
	defer C.mtmd_input_chunks_free(ic)

	// Initialize an empty text prompt so we can tokenize
	it := C.mtmd_input_text_init(C.mtmd_default_marker(), true, true)
	defer C.mtmd_input_text_free(it)
520

521
522
523
524
525
526
527
528
529
	// Initialize a bitmap with the image data
	bm := C.mtmd_helper_bitmap_init_from_buf(c.c, (*C.uchar)(unsafe.Pointer(&data[0])), C.size_t(len(data)))
	defer C.mtmd_bitmap_free(bm)

	// Tokenize the image
	if C.int32_t(0) != C.mtmd_tokenize(c.c, ic, it, &bm, 1) {
		return nil, errors.New("unable to tokenize mtmd embedding from image")
	}
	nChunks := C.mtmd_input_chunks_size(ic)
530
	numEmbed := llamaContext.Model().NEmbd()
531
	outChunks := make([]MtmdChunk, 0)
532
533
534
	for i := range int(nChunks) {
		chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
		numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
535
		slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
536

537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
		if C.mtmd_input_chunk_get_type(chunk) == C.MTMD_INPUT_CHUNK_TYPE_TEXT {
			// If this is a text chunk, add the tokens
			cNumTokens := C.size_t(0)
			cTokens := C.mtmd_input_chunk_get_tokens_text(chunk, &cNumTokens)
			cTokensArr := unsafe.Slice(cTokens, int(cNumTokens))
			tokens := make([]int, int(cNumTokens))
			for j := range int(cNumTokens) {
				tokens[j] = int(cTokensArr[j])
			}
			outChunks = append(outChunks, MtmdChunk{Tokens: tokens})
		} else {
			// Otherwise, encode the image chunk to embeddings

			// Encode the chunk
			if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
				return nil, errors.New("unable to encode mtmd image chunk")
			}

			// Get the embeddings for this chunk
			chunkEmbed := make([][]float32, numTokens)
			chunkEmbd := C.mtmd_get_output_embd(c.c)
			if nil == chunkEmbd {
				return nil, errors.New("no mtmd image embedding")
			}

			// Extend the embedding array for each token
			s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
			rows := make([]float32, len(s))
			copy(rows, s)
			for i := range numTokens {
				chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
			}
			for _, e := range chunkEmbed {
				outChunks = append(outChunks, MtmdChunk{Embed: e})
			}
572
		}
573
	}
574
575
	slog.Debug("image tokenization chunks", "totalChunks", len(outChunks))
	return outChunks, nil
576
577
}

578
579
580
581
func (c *Context) Synchronize() {
	C.llama_synchronize(c.c)
}

582
583
584
// sampling
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
type SamplingContext struct {
585
	c *C.struct_common_sampler
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
}

type SamplingParams struct {
	TopK           int
	TopP           float32
	MinP           float32
	TypicalP       float32
	Temp           float32
	RepeatLastN    int
	PenaltyRepeat  float32
	PenaltyFreq    float32
	PenaltyPresent float32
	PenalizeNl     bool
	Seed           uint32
	Grammar        string
}

Jesse Gross's avatar
Jesse Gross committed
603
func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) {
604
	var cparams C.struct_common_sampler_cparams
605
606
607
608
609
610
611
612
	cparams.top_k = C.int32_t(params.TopK)
	cparams.top_p = C.float(params.TopP)
	cparams.min_p = C.float(params.MinP)
	cparams.typical_p = C.float(params.TypicalP)
	cparams.temp = C.float(params.Temp)
	cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
	cparams.penalty_repeat = C.float(params.PenaltyRepeat)
	cparams.penalty_freq = C.float(params.PenaltyFreq)
613
	cparams.penalty_present = C.float(params.PenaltyPresent)
614
615
616
617
618
619
	cparams.seed = C.uint32_t(params.Seed)

	grammar := C.CString(params.Grammar)
	defer C.free(unsafe.Pointer(grammar))

	cparams.grammar = grammar
620
	context := &SamplingContext{c: C.common_sampler_cinit(model.c, &cparams)}
Jesse Gross's avatar
Jesse Gross committed
621
622
623
624
	if context.c == nil {
		return nil, errors.New("unable to create sampling context")
	}

625
	runtime.SetFinalizer(context, func(s *SamplingContext) { C.common_sampler_cfree(s.c) })
626

Jesse Gross's avatar
Jesse Gross committed
627
	return context, nil
628
629
630
}

func (s *SamplingContext) Reset() {
631
	C.common_sampler_creset(s.c)
632
633
}

634
func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
635
	return int(C.common_sampler_csample(s.c, llamaContext.c, C.int(idx)))
636
637
}

638
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
639
	C.common_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
640
}
641

642
643
644
645
// SchemaToGrammar converts the provided JSON schema to a grammar. It returns
// nil if the provided schema is invalid JSON or an invalid JSON schema.
func SchemaToGrammar(schema []byte) []byte {
	cStr := C.CString(string(schema))
646
647
	defer C.free(unsafe.Pointer(cStr))

648
	// Allocate buffer for grammar based on schema length but with upper bound
649
	maxLen := max(32768, min(1024*1024, len(schema)*4))
650
651
652
	buf := make([]byte, maxLen)

	// Call C function to convert schema to grammar
653
654
655
656
	n := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
	if n == 0 {
		// preserve nil
		return nil
657
	}
658
	return buf[:n]
659
}
660

661
662
663
664
665
666
667
668
type TokenData struct {
	ID    int32
	Logit float32
}

type Grammar struct {
	c  *C.struct_llama_grammar
	mu sync.Mutex
669
670
}

671
func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []int32) *Grammar {
672
673
674
	cGrammar := C.CString(grammar)
	defer C.free(unsafe.Pointer(cGrammar))

675
676
677
678
	cTokens := make([]C.uint32_t, len(vocabIds))
	for i, token := range vocabIds {
		cTokens[i] = C.uint32_t(token)
	}
679

680
681
682
683
684
685
686
687
688
689
690
	cPieces := make([]*C.char, len(vocabValues))
	for i, piece := range vocabValues {
		cPieces[i] = C.CString(piece)
		defer C.free(unsafe.Pointer(cPieces[i]))
	}

	cEogTokens := make([]C.uint32_t, len(eogTokens))
	for i, token := range eogTokens {
		cEogTokens[i] = C.uint32_t(token)
	}

691
	g := C.grammar_init(cGrammar, unsafe.SliceData(cTokens), C.size_t(len(cTokens)), unsafe.SliceData(cPieces), unsafe.SliceData(cEogTokens), C.size_t(len(cEogTokens)))
692
693
694
	if g == nil {
		return nil
	}
695

696
	return &Grammar{c: g}
697
698
}

699
700
701
702
703
704
705
func (g *Grammar) Free() {
	g.mu.Lock()
	defer g.mu.Unlock()
	if g.c != nil {
		C.grammar_free(g.c)
		g.c = nil
	}
706
707
}

708
709
710
711
712
713
714
715
func (g *Grammar) Apply(tokens []TokenData) {
	g.mu.Lock()
	defer g.mu.Unlock()

	if g.c == nil {
		return
	}

716
717
718
	tds := make([]C.struct_llama_token_data, len(tokens))
	for i, token := range tokens {
		tds[i] = C.struct_llama_token_data{
719
			id:    C.int32_t(token.ID),
720
721
722
723
724
725
726
727
728
729
730
731
732
733
			logit: C.float(token.Logit),
			p:     C.float(0.0),
		}
	}
	tda := &C.llama_token_data_array{
		data:     (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])),
		size:     C.size_t(len(tokens)),
		selected: C.int64_t(-1),
		sorted:   C.bool(false),
	}
	var pinner runtime.Pinner
	pinner.Pin(&tds[0])
	defer pinner.Unpin()

734
	C.grammar_apply(g.c, tda)
735
736
737
738
	for i := range tokens {
		tokens[i].Logit = float32(tds[i].logit)
	}
}
739
740
741
742
743
744
745
746
747
748
749
750

func (g *Grammar) Accept(token int32) {
	g.mu.Lock()
	defer g.mu.Unlock()

	// Check if grammar was freed
	if g.c == nil {
		return
	}

	C.grammar_accept(g.c, C.llama_token(token))
}