"src/turbomind/models/llama/LlamaContextDecoder.cc" did not exist on "4d42a781254e85176bd91f943a28b2d0360e7768"
Commit 7b6cbc10 authored by Daniel Hiltgen's avatar Daniel Hiltgen
Browse files

Integration tests conditionally pull

If images aren't present, pull them.
Also fixes the expected responses
parent acfa2b94
......@@ -12,7 +12,7 @@ import (
)
func TestOrcaMiniBlueSky(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
......
......@@ -30,7 +30,7 @@ func TestIntegrationMultimodal(t *testing.T) {
}
resp := "the ollamas"
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
}
......
......@@ -40,16 +40,16 @@ var (
},
},
}
resp = [2]string{
"scattering",
"united states thanksgiving",
resp = [2][]string{
[]string{"sunlight"},
[]string{"england", "english", "massachusetts", "pilgrims"},
}
)
func TestIntegrationSimpleOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req[0], []string{resp[0]})
GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
}
// TODO
......@@ -59,12 +59,12 @@ func TestIntegrationSimpleOrcaMini(t *testing.T) {
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, &http.Client{}, req[i], []string{resp[i]})
GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
}(i)
}
wg.Wait()
......
......@@ -125,6 +125,55 @@ func StartServer(ctx context.Context, ollamaHost string) error {
return nil
}
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
slog.Debug("checking status of model", "model", modelName)
showReq := &api.ShowRequest{Name: modelName}
requestJSON, err := json.Marshal(showReq)
if err != nil {
return err
}
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
if err != nil {
return err
}
// Make the request with the HTTP client
response, err := client.Do(req.WithContext(ctx))
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode == 200 {
slog.Info("model already present", "model", modelName)
return nil
}
slog.Info("model missing", "status", response.StatusCode)
pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
requestJSON, err = json.Marshal(pullReq)
if err != nil {
return err
}
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
if err != nil {
return err
}
slog.Info("pulling", "model", modelName)
response, err = client.Do(req.WithContext(ctx))
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != 200 {
return fmt.Errorf("failed to pull model") // TODO more details perhaps
}
slog.Info("model pulled", "model", modelName)
return nil
}
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
requestJSON, err := json.Marshal(genReq)
if err != nil {
......@@ -158,6 +207,11 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
assert.NoError(t, StartServer(ctx, testEndpoint))
}
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
if err != nil {
t.Fatalf("Error pulling model: %v", err)
}
// Make the request and get the response
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
if err != nil {
......@@ -172,6 +226,7 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
if err != nil {
t.Fatalf("Error making request: %v", err)
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
assert.NoError(t, err)
assert.Equal(t, response.StatusCode, 200, string(body))
......@@ -184,7 +239,12 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
}
// Verify the response contains the expected data
atLeastOne := false
for _, resp := range anyResp {
assert.Contains(t, strings.ToLower(payload.Response), resp)
if strings.Contains(strings.ToLower(payload.Response), resp) {
atLeastOne = true
break
}
}
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment