Unverified Commit 6567e137 authored by Travis Johnson's avatar Travis Johnson Committed by GitHub
Browse files

[Bugfix] Fix crash with llama 3.2 vision models and guided decoding (#9631)


Signed-off-by: default avatarTravis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: default avatarpavlo-ruban <pavlo.ruban@servicenow.com>
Co-authored-by: default avatarNick Hill <nickhill@us.ibm.com>
parent 228cfbd0
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
# limitations under the License. # limitations under the License.
import copy import copy
import json import json
import math
from collections import defaultdict from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union from typing import Callable, DefaultDict, Dict, List, Union
import numpy as np
import torch import torch
from lark import Lark from lark import Lark
from outlines import grammars from outlines import grammars
...@@ -77,9 +77,17 @@ class BaseLogitsProcessor: ...@@ -77,9 +77,17 @@ class BaseLogitsProcessor:
f"Unsupported instruction type {type(instruction)}") f"Unsupported instruction type {type(instruction)}")
mask = torch.full((scores.shape[-1], ), mask = torch.full((scores.shape[-1], ),
-math.inf, -torch.inf,
device=scores.device) device=scores.device)
mask[allowed_tokens] = 0 # The tokenizer may support more token ids than the model can generate,
# eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256
# but scores.shape == torch.Size([128256])
# Using NumPy is faster for filtering token ids
allowed_tokens = np.array(allowed_tokens, dtype=np.int64)
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
allowed_tokens = allowed_tokens.masked_select(
allowed_tokens < scores.shape[-1])
mask.index_fill_(0, allowed_tokens, 0)
scores.add_(mask) scores.add_(mask)
return scores return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment