llm_test.go 3.58 KB
Newer Older
1
2
3
4
package server

import (
	"context"
5
	"os"
6
7
8
9
10
11
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
12
	"github.com/stretchr/testify/require"
13
14

	"github.com/jmorganca/ollama/api"
15
	"github.com/jmorganca/ollama/llm"
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
)

// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
//        package to avoid circular dependencies

// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
//
// TODO - Fix this ^^

var (
	req = [2]api.GenerateRequest{
		{
			Model:   "orca-mini",
			Prompt:  "tell me a short story about agi?",
			Options: map[string]interface{}{},
		}, {
			Model:   "orca-mini",
			Prompt:  "what is the origin of the us thanksgiving holiday?",
			Options: map[string]interface{}{},
		},
	}
	resp = [2]string{
		"once upon a time",
39
		"united states thanksgiving",
40
41
42
43
44
	}
)

func TestIntegrationSimpleOrcaMini(t *testing.T) {
	SkipIFNoTestData(t)
45
46
47
48
	workDir, err := os.MkdirTemp("", "ollama")
	require.NoError(t, err)
	defer os.RemoveAll(workDir)
	require.NoError(t, llm.Init(workDir))
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
	defer cancel()
	opts := api.DefaultOptions()
	opts.Seed = 42
	opts.Temperature = 0.0
	model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
	defer llmRunner.Close()
	response := OneShotPromptResponse(t, ctx, req[0], model, llmRunner)
	assert.Contains(t, strings.ToLower(response), resp[0])
}

// TODO
// The server always loads a new runner and closes the old one, which forces serial execution
// At present this test case fails with concurrency problems.  Eventually we should try to
// get true concurrency working with n_parallel support in the backend
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
	SkipIFNoTestData(t)
66

67
	t.Skip("concurrent prediction on single runner not currently supported")
68
69
70
71
72

	workDir, err := os.MkdirTemp("", "ollama")
	require.NoError(t, err)
	defer os.RemoveAll(workDir)
	require.NoError(t, llm.Init(workDir))
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
	defer cancel()
	opts := api.DefaultOptions()
	opts.Seed = 42
	opts.Temperature = 0.0
	var wg sync.WaitGroup
	wg.Add(len(req))
	model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
	defer llmRunner.Close()
	for i := 0; i < len(req); i++ {
		go func(i int) {
			defer wg.Done()
			response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
			t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
			assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
		}(i)
	}
	wg.Wait()
}

func TestIntegrationConcurrentRunnersOrcaMini(t *testing.T) {
	SkipIFNoTestData(t)
95
96
97
98
	workDir, err := os.MkdirTemp("", "ollama")
	require.NoError(t, err)
	defer os.RemoveAll(workDir)
	require.NoError(t, llm.Init(workDir))
99
100
101
102
103
104
105
106
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
	defer cancel()
	opts := api.DefaultOptions()
	opts.Seed = 42
	opts.Temperature = 0.0
	var wg sync.WaitGroup
	wg.Add(len(req))

107
	t.Logf("Running %d concurrently", len(req))
108
109
110
111
112
113
114
115
116
117
118
119
120
121
	for i := 0; i < len(req); i++ {
		go func(i int) {
			defer wg.Done()
			model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
			defer llmRunner.Close()
			response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
			t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
			assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
		}(i)
	}
	wg.Wait()
}

// TODO - create a parallel test with 2 different models once we support concurrency