Unverified Commit 110a6598 authored by datdo-msft's avatar datdo-msft Committed by GitHub
Browse files

[MTP] Force greedy sampling on AMD (#9127)

parent 49f9d025
......@@ -49,6 +49,8 @@ SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
@dataclass
class EagleDraftInput:
......@@ -423,8 +425,15 @@ class EagleVerifyInput:
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
)
# Sample tokens
if batch.sampling_info.is_all_greedy:
# Sample tokens. Force greedy sampling on AMD
is_all_greedy = sampling_info.is_all_greedy
if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
logger.warning(
"Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
"Falling back to greedy verification."
)
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
target_predict = target_predict.reshape(bs, self.draft_token_num)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment