openai.go 20.2 KB
Newer Older
1
// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
2
3
4
package openai

import (
5
	"bytes"
6
	"encoding/base64"
7
	"encoding/binary"
8
	"encoding/json"
Michael Yang's avatar
lint  
Michael Yang committed
9
	"errors"
10
	"fmt"
royjhan's avatar
royjhan committed
11
	"log/slog"
12
	"net/http"
13
	"slices"
14
	"strings"
15
16
	"time"

17
	"github.com/ollama/ollama/api"
18
	"github.com/ollama/ollama/types/model"
19
20
)

21
22
var finishReasonToolCalls = "tool_calls"

23
type Error struct {
24
25
26
27
	Message string  `json:"message"`
	Type    string  `json:"type"`
	Param   any     `json:"param"`
	Code    *string `json:"code"`
28
29
30
31
32
33
34
}

type ErrorResponse struct {
	Error Error `json:"error"`
}

type Message struct {
35
36
37
38
39
40
	Role       string     `json:"role"`
	Content    any        `json:"content"`
	Reasoning  string     `json:"reasoning,omitempty"`
	ToolCalls  []ToolCall `json:"tool_calls,omitempty"`
	Name       string     `json:"name,omitempty"`
	ToolCallID string     `json:"tool_call_id,omitempty"`
41
42
}

43
44
45
46
type ChoiceLogprobs struct {
	Content []api.Logprob `json:"content"`
}

47
type Choice struct {
48
49
50
51
	Index        int             `json:"index"`
	Message      Message         `json:"message"`
	FinishReason *string         `json:"finish_reason"`
	Logprobs     *ChoiceLogprobs `json:"logprobs,omitempty"`
52
53
54
}

type ChunkChoice struct {
55
56
57
58
	Index        int             `json:"index"`
	Delta        Message         `json:"delta"`
	FinishReason *string         `json:"finish_reason"`
	Logprobs     *ChoiceLogprobs `json:"logprobs,omitempty"`
59
60
}

61
type CompleteChunkChoice struct {
62
63
64
65
	Text         string          `json:"text"`
	Index        int             `json:"index"`
	FinishReason *string         `json:"finish_reason"`
	Logprobs     *ChoiceLogprobs `json:"logprobs,omitempty"`
66
67
}

68
69
70
71
72
73
74
type Usage struct {
	PromptTokens     int `json:"prompt_tokens"`
	CompletionTokens int `json:"completion_tokens"`
	TotalTokens      int `json:"total_tokens"`
}

type ResponseFormat struct {
75
76
77
78
79
	Type       string      `json:"type"`
	JsonSchema *JsonSchema `json:"json_schema,omitempty"`
}

type JsonSchema struct {
80
	Schema json.RawMessage `json:"schema"`
81
82
}

83
type EmbedRequest struct {
84
85
86
87
	Input          any    `json:"input"`
	Model          string `json:"model"`
	Dimensions     int    `json:"dimensions,omitempty"`
	EncodingFormat string `json:"encoding_format,omitempty"` // "float" or "base64"
88
89
}

90
91
92
93
type StreamOptions struct {
	IncludeUsage bool `json:"include_usage"`
}

Michael Yang's avatar
Michael Yang committed
94
type Reasoning struct {
95
	Effort string `json:"effort,omitempty"`
Michael Yang's avatar
Michael Yang committed
96
97
}

98
99
100
101
type ChatCompletionRequest struct {
	Model            string          `json:"model"`
	Messages         []Message       `json:"messages"`
	Stream           bool            `json:"stream"`
102
	StreamOptions    *StreamOptions  `json:"stream_options"`
103
104
105
106
107
	MaxTokens        *int            `json:"max_tokens"`
	Seed             *int            `json:"seed"`
	Stop             any             `json:"stop"`
	Temperature      *float64        `json:"temperature"`
	FrequencyPenalty *float64        `json:"frequency_penalty"`
108
	PresencePenalty  *float64        `json:"presence_penalty"`
109
110
	TopP             *float64        `json:"top_p"`
	ResponseFormat   *ResponseFormat `json:"response_format"`
royjhan's avatar
royjhan committed
111
	Tools            []api.Tool      `json:"tools"`
Michael Yang's avatar
Michael Yang committed
112
	Reasoning        *Reasoning      `json:"reasoning,omitempty"`
113
	ReasoningEffort  *string         `json:"reasoning_effort,omitempty"`
114
115
	Logprobs         *bool           `json:"logprobs"`
	TopLogprobs      int             `json:"top_logprobs"`
Devon Rifkin's avatar
Devon Rifkin committed
116
	DebugRenderOnly  bool            `json:"_debug_render_only"`
117
118
119
}

type ChatCompletion struct {
Devon Rifkin's avatar
Devon Rifkin committed
120
121
122
123
124
125
126
127
	Id                string         `json:"id"`
	Object            string         `json:"object"`
	Created           int64          `json:"created"`
	Model             string         `json:"model"`
	SystemFingerprint string         `json:"system_fingerprint"`
	Choices           []Choice       `json:"choices"`
	Usage             Usage          `json:"usage,omitempty"`
	DebugInfo         *api.DebugInfo `json:"_debug_info,omitempty"`
128
129
130
131
132
133
134
135
136
}

type ChatCompletionChunk struct {
	Id                string        `json:"id"`
	Object            string        `json:"object"`
	Created           int64         `json:"created"`
	Model             string        `json:"model"`
	SystemFingerprint string        `json:"system_fingerprint"`
	Choices           []ChunkChoice `json:"choices"`
137
	Usage             *Usage        `json:"usage,omitempty"`
138
139
}

140
141
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
type CompletionRequest struct {
142
143
144
145
146
147
148
149
150
151
152
153
	Model            string         `json:"model"`
	Prompt           string         `json:"prompt"`
	FrequencyPenalty float32        `json:"frequency_penalty"`
	MaxTokens        *int           `json:"max_tokens"`
	PresencePenalty  float32        `json:"presence_penalty"`
	Seed             *int           `json:"seed"`
	Stop             any            `json:"stop"`
	Stream           bool           `json:"stream"`
	StreamOptions    *StreamOptions `json:"stream_options"`
	Temperature      *float32       `json:"temperature"`
	TopP             float32        `json:"top_p"`
	Suffix           string         `json:"suffix"`
154
	Logprobs         *int           `json:"logprobs"`
Devon Rifkin's avatar
Devon Rifkin committed
155
	DebugRenderOnly  bool           `json:"_debug_render_only"`
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
}

type Completion struct {
	Id                string                `json:"id"`
	Object            string                `json:"object"`
	Created           int64                 `json:"created"`
	Model             string                `json:"model"`
	SystemFingerprint string                `json:"system_fingerprint"`
	Choices           []CompleteChunkChoice `json:"choices"`
	Usage             Usage                 `json:"usage,omitempty"`
}

type CompletionChunk struct {
	Id                string                `json:"id"`
	Object            string                `json:"object"`
	Created           int64                 `json:"created"`
	Choices           []CompleteChunkChoice `json:"choices"`
	Model             string                `json:"model"`
	SystemFingerprint string                `json:"system_fingerprint"`
175
	Usage             *Usage                `json:"usage,omitempty"`
176
177
}

royjhan's avatar
royjhan committed
178
179
type ToolCall struct {
	ID       string `json:"id"`
180
	Index    int    `json:"index"`
royjhan's avatar
royjhan committed
181
182
183
184
185
186
187
	Type     string `json:"type"`
	Function struct {
		Name      string `json:"name"`
		Arguments string `json:"arguments"`
	} `json:"function"`
}

188
189
190
191
192
193
194
type Model struct {
	Id      string `json:"id"`
	Object  string `json:"object"`
	Created int64  `json:"created"`
	OwnedBy string `json:"owned_by"`
}

195
type Embedding struct {
196
197
198
	Object    string `json:"object"`
	Embedding any    `json:"embedding"` // Can be []float32 (float format) or string (base64 format)
	Index     int    `json:"index"`
199
200
}

201
202
203
204
205
type ListCompletion struct {
	Object string  `json:"object"`
	Data   []Model `json:"data"`
}

206
type EmbeddingList struct {
207
208
209
210
211
212
213
214
215
	Object string         `json:"object"`
	Data   []Embedding    `json:"data"`
	Model  string         `json:"model"`
	Usage  EmbeddingUsage `json:"usage,omitempty"`
}

type EmbeddingUsage struct {
	PromptTokens int `json:"prompt_tokens"`
	TotalTokens  int `json:"total_tokens"`
216
217
}

218
219
220
221
222
223
224
225
226
227
228
229
230
231
func NewError(code int, message string) ErrorResponse {
	var etype string
	switch code {
	case http.StatusBadRequest:
		etype = "invalid_request_error"
	case http.StatusNotFound:
		etype = "not_found_error"
	default:
		etype = "api_error"
	}

	return ErrorResponse{Error{Type: etype, Message: message}}
}

232
233
// ToUsage converts an api.ChatResponse to Usage
func ToUsage(r api.ChatResponse) Usage {
234
	return Usage{
235
236
237
		PromptTokens:     r.Metrics.PromptEvalCount,
		CompletionTokens: r.Metrics.EvalCount,
		TotalTokens:      r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
238
239
240
	}
}

241
242
// ToToolCalls converts api.ToolCall to OpenAI ToolCall format
func ToToolCalls(tc []api.ToolCall) []ToolCall {
243
244
	toolCalls := make([]ToolCall, len(tc))
	for i, tc := range tc {
Grace's avatar
Grace committed
245
		toolCalls[i].ID = tc.ID
royjhan's avatar
royjhan committed
246
247
		toolCalls[i].Type = "function"
		toolCalls[i].Function.Name = tc.Function.Name
248
		toolCalls[i].Index = tc.Function.Index
royjhan's avatar
royjhan committed
249
250
251
252
253
254
255
256
257

		args, err := json.Marshal(tc.Function.Arguments)
		if err != nil {
			slog.Error("could not marshall function arguments to json", "error", err)
			continue
		}

		toolCalls[i].Function.Arguments = string(args)
	}
258
259
	return toolCalls
}
royjhan's avatar
royjhan committed
260

261
262
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
263
	toolCalls := ToToolCalls(r.Message.ToolCalls)
264
265
266
267
268
269

	var logprobs *ChoiceLogprobs
	if len(r.Logprobs) > 0 {
		logprobs = &ChoiceLogprobs{Content: r.Logprobs}
	}

270
271
272
273
274
275
276
	return ChatCompletion{
		Id:                id,
		Object:            "chat.completion",
		Created:           r.CreatedAt.Unix(),
		Model:             r.Model,
		SystemFingerprint: "fp_ollama",
		Choices: []Choice{{
277
			Index:   0,
Michael Yang's avatar
Michael Yang committed
278
			Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
279
			FinishReason: func(reason string) *string {
280
281
282
				if len(toolCalls) > 0 {
					reason = "tool_calls"
				}
283
284
285
286
287
				if len(reason) > 0 {
					return &reason
				}
				return nil
			}(r.DoneReason),
288
			Logprobs: logprobs,
289
		}}, Usage: ToUsage(r),
Devon Rifkin's avatar
Devon Rifkin committed
290
		DebugInfo: r.DebugInfo,
291
292
293
	}
}

294
295
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
296
	toolCalls := ToToolCalls(r.Message.ToolCalls)
297
298
299
300
301
302

	var logprobs *ChoiceLogprobs
	if len(r.Logprobs) > 0 {
		logprobs = &ChoiceLogprobs{Content: r.Logprobs}
	}

303
304
305
306
307
308
	return ChatCompletionChunk{
		Id:                id,
		Object:            "chat.completion.chunk",
		Created:           time.Now().Unix(),
		Model:             r.Model,
		SystemFingerprint: "fp_ollama",
309
310
		Choices: []ChunkChoice{{
			Index: 0,
Michael Yang's avatar
Michael Yang committed
311
			Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
312
313
			FinishReason: func(reason string) *string {
				if len(reason) > 0 {
Michael Yang's avatar
Michael Yang committed
314
					if toolCallSent || len(toolCalls) > 0 {
315
316
						return &finishReasonToolCalls
					}
317
318
319
320
					return &reason
				}
				return nil
			}(r.DoneReason),
321
			Logprobs: logprobs,
322
		}},
323
324
325
	}
}

326
327
// ToUsageGenerate converts an api.GenerateResponse to Usage
func ToUsageGenerate(r api.GenerateResponse) Usage {
328
	return Usage{
329
330
331
		PromptTokens:     r.Metrics.PromptEvalCount,
		CompletionTokens: r.Metrics.EvalCount,
		TotalTokens:      r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
332
333
334
	}
}

335
336
// ToCompletion converts an api.GenerateResponse to Completion
func ToCompletion(id string, r api.GenerateResponse) Completion {
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
	return Completion{
		Id:                id,
		Object:            "text_completion",
		Created:           r.CreatedAt.Unix(),
		Model:             r.Model,
		SystemFingerprint: "fp_ollama",
		Choices: []CompleteChunkChoice{{
			Text:  r.Response,
			Index: 0,
			FinishReason: func(reason string) *string {
				if len(reason) > 0 {
					return &reason
				}
				return nil
			}(r.DoneReason),
		}},
353
		Usage: ToUsageGenerate(r),
354
355
356
	}
}

357
358
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
	return CompletionChunk{
		Id:                id,
		Object:            "text_completion",
		Created:           time.Now().Unix(),
		Model:             r.Model,
		SystemFingerprint: "fp_ollama",
		Choices: []CompleteChunkChoice{{
			Text:  r.Response,
			Index: 0,
			FinishReason: func(reason string) *string {
				if len(reason) > 0 {
					return &reason
				}
				return nil
			}(r.DoneReason),
		}},
	}
}

378
379
// ToListCompletion converts an api.ListResponse to ListCompletion
func ToListCompletion(r api.ListResponse) ListCompletion {
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
	var data []Model
	for _, m := range r.Models {
		data = append(data, Model{
			Id:      m.Name,
			Object:  "model",
			Created: m.ModifiedAt.Unix(),
			OwnedBy: model.ParseName(m.Name).Namespace,
		})
	}

	return ListCompletion{
		Object: "list",
		Data:   data,
	}
}

396
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
397
398
// encodingFormat can be "float", "base64", or empty (defaults to "float")
func ToEmbeddingList(model string, r api.EmbedResponse, encodingFormat string) EmbeddingList {
399
400
401
	if r.Embeddings != nil {
		var data []Embedding
		for i, e := range r.Embeddings {
402
403
404
405
406
407
408
			var embedding any
			if strings.EqualFold(encodingFormat, "base64") {
				embedding = floatsToBase64(e)
			} else {
				embedding = e
			}

409
410
			data = append(data, Embedding{
				Object:    "embedding",
411
				Embedding: embedding,
412
413
414
415
416
417
418
419
				Index:     i,
			})
		}

		return EmbeddingList{
			Object: "list",
			Data:   data,
			Model:  model,
420
421
422
423
			Usage: EmbeddingUsage{
				PromptTokens: r.PromptEvalCount,
				TotalTokens:  r.PromptEvalCount,
			},
424
425
426
427
428
429
		}
	}

	return EmbeddingList{}
}

430
431
432
433
434
435
436
// floatsToBase64 encodes a []float32 to a base64 string
func floatsToBase64(floats []float32) string {
	var buf bytes.Buffer
	binary.Write(&buf, binary.LittleEndian, floats)
	return base64.StdEncoding.EncodeToString(buf.Bytes())
}

437
438
// ToModel converts an api.ShowResponse to Model
func ToModel(r api.ShowResponse, m string) Model {
439
440
441
442
443
444
445
446
	return Model{
		Id:      m,
		Object:  "model",
		Created: r.ModifiedAt.Unix(),
		OwnedBy: model.ParseName(m).Namespace,
	}
}

447
448
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
449
450
	var messages []api.Message
	for _, msg := range r.Messages {
451
452
453
454
455
456
457
		toolName := ""
		if strings.ToLower(msg.Role) == "tool" {
			toolName = msg.Name
			if toolName == "" && msg.ToolCallID != "" {
				toolName = nameFromToolCallID(r.Messages, msg.ToolCallID)
			}
		}
458
459
		switch content := msg.Content.(type) {
		case string:
460
			toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
461
462
463
			if err != nil {
				return nil, err
			}
464
			messages = append(messages, api.Message{Role: msg.Role, Content: content, Thinking: msg.Reasoning, ToolCalls: toolCalls, ToolName: toolName, ToolCallID: msg.ToolCallID})
465
466
467
468
		case []any:
			for _, c := range content {
				data, ok := c.(map[string]any)
				if !ok {
Michael Yang's avatar
lint  
Michael Yang committed
469
					return nil, errors.New("invalid message format")
470
471
472
473
474
				}
				switch data["type"] {
				case "text":
					text, ok := data["text"].(string)
					if !ok {
Michael Yang's avatar
lint  
Michael Yang committed
475
						return nil, errors.New("invalid message format")
476
					}
477
					messages = append(messages, api.Message{Role: msg.Role, Content: text})
478
479
480
481
				case "image_url":
					var url string
					if urlMap, ok := data["image_url"].(map[string]any); ok {
						if url, ok = urlMap["url"].(string); !ok {
Michael Yang's avatar
lint  
Michael Yang committed
482
							return nil, errors.New("invalid message format")
483
484
485
						}
					} else {
						if url, ok = data["image_url"].(string); !ok {
Michael Yang's avatar
lint  
Michael Yang committed
486
							return nil, errors.New("invalid message format")
487
488
489
						}
					}

490
					types := []string{"jpeg", "jpg", "png", "webp"}
491
					valid := false
492
493
494
495
496
					// support blank mime type to match api/chat taking just unadorned base64
					if strings.HasPrefix(url, "data:;base64,") {
						url = strings.TrimPrefix(url, "data:;base64,")
						valid = true
					}
497
498
499
500
501
502
503
504
505
506
					for _, t := range types {
						prefix := "data:image/" + t + ";base64,"
						if strings.HasPrefix(url, prefix) {
							url = strings.TrimPrefix(url, prefix)
							valid = true
							break
						}
					}

					if !valid {
Michael Yang's avatar
lint  
Michael Yang committed
507
						return nil, errors.New("invalid image input")
508
509
510
511
					}

					img, err := base64.StdEncoding.DecodeString(url)
					if err != nil {
Michael Yang's avatar
lint  
Michael Yang committed
512
						return nil, errors.New("invalid message format")
513
					}
514
515

					messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
516
				default:
Michael Yang's avatar
lint  
Michael Yang committed
517
					return nil, errors.New("invalid message format")
518
519
				}
			}
520
521
522
			// since we might have added multiple messages above, if we have tools
			// calls we'll add them to the last message
			if len(messages) > 0 && len(msg.ToolCalls) > 0 {
523
				toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
524
525
526
527
				if err != nil {
					return nil, err
				}
				messages[len(messages)-1].ToolCalls = toolCalls
528
529
				messages[len(messages)-1].ToolName = toolName
				messages[len(messages)-1].ToolCallID = msg.ToolCallID
530
				messages[len(messages)-1].Thinking = msg.Reasoning
531
			}
532
		default:
533
			// content is only optional if tool calls are present
royjhan's avatar
royjhan committed
534
535
536
537
			if msg.ToolCalls == nil {
				return nil, fmt.Errorf("invalid message content type: %T", content)
			}

538
539
540
			toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
			if err != nil {
				return nil, err
royjhan's avatar
royjhan committed
541
			}
542
			messages = append(messages, api.Message{Role: msg.Role, Thinking: msg.Reasoning, ToolCalls: toolCalls, ToolCallID: msg.ToolCallID})
543
		}
544
545
	}

546
	options := make(map[string]any)
547
548
549
550

	switch stop := r.Stop.(type) {
	case string:
		options["stop"] = []string{stop}
551
	case []any:
552
553
554
555
556
557
558
559
560
561
562
563
564
565
		var stops []string
		for _, s := range stop {
			if str, ok := s.(string); ok {
				stops = append(stops, str)
			}
		}
		options["stop"] = stops
	}

	if r.MaxTokens != nil {
		options["num_predict"] = *r.MaxTokens
	}

	if r.Temperature != nil {
566
		options["temperature"] = *r.Temperature
567
568
569
570
571
572
573
574
575
	} else {
		options["temperature"] = 1.0
	}

	if r.Seed != nil {
		options["seed"] = *r.Seed
	}

	if r.FrequencyPenalty != nil {
576
		options["frequency_penalty"] = *r.FrequencyPenalty
577
578
579
	}

	if r.PresencePenalty != nil {
580
		options["presence_penalty"] = *r.PresencePenalty
581
582
583
584
585
586
587
588
	}

	if r.TopP != nil {
		options["top_p"] = *r.TopP
	} else {
		options["top_p"] = 1.0
	}

589
590
591
592
593
594
595
596
	var format json.RawMessage
	if r.ResponseFormat != nil {
		switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
		// Support the old "json_object" type for OpenAI compatibility
		case "json_object":
			format = json.RawMessage(`"json"`)
		case "json_schema":
			if r.ResponseFormat.JsonSchema != nil {
597
				format = r.ResponseFormat.JsonSchema.Schema
598
599
			}
		}
600
601
	}

Michael Yang's avatar
Michael Yang committed
602
	var think *api.ThinkValue
603
604
	var effort string

Michael Yang's avatar
Michael Yang committed
605
	if r.Reasoning != nil {
606
607
608
609
610
611
612
613
		effort = r.Reasoning.Effort
	} else if r.ReasoningEffort != nil {
		effort = *r.ReasoningEffort
	}

	if effort != "" {
		if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) {
			return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort)
Michael Yang's avatar
Michael Yang committed
614
		}
615

616
		if effort == "none" {
617
618
			think = &api.ThinkValue{Value: false}
		} else {
619
			think = &api.ThinkValue{Value: effort}
620
		}
Michael Yang's avatar
Michael Yang committed
621
622
	}

623
	return &api.ChatRequest{
Devon Rifkin's avatar
Devon Rifkin committed
624
625
626
627
628
629
630
		Model:           r.Model,
		Messages:        messages,
		Format:          format,
		Options:         options,
		Stream:          &r.Stream,
		Tools:           r.Tools,
		Think:           think,
631
632
		Logprobs:        r.Logprobs != nil && *r.Logprobs,
		TopLogprobs:     r.TopLogprobs,
Devon Rifkin's avatar
Devon Rifkin committed
633
		DebugRenderOnly: r.DebugRenderOnly,
634
	}, nil
635
636
}

637
638
639
640
641
642
643
644
645
646
647
648
649
650
func nameFromToolCallID(messages []Message, toolCallID string) string {
	// iterate backwards to be more resilient to duplicate tool call IDs (this
	// follows "last one wins")
	for i := len(messages) - 1; i >= 0; i-- {
		msg := messages[i]
		for _, tc := range msg.ToolCalls {
			if tc.ID == toolCallID {
				return tc.Function.Name
			}
		}
	}
	return ""
}

651
652
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
653
654
	apiToolCalls := make([]api.ToolCall, len(toolCalls))
	for i, tc := range toolCalls {
655
		apiToolCalls[i].ID = tc.ID
656
657
658
659
660
661
662
663
664
665
		apiToolCalls[i].Function.Name = tc.Function.Name
		err := json.Unmarshal([]byte(tc.Function.Arguments), &apiToolCalls[i].Function.Arguments)
		if err != nil {
			return nil, errors.New("invalid tool call arguments")
		}
	}

	return apiToolCalls, nil
}

666
667
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
668
669
670
671
672
	options := make(map[string]any)

	switch stop := r.Stop.(type) {
	case string:
		options["stop"] = []string{stop}
673
674
675
676
677
678
679
680
	case []any:
		var stops []string
		for _, s := range stop {
			if str, ok := s.(string); ok {
				stops = append(stops, str)
			} else {
				return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
			}
681
		}
682
		options["stop"] = stops
683
684
685
686
687
688
689
	}

	if r.MaxTokens != nil {
		options["num_predict"] = *r.MaxTokens
	}

	if r.Temperature != nil {
690
		options["temperature"] = *r.Temperature
691
692
693
694
695
696
697
698
	} else {
		options["temperature"] = 1.0
	}

	if r.Seed != nil {
		options["seed"] = *r.Seed
	}

699
	options["frequency_penalty"] = r.FrequencyPenalty
700

701
	options["presence_penalty"] = r.PresencePenalty
702
703
704
705
706
707
708

	if r.TopP != 0.0 {
		options["top_p"] = r.TopP
	} else {
		options["top_p"] = 1.0
	}

709
710
711
712
713
714
715
	var logprobs bool
	var topLogprobs int
	if r.Logprobs != nil && *r.Logprobs > 0 {
		logprobs = true
		topLogprobs = *r.Logprobs
	}

716
	return api.GenerateRequest{
Devon Rifkin's avatar
Devon Rifkin committed
717
718
719
720
721
		Model:           r.Model,
		Prompt:          r.Prompt,
		Options:         options,
		Stream:          &r.Stream,
		Suffix:          r.Suffix,
722
723
		Logprobs:        logprobs,
		TopLogprobs:     topLogprobs,
Devon Rifkin's avatar
Devon Rifkin committed
724
		DebugRenderOnly: r.DebugRenderOnly,
725
726
	}, nil
}