Commit f33ccd5d authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

ggml: Use pointer receivers for Context

Context is currently mixed between pointer and value receivers. Change
this to be all pointer receivers so don't have to reason about whether
the things we are updating in the struct will be retained.
parent bc108b9a
......@@ -484,7 +484,7 @@ type Context struct {
maxGraphNodes int
}
func (c Context) Input() ml.Context {
func (c *Context) Input() ml.Context {
if c.b.input != nil {
return &Context{
b: c.b,
......@@ -494,10 +494,10 @@ func (c Context) Input() ml.Context {
}
}
return &c
return c
}
func (c Context) Layer(i int) ml.Context {
func (c *Context) Layer(i int) ml.Context {
if buft, ok := c.b.layers[i]; ok {
return &Context{
b: c.b,
......@@ -507,7 +507,7 @@ func (c Context) Layer(i int) ml.Context {
}
}
return &c
return c
}
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
......@@ -522,7 +522,7 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
return c
}
func (c Context) Compute(tensors ...ml.Tensor) {
func (c *Context) Compute(tensors ...ml.Tensor) {
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
C.ggml_backend_sched_reset(c.b.sched)
......@@ -541,7 +541,7 @@ func (c Context) Compute(tensors ...ml.Tensor) {
}
}
func (c Context) Reserve() error {
func (c *Context) Reserve() error {
if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) {
C.ggml_backend_sched_reset(c.b.sched)
return errors.New("failed to reserve graph")
......@@ -559,7 +559,7 @@ func (c Context) Reserve() error {
return nil
}
func (c Context) MaxGraphNodes() int {
func (c *Context) MaxGraphNodes() int {
return c.maxGraphNodes
}
......@@ -576,7 +576,7 @@ func pad(length, pad C.size_t) C.size_t {
return ((length + pad - 1) / pad) * pad
}
func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
if c.buft == nil {
panic("set Input or Layer before creating tensors")
}
......@@ -621,7 +621,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
return &Tensor{b: c.b, t: t}, nil
}
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
t, err := c.newTensor(dtype, shape)
if err != nil {
panic(err)
......@@ -630,7 +630,7 @@ func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
return t
}
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
t, err := c.newTensor(dtype, shape)
if err != nil {
panic(err)
......@@ -658,7 +658,7 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error {
return nil
}
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
if err := checkShape(s, shape...); err != nil {
return nil, err
}
......@@ -675,7 +675,7 @@ func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
return t, nil
}
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
if err := checkShape(s, shape...); err != nil {
return nil, err
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment