Commit f50d6912 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

ggml: Fix memory leak on input tensors

For every forward pass through the model, we need to allocate input
tensors: tokens, images, positions, outputs and masks. These get
allocated in system memory.

However, when we close the context that the tensors were allocated
through, the metadata gets freed but the actual backend memory does
not. This results in a significant memory leak.

This makes it so that all the memory allocated through a context
gets freed when it is closed.

Fixes #10040
parent 34c3b68f
...@@ -447,6 +447,8 @@ func (b *Backend) NewContextSize(n int) ml.Context { ...@@ -447,6 +447,8 @@ func (b *Backend) NewContextSize(n int) ml.Context {
panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes)) panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
} }
var allocatedBuffers []*C.struct_ggml_backend_buffer
return &Context{ return &Context{
b: b, b: b,
maxGraphNodes: n, maxGraphNodes: n,
...@@ -454,6 +456,7 @@ func (b *Backend) NewContextSize(n int) ml.Context { ...@@ -454,6 +456,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false), mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
no_alloc: true, no_alloc: true,
}), }),
allocatedBuffers: &allocatedBuffers,
} }
} }
...@@ -474,6 +477,10 @@ type Context struct { ...@@ -474,6 +477,10 @@ type Context struct {
// buft is the buffer type used for new tensors // buft is the buffer type used for new tensors
buft *C.struct_ggml_backend_buffer_type buft *C.struct_ggml_backend_buffer_type
// allocatedBuffers are buffers for tensors that we have allocated in this context
// so that we can free them when we close the context
allocatedBuffers *[]*C.struct_ggml_backend_buffer
// maxGraphNodes is the maximum allowed number of graph nodes in this context // maxGraphNodes is the maximum allowed number of graph nodes in this context
maxGraphNodes int maxGraphNodes int
} }
...@@ -484,6 +491,7 @@ func (c *Context) Input() ml.Context { ...@@ -484,6 +491,7 @@ func (c *Context) Input() ml.Context {
b: c.b, b: c.b,
ctx: c.ctx, ctx: c.ctx,
buft: c.b.input, buft: c.b.input,
allocatedBuffers: c.allocatedBuffers,
maxGraphNodes: c.maxGraphNodes, maxGraphNodes: c.maxGraphNodes,
} }
} }
...@@ -497,6 +505,7 @@ func (c *Context) Layer(i int) ml.Context { ...@@ -497,6 +505,7 @@ func (c *Context) Layer(i int) ml.Context {
b: c.b, b: c.b,
ctx: c.ctx, ctx: c.ctx,
buft: buft, buft: buft,
allocatedBuffers: c.allocatedBuffers,
maxGraphNodes: c.maxGraphNodes, maxGraphNodes: c.maxGraphNodes,
} }
} }
...@@ -610,6 +619,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { ...@@ -610,6 +619,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
if b == nil { if b == nil {
return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft))) return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft)))
} }
*c.allocatedBuffers = append(*c.allocatedBuffers, b)
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
return &Tensor{b: c.b, t: t}, nil return &Tensor{b: c.b, t: t}, nil
...@@ -688,6 +698,11 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { ...@@ -688,6 +698,11 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
func (c *Context) Close() { func (c *Context) Close() {
if c != nil { if c != nil {
for _, b := range *c.allocatedBuffers {
C.ggml_backend_buffer_free(b)
}
*c.allocatedBuffers = nil
C.ggml_free(c.ctx) C.ggml_free(c.ctx)
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment