Commit 60830695 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

ggml-backend: Ensure data is available after async computation

We need to sync before retrieving data after async computation.
It is also important to ensure that the Go buffer is not moved by
the GC across function calls so we do a synchronous copy.
parent 01d9a468
...@@ -9,8 +9,6 @@ package ggml ...@@ -9,8 +9,6 @@ package ggml
import "C" import "C"
import ( import (
"bytes"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
...@@ -245,12 +243,17 @@ func (c *Context) Forward(t ml.Tensor) { ...@@ -245,12 +243,17 @@ func (c *Context) Forward(t ml.Tensor) {
func (c *Context) Compute(tensors ...ml.Tensor) { func (c *Context) Compute(tensors ...ml.Tensor) {
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph) C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
for _, t := range tensors { needSync := true
if C.ggml_nbytes(t.(*Tensor).t) != 0 { sync := func() {
backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t) if needSync {
C.ggml_backend_sched_synchronize(c.sched)
needSync = false
}
}
t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t)) for _, t := range tensors {
C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) if C.ggml_nbytes(t.(*Tensor).t) > 0 {
t.(*Tensor).sync = sync
} }
} }
} }
...@@ -330,7 +333,7 @@ func (c *Context) Close() { ...@@ -330,7 +333,7 @@ func (c *Context) Close() {
type Tensor struct { type Tensor struct {
t *C.struct_ggml_tensor t *C.struct_ggml_tensor
data []byte sync func()
} }
func (t *Tensor) LogValue() slog.Value { func (t *Tensor) LogValue() slog.Value {
...@@ -358,14 +361,23 @@ func (t *Tensor) Shape() []int { ...@@ -358,14 +361,23 @@ func (t *Tensor) Shape() []int {
return shape return shape
} }
func (t *Tensor) Bytes() []byte { func (t *Tensor) Bytes() (data []byte) {
return t.data if t.sync != nil {
data = make([]byte, C.ggml_nbytes(t.t))
t.sync()
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
}
return
} }
func (t *Tensor) Floats() (f32s []float32) { func (t *Tensor) Floats() (data []float32) {
if t.data != nil { if t.sync != nil {
f32s = make([]float32, C.ggml_nelements(t.t)) data = make([]float32, C.ggml_nelements(t.t))
_ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s)
t.sync()
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
} }
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment