Commit 5b3393b6 authored by Michael Yang's avatar Michael Yang
Browse files

fix(mllama): sync backend between batches

parent c2e8cbaa
...@@ -598,6 +598,10 @@ func (c *Context) SetCrossAttention(state bool) { ...@@ -598,6 +598,10 @@ func (c *Context) SetCrossAttention(state bool) {
C.llama_set_cross_attention(c.c, C.bool(state)) C.llama_set_cross_attention(c.c, C.bool(state))
} }
func (c *Context) Synchronize() {
C.llama_synchronize(c.c)
}
// sampling // sampling
// TODO: this is a temporary wrapper to allow calling C++ code from CGo // TODO: this is a temporary wrapper to allow calling C++ code from CGo
type SamplingContext struct { type SamplingContext struct {
......
...@@ -427,6 +427,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -427,6 +427,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return return
} }
if crossAttention {
// synchronize state to ensure the cross attention batch is complete.
// needed specifically for multi-GPU systems otherwise an inflight
// task may be incorrectly invalidated causing a crash
s.lc.Synchronize()
}
for i, seq := range s.seqs { for i, seq := range s.seqs {
if seq == nil { if seq == nil {
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment