"...git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "031a4157e2d2b7aed4f88b8543c5689dfbbf53f9"
Commit 26acdcf4 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

runner.go: Don't set cross attention before sending embeddings

Currently if an input has embeddings at any point then we will set
cross attention to true from the beginning. This means that any
tokens before the embeddings are sent will incorrectly have cross
attention layers applied.

This only sets cross attention when we have an embedding, either
previously in this sequence or in the cache. It also makes cross
attention capable of supporting parallelism at the runner level,
though the mllama implementation doesn't support that yet.
parent 921779bb
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"hash/maphash" "hash/maphash"
"log/slog" "log/slog"
"slices"
"sync" "sync"
"time" "time"
...@@ -96,6 +97,16 @@ func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int { ...@@ -96,6 +97,16 @@ func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
} }
} }
func (c *ImageContext) NeedCrossAttention(inputs ...input) bool {
if c == nil || c.mllama == nil {
return false
}
return slices.ContainsFunc(inputs, func(input input) bool {
return input.embed != nil
})
}
type imageCache struct { type imageCache struct {
key uint64 key uint64
val [][]float32 val [][]float32
......
...@@ -52,6 +52,10 @@ type Sequence struct { ...@@ -52,6 +52,10 @@ type Sequence struct {
// input cache being used by this sequence // input cache being used by this sequence
cache *InputCacheSlot cache *InputCacheSlot
// does this sequence require cross-attention layers to be processed? - if we have seen
// an image for certain multi-modal models
crossAttention bool
// channel to send responses over // channel to send responses over
responses chan string responses chan string
...@@ -287,7 +291,6 @@ func flushPending(seq *Sequence) bool { ...@@ -287,7 +291,6 @@ func flushPending(seq *Sequence) bool {
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason string) {
seq := s.seqs[seqIndex] seq := s.seqs[seqIndex]
s.lc.SetCrossAttention(false)
flushPending(seq) flushPending(seq)
seq.doneReason = reason seq.doneReason = reason
close(seq.responses) close(seq.responses)
...@@ -334,6 +337,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -334,6 +337,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
defer s.mu.Unlock() defer s.mu.Unlock()
var batch *llama.Batch var batch *llama.Batch
crossAttention := false
seqIdx := s.nextSeq - 1 seqIdx := s.nextSeq - 1
for range s.seqs { for range s.seqs {
...@@ -367,8 +371,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -367,8 +371,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
batch = tokenBatch batch = tokenBatch
} else { } else {
batch = embedBatch batch = embedBatch
seq.crossAttention = s.image.NeedCrossAttention(input)
} }
} else if embedding != batch.IsEmbedding() { } else if embedding != batch.IsEmbedding() || crossAttention != seq.crossAttention {
s.nextSeq = seqIdx s.nextSeq = seqIdx
break break
} }
...@@ -378,6 +383,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -378,6 +383,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
break break
} }
crossAttention = seq.crossAttention
batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs)) batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs))
seq.numPast++ seq.numPast++
numInputsProcessed++ numInputsProcessed++
...@@ -394,6 +400,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -394,6 +400,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return return
} }
s.lc.SetCrossAttention(crossAttention)
err := s.lc.Decode(batch) err := s.lc.Decode(batch)
if err != nil { if err != nil {
slog.Error("failed to decode batch", "error", err) slog.Error("failed to decode batch", "error", err)
...@@ -605,13 +613,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { ...@@ -605,13 +613,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
s.mu.Lock() s.mu.Lock()
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
for _, input := range seq.inputs {
if input.embed != nil {
s.lc.SetCrossAttention(true)
break
}
}
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil { if err != nil {
s.mu.Unlock() s.mu.Unlock()
...@@ -619,6 +620,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { ...@@ -619,6 +620,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
seq.crossAttention = s.image.NeedCrossAttention(seq.cache.Inputs...)
s.seqs[i] = seq s.seqs[i] = seq
s.cond.Signal() s.cond.Signal()
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment