Commit a8e83a76 authored by Jesse Gross's avatar Jesse Gross Committed by Michael Yang
Browse files

Disable causal attention based on batch index

Currently we are using positions, which are relative to a
sequence and may not be unique.
parent 47500550
...@@ -144,6 +144,7 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error { ...@@ -144,6 +144,7 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
c.curBatchSize = len(opts.Positions) c.curBatchSize = len(opts.Positions)
c.curSequences = opts.Sequences c.curSequences = opts.Sequences
c.curPositions = opts.Positions c.curPositions = opts.Positions
c.opts.Except = nil
var err error var err error
c.curLoc, err = c.findStartLoc() c.curLoc, err = c.findStartLoc()
...@@ -234,7 +235,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { ...@@ -234,7 +235,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
mask := make([]float32, batchSize*length) mask := make([]float32, batchSize*length)
for i := range c.curBatchSize { for i := range c.curBatchSize {
enabled := !slices.Contains(c.opts.Except, c.curPositions[i]) enabled := !slices.Contains(c.opts.Except, i)
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
(enabled && c.cells[j].pos > c.curPositions[i]) || (enabled && c.cells[j].pos > c.curPositions[i]) ||
...@@ -405,15 +406,12 @@ func (c *Causal) SetLayer(layer int) { ...@@ -405,15 +406,12 @@ func (c *Causal) SetLayer(layer int) {
} }
type CausalOptions struct { type CausalOptions struct {
// Enabled controls whether the causal mask is generated for a particular position. // Enabled controls whether the causal mask is generated for a particular index in a batch
Except []int32 Except []int
} }
// SetCausal enables or disables causal mask generation for subsequent calls to Get. // SetCausal disables causal mask generation for a particular range of indicies in
// This state carries over to future forward passes. The default value is true. // the current batch for subsequent calls to Get. The state resets for the next forward pass.
//
// ctx may be set to nil if this is called from outside of a forward pass, for
// example, when initializing the cache.
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) { func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
if !slices.Equal(c.opts.Except, opts.Except) { if !slices.Equal(c.opts.Except, opts.Except) {
c.opts = opts c.opts = opts
......
...@@ -173,10 +173,10 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, ...@@ -173,10 +173,10 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex, positions []int32) []int32 { func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
var embedding ml.Tensor var embedding ml.Tensor
var src, dst, length int var src, dst, length int
var except []int32 var except []int
for _, image := range multimodal { for _, image := range multimodal {
imageToken := image.Multimodal.(imageToken) imageToken := image.Multimodal.(imageToken)
...@@ -204,7 +204,7 @@ func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []inpu ...@@ -204,7 +204,7 @@ func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []inpu
length = 1 length = 1
} }
except = append(except, positions[imageDst]) except = append(except, imageDst)
} }
if embedding != nil { if embedding != nil {
...@@ -219,7 +219,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor ...@@ -219,7 +219,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal, opts.Positions) except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
for i, layer := range m.Layers { for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global) // gemma alternates between the sliding window (local) and causal (global)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment