Unverified Commit 0aff6787 authored by royjhan's avatar royjhan Committed by GitHub
Browse files

separate request tests (#5578)

parent 9544a57e
...@@ -3,7 +3,6 @@ package openai ...@@ -3,7 +3,6 @@ package openai
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
...@@ -16,49 +15,33 @@ import ( ...@@ -16,49 +15,33 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestMiddleware(t *testing.T) { func TestMiddlewareRequests(t *testing.T) {
type testCase struct { type testCase struct {
Name string Name string
Method string Method string
Path string Path string
TestPath string
Handler func() gin.HandlerFunc Handler func() gin.HandlerFunc
Endpoint func(c *gin.Context)
Setup func(t *testing.T, req *http.Request) Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *httptest.ResponseRecorder) Expected func(t *testing.T, req *http.Request)
} }
testCases := []testCase{ var capturedRequest *http.Request
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
TestPath: "/api/chat",
Handler: ChatMiddleware,
Endpoint: func(c *gin.Context) {
var chatReq api.ChatRequest
if err := c.ShouldBindJSON(&chatReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
userMessage := chatReq.Messages[0].Content
var assistantMessage string
switch userMessage { captureRequestMiddleware := func() gin.HandlerFunc {
case "Hello": return func(c *gin.Context) {
assistantMessage = "Hello!" bodyBytes, _ := io.ReadAll(c.Request.Body)
default: c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
assistantMessage = "I'm not sure how to respond to that." capturedRequest = c.Request
} c.Next()
}
}
c.JSON(http.StatusOK, api.ChatResponse{ testCases := []testCase{
Message: api.Message{ {
Role: "assistant", Name: "chat handler",
Content: assistantMessage, Method: http.MethodPost,
}, Path: "/api/chat",
}) Handler: ChatMiddleware,
},
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{ body := ChatCompletionRequest{
Model: "test-model", Model: "test-model",
...@@ -70,38 +53,32 @@ func TestMiddleware(t *testing.T) { ...@@ -70,38 +53,32 @@ func TestMiddleware(t *testing.T) {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
}, },
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { Expected: func(t *testing.T, req *http.Request) {
assert.Equal(t, http.StatusOK, resp.Code) var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
var chatResp ChatCompletion
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if chatResp.Object != "chat.completion" { if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected chat.completion, got %s", chatResp.Object) t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
} }
if chatResp.Choices[0].Message.Content != "Hello!" { if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content) t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
} }
}, },
}, },
{ {
Name: "completions handler", Name: "completions handler",
Method: http.MethodPost, Method: http.MethodPost,
Path: "/api/generate", Path: "/api/generate",
TestPath: "/api/generate", Handler: CompletionsMiddleware,
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.GenerateResponse{
Response: "Hello!",
})
},
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{ body := CompletionRequest{
Model: "test-model", Model: "test-model",
Prompt: "Hello", Prompt: "Hello",
Temperature: &temp,
} }
bodyBytes, _ := json.Marshal(body) bodyBytes, _ := json.Marshal(body)
...@@ -109,80 +86,65 @@ func TestMiddleware(t *testing.T) { ...@@ -109,80 +86,65 @@ func TestMiddleware(t *testing.T) {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
}, },
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { Expected: func(t *testing.T, req *http.Request) {
assert.Equal(t, http.StatusOK, resp.Code) var genReq api.GenerateRequest
var completionResp Completion if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if completionResp.Object != "text_completion" { if genReq.Prompt != "Hello" {
t.Fatalf("expected text_completion, got %s", completionResp.Object) t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
} }
if completionResp.Choices[0].Text != "Hello!" { if genReq.Options["temperature"] != 1.6 {
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text) t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
} }
}, },
}, },
{ }
Name: "completions handler with params",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
var generateReq api.GenerateRequest
if err := c.ShouldBindJSON(&generateReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
temperature := generateReq.Options["temperature"].(float64) gin.SetMode(gin.TestMode)
var assistantMessage string router := gin.New()
switch temperature { endpoint := func(c *gin.Context) {
case 1.6: c.Status(http.StatusOK)
assistantMessage = "Received temperature of 1.6" }
default:
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
}
c.JSON(http.StatusOK, api.GenerateResponse{ for _, tc := range testCases {
Response: assistantMessage, t.Run(tc.Name, func(t *testing.T) {
}) router = gin.New()
}, router.Use(captureRequestMiddleware())
Setup: func(t *testing.T, req *http.Request) { router.Use(tc.Handler())
temp := float32(0.8) router.Handle(tc.Method, tc.Path, endpoint)
body := CompletionRequest{ req, _ := http.NewRequest(tc.Method, tc.Path, nil)
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
}
bodyBytes, _ := json.Marshal(body) if tc.Setup != nil {
tc.Setup(t, req)
}
req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) resp := httptest.NewRecorder()
req.Header.Set("Content-Type", "application/json") router.ServeHTTP(resp, req)
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" { tc.Expected(t, capturedRequest)
t.Fatalf("expected text_completion, got %s", completionResp.Object) })
} }
}
if completionResp.Choices[0].Text != "Received temperature of 1.6" { func TestMiddlewareResponses(t *testing.T) {
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text) type testCase struct {
} Name string
}, Method string
}, Path string
TestPath string
Handler func() gin.HandlerFunc
Endpoint func(c *gin.Context)
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
}
testCases := []testCase{
{ {
Name: "completions handler with error", Name: "completions handler error forwarding",
Method: http.MethodPost, Method: http.MethodPost,
Path: "/api/generate", Path: "/api/generate",
TestPath: "/api/generate", TestPath: "/api/generate",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment