Commit 0df18004 authored by Michael Yang's avatar Michael Yang
Browse files

set non-causal attention

parent 631fecc6
...@@ -58,9 +58,6 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV { ...@@ -58,9 +58,6 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon
kv["tokenizer.ggml.bos_token_id"] = uint32(2)
kv["tokenizer.ggml.eot_token_id"] = uint32(1)
return kv return kv
} }
......
...@@ -148,6 +148,7 @@ type Tensor interface { ...@@ -148,6 +148,7 @@ type Tensor interface {
View(ctx Context, offset int, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context) Tensor Contiguous(ctx Context) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
Pad(ctx Context, shape ...int) Tensor Pad(ctx Context, shape ...int) Tensor
Unpad(ctx Context, shape ...int) Tensor Unpad(ctx Context, shape ...int) Tensor
......
...@@ -954,6 +954,20 @@ func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor { ...@@ -954,6 +954,20 @@ func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
} }
} }
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
var tt *C.struct_ggml_tensor
switch len(strides) {
case 0:
tt = C.ggml_set_1d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
case 1:
tt = C.ggml_set_2d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
default:
panic("unsupported number of dimensions")
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor { func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor var kqMask *C.struct_ggml_tensor
if mask != nil { if mask != nil {
......
...@@ -51,8 +51,10 @@ func New(c ml.Config) (model.Model, error) { ...@@ -51,8 +51,10 @@ func New(c ml.Config) (model.Model, error) {
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), EOS: int32(1),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOT: int32(106),
AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
}, },
), ),
ImageProcessor: newImageProcessor(c), ImageProcessor: newImageProcessor(c),
...@@ -109,35 +111,46 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu ...@@ -109,35 +111,46 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
for i := range inputs { for i := range inputs {
if inputs[i].Multimodal == nil { if inputs[i].Multimodal == nil {
if len(images) > 0 { for j := range images {
inputs[i].Multimodal = images[0].Multimodal if j == 0 {
inputs[i].MultimodalHash = images[0].MultimodalHash inputs[i].Multimodal = images[j].Multimodal
for j := 1; j < len(images); j++ { inputs[i].MultimodalHash = images[j].MultimodalHash
} else {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3) inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset() fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, images[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64() inputs[i].MultimodalHash = fnvHash.Sum64()
} }
images = nil
} }
images = nil
} else { } else {
images = append(images, inputs[i]) images = append(images, inputs[i])
inputs[i].Token = -1 inputs[i].Token = -1
} }
} }
inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 }) for i := range inputs {
if inputs[i].Token == -1 {
imageInputs := []input.Input{
{Token: 108}, // "\n\n"
{Token: 255999}, // "<start_of_image>""
}
// <image_soft_token>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 262144}}, 256)...)
// <end_of_image>
imageInputs = append(imageInputs, input.Input{Token: 256000})
inputs = append(inputs[:i], append(imageInputs, inputs[i+1:]...)...)
}
}
return inputs, nil return inputs, nil
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var embeddings ml.Tensor
if opts.Multimodal != nil {
embeddings = opts.Multimodal[0].Multimodal.(ml.Tensor)
}
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -153,7 +166,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { ...@@ -153,7 +166,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err return nil, err
} }
return m.TextModel.Forward(ctx, inputs, positions, embeddings, outputs, m.Cache), nil return m.TextModel.Forward(ctx, inputs, positions, outputs, opts.Multimodal, m.Cache), nil
} }
func init() { func init() {
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
) )
type TextOptions struct { type TextOptions struct {
...@@ -165,12 +166,15 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, ...@@ -165,12 +166,15 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor { func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor {
if embeddings == nil { hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
embeddings = m.TokenEmbedding.Forward(ctx, inputs) if multimodal != nil {
visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(0))
} }
hiddenState := embeddings.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount { if len(m.Layers) == gemma27BLayerCount {
m.TextOptions.largeModelScaling = true m.TextOptions.largeModelScaling = true
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"cmp" "cmp"
"iter" "iter"
"log/slog" "log/slog"
"slices"
"strings" "strings"
"sync" "sync"
...@@ -39,8 +40,8 @@ type Vocabulary struct { ...@@ -39,8 +40,8 @@ type Vocabulary struct {
Scores []float32 Scores []float32
Merges []string Merges []string
BOS, EOS int32 BOS, EOS, EOT int32
AddBOS, AddEOS bool AddBOS, AddEOS, AddEOT bool
specialOnce sync.Once specialOnce sync.Once
special []string special []string
...@@ -57,7 +58,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool { ...@@ -57,7 +58,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
case SpecialBOS: case SpecialBOS:
return id == v.BOS return id == v.BOS
case SpecialEOS: case SpecialEOS:
return id == v.EOS return id == v.EOS || id == v.EOT
default: default:
return false return false
} }
...@@ -85,7 +86,9 @@ func (v *Vocabulary) Decode(id int32) string { ...@@ -85,7 +86,9 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string { func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() { v.specialOnce.Do(func() {
for i := range v.Values { for i := range v.Values {
if v.Types[i] == TOKEN_TYPE_CONTROL { if slices.Contains([]int{105, 106}, i) {
v.special = append(v.special, v.Values[i])
} else if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i]) v.special = append(v.special, v.Values[i])
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment