Commit 6ac0fcf4 authored by Pleaplusone's avatar Pleaplusone Committed by Kevin H. Luu
Browse files

[ROCm][Bugfix] Disable hip sampler to fix deepseek's accuracy issue on ROCm (#32413)


Signed-off-by: default avatarganyi <ygan@amd.com>
(cherry picked from commit 77c16df3)
parent b6224972
...@@ -174,6 +174,8 @@ class TopKTopPSampler(nn.Module): ...@@ -174,6 +174,8 @@ class TopKTopPSampler(nn.Module):
k: torch.Tensor | None, k: torch.Tensor | None,
p: torch.Tensor | None, p: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
# FIXME: Fix aiter_sampler's accuracy issue and remove this flag
DISABLE_AITER_SAMPLER = True
"""Optimized ROCm/aiter path (same structure as forward_cuda).""" """Optimized ROCm/aiter path (same structure as forward_cuda)."""
if (k is None and p is None) or generators: if (k is None and p is None) or generators:
if generators: if generators:
...@@ -186,6 +188,8 @@ class TopKTopPSampler(nn.Module): ...@@ -186,6 +188,8 @@ class TopKTopPSampler(nn.Module):
"processed_logits", "processed_logits",
"processed_logprobs", "processed_logprobs",
), "aiter sampler does not support returning logits/logprobs." ), "aiter sampler does not support returning logits/logprobs."
if DISABLE_AITER_SAMPLER:
return self.forward_native(logits, generators, k, p)
return self.aiter_sample(logits, k, p, generators), None return self.aiter_sample(logits, k, p, generators), None
def aiter_sample( def aiter_sample(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment