Unverified Commit 333203d8 authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

chore: update models to use slice/chunk/chunksections (#12934)

* use slice/chunks

* bert

* llama4

* gemma3n

* gptoss

* mistral3

* qwen3vl

* qwen25vl

* deepseek2

* remove unused ops
parent c1149875
...@@ -110,9 +110,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { ...@@ -110,9 +110,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s { for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape() dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") {
name = name + ".weight"
}
if strings.Contains(name, "ffn_down_exps") { if strings.Contains(name, "ffn_down_exps") {
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
Name: name + ".weight", Name: name,
Kind: uint32(ggml.TensorTypeMXFP4), Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2}, Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4, WriterTo: mxfp4,
...@@ -121,12 +124,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { ...@@ -121,12 +124,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
// gate_up_exps is interleaved, need to split into gate_exps and up_exps // gate_up_exps is interleaved, need to split into gate_exps and up_exps
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...] // e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight", Name: strings.Replace(name, "gate_up", "gate", 1),
Kind: uint32(ggml.TensorTypeMXFP4), Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2}, Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2), WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
}, &ggml.Tensor{ }, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight", Name: strings.Replace(name, "gate_up", "up", 1),
Kind: uint32(ggml.TensorTypeMXFP4), Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2}, Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2), WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
......
...@@ -146,7 +146,6 @@ type Tensor interface { ...@@ -146,7 +146,6 @@ type Tensor interface {
FromFloats([]float32) FromFloats([]float32)
FromInts([]int32) FromInts([]int32)
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor Sub(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor
...@@ -185,7 +184,6 @@ type Tensor interface { ...@@ -185,7 +184,6 @@ type Tensor interface {
View(ctx Context, offset int, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context, shape ...int) Tensor Contiguous(ctx Context, shape ...int) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
Pad(ctx Context, shape ...int) Tensor Pad(ctx Context, shape ...int) Tensor
...@@ -209,7 +207,6 @@ type Tensor interface { ...@@ -209,7 +207,6 @@ type Tensor interface {
Stddev(ctx Context) Tensor Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor Sqrt(ctx Context) Tensor
Clamp(ctx Context, min, max float32) Tensor
} }
// ScaledDotProductAttention implements a fused attention // ScaledDotProductAttention implements a fused attention
......
...@@ -1137,13 +1137,6 @@ func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor { ...@@ -1137,13 +1137,6 @@ func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
} }
} }
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
...@@ -1632,20 +1625,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { ...@@ -1632,20 +1625,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
} }
} }
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
var tt *C.struct_ggml_tensor
switch len(strides) {
case 0:
tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
case 1:
tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
default:
panic("unsupported number of dimensions")
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor { func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor var kqMask *C.struct_ggml_tensor
if mask != nil { if mask != nil {
...@@ -1732,13 +1711,6 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor { ...@@ -1732,13 +1711,6 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
} }
} }
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
}
}
// Slice returns a view of the tensor sliced along dim from low to high in step steps. // Slice returns a view of the tensor sliced along dim from low to high in step steps.
// Slice panics if the dimension is invalid or the slice parameters are out of range. // Slice panics if the dimension is invalid or the slice parameters are out of range.
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape. // If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.
......
...@@ -32,10 +32,9 @@ func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { ...@@ -32,10 +32,9 @@ func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS: case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) return hiddenStates.Slice(ctx, 1, 0, 1, 1)
case TypeLast: case TypeLast:
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0)) return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
return hiddenStates
default: default:
panic("unknown pooling type") panic("unknown pooling type")
} }
......
...@@ -29,7 +29,7 @@ type Model struct { ...@@ -29,7 +29,7 @@ type Model struct {
// Forward implements model.Model. // Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize)) hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.Slice(ctx, 1, 0, 1, 1))
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromInts(batch.Positions, len(batch.Positions)))) hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromInts(batch.Positions, len(batch.Positions))))
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps) hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
......
...@@ -78,44 +78,31 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor ...@@ -78,44 +78,31 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
} }
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength) query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
qPass := query.View(ctx, 0,
opts.qkNopeHeadDim, query.Stride(1),
query.Dim(1), query.Stride(2),
query.Dim(2))
qRot := query.View(ctx, opts.qkNopeHeadDim*query.Stride(0),
opts.qkRopeHeadDim, query.Stride(1),
query.Dim(1), query.Stride(2),
query.Dim(2))
compressedKV := attn.KVA.Forward(ctx, hiddenStates) compressedKV := attn.KVA.Forward(ctx, hiddenStates)
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
kPass := compressedKV.View(ctx, 0, opts.kvLoraRank, compressedKV.Stride(1), compressedKV.Dim(1)) kRot := compressedKV.View(ctx,
kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0), opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
opts.qkRopeHeadDim, compressedKV.Stride(1), compressedKV.Stride(1), 1,
1, compressedKV.Stride(1), compressedKV.Stride(1), compressedKV.Dim(1),
compressedKV.Dim(1)) )
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass) kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kPass = kv.View(ctx, 0, opts.kqNopeHeadDim, kv.Stride(1), kv.Dim(1), kv.Stride(2), kv.Dim(2)) kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
value := kv.View(ctx, opts.kqNopeHeadDim*kv.Stride(0),
opts.vHeadDim, kv.Stride(1),
kv.Dim(1), kv.Stride(2),
kv.Dim(2)).Contiguous(ctx)
qRot = fast.RoPE(ctx, qRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = kRot.Repeat(ctx, 1, qPass.Dim(1)) kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, qPass, 0) query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kPass, 0) key := kRot.Concat(ctx, kvChunks[0], 0)
attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache) attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention) return attn.Output.Forward(ctx, attention)
} }
...@@ -142,6 +129,7 @@ func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml ...@@ -142,6 +129,7 @@ func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices) experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
experts = experts.Mul(ctx, topKWeights) experts = experts.Mul(ctx, topKWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ { for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
......
...@@ -64,18 +64,18 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac ...@@ -64,18 +64,18 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
cache.(*kvcache.WrapperCache).SetLayerType(layerType) cache.(*kvcache.WrapperCache).SetLayerType(layerType)
// inputPerLayer = inputsPerLayer[:, i, :] // inputPerLayer = inputsPerLayer[:, i, :].squeeze(1)
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2)).Contiguous(ctx) inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions) hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
} }
// hiddenStates = hiddenStates[:, :, 0] // hiddenStates = hiddenStates[:, :, 0]
hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1)) hiddenStates0 := hiddenStates.Slice(ctx, 2, 0, 1, 1)
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx) targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1) targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
// hiddenState = hiddenStates[:, :, 1:] // hiddenState = hiddenStates[:, :, 1:]
hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1) hiddenState = hiddenStates.Slice(ctx, 2, 1, hiddenStates.Dim(2), 1)
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState) altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx))) altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
...@@ -176,10 +176,10 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position ...@@ -176,10 +176,10 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps) active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
// inactive := predictions[:, :, 1:] // inactive := predictions[:, :, 1:]
inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1) inactive := predictions.Slice(ctx, 2, 1, predictions.Dim(2), 1)
active = inactive.Add(ctx, active) active = inactive.Add(ctx, active)
predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1)) predictions0 := predictions.Slice(ctx, 2, 0, 1, 1)
return predictions0.Concat(ctx, active, 2) return predictions0.Concat(ctx, active, 2)
} }
...@@ -319,7 +319,7 @@ type TextOptions struct { ...@@ -319,7 +319,7 @@ type TextOptions struct {
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor { func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
// t[:, :, o.altupActiveIndex] // t[:, :, o.altupActiveIndex]
return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1)) return t.Slice(ctx, 2, o.altupActiveIndex, o.altupActiveIndex+1, 1)
} }
func (o *TextOptions) headDim() int { func (o *TextOptions) headDim() int {
......
...@@ -121,30 +121,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T ...@@ -121,30 +121,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T
var query, key, value ml.Tensor var query, key, value ml.Tensor
if attn.QKV != nil { if attn.QKV != nil {
qkv := attn.QKV.Forward(ctx, hiddenStates) qkv := attn.QKV.Forward(ctx, hiddenStates)
qkv = qkv.Reshape(ctx, opts.headDim(), -1, batchSize)
// query = qkv[..., : num_attention_heads * head_dim].reshape(batch_size, num_attention_heads, head_dim) chunks := qkv.ChunkSections(ctx, 1, opts.numHeads, opts.numKVHeads, opts.numKVHeads)
query = qkv.View(ctx, query, key, value = chunks[0], chunks[1], chunks[2]
0,
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numHeads, qkv.Stride(1),
batchSize,
)
// key = qkv[..., num_attention_heads * head_dim:(num_attention_heads + num_key_value_heads) * head_dim].reshape(batch_size, num_key_value_heads, head_dim)
key = qkv.View(ctx,
qkv.Stride(0)*opts.headDim()*opts.numHeads,
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numKVHeads, qkv.Stride(1),
batchSize,
)
// value = qkv[..., (num_attention_heads + num_key_value_heads) * head_dim:].reshape(batch_size, num_key_value_heads, head_dim)
value = qkv.View(ctx,
qkv.Stride(0)*opts.headDim()*(opts.numHeads+opts.numKVHeads),
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numKVHeads, qkv.Stride(1),
batchSize,
)
} else { } else {
query = attn.Query.Forward(ctx, hiddenStates) query = attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
...@@ -195,15 +174,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio ...@@ -195,15 +174,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
var gate, up ml.Tensor var gate, up ml.Tensor
if mlp.GateUp != nil { if mlp.GateUp != nil {
hiddenStates = mlp.GateUp.Forward(ctx, hiddenStates, selectedExperts) hiddenStates = mlp.GateUp.Forward(ctx, hiddenStates, selectedExperts)
hiddenStates = hiddenStates.Reshape(ctx, 2, hiddenStates.Dim(0)/2, hiddenStates.Dim(1), hiddenStates.Dim(2)) gate = hiddenStates.Slice(ctx, 0, 0, hiddenStates.Dim(0), 2)
up = hiddenStates.Slice(ctx, 0, 1, hiddenStates.Dim(0), 2)
dimStride := []int{hiddenStates.Dim(0) / 2, hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), hiddenStates.Dim(2), hiddenStates.Stride(3), hiddenStates.Dim(3)}
gate = hiddenStates.View(ctx, 0, dimStride...)
gate = gate.Contiguous(ctx, gate.Dim(0)*gate.Dim(1), gate.Dim(2), gate.Dim(3))
up = hiddenStates.View(ctx, hiddenStates.Stride(0), dimStride...)
up = up.Contiguous(ctx, up.Dim(0)*up.Dim(1), up.Dim(2), up.Dim(3))
} else { } else {
gate = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts) gate = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts)
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts) up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
......
...@@ -105,9 +105,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input ...@@ -105,9 +105,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
for range aspectRatio.Y { for range aspectRatio.Y {
for x := range aspectRatio.X { for x := range aspectRatio.X {
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset, view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1)
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
var separator separator var separator separator
if x < aspectRatio.X-1 { if x < aspectRatio.X-1 {
separator.x = true // <|tile_x_separator|> separator.x = true // <|tile_x_separator|>
...@@ -120,9 +118,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input ...@@ -120,9 +118,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
} }
} }
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset, view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1)
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}}) multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}})
return multimodal, nil return multimodal, nil
......
...@@ -37,27 +37,23 @@ type VisionAttention struct { ...@@ -37,27 +37,23 @@ type VisionAttention struct {
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3) width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3))
// t1 = t[..., 0::2] // t1 = t[..., 0::2]
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) t1 := t.Slice(ctx, 0, 0, t.Dim(0), 2)
t1 = t1.Reshape(ctx, width/2, height, channels, tiles)
// t2 = t[..., 1::2] // t2 = t[..., 1::2]
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) t2 := t.Slice(ctx, 0, 1, t.Dim(0), 2)
t2 = t2.Reshape(ctx, width/2, height, channels, tiles)
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1) // cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0) cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3)) cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, -1)
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) cosOut = cosOut.Permute(ctx, 1, 0, 2, 3)
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles) cosOut = cosOut.Contiguous(ctx, width, height, channels, tiles)
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1) // sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0) sinOut := t2.Scale(ctx, -1).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3)) sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, -1)
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) sinOut = sinOut.Permute(ctx, 1, 0, 2, 3)
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles) sinOut = sinOut.Contiguous(ctx, width, height, channels, tiles)
return cosOut.Add(ctx, sinOut) return cosOut.Add(ctx, sinOut)
} }
......
...@@ -11,9 +11,9 @@ import ( ...@@ -11,9 +11,9 @@ import (
var batchSize int = 1 var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0) return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
} }
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
......
...@@ -13,9 +13,9 @@ import ( ...@@ -13,9 +13,9 @@ import (
var batchSize int = 1 var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0) return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
} }
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
......
...@@ -18,8 +18,8 @@ type VisionAttention struct { ...@@ -18,8 +18,8 @@ type VisionAttention struct {
} }
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
return x2.Scale(ctx, -1).Concat(ctx, x1, 0) return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
} }
...@@ -160,10 +160,11 @@ func (m *VisionPositionEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor ...@@ -160,10 +160,11 @@ func (m *VisionPositionEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor
positionEmbeds = positionEmbeds.Mul(ctx, weights) positionEmbeds = positionEmbeds.Mul(ctx, weights)
positionEmbeds = positionEmbeds.Reshape(ctx, n, -1, 4) positionEmbeds = positionEmbeds.Reshape(ctx, n, -1, 4)
positionEmbeds = positionEmbeds.View(ctx, 0, n, positionEmbeds.Stride(1), grid.Height*grid.Width). positionEmbedsChunks := positionEmbeds.Chunk(ctx, 2, 1)
Add(ctx, positionEmbeds.View(ctx, 1*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)). positionEmbeds = positionEmbedsChunks[0].
Add(ctx, positionEmbeds.View(ctx, 2*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)). Add(ctx, positionEmbedsChunks[1]).
Add(ctx, positionEmbeds.View(ctx, 3*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)) Add(ctx, positionEmbedsChunks[2]).
Add(ctx, positionEmbedsChunks[3])
positionEmbeds = positionEmbeds.Reshape(ctx, -1, grid.Width/opts.spatialMergeSize, opts.spatialMergeSize, grid.Height/opts.spatialMergeSize) positionEmbeds = positionEmbeds.Reshape(ctx, -1, grid.Width/opts.spatialMergeSize, opts.spatialMergeSize, grid.Height/opts.spatialMergeSize)
positionEmbeds = positionEmbeds.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, n, -1) positionEmbeds = positionEmbeds.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, n, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment