Commit e119783e authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

llm: Clamp batch size to context size

The context must always be able to store the current batch, so
if the user requests a small context then we should also shrink
the batch to match. This also fixes the TestLongInputContext
test on the new engine. (The old engine already has this behavior.)
parent 1a558f98
...@@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) { ...@@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err) t.Fatalf("PullIfMissing failed: %v", err)
} }
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second) DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
} }
func TestContextExhaustion(t *testing.T) { func TestContextExhaustion(t *testing.T) {
......
...@@ -173,6 +173,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a ...@@ -173,6 +173,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
opts.NumCtx = int(trainCtx) opts.NumCtx = int(trainCtx)
} }
opts.NumBatch = min(opts.NumBatch, opts.NumCtx)
loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()} loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()}
defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount() defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount()
......
...@@ -34,8 +34,8 @@ type InputCache struct { ...@@ -34,8 +34,8 @@ type InputCache struct {
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) { func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
numCtx := kvSize / int32(numSlots) numCtx := kvSize / int32(numSlots)
if numCtx < 1 { if int(numCtx) < batchSize {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots) return nil, fmt.Errorf("kv size must be at least as large as batch size * parallel (kv: %v batch: %v parallel: %v)", kvSize, batchSize, numSlots)
} }
slots := make([]InputCacheSlot, numSlots) slots := make([]InputCacheSlot, numSlots)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment