Commit 1f371ea9 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

ml: Panic rather than return error on tensor allocation failure

FromFloatSlice and FromIntSlice return an error if the shape doesn't
match the passed data or if memory can't be allocated. Since these
are inputs, the memory being allocated is system memory rather than VRAM.

In many cases, the caller can't really handle the error and panics.

Empty and Zeros directly panic if they can't allocate memory.

This makes things consistent by panicing for the first two cases,
removing a fair amount of error handling code. This is also consistent
with how Go typically handles these situations.
parent 73d6a82c
...@@ -211,10 +211,9 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e ...@@ -211,10 +211,9 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curCellRange.max = len(c.cells) - 1 c.curCellRange.max = len(c.cells) - 1
} }
var err error c.curMask = c.buildMask(ctx)
c.curMask, err = c.buildMask(ctx)
return err return nil
} }
func newRange() cellRange { func newRange() cellRange {
...@@ -297,7 +296,7 @@ func roundUp(length, pad int) int { ...@@ -297,7 +296,7 @@ func roundUp(length, pad int) int {
// Builds a mask of history x batch indicating whether for each token in the batch the // Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the // token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch). // position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
// Align and pad the two dimensions as required by the backend // Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
...@@ -325,10 +324,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { ...@@ -325,10 +324,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
mask[i] = float32(math.Inf(-1)) mask[i] = float32(math.Inf(-1))
} }
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize) maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 { if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...) out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
...@@ -336,7 +332,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { ...@@ -336,7 +332,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
maskTensor = out maskTensor = out
} }
return maskTensor, nil return maskTensor
} }
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) { func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
...@@ -491,12 +487,7 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) { ...@@ -491,12 +487,7 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
if !slices.Equal(c.opts.Except, opts.Except) { if !slices.Equal(c.opts.Except, opts.Except) {
c.opts = opts c.opts = opts
if ctx != nil { if ctx != nil {
var err error c.curMask = c.buildMask(ctx)
c.curMask, err = c.buildMask(ctx)
if err != nil {
// This error should never occur because we have previously built a mask with the same shape
panic(fmt.Errorf("SetCausal: %w", err))
}
} }
} }
} }
...@@ -652,10 +643,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { ...@@ -652,10 +643,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
} }
} }
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets)) kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
if err != nil {
return err
}
for i, key := range c.keys { for i, key := range c.keys {
if key == nil { if key == nil {
......
...@@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) ...@@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor, _ := context.FromFloatSlice(test.in, test.inShape...) tensor := context.FromFloatSlice(test.in, test.inShape...)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
out, _, mask := cache.Get(context) out, _, mask := cache.Get(context)
...@@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) { ...@@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) {
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet // with window size 4, nothing has slid out of the window yet
...@@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) { ...@@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) {
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows // only the latest position has overlapping windows
...@@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { ...@@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...) return c.Empty(dtype, shape...)
} }
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
t := c.Empty(ml.DTypeF32, shape...).(*testTensor) t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s) copy(t.data, s)
return t, nil return t
} }
func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor {
f := make([]float32, len(s)) f := make([]float32, len(s))
for i := range f { for i := range f {
f[i] = float32(s[i]) f[i] = float32(s[i])
} }
out, _ := c.FromFloatSlice(f, shape...) out := c.FromFloatSlice(f, shape...)
out.(*testTensor).dtype = ml.DTypeI32 out.(*testTensor).dtype = ml.DTypeI32
return out, nil return out
} }
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
...@@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso ...@@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso
s = append(s, i) s = append(s, i)
} }
out, _ := c.FromFloatSlice(s, len(s)) out := c.FromFloatSlice(s, len(s))
out.(*testTensor).dtype = dtype out.(*testTensor).dtype = dtype
return out return out
} }
......
...@@ -171,8 +171,8 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) { ...@@ -171,8 +171,8 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) {
type Context interface { type Context interface {
Empty(dtype DType, shape ...int) Tensor Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor Zeros(dtype DType, shape ...int) Tensor
FromFloatSlice(s []float32, shape ...int) (Tensor, error) FromFloatSlice(s []float32, shape ...int) Tensor
FromIntSlice(s []int32, shape ...int) (Tensor, error) FromIntSlice(s []int32, shape ...int) Tensor
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step. // Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
Arange(start, stop, step float32, dtype DType) Tensor Arange(start, stop, step float32, dtype DType) Tensor
......
...@@ -729,11 +729,11 @@ func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { ...@@ -729,11 +729,11 @@ func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return t return t
} }
func checkShape[S ~[]E, E any](s S, shape ...int) error { func checkShape[S ~[]E, E any](s S, shape ...int) {
n := len(s) n := len(s)
if n == 0 { if n == 0 {
return nil return
} }
for _, v := range shape { for _, v := range shape {
...@@ -741,16 +741,12 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error { ...@@ -741,16 +741,12 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error {
} }
if n != 1 { if n != 1 {
return fmt.Errorf("invalid shape: %v", shape) panic(fmt.Errorf("invalid shape: %v", shape))
} }
return nil
} }
func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { func (c *Context) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
if err := checkShape(s, shape...); err != nil { checkShape(s, shape...)
return nil, err
}
t := c.newTensor(ml.DTypeF32, shape) t := c.newTensor(ml.DTypeF32, shape)
...@@ -758,13 +754,11 @@ func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { ...@@ -758,13 +754,11 @@ func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
} }
return t, nil return t
} }
func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { func (c *Context) FromIntSlice(s []int32, shape ...int) ml.Tensor {
if err := checkShape(s, shape...); err != nil { checkShape(s, shape...)
return nil, err
}
t := c.newTensor(ml.DTypeI32, shape) t := c.newTensor(ml.DTypeI32, shape)
...@@ -772,7 +766,7 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { ...@@ -772,7 +766,7 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
} }
return t, nil return t
} }
func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
...@@ -790,12 +784,7 @@ func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { ...@@ -790,12 +784,7 @@ func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
arange = append(arange, int32(i)) arange = append(arange, int32(i))
} }
t, err := c.Input().FromIntSlice(arange, len(arange)) return c.Input().FromIntSlice(arange, len(arange))
if err != nil {
panic(err)
}
return t
default: default:
panic("unsupported dtype for arange") panic("unsupported dtype for arange")
} }
......
...@@ -287,11 +287,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten ...@@ -287,11 +287,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
return nil, errors.New("batch size cannot be less than 1") return nil, errors.New("batch size cannot be less than 1")
} }
var err error batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return nil, err
}
cache := m.Config().Cache cache := m.Config().Cache
if cache != nil { if cache != nil {
......
...@@ -175,15 +175,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten ...@@ -175,15 +175,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
......
...@@ -101,14 +101,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input ...@@ -101,14 +101,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err return nil, err
} }
pixelValues, err := ctx.Input().FromFloatSlice(f32s, pixelValues := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize, m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize, m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels, m.ImageProcessor.numChannels,
) )
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps) visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
...@@ -144,15 +141,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { ...@@ -144,15 +141,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
} }
......
...@@ -142,10 +142,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso ...@@ -142,10 +142,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
...@@ -154,10 +151,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { ...@@ -154,10 +151,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
} }
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
......
...@@ -77,10 +77,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input ...@@ -77,10 +77,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err return nil, err
} }
tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels) tilesLocal := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels)
if err != nil {
return nil, err
}
ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize
...@@ -91,11 +88,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input ...@@ -91,11 +88,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
pixelValues := tilesLocal pixelValues := tilesLocal
if len(pixelsGlobal) > 0 { if len(pixelsGlobal) > 0 {
tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels) tilesGlobal := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels)
if err != nil {
return nil, err
}
pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3) pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3)
} }
...@@ -182,15 +175,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { ...@@ -182,15 +175,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
} }
......
...@@ -223,11 +223,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor ...@@ -223,11 +223,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0) scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0)
} }
var err error attentionScales = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales))
attentionScales, err = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales))
if err != nil {
panic(err)
}
} }
for i, layer := range m.Layers { for i, layer := range m.Layers {
......
...@@ -245,10 +245,7 @@ func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) { ...@@ -245,10 +245,7 @@ func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) {
} }
} }
ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2) ropeFreqs := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2)
if err != nil {
panic(err)
}
ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches) ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches)
......
...@@ -114,10 +114,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input ...@@ -114,10 +114,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err return nil, err
} }
pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels) pixelValues := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
...@@ -161,15 +158,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { ...@@ -161,15 +158,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
} }
......
...@@ -110,15 +110,8 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) ...@@ -110,15 +110,8 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor)
} }
} }
h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2) h := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
if err != nil { w := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
panic(err)
}
w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
if err != nil {
panic(err)
}
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
...@@ -151,10 +144,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { ...@@ -151,10 +144,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
} }
} }
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) positionIDs := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
panic(err)
}
positionEmbedding := m.positionalEmbedding(ctx, positionIDs) positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
......
...@@ -80,15 +80,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input ...@@ -80,15 +80,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles] f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles]
} }
pixelValues, err := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles) pixelValues := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles)
if err != nil { aspectRatio := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1)
return nil, err
}
aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1)
if err != nil {
return nil, err
}
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32) positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
...@@ -113,15 +106,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { ...@@ -113,15 +106,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor
} }
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
// TODO: attention mask, cross attention mask // TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
......
...@@ -100,10 +100,7 @@ type Model struct { ...@@ -100,10 +100,7 @@ type Model struct {
// Forward implements model.Model. // Forward implements model.Model.
func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
...@@ -112,10 +109,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { ...@@ -112,10 +109,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
} }
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
......
...@@ -69,10 +69,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, * ...@@ -69,10 +69,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
numPatches := grid.Temporal * grid.Height * grid.Width numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) pixelValues := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
if err != nil {
return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err)
}
return pixelValues, grid, nil return pixelValues, grid, nil
} }
...@@ -142,15 +139,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { ...@@ -142,15 +139,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
} }
......
package qwen25vl package qwen25vl
import ( import (
"fmt"
"math" "math"
"slices" "slices"
...@@ -44,10 +43,8 @@ func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int ...@@ -44,10 +43,8 @@ func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int
} }
} }
mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength) mask := ctx.Input().FromFloatSlice(flat, seqLength, seqLength)
if err != nil {
panic(err)
}
// Reshape to match [seqLength, seqLength, 1] for broadcasting // Reshape to match [seqLength, seqLength, 1] for broadcasting
mask = mask.Reshape(ctx, seqLength, seqLength, 1) mask = mask.Reshape(ctx, seqLength, seqLength, 1)
...@@ -303,10 +300,7 @@ func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) ...@@ -303,10 +300,7 @@ func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int)
} }
} }
t, err := ctx.Input().FromIntSlice(index, len(index)) t := ctx.Input().FromIntSlice(index, len(index))
if err != nil {
panic(err)
}
return t, bounds return t, bounds
} }
...@@ -326,10 +320,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor ...@@ -326,10 +320,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim))) freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
} }
} }
freqs, err := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize) freqs := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize)
if err != nil {
panic(fmt.Errorf("failed to create tensor from frequencies: %w", err))
}
// Create position coordinates (y,x pairs) for the grid // Create position coordinates (y,x pairs) for the grid
// In PyTorch: Equivalent to generating position ids with torch.arange() // In PyTorch: Equivalent to generating position ids with torch.arange()
...@@ -339,10 +330,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor ...@@ -339,10 +330,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
coords = append(coords, int32(y), int32(x)) coords = append(coords, int32(y), int32(x))
} }
} }
pos, err := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height) pos := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height)
if err != nil {
panic(fmt.Errorf("failed to create tensor from positions: %w", err))
}
// Reshape and permute positions to match spatial merging pattern // Reshape and permute positions to match spatial merging pattern
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge) pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
......
...@@ -156,10 +156,7 @@ type Model struct { ...@@ -156,10 +156,7 @@ type Model struct {
// Forward implements model.Model. // Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
...@@ -168,10 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { ...@@ -168,10 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
} }
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
......
...@@ -102,7 +102,7 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten ...@@ -102,7 +102,7 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten
for i, t := range entry.mm { for i, t := range entry.mm {
if in == t.Tensor { if in == t.Tensor {
if !reserve { if !reserve {
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...) return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...), nil
} else { } else {
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
} }
......
...@@ -808,10 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error { ...@@ -808,10 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Outputs[i] = int32(i) batch.Outputs[i] = int32(i)
} }
batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
if err != nil {
return err
}
cache := s.model.Config().Cache cache := s.model.Config().Cache
if cache != nil { if cache != nil {
...@@ -876,7 +873,8 @@ func (s *Server) load( ...@@ -876,7 +873,8 @@ func (s *Server) load(
parallel int, parallel int,
kvCacheType string, kvCacheType string,
kvSize int, kvSize int,
multiUserCache bool) { multiUserCache bool,
) {
err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache) err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache)
if err != nil { if err != nil {
panic(err) panic(err)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment