Commit a8e83a76 authored by Jesse Gross's avatar Jesse Gross Committed by Michael Yang
Browse files

Disable causal attention based on batch index

Currently we are using positions, which are relative to a
sequence and may not be unique.
parent 47500550
......@@ -144,6 +144,7 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
c.curBatchSize = len(opts.Positions)
c.curSequences = opts.Sequences
c.curPositions = opts.Positions
c.opts.Except = nil
var err error
c.curLoc, err = c.findStartLoc()
......@@ -234,7 +235,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {
enabled := !slices.Contains(c.opts.Except, c.curPositions[i])
enabled := !slices.Contains(c.opts.Except, i)
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
......@@ -405,15 +406,12 @@ func (c *Causal) SetLayer(layer int) {
}
type CausalOptions struct {
// Enabled controls whether the causal mask is generated for a particular position.
Except []int32
// Enabled controls whether the causal mask is generated for a particular index in a batch
Except []int
}
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
// This state carries over to future forward passes. The default value is true.
//
// ctx may be set to nil if this is called from outside of a forward pass, for
// example, when initializing the cache.
// SetCausal disables causal mask generation for a particular range of indicies in
// the current batch for subsequent calls to Get. The state resets for the next forward pass.
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
if !slices.Equal(c.opts.Except, opts.Except) {
c.opts = opts
......
......@@ -173,10 +173,10 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex, positions []int32) []int32 {
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
var embedding ml.Tensor
var src, dst, length int
var except []int32
var except []int
for _, image := range multimodal {
imageToken := image.Multimodal.(imageToken)
......@@ -204,7 +204,7 @@ func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []inpu
length = 1
}
except = append(except, positions[imageDst])
except = append(except, imageDst)
}
if embedding != nil {
......@@ -219,7 +219,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal, opts.Positions)
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment