Commit 2c2f4dea authored by Devon Rifkin's avatar Devon Rifkin
Browse files

openai: refactor to split compat layer and middleware

This makes the core openai compat layer independent of the middleware
that adapts it to our particular gin routes
parent e4340667
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
)
type BaseWriter struct {
gin.ResponseWriter
}
type ChatWriter struct {
stream bool
streamOptions *openai.StreamOptions
id string
toolCallSent bool
BaseWriter
}
type CompleteWriter struct {
stream bool
streamOptions *openai.StreamOptions
id string
BaseWriter
}
type ListWriter struct {
BaseWriter
}
type RetrieveWriter struct {
BaseWriter
model string
}
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.NewError(http.StatusInternalServerError, serr.Error()))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
// chat chunk
if w.stream {
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
w.toolCallSent = true
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if chatResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := openai.ToUsage(chatResponse)
c.Usage = &u
c.Choices = []openai.ChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// chat completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToChatCompletion(w.id, chatResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// completion chunk
if w.stream {
c := openai.ToCompleteChunk(w.id, generateResponse)
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &openai.Usage{}
}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if generateResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := openai.ToUsageGenerate(generateResponse)
c.Usage = &u
c.Choices = []openai.CompleteChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToCompletion(w.id, generateResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *ListWriter) writeResponse(data []byte) (int, error) {
var listResponse api.ListResponse
err := json.Unmarshal(data, &listResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToListCompletion(listResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
var showResponse api.ShowResponse
err := json.Unmarshal(data, &showResponse)
if err != nil {
return 0, err
}
// retrieve completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToModel(showResponse, w.model))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func RetrieveMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &RetrieveWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: c.Param("model"),
}
c.Writer = w
c.Next()
}
}
func CompletionsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.CompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
genReq, err := openai.FromCompleteRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ChatCompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
return
}
var b bytes.Buffer
chatReq, err := openai.FromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}
This diff is collapsed.
// openai package provides middleware for partial compatibility with the OpenAI REST API // openai package provides core transformation logic for partial compatibility with the OpenAI REST API
package openai package openai
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"log/slog" "log/slog"
"math/rand" "math/rand"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
...@@ -220,11 +216,12 @@ func NewError(code int, message string) ErrorResponse { ...@@ -220,11 +216,12 @@ func NewError(code int, message string) ErrorResponse {
return ErrorResponse{Error{Type: etype, Message: message}} return ErrorResponse{Error{Type: etype, Message: message}}
} }
func toUsage(r api.ChatResponse) Usage { // ToUsage converts an api.ChatResponse to Usage
func ToUsage(r api.ChatResponse) Usage {
return Usage{ return Usage{
PromptTokens: r.PromptEvalCount, PromptTokens: r.Metrics.PromptEvalCount,
CompletionTokens: r.EvalCount, CompletionTokens: r.Metrics.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount, TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
} }
} }
...@@ -256,7 +253,8 @@ func toToolCalls(tc []api.ToolCall) []ToolCall { ...@@ -256,7 +253,8 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
return toolCalls return toolCalls
} }
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { // ToChatCompletion converts an api.ChatResponse to ChatCompletion
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := toToolCalls(r.Message.ToolCalls) toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletion{ return ChatCompletion{
Id: id, Id: id,
...@@ -276,12 +274,13 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { ...@@ -276,12 +274,13 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
} }
return nil return nil
}(r.DoneReason), }(r.DoneReason),
}}, Usage: toUsage(r), }}, Usage: ToUsage(r),
DebugInfo: r.DebugInfo, DebugInfo: r.DebugInfo,
} }
} }
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk { // ToChunk converts an api.ChatResponse to ChatCompletionChunk
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
toolCalls := toToolCalls(r.Message.ToolCalls) toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletionChunk{ return ChatCompletionChunk{
Id: id, Id: id,
...@@ -305,15 +304,17 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu ...@@ -305,15 +304,17 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
} }
} }
func toUsageGenerate(r api.GenerateResponse) Usage { // ToUsageGenerate converts an api.GenerateResponse to Usage
func ToUsageGenerate(r api.GenerateResponse) Usage {
return Usage{ return Usage{
PromptTokens: r.PromptEvalCount, PromptTokens: r.Metrics.PromptEvalCount,
CompletionTokens: r.EvalCount, CompletionTokens: r.Metrics.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount, TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
} }
} }
func toCompletion(id string, r api.GenerateResponse) Completion { // ToCompletion converts an api.GenerateResponse to Completion
func ToCompletion(id string, r api.GenerateResponse) Completion {
return Completion{ return Completion{
Id: id, Id: id,
Object: "text_completion", Object: "text_completion",
...@@ -330,11 +331,12 @@ func toCompletion(id string, r api.GenerateResponse) Completion { ...@@ -330,11 +331,12 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
return nil return nil
}(r.DoneReason), }(r.DoneReason),
}}, }},
Usage: toUsageGenerate(r), Usage: ToUsageGenerate(r),
} }
} }
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk { // ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
return CompletionChunk{ return CompletionChunk{
Id: id, Id: id,
Object: "text_completion", Object: "text_completion",
...@@ -354,7 +356,8 @@ func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk { ...@@ -354,7 +356,8 @@ func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
} }
} }
func toListCompletion(r api.ListResponse) ListCompletion { // ToListCompletion converts an api.ListResponse to ListCompletion
func ToListCompletion(r api.ListResponse) ListCompletion {
var data []Model var data []Model
for _, m := range r.Models { for _, m := range r.Models {
data = append(data, Model{ data = append(data, Model{
...@@ -371,7 +374,8 @@ func toListCompletion(r api.ListResponse) ListCompletion { ...@@ -371,7 +374,8 @@ func toListCompletion(r api.ListResponse) ListCompletion {
} }
} }
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { // ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
if r.Embeddings != nil { if r.Embeddings != nil {
var data []Embedding var data []Embedding
for i, e := range r.Embeddings { for i, e := range r.Embeddings {
...@@ -396,7 +400,8 @@ func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { ...@@ -396,7 +400,8 @@ func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
return EmbeddingList{} return EmbeddingList{}
} }
func toModel(r api.ShowResponse, m string) Model { // ToModel converts an api.ShowResponse to Model
func ToModel(r api.ShowResponse, m string) Model {
return Model{ return Model{
Id: m, Id: m,
Object: "model", Object: "model",
...@@ -405,7 +410,8 @@ func toModel(r api.ShowResponse, m string) Model { ...@@ -405,7 +410,8 @@ func toModel(r api.ShowResponse, m string) Model {
} }
} }
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { // FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var messages []api.Message var messages []api.Message
for _, msg := range r.Messages { for _, msg := range r.Messages {
toolName := "" toolName := ""
...@@ -609,7 +615,8 @@ func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) { ...@@ -609,7 +615,8 @@ func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
return apiToolCalls, nil return apiToolCalls, nil
} }
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { // FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
options := make(map[string]any) options := make(map[string]any)
switch stop := r.Stop.(type) { switch stop := r.Stop.(type) {
...@@ -660,413 +667,3 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { ...@@ -660,413 +667,3 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
DebugRenderOnly: r.DebugRenderOnly, DebugRenderOnly: r.DebugRenderOnly,
}, nil }, nil
} }
type BaseWriter struct {
gin.ResponseWriter
}
type ChatWriter struct {
stream bool
streamOptions *StreamOptions
id string
toolCallSent bool
BaseWriter
}
type CompleteWriter struct {
stream bool
streamOptions *StreamOptions
id string
BaseWriter
}
type ListWriter struct {
BaseWriter
}
type RetrieveWriter struct {
BaseWriter
model string
}
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
// chat chunk
if w.stream {
c := toChunk(w.id, chatResponse, w.toolCallSent)
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
w.toolCallSent = true
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if chatResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsage(chatResponse)
c.Usage = &u
c.Choices = []ChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// chat completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// completion chunk
if w.stream {
c := toCompleteChunk(w.id, generateResponse)
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &Usage{}
}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if generateResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsageGenerate(generateResponse)
c.Usage = &u
c.Choices = []CompleteChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *ListWriter) writeResponse(data []byte) (int, error) {
var listResponse api.ListResponse
err := json.Unmarshal(data, &listResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
var showResponse api.ShowResponse
err := json.Unmarshal(data, &showResponse)
if err != nil {
return 0, err
}
// retrieve completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func RetrieveMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
// response writer
w := &RetrieveWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: c.Param("model"),
}
c.Writer = w
c.Next()
}
}
func CompletionsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req CompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
genReq, err := fromCompleteRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req ChatCompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
return
}
var b bytes.Buffer
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}
This diff is collapsed.
...@@ -37,8 +37,8 @@ import ( ...@@ -37,8 +37,8 @@ import (
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/middleware"
"github.com/ollama/ollama/model/parsers" "github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
...@@ -1449,11 +1449,11 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { ...@@ -1449,11 +1449,11 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/embeddings", s.EmbeddingsHandler) r.POST("/api/embeddings", s.EmbeddingsHandler)
// Inference (OpenAI compatibility) // Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler) r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler) r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
if rc != nil { if rc != nil {
// wrap old with new // wrap old with new
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment