Commit 3e8b8a19 authored by Michael Yang's avatar Michael Yang
Browse files

ml: update Context.Forward interface

update Context.Forward to accept multiple tensors to match
Context.Compute signature

update Context.Forward to return Context such that it can be chained
with Context.Compute
parent 41dc2804
......@@ -330,8 +330,10 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
}
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
ctx.Forward(
key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
)
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
......
......@@ -280,9 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
out, _, mask := cache.Get(context)
context.Forward(out)
context.Forward(mask)
context.Compute(out, mask)
context.Forward(out, mask).Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
......
......@@ -80,8 +80,10 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
}
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
ctx.Forward(
key.Copy(ctx, c.keys[c.curLayer]),
value.Copy(ctx, c.values[c.curLayer]),
)
}
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
......
......@@ -65,7 +65,7 @@ type Context interface {
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
FromIntSlice(s []int32, shape ...int) (Tensor, error)
Forward(Tensor)
Forward(...Tensor) Context
Compute(...Tensor)
MaxTensors() int
Close()
......@@ -186,8 +186,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
if t.Bytes() == nil {
ctx.Forward(t)
ctx.Compute(t)
ctx.Forward(t).Compute(t)
}
s := make(S, mul(t.Shape()...))
......
......@@ -256,12 +256,16 @@ type Context struct {
nodes int
}
func (c *Context) Forward(t ml.Tensor) {
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
if c.graph == nil {
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
}
C.ggml_build_forward_expand(c.graph, t.(*Tensor).t)
for _, tensor := range tensors {
C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
}
return c
}
func (c *Context) Compute(tensors ...ml.Tensor) {
......
......@@ -248,8 +248,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
return nil, err
}
ctx.Forward(t)
ctx.Compute(t)
ctx.Forward(t).Compute(t)
return t, nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment