Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
ac33aa7d
Unverified
Commit
ac33aa7d
authored
Jul 24, 2024
by
royjhan
Committed by
GitHub
Jul 24, 2024
Browse files
Fix Embed Test Flakes (#5893)
* float cmp * increase tolerance
parent
a6cd8f61
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
54 additions
and
5 deletions
+54
-5
integration/embed_test.go
integration/embed_test.go
+54
-5
No files found.
integration/embed_test.go
View file @
ac33aa7d
...
@@ -4,12 +4,45 @@ package integration
...
@@ -4,12 +4,45 @@ package integration
import
(
import
(
"context"
"context"
"math"
"testing"
"testing"
"time"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/api"
)
)
func
floatsEqual32
(
a
,
b
float32
)
bool
{
return
math
.
Abs
(
float64
(
a
-
b
))
<=
1e-4
}
func
floatsEqual64
(
a
,
b
float64
)
bool
{
return
math
.
Abs
(
a
-
b
)
<=
1e-4
}
func
TestAllMiniLMEmbeddings
(
t
*
testing
.
T
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Minute
)
defer
cancel
()
req
:=
api
.
EmbeddingRequest
{
Model
:
"all-minilm"
,
Prompt
:
"why is the sky blue?"
,
}
res
,
err
:=
embeddingTestHelper
(
ctx
,
t
,
req
)
if
err
!=
nil
{
t
.
Fatalf
(
"error: %v"
,
err
)
}
if
len
(
res
.
Embedding
)
!=
384
{
t
.
Fatalf
(
"expected 384 floats, got %d"
,
len
(
res
.
Embedding
))
}
if
!
floatsEqual64
(
res
.
Embedding
[
0
],
0.06642947345972061
)
{
t
.
Fatalf
(
"expected 0.06642947345972061, got %.16f"
,
res
.
Embedding
[
0
])
}
}
func
TestAllMiniLMEmbed
(
t
*
testing
.
T
)
{
func
TestAllMiniLMEmbed
(
t
*
testing
.
T
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Minute
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Minute
)
defer
cancel
()
defer
cancel
()
...
@@ -33,8 +66,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
...
@@ -33,8 +66,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
t
.
Fatalf
(
"expected 384 floats, got %d"
,
len
(
res
.
Embeddings
[
0
]))
t
.
Fatalf
(
"expected 384 floats, got %d"
,
len
(
res
.
Embeddings
[
0
]))
}
}
if
res
.
Embeddings
[
0
][
0
]
!=
0.010071031
{
if
!
floatsEqual32
(
res
.
Embeddings
[
0
][
0
]
,
0.010071031
)
{
t
.
Fatalf
(
"expected 0.010071031, got %f"
,
res
.
Embeddings
[
0
][
0
])
t
.
Fatalf
(
"expected 0.010071031, got %
.8
f"
,
res
.
Embeddings
[
0
][
0
])
}
}
}
}
...
@@ -61,12 +94,12 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
...
@@ -61,12 +94,12 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
t
.
Fatalf
(
"expected 384 floats, got %d"
,
len
(
res
.
Embeddings
[
0
]))
t
.
Fatalf
(
"expected 384 floats, got %d"
,
len
(
res
.
Embeddings
[
0
]))
}
}
if
res
.
Embeddings
[
0
][
0
]
!=
0.010071031
||
res
.
Embeddings
[
1
][
0
]
!=
-
0.009802706
{
if
!
floatsEqual32
(
res
.
Embeddings
[
0
][
0
]
,
0.010071031
)
||
!
floatsEqual32
(
res
.
Embeddings
[
1
][
0
]
,
-
0.009802706
)
{
t
.
Fatalf
(
"expected 0.010071031 and -0.009802706, got %f and %f"
,
res
.
Embeddings
[
0
][
0
],
res
.
Embeddings
[
1
][
0
])
t
.
Fatalf
(
"expected 0.010071031 and -0.009802706, got %
.8
f and %
.8
f"
,
res
.
Embeddings
[
0
][
0
],
res
.
Embeddings
[
1
][
0
])
}
}
}
}
func
TestAllMiniL
m
EmbedTruncate
(
t
*
testing
.
T
)
{
func
TestAllMiniL
M
EmbedTruncate
(
t
*
testing
.
T
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Minute
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Minute
)
defer
cancel
()
defer
cancel
()
...
@@ -135,6 +168,22 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
...
@@ -135,6 +168,22 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
}
}
}
}
func
embeddingTestHelper
(
ctx
context
.
Context
,
t
*
testing
.
T
,
req
api
.
EmbeddingRequest
)
(
*
api
.
EmbeddingResponse
,
error
)
{
client
,
_
,
cleanup
:=
InitServerConnection
(
ctx
,
t
)
defer
cleanup
()
if
err
:=
PullIfMissing
(
ctx
,
client
,
req
.
Model
);
err
!=
nil
{
t
.
Fatalf
(
"failed to pull model %s: %v"
,
req
.
Model
,
err
)
}
response
,
err
:=
client
.
Embeddings
(
ctx
,
&
req
)
if
err
!=
nil
{
return
nil
,
err
}
return
response
,
nil
}
func
embedTestHelper
(
ctx
context
.
Context
,
t
*
testing
.
T
,
req
api
.
EmbedRequest
)
(
*
api
.
EmbedResponse
,
error
)
{
func
embedTestHelper
(
ctx
context
.
Context
,
t
*
testing
.
T
,
req
api
.
EmbedRequest
)
(
*
api
.
EmbedResponse
,
error
)
{
client
,
_
,
cleanup
:=
InitServerConnection
(
ctx
,
t
)
client
,
_
,
cleanup
:=
InitServerConnection
(
ctx
,
t
)
defer
cleanup
()
defer
cleanup
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment