Commit e49dc9f3 authored by Michael Yang's avatar Michael Yang
Browse files

fix tests

parent d125510b
...@@ -238,18 +238,37 @@ func chatHistoryEqual(a, b ChatHistory) bool { ...@@ -238,18 +238,37 @@ func chatHistoryEqual(a, b ChatHistory) bool {
if len(a.Prompts) != len(b.Prompts) { if len(a.Prompts) != len(b.Prompts) {
return false return false
} }
if len(a.CurrentImages) != len(b.CurrentImages) {
return false
}
for i, v := range a.Prompts { for i, v := range a.Prompts {
if v != b.Prompts[i] {
if v.First != b.Prompts[i].First {
return false return false
} }
}
for i, v := range a.CurrentImages { if v.Response != b.Prompts[i].Response {
if !bytes.Equal(v, b.CurrentImages[i]) {
return false return false
} }
if v.Prompt != b.Prompts[i].Prompt {
return false
}
if v.System != b.Prompts[i].System {
return false
}
if len(v.Images) != len(b.Prompts[i].Images) {
return false
}
for j, img := range v.Images {
if img.ID != b.Prompts[i].Images[j].ID {
return false
}
if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
return false
}
}
} }
return a.LastSystem == b.LastSystem return a.LastSystem == b.LastSystem
} }
......
...@@ -455,7 +455,8 @@ func Test_ChatPrompt(t *testing.T) { ...@@ -455,7 +455,8 @@ func Test_ChatPrompt(t *testing.T) {
NumCtx: tt.numCtx, NumCtx: tt.numCtx,
}, },
} }
got, err := trimmedPrompt(context.Background(), tt.chat, m) // TODO: add tests for trimming images
got, _, err := trimmedPrompt(context.Background(), tt.chat, m)
if tt.wantErr != "" { if tt.wantErr != "" {
if err == nil { if err == nil {
t.Errorf("ChatPrompt() expected error, got nil") t.Errorf("ChatPrompt() expected error, got nil")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment