"vscode:/vscode.git/clone" did not exist on "8f80b7116aa1e9ded2ad4e3d5bf64755ba6ccbd9"
Unverified Commit 6567e137 authored by Travis Johnson's avatar Travis Johnson Committed by GitHub
Browse files

[Bugfix] Fix crash with llama 3.2 vision models and guided decoding (#9631)


Signed-off-by: default avatarTravis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: default avatarpavlo-ruban <pavlo.ruban@servicenow.com>
Co-authored-by: default avatarNick Hill <nickhill@us.ibm.com>
parent 228cfbd0
......@@ -15,11 +15,11 @@
# limitations under the License.
import copy
import json
import math
from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union
import numpy as np
import torch
from lark import Lark
from outlines import grammars
......@@ -77,9 +77,17 @@ class BaseLogitsProcessor:
f"Unsupported instruction type {type(instruction)}")
mask = torch.full((scores.shape[-1], ),
-math.inf,
-torch.inf,
device=scores.device)
mask[allowed_tokens] = 0
# The tokenizer may support more token ids than the model can generate,
# eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256
# but scores.shape == torch.Size([128256])
# Using NumPy is faster for filtering token ids
allowed_tokens = np.array(allowed_tokens, dtype=np.int64)
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
allowed_tokens = allowed_tokens.masked_select(
allowed_tokens < scores.shape[-1])
mask.index_fill_(0, allowed_tokens, 0)
scores.add_(mask)
return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment