Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
c57317cb
Unverified
Commit
c57317cb
authored
Jul 19, 2024
by
royjhan
Committed by
GitHub
Jul 19, 2024
Browse files
OpenAI: Function Based Testing (#5752)
* distinguish error forwarding * more coverage * rm comment
parent
51b2fd29
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
266 additions
and
170 deletions
+266
-170
openai/openai.go
openai/openai.go
+1
-0
openai/openai_test.go
openai/openai_test.go
+265
-170
No files found.
openai/openai.go
View file @
c57317cb
...
@@ -877,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc {
...
@@ -877,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc {
chatReq
,
err
:=
fromChatRequest
(
req
)
chatReq
,
err
:=
fromChatRequest
(
req
)
if
err
!=
nil
{
if
err
!=
nil
{
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
NewError
(
http
.
StatusBadRequest
,
err
.
Error
()))
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
NewError
(
http
.
StatusBadRequest
,
err
.
Error
()))
return
}
}
if
err
:=
json
.
NewEncoder
(
&
b
)
.
Encode
(
chatReq
);
err
!=
nil
{
if
err
:=
json
.
NewEncoder
(
&
b
)
.
Encode
(
chatReq
);
err
!=
nil
{
...
...
openai/openai_test.go
View file @
c57317cb
...
@@ -20,64 +20,195 @@ const prefix = `data:image/jpeg;base64,`
...
@@ -20,64 +20,195 @@ const prefix = `data:image/jpeg;base64,`
const
image
=
`iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
const
image
=
`iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
const
imageURL
=
prefix
+
image
const
imageURL
=
prefix
+
image
func
TestMiddlewareRequests
(
t
*
testing
.
T
)
{
func
prepareRequest
(
req
*
http
.
Request
,
body
any
)
{
bodyBytes
,
_
:=
json
.
Marshal
(
body
)
req
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
bodyBytes
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
}
func
captureRequestMiddleware
(
capturedRequest
any
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
bodyBytes
,
_
:=
io
.
ReadAll
(
c
.
Request
.
Body
)
c
.
Request
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
bodyBytes
))
err
:=
json
.
Unmarshal
(
bodyBytes
,
capturedRequest
)
if
err
!=
nil
{
c
.
AbortWithStatusJSON
(
http
.
StatusInternalServerError
,
"failed to unmarshal request"
)
}
c
.
Next
()
}
}
func
TestChatMiddleware
(
t
*
testing
.
T
)
{
type
testCase
struct
{
type
testCase
struct
{
Name
string
Name
string
Method
string
Path
string
Handler
func
()
gin
.
HandlerFunc
Setup
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
Setup
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
Expected
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
Expected
func
(
t
*
testing
.
T
,
req
*
api
.
ChatRequest
,
resp
*
httptest
.
ResponseRecorder
)
}
}
var
capturedRequest
*
http
.
Request
var
capturedRequest
*
api
.
ChatRequest
captureRequestMiddleware
:=
func
()
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
bodyBytes
,
_
:=
io
.
ReadAll
(
c
.
Request
.
Body
)
c
.
Request
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
bodyBytes
))
capturedRequest
=
c
.
Request
c
.
Next
()
}
}
testCases
:=
[]
testCase
{
testCases
:=
[]
testCase
{
{
{
Name
:
"chat handler"
,
Name
:
"chat handler"
,
Method
:
http
.
MethodPost
,
Path
:
"/api/chat"
,
Handler
:
ChatMiddleware
,
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
body
:=
ChatCompletionRequest
{
body
:=
ChatCompletionRequest
{
Model
:
"test-model"
,
Model
:
"test-model"
,
Messages
:
[]
Message
{{
Role
:
"user"
,
Content
:
"Hello"
}},
Messages
:
[]
Message
{{
Role
:
"user"
,
Content
:
"Hello"
}},
}
}
prepareRequest
(
req
,
body
)
},
Expected
:
func
(
t
*
testing
.
T
,
req
*
api
.
ChatRequest
,
resp
*
httptest
.
ResponseRecorder
)
{
if
resp
.
Code
!=
http
.
StatusOK
{
t
.
Fatalf
(
"expected 200, got %d"
,
resp
.
Code
)
}
bodyBytes
,
_
:=
json
.
Marshal
(
body
)
if
req
.
Messages
[
0
]
.
Role
!=
"user"
{
t
.
Fatalf
(
"expected 'user', got %s"
,
req
.
Messages
[
0
]
.
Role
)
}
req
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
bodyBytes
))
if
req
.
Messages
[
0
]
.
Content
!=
"Hello"
{
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
t
.
Fatalf
(
"expected 'Hello', got %s"
,
req
.
Messages
[
0
]
.
Content
)
}
},
},
Expected
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
},
var
chatReq
api
.
ChatRequest
{
if
err
:=
json
.
NewDecoder
(
req
.
Body
)
.
Decode
(
&
chatReq
);
err
!=
nil
{
Name
:
"chat handler with image content"
,
t
.
Fatal
(
err
)
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
body
:=
ChatCompletionRequest
{
Model
:
"test-model"
,
Messages
:
[]
Message
{
{
Role
:
"user"
,
Content
:
[]
map
[
string
]
any
{
{
"type"
:
"text"
,
"text"
:
"Hello"
},
{
"type"
:
"image_url"
,
"image_url"
:
map
[
string
]
string
{
"url"
:
imageURL
}},
},
},
},
}
prepareRequest
(
req
,
body
)
},
Expected
:
func
(
t
*
testing
.
T
,
req
*
api
.
ChatRequest
,
resp
*
httptest
.
ResponseRecorder
)
{
if
resp
.
Code
!=
http
.
StatusOK
{
t
.
Fatalf
(
"expected 200, got %d"
,
resp
.
Code
)
}
if
req
.
Messages
[
0
]
.
Role
!=
"user"
{
t
.
Fatalf
(
"expected 'user', got %s"
,
req
.
Messages
[
0
]
.
Role
)
}
if
req
.
Messages
[
0
]
.
Content
!=
"Hello"
{
t
.
Fatalf
(
"expected 'Hello', got %s"
,
req
.
Messages
[
0
]
.
Content
)
}
img
,
_
:=
base64
.
StdEncoding
.
DecodeString
(
imageURL
[
len
(
prefix
)
:
])
if
req
.
Messages
[
1
]
.
Role
!=
"user"
{
t
.
Fatalf
(
"expected 'user', got %s"
,
req
.
Messages
[
1
]
.
Role
)
}
if
!
bytes
.
Equal
(
req
.
Messages
[
1
]
.
Images
[
0
],
img
)
{
t
.
Fatalf
(
"expected image encoding, got %s"
,
req
.
Messages
[
1
]
.
Images
[
0
])
}
},
},
{
Name
:
"chat handler with tools"
,
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
body
:=
ChatCompletionRequest
{
Model
:
"test-model"
,
Messages
:
[]
Message
{
{
Role
:
"user"
,
Content
:
"What's the weather like in Paris Today?"
},
{
Role
:
"assistant"
,
ToolCalls
:
[]
ToolCall
{{
ID
:
"id"
,
Type
:
"function"
,
Function
:
struct
{
Name
string
`json:"name"`
Arguments
string
`json:"arguments"`
}{
Name
:
"get_current_weather"
,
Arguments
:
"{
\"
location
\"
:
\"
Paris, France
\"
,
\"
format
\"
:
\"
celsius
\"
}"
,
},
}}},
},
}
prepareRequest
(
req
,
body
)
},
Expected
:
func
(
t
*
testing
.
T
,
req
*
api
.
ChatRequest
,
resp
*
httptest
.
ResponseRecorder
)
{
if
resp
.
Code
!=
200
{
t
.
Fatalf
(
"expected 200, got %d"
,
resp
.
Code
)
}
if
req
.
Messages
[
0
]
.
Content
!=
"What's the weather like in Paris Today?"
{
t
.
Fatalf
(
"expected What's the weather like in Paris Today?, got %s"
,
req
.
Messages
[
0
]
.
Content
)
}
}
if
chatR
eq
.
Messages
[
0
]
.
Role
!=
"user
"
{
if
r
eq
.
Messages
[
1
]
.
ToolCalls
[
0
]
.
Function
.
Arguments
[
"location"
]
!=
"Paris, France
"
{
t
.
Fatalf
(
"expected '
user
', got %
s
"
,
chatR
eq
.
Messages
[
0
]
.
Role
)
t
.
Fatalf
(
"expected '
Paris, France
', got %
v
"
,
r
eq
.
Messages
[
1
]
.
ToolCalls
[
0
]
.
Function
.
Arguments
[
"location"
]
)
}
}
if
chatR
eq
.
Messages
[
0
]
.
Content
!=
"
H
el
lo
"
{
if
r
eq
.
Messages
[
1
]
.
ToolCalls
[
0
]
.
Function
.
Arguments
[
"format"
]
!=
"
c
el
sius
"
{
t
.
Fatalf
(
"expected
'Hello'
, got %
s
"
,
chatR
eq
.
Messages
[
0
]
.
Content
)
t
.
Fatalf
(
"expected
celsius
, got %
v
"
,
r
eq
.
Messages
[
1
]
.
ToolCalls
[
0
]
.
Function
.
Arguments
[
"format"
]
)
}
}
},
},
},
},
{
{
Name
:
"completions handler"
,
Name
:
"chat handler error forwarding"
,
Method
:
http
.
MethodPost
,
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
Path
:
"/api/generate"
,
body
:=
ChatCompletionRequest
{
Handler
:
CompletionsMiddleware
,
Model
:
"test-model"
,
Messages
:
[]
Message
{{
Role
:
"user"
,
Content
:
2
}},
}
prepareRequest
(
req
,
body
)
},
Expected
:
func
(
t
*
testing
.
T
,
req
*
api
.
ChatRequest
,
resp
*
httptest
.
ResponseRecorder
)
{
if
resp
.
Code
!=
http
.
StatusBadRequest
{
t
.
Fatalf
(
"expected 400, got %d"
,
resp
.
Code
)
}
if
!
strings
.
Contains
(
resp
.
Body
.
String
(),
"invalid message content type"
)
{
t
.
Fatalf
(
"error was not forwarded"
)
}
},
},
}
endpoint
:=
func
(
c
*
gin
.
Context
)
{
c
.
Status
(
http
.
StatusOK
)
}
gin
.
SetMode
(
gin
.
TestMode
)
router
:=
gin
.
New
()
router
.
Use
(
ChatMiddleware
(),
captureRequestMiddleware
(
&
capturedRequest
))
router
.
Handle
(
http
.
MethodPost
,
"/api/chat"
,
endpoint
)
for
_
,
tc
:=
range
testCases
{
t
.
Run
(
tc
.
Name
,
func
(
t
*
testing
.
T
)
{
req
,
_
:=
http
.
NewRequest
(
http
.
MethodPost
,
"/api/chat"
,
nil
)
tc
.
Setup
(
t
,
req
)
resp
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
resp
,
req
)
tc
.
Expected
(
t
,
capturedRequest
,
resp
)
capturedRequest
=
nil
})
}
}
func
TestCompletionsMiddleware
(
t
*
testing
.
T
)
{
type
testCase
struct
{
Name
string
Setup
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
Expected
func
(
t
*
testing
.
T
,
req
*
api
.
GenerateRequest
,
resp
*
httptest
.
ResponseRecorder
)
}
var
capturedRequest
*
api
.
GenerateRequest
testCases
:=
[]
testCase
{
{
Name
:
"completions handler"
,
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
temp
:=
float32
(
0.8
)
temp
:=
float32
(
0.8
)
body
:=
CompletionRequest
{
body
:=
CompletionRequest
{
...
@@ -87,27 +218,18 @@ func TestMiddlewareRequests(t *testing.T) {
...
@@ -87,27 +218,18 @@ func TestMiddlewareRequests(t *testing.T) {
Stop
:
[]
string
{
"
\n
"
,
"stop"
},
Stop
:
[]
string
{
"
\n
"
,
"stop"
},
Suffix
:
"suffix"
,
Suffix
:
"suffix"
,
}
}
prepareRequest
(
req
,
body
)
bodyBytes
,
_
:=
json
.
Marshal
(
body
)
req
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
bodyBytes
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
},
},
Expected
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
Expected
:
func
(
t
*
testing
.
T
,
req
*
api
.
GenerateRequest
,
resp
*
httptest
.
ResponseRecorder
)
{
var
genReq
api
.
GenerateRequest
if
req
.
Prompt
!=
"Hello"
{
if
err
:=
json
.
NewDecoder
(
req
.
Body
)
.
Decode
(
&
genReq
);
err
!=
nil
{
t
.
Fatalf
(
"expected 'Hello', got %s"
,
req
.
Prompt
)
t
.
Fatal
(
err
)
}
if
genReq
.
Prompt
!=
"Hello"
{
t
.
Fatalf
(
"expected 'Hello', got %s"
,
genReq
.
Prompt
)
}
}
if
genR
eq
.
Options
[
"temperature"
]
!=
1.6
{
if
r
eq
.
Options
[
"temperature"
]
!=
1.6
{
t
.
Fatalf
(
"expected 1.6, got %f"
,
genR
eq
.
Options
[
"temperature"
])
t
.
Fatalf
(
"expected 1.6, got %f"
,
r
eq
.
Options
[
"temperature"
])
}
}
stopTokens
,
ok
:=
genR
eq
.
Options
[
"stop"
]
.
([]
any
)
stopTokens
,
ok
:=
r
eq
.
Options
[
"stop"
]
.
([]
any
)
if
!
ok
{
if
!
ok
{
t
.
Fatalf
(
"expected stop tokens to be a list"
)
t
.
Fatalf
(
"expected stop tokens to be a list"
)
...
@@ -117,113 +239,100 @@ func TestMiddlewareRequests(t *testing.T) {
...
@@ -117,113 +239,100 @@ func TestMiddlewareRequests(t *testing.T) {
t
.
Fatalf
(
"expected ['
\\
n', 'stop'], got %v"
,
stopTokens
)
t
.
Fatalf
(
"expected ['
\\
n', 'stop'], got %v"
,
stopTokens
)
}
}
if
genR
eq
.
Suffix
!=
"suffix"
{
if
r
eq
.
Suffix
!=
"suffix"
{
t
.
Fatalf
(
"expected 'suffix', got %s"
,
genR
eq
.
Suffix
)
t
.
Fatalf
(
"expected 'suffix', got %s"
,
r
eq
.
Suffix
)
}
}
},
},
},
},
{
{
Name
:
"chat handler with image content"
,
Name
:
"completions handler error forwarding"
,
Method
:
http
.
MethodPost
,
Path
:
"/api/chat"
,
Handler
:
ChatMiddleware
,
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
body
:=
ChatCompletionRequest
{
body
:=
CompletionRequest
{
Model
:
"test-model"
,
Model
:
"test-model"
,
Messages
:
[]
Message
{
Prompt
:
"Hello"
,
{
Temperature
:
nil
,
Role
:
"user"
,
Content
:
[]
map
[
string
]
any
{
Stop
:
[]
int
{
1
,
2
},
{
"type"
:
"text"
,
"text"
:
"Hello"
},
Suffix
:
"suffix"
,
{
"type"
:
"image_url"
,
"image_url"
:
map
[
string
]
string
{
"url"
:
imageURL
}},
},
},
},
}
}
prepareRequest
(
req
,
body
)
bodyBytes
,
_
:=
json
.
Marshal
(
body
)
req
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
bodyBytes
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
},
},
Expected
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
Expected
:
func
(
t
*
testing
.
T
,
req
*
api
.
GenerateRequest
,
resp
*
httptest
.
ResponseRecorder
)
{
var
chatReq
api
.
ChatRequest
if
resp
.
Code
!=
http
.
StatusBadRequest
{
if
err
:=
json
.
NewDecoder
(
req
.
Body
)
.
Decode
(
&
chatReq
);
err
!=
nil
{
t
.
Fatalf
(
"expected 400, got %d"
,
resp
.
Code
)
t
.
Fatal
(
err
)
}
}
if
chatReq
.
Messages
[
0
]
.
Role
!=
"user"
{
if
!
strings
.
Contains
(
resp
.
Body
.
String
(),
"invalid type for 'stop' field"
)
{
t
.
Fatalf
(
"e
xpected 'user', got %s"
,
chatReq
.
Messages
[
0
]
.
Role
)
t
.
Fatalf
(
"e
rror was not forwarded"
)
}
}
},
},
}
if
chatReq
.
Messages
[
0
]
.
Content
!=
"Hello"
{
endpoint
:=
func
(
c
*
gin
.
Context
)
{
t
.
Fatalf
(
"expected 'Hello', got %s"
,
chatReq
.
Messages
[
0
]
.
Content
)
c
.
Status
(
http
.
StatusOK
)
}
}
img
,
_
:=
base64
.
StdEncoding
.
DecodeString
(
imageURL
[
len
(
prefix
)
:
])
gin
.
SetMode
(
gin
.
TestMode
)
router
:=
gin
.
New
()
router
.
Use
(
CompletionsMiddleware
(),
captureRequestMiddleware
(
&
capturedRequest
))
router
.
Handle
(
http
.
MethodPost
,
"/api/generate"
,
endpoint
)
if
chatReq
.
Messages
[
1
]
.
Role
!=
"user"
{
for
_
,
tc
:=
range
testCases
{
t
.
Fatalf
(
"expected 'user', got %s"
,
chatReq
.
Messages
[
1
]
.
Role
)
t
.
Run
(
tc
.
Name
,
func
(
t
*
testing
.
T
)
{
}
req
,
_
:=
http
.
NewRequest
(
http
.
MethodPost
,
"/api/generate"
,
nil
)
if
!
bytes
.
Equal
(
chatReq
.
Messages
[
1
]
.
Images
[
0
],
img
)
{
tc
.
Setup
(
t
,
req
)
t
.
Fatalf
(
"expected image encoding, got %s"
,
chatReq
.
Messages
[
1
]
.
Images
[
0
])
}
resp
:=
httptest
.
NewRecorder
()
},
router
.
ServeHTTP
(
resp
,
req
)
},
tc
.
Expected
(
t
,
capturedRequest
,
resp
)
capturedRequest
=
nil
})
}
}
func
TestEmbeddingsMiddleware
(
t
*
testing
.
T
)
{
type
testCase
struct
{
Name
string
Setup
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
Expected
func
(
t
*
testing
.
T
,
req
*
api
.
EmbedRequest
,
resp
*
httptest
.
ResponseRecorder
)
}
var
capturedRequest
*
api
.
EmbedRequest
testCases
:=
[]
testCase
{
{
{
Name
:
"embed handler single input"
,
Name
:
"embed handler single input"
,
Method
:
http
.
MethodPost
,
Path
:
"/api/embed"
,
Handler
:
EmbeddingsMiddleware
,
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
body
:=
EmbedRequest
{
body
:=
EmbedRequest
{
Input
:
"Hello"
,
Input
:
"Hello"
,
Model
:
"test-model"
,
Model
:
"test-model"
,
}
}
prepareRequest
(
req
,
body
)
bodyBytes
,
_
:=
json
.
Marshal
(
body
)
req
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
bodyBytes
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
},
},
Expected
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
Expected
:
func
(
t
*
testing
.
T
,
req
*
api
.
EmbedRequest
,
resp
*
httptest
.
ResponseRecorder
)
{
var
embedReq
api
.
EmbedRequest
if
req
.
Input
!=
"Hello"
{
if
err
:=
json
.
NewDecoder
(
req
.
Body
)
.
Decode
(
&
embedReq
);
err
!=
nil
{
t
.
Fatalf
(
"expected 'Hello', got %s"
,
req
.
Input
)
t
.
Fatal
(
err
)
}
}
if
embedReq
.
Input
!=
"Hello"
{
if
req
.
Model
!=
"test-model"
{
t
.
Fatalf
(
"expected 'Hello', got %s"
,
embedReq
.
Input
)
t
.
Fatalf
(
"expected 'test-model', got %s"
,
req
.
Model
)
}
if
embedReq
.
Model
!=
"test-model"
{
t
.
Fatalf
(
"expected 'test-model', got %s"
,
embedReq
.
Model
)
}
}
},
},
},
},
{
{
Name
:
"embed handler batch input"
,
Name
:
"embed handler batch input"
,
Method
:
http
.
MethodPost
,
Path
:
"/api/embed"
,
Handler
:
EmbeddingsMiddleware
,
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
body
:=
EmbedRequest
{
body
:=
EmbedRequest
{
Input
:
[]
string
{
"Hello"
,
"World"
},
Input
:
[]
string
{
"Hello"
,
"World"
},
Model
:
"test-model"
,
Model
:
"test-model"
,
}
}
prepareRequest
(
req
,
body
)
bodyBytes
,
_
:=
json
.
Marshal
(
body
)
req
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
bodyBytes
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
},
},
Expected
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
Expected
:
func
(
t
*
testing
.
T
,
req
*
api
.
EmbedRequest
,
resp
*
httptest
.
ResponseRecorder
)
{
var
embedReq
api
.
EmbedRequest
input
,
ok
:=
req
.
Input
.
([]
any
)
if
err
:=
json
.
NewDecoder
(
req
.
Body
)
.
Decode
(
&
embedReq
);
err
!=
nil
{
t
.
Fatal
(
err
)
}
input
,
ok
:=
embedReq
.
Input
.
([]
any
)
if
!
ok
{
if
!
ok
{
t
.
Fatalf
(
"expected input to be a list"
)
t
.
Fatalf
(
"expected input to be a list"
)
...
@@ -237,36 +346,52 @@ func TestMiddlewareRequests(t *testing.T) {
...
@@ -237,36 +346,52 @@ func TestMiddlewareRequests(t *testing.T) {
t
.
Fatalf
(
"expected 'World', got %s"
,
input
[
1
])
t
.
Fatalf
(
"expected 'World', got %s"
,
input
[
1
])
}
}
if
embedR
eq
.
Model
!=
"test-model"
{
if
r
eq
.
Model
!=
"test-model"
{
t
.
Fatalf
(
"expected 'test-model', got %s"
,
embedR
eq
.
Model
)
t
.
Fatalf
(
"expected 'test-model', got %s"
,
r
eq
.
Model
)
}
}
},
},
},
},
}
{
Name
:
"embed handler error forwarding"
,
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
body
:=
EmbedRequest
{
Model
:
"test-model"
,
}
prepareRequest
(
req
,
body
)
},
Expected
:
func
(
t
*
testing
.
T
,
req
*
api
.
EmbedRequest
,
resp
*
httptest
.
ResponseRecorder
)
{
if
resp
.
Code
!=
http
.
StatusBadRequest
{
t
.
Fatalf
(
"expected 400, got %d"
,
resp
.
Code
)
}
gin
.
SetMode
(
gin
.
TestMode
)
if
!
strings
.
Contains
(
resp
.
Body
.
String
(),
"invalid input"
)
{
router
:=
gin
.
New
()
t
.
Fatalf
(
"error was not forwarded"
)
}
},
},
}
endpoint
:=
func
(
c
*
gin
.
Context
)
{
endpoint
:=
func
(
c
*
gin
.
Context
)
{
c
.
Status
(
http
.
StatusOK
)
c
.
Status
(
http
.
StatusOK
)
}
}
gin
.
SetMode
(
gin
.
TestMode
)
router
:=
gin
.
New
()
router
.
Use
(
EmbeddingsMiddleware
(),
captureRequestMiddleware
(
&
capturedRequest
))
router
.
Handle
(
http
.
MethodPost
,
"/api/embed"
,
endpoint
)
for
_
,
tc
:=
range
testCases
{
for
_
,
tc
:=
range
testCases
{
t
.
Run
(
tc
.
Name
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
tc
.
Name
,
func
(
t
*
testing
.
T
)
{
router
=
gin
.
New
()
req
,
_
:=
http
.
NewRequest
(
http
.
MethodPost
,
"/api/embed"
,
nil
)
router
.
Use
(
captureRequestMiddleware
())
router
.
Use
(
tc
.
Handler
())
router
.
Handle
(
tc
.
Method
,
tc
.
Path
,
endpoint
)
req
,
_
:=
http
.
NewRequest
(
tc
.
Method
,
tc
.
Path
,
nil
)
if
tc
.
Setup
!=
nil
{
tc
.
Setup
(
t
,
req
)
tc
.
Setup
(
t
,
req
)
}
resp
:=
httptest
.
NewRecorder
()
resp
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
resp
,
req
)
router
.
ServeHTTP
(
resp
,
req
)
tc
.
Expected
(
t
,
capturedRequest
)
tc
.
Expected
(
t
,
capturedRequest
,
resp
)
capturedRequest
=
nil
})
})
}
}
}
}
...
@@ -284,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) {
...
@@ -284,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) {
}
}
testCases
:=
[]
testCase
{
testCases
:=
[]
testCase
{
{
Name
:
"completions handler error forwarding"
,
Method
:
http
.
MethodPost
,
Path
:
"/api/generate"
,
TestPath
:
"/api/generate"
,
Handler
:
CompletionsMiddleware
,
Endpoint
:
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"invalid request"
})
},
Setup
:
func
(
t
*
testing
.
T
,
req
*
http
.
Request
)
{
body
:=
CompletionRequest
{
Model
:
"test-model"
,
Prompt
:
"Hello"
,
}
bodyBytes
,
_
:=
json
.
Marshal
(
body
)
req
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
bodyBytes
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
},
Expected
:
func
(
t
*
testing
.
T
,
resp
*
httptest
.
ResponseRecorder
)
{
if
resp
.
Code
!=
http
.
StatusBadRequest
{
t
.
Fatalf
(
"expected 400, got %d"
,
resp
.
Code
)
}
if
!
strings
.
Contains
(
resp
.
Body
.
String
(),
`"invalid request"`
)
{
t
.
Fatalf
(
"error was not forwarded"
)
}
},
},
{
{
Name
:
"list handler"
,
Name
:
"list handler"
,
Method
:
http
.
MethodGet
,
Method
:
http
.
MethodGet
,
...
@@ -330,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) {
...
@@ -330,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) {
})
})
},
},
Expected
:
func
(
t
*
testing
.
T
,
resp
*
httptest
.
ResponseRecorder
)
{
Expected
:
func
(
t
*
testing
.
T
,
resp
*
httptest
.
ResponseRecorder
)
{
assert
.
Equal
(
t
,
http
.
StatusOK
,
resp
.
Code
)
var
listResp
ListCompletion
var
listResp
ListCompletion
if
err
:=
json
.
NewDecoder
(
resp
.
Body
)
.
Decode
(
&
listResp
);
err
!=
nil
{
if
err
:=
json
.
NewDecoder
(
resp
.
Body
)
.
Decode
(
&
listResp
);
err
!=
nil
{
t
.
Fatal
(
err
)
t
.
Fatal
(
err
)
...
@@ -395,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) {
...
@@ -395,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) {
resp
:=
httptest
.
NewRecorder
()
resp
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
resp
,
req
)
router
.
ServeHTTP
(
resp
,
req
)
assert
.
Equal
(
t
,
http
.
StatusOK
,
resp
.
Code
)
tc
.
Expected
(
t
,
resp
)
tc
.
Expected
(
t
,
resp
)
})
})
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment