Commit 0df18004 authored by Michael Yang's avatar Michael Yang
Browse files

set non-causal attention

parent 631fecc6
......@@ -58,9 +58,6 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon
kv["tokenizer.ggml.bos_token_id"] = uint32(2)
kv["tokenizer.ggml.eot_token_id"] = uint32(1)
return kv
}
......
......@@ -148,6 +148,7 @@ type Tensor interface {
View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
Pad(ctx Context, shape ...int) Tensor
Unpad(ctx Context, shape ...int) Tensor
......
......@@ -954,6 +954,20 @@ func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
}
}
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
var tt *C.struct_ggml_tensor
switch len(strides) {
case 0:
tt = C.ggml_set_1d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
case 1:
tt = C.ggml_set_2d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
default:
panic("unsupported number of dimensions")
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor
if mask != nil {
......
......@@ -51,8 +51,10 @@ func New(c ml.Config) (model.Model, error) {
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
EOS: int32(1),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOT: int32(106),
AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
},
),
ImageProcessor: newImageProcessor(c),
......@@ -109,35 +111,46 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
for i := range inputs {
if inputs[i].Multimodal == nil {
if len(images) > 0 {
inputs[i].Multimodal = images[0].Multimodal
inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ {
for j := range images {
if j == 0 {
inputs[i].Multimodal = images[j].Multimodal
inputs[i].MultimodalHash = images[j].MultimodalHash
} else {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, images[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64()
}
images = nil
}
images = nil
} else {
images = append(images, inputs[i])
inputs[i].Token = -1
}
}
inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
for i := range inputs {
if inputs[i].Token == -1 {
imageInputs := []input.Input{
{Token: 108}, // "\n\n"
{Token: 255999}, // "<start_of_image>""
}
// <image_soft_token>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 262144}}, 256)...)
// <end_of_image>
imageInputs = append(imageInputs, input.Input{Token: 256000})
inputs = append(inputs[:i], append(imageInputs, inputs[i+1:]...)...)
}
}
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var embeddings ml.Tensor
if opts.Multimodal != nil {
embeddings = opts.Multimodal[0].Multimodal.(ml.Tensor)
}
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
......@@ -153,7 +166,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err
}
return m.TextModel.Forward(ctx, inputs, positions, embeddings, outputs, m.Cache), nil
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts.Multimodal, m.Cache), nil
}
func init() {
......
......@@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
......@@ -165,12 +166,15 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
if embeddings == nil {
embeddings = m.TokenEmbedding.Forward(ctx, inputs)
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
if multimodal != nil {
visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(0))
}
hiddenState := embeddings.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
m.TextOptions.largeModelScaling = true
......
......@@ -4,6 +4,7 @@ import (
"cmp"
"iter"
"log/slog"
"slices"
"strings"
"sync"
......@@ -39,8 +40,8 @@ type Vocabulary struct {
Scores []float32
Merges []string
BOS, EOS int32
AddBOS, AddEOS bool
BOS, EOS, EOT int32
AddBOS, AddEOS, AddEOT bool
specialOnce sync.Once
special []string
......@@ -57,7 +58,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
case SpecialBOS:
return id == v.BOS
case SpecialEOS:
return id == v.EOS
return id == v.EOS || id == v.EOT
default:
return false
}
......@@ -85,7 +86,9 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == TOKEN_TYPE_CONTROL {
if slices.Contains([]int{105, 106}, i) {
v.special = append(v.special, v.Values[i])
} else if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i])
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment