Unverified Commit 110a6598 authored by datdo-msft's avatar datdo-msft Committed by GitHub
Browse files

[MTP] Force greedy sampling on AMD (#9127)

parent 49f9d025
...@@ -49,6 +49,8 @@ SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial") ...@@ -49,6 +49,8 @@ SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
@dataclass @dataclass
class EagleDraftInput: class EagleDraftInput:
...@@ -423,8 +425,15 @@ class EagleVerifyInput: ...@@ -423,8 +425,15 @@ class EagleVerifyInput:
logits=logits_output.next_token_logits, vocab_mask=vocab_mask logits=logits_output.next_token_logits, vocab_mask=vocab_mask
) )
# Sample tokens # Sample tokens. Force greedy sampling on AMD
if batch.sampling_info.is_all_greedy: is_all_greedy = sampling_info.is_all_greedy
if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
logger.warning(
"Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
"Falling back to greedy verification."
)
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1) target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
target_predict = target_predict.reshape(bs, self.draft_token_num) target_predict = target_predict.reshape(bs, self.draft_token_num)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment