openai.go 19.5 KB
Newer Older
1
// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
2
3
4
package openai

import (
5
	"bytes"
6
	"encoding/base64"
7
	"encoding/binary"
8
	"encoding/json"
Michael Yang's avatar
lint  
Michael Yang committed
9
	"errors"
10
	"fmt"
royjhan's avatar
royjhan committed
11
	"log/slog"
12
13
	"math/rand"
	"net/http"
14
	"slices"
15
	"strings"
16
17
	"time"

18
	"github.com/ollama/ollama/api"
19
	"github.com/ollama/ollama/types/model"
20
21
)

22
23
var finishReasonToolCalls = "tool_calls"

24
type Error struct {
25
26
27
28
	Message string  `json:"message"`
	Type    string  `json:"type"`
	Param   any     `json:"param"`
	Code    *string `json:"code"`
29
30
31
32
33
34
35
}

type ErrorResponse struct {
	Error Error `json:"error"`
}

type Message struct {
36
37
38
39
40
41
	Role       string     `json:"role"`
	Content    any        `json:"content"`
	Reasoning  string     `json:"reasoning,omitempty"`
	ToolCalls  []ToolCall `json:"tool_calls,omitempty"`
	Name       string     `json:"name,omitempty"`
	ToolCallID string     `json:"tool_call_id,omitempty"`
42
43
44
45
46
47
48
49
50
51
52
53
54
55
}

type Choice struct {
	Index        int     `json:"index"`
	Message      Message `json:"message"`
	FinishReason *string `json:"finish_reason"`
}

type ChunkChoice struct {
	Index        int     `json:"index"`
	Delta        Message `json:"delta"`
	FinishReason *string `json:"finish_reason"`
}

56
57
58
59
60
61
type CompleteChunkChoice struct {
	Text         string  `json:"text"`
	Index        int     `json:"index"`
	FinishReason *string `json:"finish_reason"`
}

62
63
64
65
66
67
68
type Usage struct {
	PromptTokens     int `json:"prompt_tokens"`
	CompletionTokens int `json:"completion_tokens"`
	TotalTokens      int `json:"total_tokens"`
}

type ResponseFormat struct {
69
70
71
72
73
	Type       string      `json:"type"`
	JsonSchema *JsonSchema `json:"json_schema,omitempty"`
}

type JsonSchema struct {
74
	Schema json.RawMessage `json:"schema"`
75
76
}

77
type EmbedRequest struct {
78
79
80
81
	Input          any    `json:"input"`
	Model          string `json:"model"`
	Dimensions     int    `json:"dimensions,omitempty"`
	EncodingFormat string `json:"encoding_format,omitempty"` // "float" or "base64"
82
83
}

84
85
86
87
type StreamOptions struct {
	IncludeUsage bool `json:"include_usage"`
}

Michael Yang's avatar
Michael Yang committed
88
type Reasoning struct {
89
	Effort string `json:"effort,omitempty"`
Michael Yang's avatar
Michael Yang committed
90
91
}

92
93
94
95
type ChatCompletionRequest struct {
	Model            string          `json:"model"`
	Messages         []Message       `json:"messages"`
	Stream           bool            `json:"stream"`
96
	StreamOptions    *StreamOptions  `json:"stream_options"`
97
98
99
100
101
	MaxTokens        *int            `json:"max_tokens"`
	Seed             *int            `json:"seed"`
	Stop             any             `json:"stop"`
	Temperature      *float64        `json:"temperature"`
	FrequencyPenalty *float64        `json:"frequency_penalty"`
102
	PresencePenalty  *float64        `json:"presence_penalty"`
103
104
	TopP             *float64        `json:"top_p"`
	ResponseFormat   *ResponseFormat `json:"response_format"`
royjhan's avatar
royjhan committed
105
	Tools            []api.Tool      `json:"tools"`
Michael Yang's avatar
Michael Yang committed
106
	Reasoning        *Reasoning      `json:"reasoning,omitempty"`
107
	ReasoningEffort  *string         `json:"reasoning_effort,omitempty"`
Devon Rifkin's avatar
Devon Rifkin committed
108
	DebugRenderOnly  bool            `json:"_debug_render_only"`
109
110
111
}

type ChatCompletion struct {
Devon Rifkin's avatar
Devon Rifkin committed
112
113
114
115
116
117
118
119
	Id                string         `json:"id"`
	Object            string         `json:"object"`
	Created           int64          `json:"created"`
	Model             string         `json:"model"`
	SystemFingerprint string         `json:"system_fingerprint"`
	Choices           []Choice       `json:"choices"`
	Usage             Usage          `json:"usage,omitempty"`
	DebugInfo         *api.DebugInfo `json:"_debug_info,omitempty"`
120
121
122
123
124
125
126
127
128
}

type ChatCompletionChunk struct {
	Id                string        `json:"id"`
	Object            string        `json:"object"`
	Created           int64         `json:"created"`
	Model             string        `json:"model"`
	SystemFingerprint string        `json:"system_fingerprint"`
	Choices           []ChunkChoice `json:"choices"`
129
	Usage             *Usage        `json:"usage,omitempty"`
130
131
}

132
133
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
type CompletionRequest struct {
134
135
136
137
138
139
140
141
142
143
144
145
	Model            string         `json:"model"`
	Prompt           string         `json:"prompt"`
	FrequencyPenalty float32        `json:"frequency_penalty"`
	MaxTokens        *int           `json:"max_tokens"`
	PresencePenalty  float32        `json:"presence_penalty"`
	Seed             *int           `json:"seed"`
	Stop             any            `json:"stop"`
	Stream           bool           `json:"stream"`
	StreamOptions    *StreamOptions `json:"stream_options"`
	Temperature      *float32       `json:"temperature"`
	TopP             float32        `json:"top_p"`
	Suffix           string         `json:"suffix"`
Devon Rifkin's avatar
Devon Rifkin committed
146
	DebugRenderOnly  bool           `json:"_debug_render_only"`
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
}

type Completion struct {
	Id                string                `json:"id"`
	Object            string                `json:"object"`
	Created           int64                 `json:"created"`
	Model             string                `json:"model"`
	SystemFingerprint string                `json:"system_fingerprint"`
	Choices           []CompleteChunkChoice `json:"choices"`
	Usage             Usage                 `json:"usage,omitempty"`
}

type CompletionChunk struct {
	Id                string                `json:"id"`
	Object            string                `json:"object"`
	Created           int64                 `json:"created"`
	Choices           []CompleteChunkChoice `json:"choices"`
	Model             string                `json:"model"`
	SystemFingerprint string                `json:"system_fingerprint"`
166
	Usage             *Usage                `json:"usage,omitempty"`
167
168
}

royjhan's avatar
royjhan committed
169
170
type ToolCall struct {
	ID       string `json:"id"`
171
	Index    int    `json:"index"`
royjhan's avatar
royjhan committed
172
173
174
175
176
177
178
	Type     string `json:"type"`
	Function struct {
		Name      string `json:"name"`
		Arguments string `json:"arguments"`
	} `json:"function"`
}

179
180
181
182
183
184
185
type Model struct {
	Id      string `json:"id"`
	Object  string `json:"object"`
	Created int64  `json:"created"`
	OwnedBy string `json:"owned_by"`
}

186
type Embedding struct {
187
188
189
	Object    string `json:"object"`
	Embedding any    `json:"embedding"` // Can be []float32 (float format) or string (base64 format)
	Index     int    `json:"index"`
190
191
}

192
193
194
195
196
type ListCompletion struct {
	Object string  `json:"object"`
	Data   []Model `json:"data"`
}

197
type EmbeddingList struct {
198
199
200
201
202
203
204
205
206
	Object string         `json:"object"`
	Data   []Embedding    `json:"data"`
	Model  string         `json:"model"`
	Usage  EmbeddingUsage `json:"usage,omitempty"`
}

type EmbeddingUsage struct {
	PromptTokens int `json:"prompt_tokens"`
	TotalTokens  int `json:"total_tokens"`
207
208
}

209
210
211
212
213
214
215
216
217
218
219
220
221
222
func NewError(code int, message string) ErrorResponse {
	var etype string
	switch code {
	case http.StatusBadRequest:
		etype = "invalid_request_error"
	case http.StatusNotFound:
		etype = "not_found_error"
	default:
		etype = "api_error"
	}

	return ErrorResponse{Error{Type: etype, Message: message}}
}

223
224
// ToUsage converts an api.ChatResponse to Usage
func ToUsage(r api.ChatResponse) Usage {
225
	return Usage{
226
227
228
		PromptTokens:     r.Metrics.PromptEvalCount,
		CompletionTokens: r.Metrics.EvalCount,
		TotalTokens:      r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
229
230
231
	}
}

royjhan's avatar
royjhan committed
232
233
234
235
236
237
238
239
240
func toolCallId() string {
	const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
	b := make([]byte, 8)
	for i := range b {
		b[i] = letterBytes[rand.Intn(len(letterBytes))]
	}
	return "call_" + strings.ToLower(string(b))
}

241
242
// ToToolCalls converts api.ToolCall to OpenAI ToolCall format
func ToToolCalls(tc []api.ToolCall) []ToolCall {
243
244
	toolCalls := make([]ToolCall, len(tc))
	for i, tc := range tc {
royjhan's avatar
royjhan committed
245
246
247
		toolCalls[i].ID = toolCallId()
		toolCalls[i].Type = "function"
		toolCalls[i].Function.Name = tc.Function.Name
248
		toolCalls[i].Index = tc.Function.Index
royjhan's avatar
royjhan committed
249
250
251
252
253
254
255
256
257

		args, err := json.Marshal(tc.Function.Arguments)
		if err != nil {
			slog.Error("could not marshall function arguments to json", "error", err)
			continue
		}

		toolCalls[i].Function.Arguments = string(args)
	}
258
259
	return toolCalls
}
royjhan's avatar
royjhan committed
260

261
262
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
263
	toolCalls := ToToolCalls(r.Message.ToolCalls)
264
265
266
267
268
269
270
	return ChatCompletion{
		Id:                id,
		Object:            "chat.completion",
		Created:           r.CreatedAt.Unix(),
		Model:             r.Model,
		SystemFingerprint: "fp_ollama",
		Choices: []Choice{{
271
			Index:   0,
Michael Yang's avatar
Michael Yang committed
272
			Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
273
			FinishReason: func(reason string) *string {
274
275
276
				if len(toolCalls) > 0 {
					reason = "tool_calls"
				}
277
278
279
280
281
				if len(reason) > 0 {
					return &reason
				}
				return nil
			}(r.DoneReason),
282
		}}, Usage: ToUsage(r),
Devon Rifkin's avatar
Devon Rifkin committed
283
		DebugInfo: r.DebugInfo,
284
285
286
	}
}

287
288
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
289
	toolCalls := ToToolCalls(r.Message.ToolCalls)
290
291
292
293
294
295
	return ChatCompletionChunk{
		Id:                id,
		Object:            "chat.completion.chunk",
		Created:           time.Now().Unix(),
		Model:             r.Model,
		SystemFingerprint: "fp_ollama",
296
297
		Choices: []ChunkChoice{{
			Index: 0,
Michael Yang's avatar
Michael Yang committed
298
			Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
299
300
			FinishReason: func(reason string) *string {
				if len(reason) > 0 {
Michael Yang's avatar
Michael Yang committed
301
					if toolCallSent || len(toolCalls) > 0 {
302
303
						return &finishReasonToolCalls
					}
304
305
306
307
308
					return &reason
				}
				return nil
			}(r.DoneReason),
		}},
309
310
311
	}
}

312
313
// ToUsageGenerate converts an api.GenerateResponse to Usage
func ToUsageGenerate(r api.GenerateResponse) Usage {
314
	return Usage{
315
316
317
		PromptTokens:     r.Metrics.PromptEvalCount,
		CompletionTokens: r.Metrics.EvalCount,
		TotalTokens:      r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
318
319
320
	}
}

321
322
// ToCompletion converts an api.GenerateResponse to Completion
func ToCompletion(id string, r api.GenerateResponse) Completion {
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
	return Completion{
		Id:                id,
		Object:            "text_completion",
		Created:           r.CreatedAt.Unix(),
		Model:             r.Model,
		SystemFingerprint: "fp_ollama",
		Choices: []CompleteChunkChoice{{
			Text:  r.Response,
			Index: 0,
			FinishReason: func(reason string) *string {
				if len(reason) > 0 {
					return &reason
				}
				return nil
			}(r.DoneReason),
		}},
339
		Usage: ToUsageGenerate(r),
340
341
342
	}
}

343
344
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
	return CompletionChunk{
		Id:                id,
		Object:            "text_completion",
		Created:           time.Now().Unix(),
		Model:             r.Model,
		SystemFingerprint: "fp_ollama",
		Choices: []CompleteChunkChoice{{
			Text:  r.Response,
			Index: 0,
			FinishReason: func(reason string) *string {
				if len(reason) > 0 {
					return &reason
				}
				return nil
			}(r.DoneReason),
		}},
	}
}

364
365
// ToListCompletion converts an api.ListResponse to ListCompletion
func ToListCompletion(r api.ListResponse) ListCompletion {
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
	var data []Model
	for _, m := range r.Models {
		data = append(data, Model{
			Id:      m.Name,
			Object:  "model",
			Created: m.ModifiedAt.Unix(),
			OwnedBy: model.ParseName(m.Name).Namespace,
		})
	}

	return ListCompletion{
		Object: "list",
		Data:   data,
	}
}

382
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
383
384
// encodingFormat can be "float", "base64", or empty (defaults to "float")
func ToEmbeddingList(model string, r api.EmbedResponse, encodingFormat string) EmbeddingList {
385
386
387
	if r.Embeddings != nil {
		var data []Embedding
		for i, e := range r.Embeddings {
388
389
390
391
392
393
394
			var embedding any
			if strings.EqualFold(encodingFormat, "base64") {
				embedding = floatsToBase64(e)
			} else {
				embedding = e
			}

395
396
			data = append(data, Embedding{
				Object:    "embedding",
397
				Embedding: embedding,
398
399
400
401
402
403
404
405
				Index:     i,
			})
		}

		return EmbeddingList{
			Object: "list",
			Data:   data,
			Model:  model,
406
407
408
409
			Usage: EmbeddingUsage{
				PromptTokens: r.PromptEvalCount,
				TotalTokens:  r.PromptEvalCount,
			},
410
411
412
413
414
415
		}
	}

	return EmbeddingList{}
}

416
417
418
419
420
421
422
// floatsToBase64 encodes a []float32 to a base64 string
func floatsToBase64(floats []float32) string {
	var buf bytes.Buffer
	binary.Write(&buf, binary.LittleEndian, floats)
	return base64.StdEncoding.EncodeToString(buf.Bytes())
}

423
424
// ToModel converts an api.ShowResponse to Model
func ToModel(r api.ShowResponse, m string) Model {
425
426
427
428
429
430
431
432
	return Model{
		Id:      m,
		Object:  "model",
		Created: r.ModifiedAt.Unix(),
		OwnedBy: model.ParseName(m).Namespace,
	}
}

433
434
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
435
436
	var messages []api.Message
	for _, msg := range r.Messages {
437
438
439
440
441
442
443
		toolName := ""
		if strings.ToLower(msg.Role) == "tool" {
			toolName = msg.Name
			if toolName == "" && msg.ToolCallID != "" {
				toolName = nameFromToolCallID(r.Messages, msg.ToolCallID)
			}
		}
444
445
		switch content := msg.Content.(type) {
		case string:
446
			toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
447
448
449
			if err != nil {
				return nil, err
			}
450
			messages = append(messages, api.Message{Role: msg.Role, Content: content, Thinking: msg.Reasoning, ToolCalls: toolCalls, ToolName: toolName})
451
452
453
454
		case []any:
			for _, c := range content {
				data, ok := c.(map[string]any)
				if !ok {
Michael Yang's avatar
lint  
Michael Yang committed
455
					return nil, errors.New("invalid message format")
456
457
458
459
460
				}
				switch data["type"] {
				case "text":
					text, ok := data["text"].(string)
					if !ok {
Michael Yang's avatar
lint  
Michael Yang committed
461
						return nil, errors.New("invalid message format")
462
					}
463
					messages = append(messages, api.Message{Role: msg.Role, Content: text})
464
465
466
467
				case "image_url":
					var url string
					if urlMap, ok := data["image_url"].(map[string]any); ok {
						if url, ok = urlMap["url"].(string); !ok {
Michael Yang's avatar
lint  
Michael Yang committed
468
							return nil, errors.New("invalid message format")
469
470
471
						}
					} else {
						if url, ok = data["image_url"].(string); !ok {
Michael Yang's avatar
lint  
Michael Yang committed
472
							return nil, errors.New("invalid message format")
473
474
475
						}
					}

476
					types := []string{"jpeg", "jpg", "png", "webp"}
477
					valid := false
478
479
480
481
482
					// support blank mime type to match api/chat taking just unadorned base64
					if strings.HasPrefix(url, "data:;base64,") {
						url = strings.TrimPrefix(url, "data:;base64,")
						valid = true
					}
483
484
485
486
487
488
489
490
491
492
					for _, t := range types {
						prefix := "data:image/" + t + ";base64,"
						if strings.HasPrefix(url, prefix) {
							url = strings.TrimPrefix(url, prefix)
							valid = true
							break
						}
					}

					if !valid {
Michael Yang's avatar
lint  
Michael Yang committed
493
						return nil, errors.New("invalid image input")
494
495
496
497
					}

					img, err := base64.StdEncoding.DecodeString(url)
					if err != nil {
Michael Yang's avatar
lint  
Michael Yang committed
498
						return nil, errors.New("invalid message format")
499
					}
500
501

					messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
502
				default:
Michael Yang's avatar
lint  
Michael Yang committed
503
					return nil, errors.New("invalid message format")
504
505
				}
			}
506
507
508
			// since we might have added multiple messages above, if we have tools
			// calls we'll add them to the last message
			if len(messages) > 0 && len(msg.ToolCalls) > 0 {
509
				toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
510
511
512
513
				if err != nil {
					return nil, err
				}
				messages[len(messages)-1].ToolCalls = toolCalls
514
515
516
				if toolName != "" {
					messages[len(messages)-1].ToolName = toolName
				}
517
				messages[len(messages)-1].Thinking = msg.Reasoning
518
			}
519
		default:
520
			// content is only optional if tool calls are present
royjhan's avatar
royjhan committed
521
522
523
524
525
526
527
528
529
			if msg.ToolCalls == nil {
				return nil, fmt.Errorf("invalid message content type: %T", content)
			}

			toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
			for i, tc := range msg.ToolCalls {
				toolCalls[i].Function.Name = tc.Function.Name
				err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
				if err != nil {
Michael Yang's avatar
lint  
Michael Yang committed
530
					return nil, errors.New("invalid tool call arguments")
royjhan's avatar
royjhan committed
531
532
				}
			}
533
			messages = append(messages, api.Message{Role: msg.Role, Thinking: msg.Reasoning, ToolCalls: toolCalls})
534
		}
535
536
	}

537
	options := make(map[string]any)
538
539
540
541

	switch stop := r.Stop.(type) {
	case string:
		options["stop"] = []string{stop}
542
	case []any:
543
544
545
546
547
548
549
550
551
552
553
554
555
556
		var stops []string
		for _, s := range stop {
			if str, ok := s.(string); ok {
				stops = append(stops, str)
			}
		}
		options["stop"] = stops
	}

	if r.MaxTokens != nil {
		options["num_predict"] = *r.MaxTokens
	}

	if r.Temperature != nil {
557
		options["temperature"] = *r.Temperature
558
559
560
561
562
563
564
565
566
	} else {
		options["temperature"] = 1.0
	}

	if r.Seed != nil {
		options["seed"] = *r.Seed
	}

	if r.FrequencyPenalty != nil {
567
		options["frequency_penalty"] = *r.FrequencyPenalty
568
569
570
	}

	if r.PresencePenalty != nil {
571
		options["presence_penalty"] = *r.PresencePenalty
572
573
574
575
576
577
578
579
	}

	if r.TopP != nil {
		options["top_p"] = *r.TopP
	} else {
		options["top_p"] = 1.0
	}

580
581
582
583
584
585
586
587
	var format json.RawMessage
	if r.ResponseFormat != nil {
		switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
		// Support the old "json_object" type for OpenAI compatibility
		case "json_object":
			format = json.RawMessage(`"json"`)
		case "json_schema":
			if r.ResponseFormat.JsonSchema != nil {
588
				format = r.ResponseFormat.JsonSchema.Schema
589
590
			}
		}
591
592
	}

Michael Yang's avatar
Michael Yang committed
593
	var think *api.ThinkValue
594
595
	var effort string

Michael Yang's avatar
Michael Yang committed
596
	if r.Reasoning != nil {
597
598
599
600
601
602
603
604
		effort = r.Reasoning.Effort
	} else if r.ReasoningEffort != nil {
		effort = *r.ReasoningEffort
	}

	if effort != "" {
		if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) {
			return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort)
Michael Yang's avatar
Michael Yang committed
605
		}
606

607
		if effort == "none" {
608
609
			think = &api.ThinkValue{Value: false}
		} else {
610
			think = &api.ThinkValue{Value: effort}
611
		}
Michael Yang's avatar
Michael Yang committed
612
613
	}

614
	return &api.ChatRequest{
Devon Rifkin's avatar
Devon Rifkin committed
615
616
617
618
619
620
621
622
		Model:           r.Model,
		Messages:        messages,
		Format:          format,
		Options:         options,
		Stream:          &r.Stream,
		Tools:           r.Tools,
		Think:           think,
		DebugRenderOnly: r.DebugRenderOnly,
623
	}, nil
624
625
}

626
627
628
629
630
631
632
633
634
635
636
637
638
639
func nameFromToolCallID(messages []Message, toolCallID string) string {
	// iterate backwards to be more resilient to duplicate tool call IDs (this
	// follows "last one wins")
	for i := len(messages) - 1; i >= 0; i-- {
		msg := messages[i]
		for _, tc := range msg.ToolCalls {
			if tc.ID == toolCallID {
				return tc.Function.Name
			}
		}
	}
	return ""
}

640
641
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
642
643
644
645
646
647
648
649
650
651
652
653
	apiToolCalls := make([]api.ToolCall, len(toolCalls))
	for i, tc := range toolCalls {
		apiToolCalls[i].Function.Name = tc.Function.Name
		err := json.Unmarshal([]byte(tc.Function.Arguments), &apiToolCalls[i].Function.Arguments)
		if err != nil {
			return nil, errors.New("invalid tool call arguments")
		}
	}

	return apiToolCalls, nil
}

654
655
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
656
657
658
659
660
	options := make(map[string]any)

	switch stop := r.Stop.(type) {
	case string:
		options["stop"] = []string{stop}
661
662
663
664
665
666
667
668
	case []any:
		var stops []string
		for _, s := range stop {
			if str, ok := s.(string); ok {
				stops = append(stops, str)
			} else {
				return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
			}
669
		}
670
		options["stop"] = stops
671
672
673
674
675
676
677
	}

	if r.MaxTokens != nil {
		options["num_predict"] = *r.MaxTokens
	}

	if r.Temperature != nil {
678
		options["temperature"] = *r.Temperature
679
680
681
682
683
684
685
686
	} else {
		options["temperature"] = 1.0
	}

	if r.Seed != nil {
		options["seed"] = *r.Seed
	}

687
	options["frequency_penalty"] = r.FrequencyPenalty
688

689
	options["presence_penalty"] = r.PresencePenalty
690
691
692
693
694
695
696
697

	if r.TopP != 0.0 {
		options["top_p"] = r.TopP
	} else {
		options["top_p"] = 1.0
	}

	return api.GenerateRequest{
Devon Rifkin's avatar
Devon Rifkin committed
698
699
700
701
702
703
		Model:           r.Model,
		Prompt:          r.Prompt,
		Options:         options,
		Stream:          &r.Stream,
		Suffix:          r.Suffix,
		DebugRenderOnly: r.DebugRenderOnly,
704
705
	}, nil
}