Unverified Commit c9e2d644 authored by Yu-Zhou's avatar Yu-Zhou Committed by GitHub
Browse files

[Hardware][Gaudi][Bugfix] Fix error for guided decoding (#12317)

parent 7734e9a2
...@@ -32,6 +32,8 @@ from outlines_core.fsm.json_schema import build_regex_from_schema ...@@ -32,6 +32,8 @@ from outlines_core.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.platforms import current_platform
class BaseLogitsProcessor: class BaseLogitsProcessor:
...@@ -91,7 +93,14 @@ class BaseLogitsProcessor: ...@@ -91,7 +93,14 @@ class BaseLogitsProcessor:
allowed_tokens = allowed_tokens.masked_select( allowed_tokens = allowed_tokens.masked_select(
allowed_tokens < scores.shape[-1]) allowed_tokens < scores.shape[-1])
mask.index_fill_(0, allowed_tokens, 0) mask.index_fill_(0, allowed_tokens, 0)
scores.add_(mask) if current_platform.is_hpu():
# Workaround for HPU bug where add_() raise RuntimeError:
# synNodeCreateWithId failed for node: strided_insert
# with synStatus 1 [Invalid argument], hopefully it will
# be fixed in the future releases of the HPU runtime.
scores = scores.add(mask)
else:
scores.add_(mask)
return scores return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment