Commit ceb0e26e authored by Bryce Reitano's avatar Bryce Reitano
Browse files

Provide variable ggml for TestLoad

parent 284e02be
...@@ -47,6 +47,7 @@ func TestLoad(t *testing.T) { ...@@ -47,6 +47,7 @@ func TestLoad(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond) ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done() defer done()
s := InitScheduler(ctx) s := InitScheduler(ctx)
ggml := nil // value not used in tests
req := &LlmRequest{ req := &LlmRequest{
ctx: ctx, ctx: ctx,
model: &Model{ModelPath: "foo"}, model: &Model{ModelPath: "foo"},
...@@ -59,7 +60,7 @@ func TestLoad(t *testing.T) { ...@@ -59,7 +60,7 @@ func TestLoad(t *testing.T) {
return nil, fmt.Errorf("something failed to load model blah") return nil, fmt.Errorf("something failed to load model blah")
} }
gpus := gpu.GpuInfoList{} gpus := gpu.GpuInfoList{}
s.load(req, nil, gpus) s.load(req, ggml, gpus)
require.Len(t, req.successCh, 0) require.Len(t, req.successCh, 0)
require.Len(t, req.errCh, 1) require.Len(t, req.errCh, 1)
require.Len(t, s.loaded, 0) require.Len(t, s.loaded, 0)
...@@ -70,7 +71,7 @@ func TestLoad(t *testing.T) { ...@@ -70,7 +71,7 @@ func TestLoad(t *testing.T) {
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return server, nil return server, nil
} }
s.load(req, nil, gpus) s.load(req, ggml, gpus)
select { select {
case err := <-req.errCh: case err := <-req.errCh:
require.NoError(t, err) require.NoError(t, err)
...@@ -82,7 +83,7 @@ func TestLoad(t *testing.T) { ...@@ -82,7 +83,7 @@ func TestLoad(t *testing.T) {
req.model.ModelPath = "dummy_model_path" req.model.ModelPath = "dummy_model_path"
server.waitResp = fmt.Errorf("wait failure") server.waitResp = fmt.Errorf("wait failure")
s.load(req, nil, gpus) s.load(req, ggml, gpus)
select { select {
case err := <-req.errCh: case err := <-req.errCh:
require.Contains(t, err.Error(), "wait failure") require.Contains(t, err.Error(), "wait failure")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment