Commit 9d2a20a7 authored by Michael Yang's avatar Michael Yang
Browse files

use non-causal mask for inputs with images

parent 2e54d72f
...@@ -181,6 +181,11 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor ...@@ -181,6 +181,11 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
visionOutputs := multimodal[0].Multimodal.(ml.Tensor) visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
offset := multimodal[0].Index - 1 - visionOutputs.Dim(1) offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1)) hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, false)
defer causal.SetCausal(ctx, true)
}
} }
for i, layer := range m.Layers { for i, layer := range m.Layers {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment