Commit 160cecc8 authored by Devon Rifkin's avatar Devon Rifkin
Browse files

openai: make tool call conversion fns public

parent 8b6e5bae
...@@ -235,7 +235,8 @@ func toolCallId() string { ...@@ -235,7 +235,8 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b)) return "call_" + strings.ToLower(string(b))
} }
func toToolCalls(tc []api.ToolCall) []ToolCall { // ToToolCalls converts api.ToolCall to OpenAI ToolCall format
func ToToolCalls(tc []api.ToolCall) []ToolCall {
toolCalls := make([]ToolCall, len(tc)) toolCalls := make([]ToolCall, len(tc))
for i, tc := range tc { for i, tc := range tc {
toolCalls[i].ID = toolCallId() toolCalls[i].ID = toolCallId()
...@@ -256,7 +257,7 @@ func toToolCalls(tc []api.ToolCall) []ToolCall { ...@@ -256,7 +257,7 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
// ToChatCompletion converts an api.ChatResponse to ChatCompletion // ToChatCompletion converts an api.ChatResponse to ChatCompletion
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion { func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := toToolCalls(r.Message.ToolCalls) toolCalls := ToToolCalls(r.Message.ToolCalls)
return ChatCompletion{ return ChatCompletion{
Id: id, Id: id,
Object: "chat.completion", Object: "chat.completion",
...@@ -282,7 +283,7 @@ func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion { ...@@ -282,7 +283,7 @@ func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
// ToChunk converts an api.ChatResponse to ChatCompletionChunk // ToChunk converts an api.ChatResponse to ChatCompletionChunk
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk { func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
toolCalls := toToolCalls(r.Message.ToolCalls) toolCalls := ToToolCalls(r.Message.ToolCalls)
return ChatCompletionChunk{ return ChatCompletionChunk{
Id: id, Id: id,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
...@@ -424,7 +425,7 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { ...@@ -424,7 +425,7 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
} }
switch content := msg.Content.(type) { switch content := msg.Content.(type) {
case string: case string:
toolCalls, err := fromCompletionToolCall(msg.ToolCalls) toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -487,7 +488,7 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { ...@@ -487,7 +488,7 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
// since we might have added multiple messages above, if we have tools // since we might have added multiple messages above, if we have tools
// calls we'll add them to the last message // calls we'll add them to the last message
if len(messages) > 0 && len(msg.ToolCalls) > 0 { if len(messages) > 0 && len(msg.ToolCalls) > 0 {
toolCalls, err := fromCompletionToolCall(msg.ToolCalls) toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -618,7 +619,8 @@ func nameFromToolCallID(messages []Message, toolCallID string) string { ...@@ -618,7 +619,8 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
return "" return ""
} }
func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) { // FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
apiToolCalls := make([]api.ToolCall, len(toolCalls)) apiToolCalls := make([]api.ToolCall, len(toolCalls))
for i, tc := range toolCalls { for i, tc := range toolCalls {
apiToolCalls[i].Function.Name = tc.Function.Name apiToolCalls[i].Function.Name = tc.Function.Name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment