Commit 756c78cf authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

ggml: Support closing backends

In order to iteratively find the best memory allocation, we need to
be able to free backend memory so we can try again.
parent d7f4f788
......@@ -15,6 +15,9 @@ import (
)
type Backend interface {
// Close frees all memory associated with this backend
Close()
Load(ctx context.Context, progress func(float32)) error
// BackendMemory returns the memory allocations that were made for this model
......
......@@ -19,6 +19,7 @@ import (
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"unicode"
"unsafe"
......@@ -33,15 +34,33 @@ import (
"golang.org/x/sync/errgroup"
)
func devices() []C.ggml_backend_dev_t {
var (
cpus, accels, gpus []C.ggml_backend_dev_t
backends map[C.ggml_backend_dev_t]C.ggml_backend_t
)
var initDevices = sync.OnceFunc(func() {
ggml.OnceLoad()
ds := make([]C.ggml_backend_dev_t, C.ggml_backend_dev_count())
for i := range ds {
ds[i] = C.ggml_backend_dev_get(C.size_t(i))
}
return ds
}
backends = make(map[C.ggml_backend_dev_t]C.ggml_backend_t)
for i := range C.ggml_backend_dev_count() {
d := C.ggml_backend_dev_get(i)
switch C.ggml_backend_dev_type(d) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
if len(cpus) == 0 {
// only the first cpu device should be used
cpus = append(cpus, d)
}
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
accels = append(accels, d)
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
gpus = append(gpus, d)
}
backends[d] = C.ggml_backend_dev_init(d, nil)
}
})
type Backend struct {
// modelPath is the location of the model data
......@@ -75,6 +94,9 @@ type Backend struct {
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
maxGraphNodes int
// weightBuffers are the GGML contexts and buffers for allocating weights
weightBuffers map[*C.struct_ggml_context]C.ggml_backend_buffer_t
}
func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
......@@ -99,6 +121,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
"num_key_values", len(meta.KV()),
)
initDevices()
var requiredMemory ml.BackendMemory
btDeviceMemory := make(map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory)
......@@ -107,21 +131,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
bts []C.ggml_backend_buffer_type_t
}
var cpus, accels, gpus []C.ggml_backend_dev_t
for _, d := range devices() {
switch C.ggml_backend_dev_type(d) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
if len(cpus) == 0 {
// only the first cpu device should be used
cpus = append(cpus, d)
}
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
accels = append(accels, d)
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
gpus = append(gpus, d)
}
}
blocks := int(meta.KV().BlockCount())
// create list of buffer types for the cpu
......@@ -348,6 +357,14 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
}
if b == nil {
for _, b := range bbs {
C.ggml_backend_buffer_free(b)
}
for _, ctx := range ctxs {
C.ggml_free(ctx)
}
panic(ml.ErrNoMem{BackendMemory: requiredMemory})
}
......@@ -394,7 +411,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
var schedBackends []C.ggml_backend_t
var schedBufts []C.ggml_backend_buffer_type_t
for _, d := range append(gpus, append(accels, cpus...)...) {
b := C.ggml_backend_dev_init(d, nil)
b := backends[d]
bt := C.ggml_backend_get_default_buffer_type(b)
deviceBufferTypes[d] = bt
......@@ -436,6 +453,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
requiredMemory: &requiredMemory,
btDeviceMemory: btDeviceMemory,
maxGraphNodes: maxGraphNodes,
weightBuffers: bbs,
}, nil
}
......@@ -443,6 +461,19 @@ func init() {
ml.RegisterBackend("ggml", New)
}
func (b *Backend) Close() {
if b == nil {
return
}
for ctx, b := range b.weightBuffers {
C.ggml_backend_buffer_free(b)
C.ggml_free(ctx)
}
C.ggml_backend_sched_free(b.sched)
}
func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
var doneBytes atomic.Uint64
totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset
......
......@@ -70,6 +70,10 @@ func kvCacheTypeFromStr(s string) ml.DType {
}
func (c *InputCache) Close() {
if c == nil {
return
}
c.cache.Close()
}
......
......@@ -877,6 +877,15 @@ func (s *Server) load(
) {
err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache)
if err != nil {
var noMem ml.ErrNoMem
if errors.As(err, &noMem) {
// We can't yet handle this but in the future we will
s.cache.Close()
if s.model != nil {
s.model.Backend().Close()
}
}
panic(err)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment