Unverified Commit 37f6f3af authored by lif's avatar lif Committed by GitHub
Browse files

server: return error when embedding contains NaN or Inf values (#13599)



The normalize function now checks for NaN and Inf values in the
embedding vector before processing. This prevents JSON encoding
failures when models produce invalid floating-point values.

Fixes #13572
Signed-off-by: default avatarmajiayu000 <1835304752@qq.com>
parent e1bdc23d
...@@ -752,9 +752,15 @@ func (s *Server) EmbedHandler(c *gin.Context) { ...@@ -752,9 +752,15 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return err return err
} }
// TODO: this first normalization should be done by the model // TODO: this first normalization should be done by the model
embedding = normalize(embedding) embedding, err = normalize(embedding)
if err != nil {
return err
}
if req.Dimensions > 0 && req.Dimensions < len(embedding) { if req.Dimensions > 0 && req.Dimensions < len(embedding) {
embedding = normalize(embedding[:req.Dimensions]) embedding, err = normalize(embedding[:req.Dimensions])
if err != nil {
return err
}
} }
embeddings[i] = embedding embeddings[i] = embedding
atomic.AddUint64(&totalTokens, uint64(tokenCount)) atomic.AddUint64(&totalTokens, uint64(tokenCount))
...@@ -787,9 +793,12 @@ func (s *Server) EmbedHandler(c *gin.Context) { ...@@ -787,9 +793,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
func normalize(vec []float32) []float32 { func normalize(vec []float32) ([]float32, error) {
var sum float32 var sum float32
for _, v := range vec { for _, v := range vec {
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
return nil, errors.New("embedding contains NaN or Inf values")
}
sum += v * v sum += v * v
} }
...@@ -797,7 +806,7 @@ func normalize(vec []float32) []float32 { ...@@ -797,7 +806,7 @@ func normalize(vec []float32) []float32 {
for i := range vec { for i := range vec {
vec[i] *= norm vec[i] *= norm
} }
return vec return vec, nil
} }
func (s *Server) EmbeddingsHandler(c *gin.Context) { func (s *Server) EmbeddingsHandler(c *gin.Context) {
...@@ -2395,4 +2404,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message { ...@@ -2395,4 +2404,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
} }
return msgs return msgs
} }
...@@ -723,15 +723,20 @@ func TestShow(t *testing.T) { ...@@ -723,15 +723,20 @@ func TestShow(t *testing.T) {
func TestNormalize(t *testing.T) { func TestNormalize(t *testing.T) {
type testCase struct { type testCase struct {
input []float32 input []float32
expectError bool
} }
testCases := []testCase{ testCases := []testCase{
{input: []float32{1}}, {input: []float32{1}, expectError: false},
{input: []float32{0, 1, 2, 3}}, {input: []float32{0, 1, 2, 3}, expectError: false},
{input: []float32{0.1, 0.2, 0.3}}, {input: []float32{0.1, 0.2, 0.3}, expectError: false},
{input: []float32{-0.1, 0.2, 0.3, -0.4}}, {input: []float32{-0.1, 0.2, 0.3, -0.4}, expectError: false},
{input: []float32{0, 0, 0}}, {input: []float32{0, 0, 0}, expectError: false},
{input: []float32{float32(math.NaN()), 0.2, 0.3}, expectError: true},
{input: []float32{0.1, float32(math.NaN()), 0.3}, expectError: true},
{input: []float32{float32(math.Inf(1)), 0.2, 0.3}, expectError: true},
{input: []float32{float32(math.Inf(-1)), 0.2, 0.3}, expectError: true},
} }
isNormalized := func(vec []float32) (res bool) { isNormalized := func(vec []float32) (res bool) {
...@@ -748,9 +753,18 @@ func TestNormalize(t *testing.T) { ...@@ -748,9 +753,18 @@ func TestNormalize(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
normalized := normalize(tc.input) normalized, err := normalize(tc.input)
if !isNormalized(normalized) { if tc.expectError {
t.Errorf("Vector %v is not normalized", tc.input) if err == nil {
t.Errorf("Expected error for input %v, but got none", tc.input)
}
} else {
if err != nil {
t.Errorf("Unexpected error for input %v: %v", tc.input, err)
}
if !isNormalized(normalized) {
t.Errorf("Vector %v is not normalized", tc.input)
}
} }
}) })
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment