Commit 8306248c authored by Jesse Gross's avatar Jesse Gross Committed by Michael Yang
Browse files

ggml: Multiply by numParallel for gptoss sliding window

When computing the graph size estimate, the context size is already
multiplied by numParallel so estimates reflect that. However, since
sliding window models use a smaller, fixed context size, they need
to manually take numParallel into account.
parent f5fd7cc1
......@@ -670,7 +670,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
for i := range kv {
kv[i] = uint64(float64((embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
if i%2 == 0 {
kv[i] *= (4096 + batch)
kv[i] *= (uint64(numParallel)*4096 + batch)
} else {
kv[i] *= context
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment