Unverified Commit adff143b authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

fix: mllama quality (#10807)

* fix mllama convert

- transform attn_gate and ffn_gate
- swap attention heads for vision models

* fix mllama

the mlp gate which was applied in the wrong place
parent fbe6ae28
...@@ -94,7 +94,9 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor { ...@@ -94,7 +94,9 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor var out []*ggml.Tensor
var text []Tensor var text []Tensor
for _, t := range ts { for _, t := range ts {
if t.Name() == "v.position_embd.gate" { if !strings.HasPrefix(t.Name(), "v.") && !strings.HasPrefix(t.Name(), "mm.") {
text = append(text, t)
} else if t.Name() == "v.position_embd.gate" {
for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} { for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} {
tt := t.Clone() tt := t.Clone()
tt.SetRepacker(m.repack(name)) tt.SetRepacker(m.repack(name))
...@@ -105,23 +107,21 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor { ...@@ -105,23 +107,21 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
WriterTo: tt, WriterTo: tt,
}) })
} }
} else if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" { } else {
t.SetRepacker(m.repack(t.Name())) if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
out = append(out, &ggml.Tensor{ t.SetRepacker(m.repack(t.Name()))
Name: t.Name(), } else if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
Kind: t.Kind(), t.SetRepacker(m.repack(t.Name()))
Shape: t.Shape(), } else if strings.HasSuffix(t.Name(), "attn_gate") || strings.HasSuffix(t.Name(), "ffn_gate") {
WriterTo: t, t.SetRepacker(m.repack(t.Name()))
}) }
} else if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),
WriterTo: t, WriterTo: t,
}) })
} else {
text = append(text, t)
} }
} }
...@@ -137,16 +137,35 @@ func (m *mllamaModel) repack(name string) Repacker { ...@@ -137,16 +137,35 @@ func (m *mllamaModel) repack(name string) Repacker {
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err = tensor.Tanh(t) if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_k.weight") {
if err != nil { heads := m.VisionModel.AttentionHeads
return nil, err if err := t.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
} return nil, err
}
if name == "v.position_embd.gate" { if err := t.T(0, 2, 1, 3); err != nil {
t, err = tensor.Sub(float32(1), t) return nil, err
}
if err := t.Reshape(dims...); err != nil {
return nil, err
}
if err := t.Transpose(); err != nil {
return nil, err
}
} else {
t, err = tensor.Tanh(t)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if name == "v.position_embd.gate" {
t, err = tensor.Sub(float32(1), t)
if err != nil {
return nil, err
}
}
} }
t = tensor.Materialize(t) t = tensor.Materialize(t)
......
...@@ -16,8 +16,6 @@ type VisionSelfAttention struct { ...@@ -16,8 +16,6 @@ type VisionSelfAttention struct {
Key *nn.Linear `gguf:"attn_k"` Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"` Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"` Output *nn.Linear `gguf:"attn_output"`
Gate ml.Tensor `gguf:"attn_gate"`
} }
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
...@@ -25,27 +23,16 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op ...@@ -25,27 +23,16 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
query := sa.Query.Forward(ctx, hiddenState) query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key := sa.Key.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize) key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.Mulmat(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
hiddenState = sa.Output.Forward(ctx, attention)
return hiddenState
} }
type VisionMLP struct { type VisionMLP struct {
...@@ -76,21 +63,18 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts ...@@ -76,21 +63,18 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts
// self attention // self attention
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts) hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
if e.AttentionGate != nil { if e.AttentionGate != nil {
hiddenState = hiddenState.Mul(ctx, e.AttentionGate) hiddenState = hiddenState.Mul(ctx, e.AttentionGate)
} }
hiddenState = hiddenState.Add(ctx, residual) hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState residual = hiddenState
// feed forward
hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts) hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
if e.MLPGate != nil { if e.MLPGate != nil {
hiddenState = hiddenState.Mul(ctx, e.MLPGate) hiddenState = hiddenState.Mul(ctx, e.MLPGate)
} }
hiddenState = hiddenState.Add(ctx, residual)
return hiddenState return hiddenState
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment