Commit 6bd0a983 authored by Bruce MacDonald's avatar Bruce MacDonald Committed by Michael Yang
Browse files

model: support for mistral-small in the ollama runner

Mistral is a popular research lab making open source models. This updates
the forward pass of llama architecture models to support both llama models
and mistral models by accounting for additional metadata present in mistral
models, and finding the correct dimensions for the output projection.
parent 1861fbde
package mistral3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0)
}
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type VisionEncoderLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *VisionSelfAttention
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *VisionMLP
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = e.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
intermediateSize int
imageSize int
patchSize int
numChannels int
eps float32
ropeBase float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) ml.Tensor {
maxPatchesPerSide := m.imageSize / m.patchSize
frequencies := m.headDim / 2
frequenciesHeight := make([]float32, frequencies/2*maxPatchesPerSide)
frequenciesWidth := make([]float32, frequencies/2*maxPatchesPerSide)
for i := range frequencies {
for j := range maxPatchesPerSide {
frequency := float32(j) / float32(math.Pow(float64(m.ropeBase), float64(i)*2/float64(m.headDim)))
if i%2 == 0 {
frequenciesHeight[i/2*maxPatchesPerSide+j] = frequency
} else {
frequenciesWidth[i/2*maxPatchesPerSide+j] = frequency
}
}
}
h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
if err != nil {
panic(err)
}
w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
if err != nil {
panic(err)
}
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
h = h.Repeat(ctx, 1, maxPatchesPerSide)
h = h.Reshape(ctx, frequencies/2, maxPatchesPerSide, maxPatchesPerSide).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
w = w.Repeat(ctx, 2, maxPatchesPerSide)
inverseFrequencies := h.Concat(ctx, w, 0).Reshape(ctx, frequencies, maxPatchesPerSide*maxPatchesPerSide)
inverseFrequencies = inverseFrequencies.Concat(ctx, inverseFrequencies, 0)
return inverseFrequencies.Rows(ctx, positionIDs)
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatchesW := pixelValues.Dim(0) / m.patchSize
numPatchesH := pixelValues.Dim(1) / m.patchSize
numPatches := numPatchesW * numPatchesH
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps)
// Prepare position IDs for 2D rope
positions := make([]int32, numPatches)
for h := range numPatchesH {
for w := range numPatchesW {
idx := h*numPatchesW + w
positions[idx] = int32(h*m.imageSize/m.patchSize + w)
}
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
panic(err)
}
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
}
return hiddenStates
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
numHeads: int(c.Uint("vision.attention.head_count", 16)),
headDim: int(c.Uint("vision.attention.key_length", 64)),
intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)),
imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
},
}
}
...@@ -186,7 +186,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa ...@@ -186,7 +186,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions) hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, m.numTiles-1)...).Concat(ctx, hiddenState, 1) hiddenState = m.ClassEmbedding.Repeat(ctx, 2, m.numTiles).Concat(ctx, hiddenState, 1)
hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions) hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions)
hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps) hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps)
......
...@@ -4,5 +4,6 @@ import ( ...@@ -4,5 +4,6 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama" _ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama" _ "github.com/ollama/ollama/model/models/mllama"
) )
package pixtral
import (
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"io"
"math"
"github.com/ollama/ollama/model/imageproc"
)
func getNumImageTokens(imageSize, patchSize image.Point) image.Point {
return image.Point{
(imageSize.X-1)/patchSize.X + 1,
(imageSize.Y-1)/patchSize.Y + 1,
}
}
func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point {
b := img.Bounds()
le := float64(longestEdge)
ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le)
newSize := img.Bounds().Max
if ratio > 1.0 {
newSize = image.Point{
int(math.Ceil(float64(b.Max.X) / ratio)),
int(math.Ceil(float64(b.Max.Y) / ratio)),
}
}
tokens := getNumImageTokens(newSize, patchSize)
return image.Point{
tokens.X * patchSize.X,
tokens.Y * patchSize.Y,
}
}
func resizeImage(img image.Image, format string, longestEdge int, patchSize image.Point) image.Image {
if format == "png" {
img = imageproc.Composite(img)
}
newSize := getResizeOutputImageSize(img, longestEdge, patchSize)
// todo should be ResizeBicubic, but it doesn't exist
return imageproc.Resize(img, newSize, imageproc.ResizeBilinear)
}
func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) {
img, format, err := image.Decode(imageData)
if err != nil {
return nil, nil, fmt.Errorf("failed to decode image: %w", err)
}
longestEdge := 1024
patchSize := image.Point{16, 16}
img = resizeImage(img, format, longestEdge, patchSize)
data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
opts := map[string]any{}
return data, opts, nil
}
package pixtral
import (
"bytes"
"encoding/binary"
"image"
"image/png"
"math"
"os"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestGetNumImageTokens(t *testing.T) {
type numImageTokensCase struct {
ImageSize image.Point
PatchSize image.Point
Expected image.Point
}
cases := []numImageTokensCase{
{
ImageSize: image.Point{1024, 764},
PatchSize: image.Point{16, 16},
Expected: image.Point{64, 48},
},
{
ImageSize: image.Point{800, 600},
PatchSize: image.Point{16, 16},
Expected: image.Point{50, 38},
},
{
ImageSize: image.Point{640, 480},
PatchSize: image.Point{16, 16},
Expected: image.Point{40, 30},
},
{
ImageSize: image.Point{320, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{20, 13},
},
{
ImageSize: image.Point{1320, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{83, 13},
},
{
ImageSize: image.Point{2000, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{125, 13},
},
{
ImageSize: image.Point{10000, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{625, 13},
},
{
ImageSize: image.Point{1131, 577},
PatchSize: image.Point{16, 16},
Expected: image.Point{71, 37},
},
{
ImageSize: image.Point{16, 16},
PatchSize: image.Point{16, 16},
Expected: image.Point{1, 1},
},
}
for _, c := range cases {
actual := getNumImageTokens(c.ImageSize, c.PatchSize)
if diff := cmp.Diff(actual, c.Expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func TestGetResizeOutputImageSize(t *testing.T) {
type resizeCase struct {
Image image.Image
LongestEdge int
PatchSize image.Point
Expected image.Point
}
cases := []resizeCase{
{
Image: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{1024, 768},
},
{
Image: image.NewRGBA(image.Rect(0, 0, 1162, 690)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{1024, 624},
},
{
Image: image.NewRGBA(image.Rect(0, 0, 300, 200)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{304, 208},
},
{
Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{1024, 288},
},
}
for _, c := range cases {
actual := getResizeOutputImageSize(c.Image, c.LongestEdge, c.PatchSize)
if diff := cmp.Diff(actual, c.Expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func TestResize(t *testing.T) {
type resizeCase struct {
Image image.Image
LongestEdge int
PatchSize image.Point
Expected image.Image
}
cases := []resizeCase{
{
Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.NewRGBA(image.Rect(0, 0, 1024, 288)),
},
{
Image: image.NewRGBA(image.Rect(0, 0, 10, 10)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.NewRGBA(image.Rect(0, 0, 16, 16)),
},
}
for _, c := range cases {
actual := resizeImage(c.Image, "png", c.LongestEdge, c.PatchSize)
if actual.Bounds() != c.Expected.Bounds() {
t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds())
}
}
}
func TestPreprocess(t *testing.T) {
type preprocessCase struct {
TestImage image.Image
ExpectedLen int
}
cases := []preprocessCase{
{
TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)),
ExpectedLen: 16 * 16 * 3 * 1,
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)),
ExpectedLen: 1024 * 1024 * 3 * 1,
},
}
for _, c := range cases {
var buf bytes.Buffer
err := png.Encode(&buf, c.TestImage)
if err != nil {
t.Fatal(err)
}
imgData, _, err := Preprocess(&buf)
if err != nil {
t.Fatalf("error processing: %q", err)
}
switch len(imgData) {
case 0:
t.Errorf("no image data returned")
case c.ExpectedLen:
// ok
default:
t.Errorf("unexpected image data length: %d, expected: %d", len(imgData), c.ExpectedLen)
}
}
}
func TestPreprocessImages(t *testing.T) {
for _, testFile := range []string{"flight.png", "sportsball.png"} {
f, err := os.Open(testFile)
if err != nil {
t.Skipf("skipping test, no test image found at %s", testFile)
}
defer f.Close()
imgData, _, err := Preprocess(f)
if err != nil {
t.Fatalf("error processing: %q", err)
}
byteData := make([]byte, len(imgData)*4) // float32 is 4 bytes
for i, f := range imgData {
binary.LittleEndian.PutUint32(byteData[i*4:], math.Float32bits(f))
}
outputPath := "processed_" + testFile + ".bin"
err = os.WriteFile(outputPath, byteData, 0o644)
if err != nil {
t.Fatalf("error writing processed image: %q", err)
}
}
}
...@@ -263,6 +263,10 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { ...@@ -263,6 +263,10 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
continue continue
} }
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...) merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil merges[pair.b].runes = nil
......
...@@ -211,16 +211,10 @@ func filesForModel(path string) ([]string, error) { ...@@ -211,16 +211,10 @@ func filesForModel(path string) ([]string, error) {
} }
var files []string var files []string
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 { if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are // safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...) files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapters.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapter_model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are // pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment