Unverified Commit 7bdb42b2 authored by Zhang Xiangze's avatar Zhang Xiangze Committed by GitHub
Browse files

[CPU]Avoid repeated random sample compile (#28260)


Signed-off-by: default avatarZhang Xiangze <Xiangze.Zhang@arm.com>
parent 315068eb
......@@ -127,15 +127,6 @@ class TopKTopPSampler(nn.Module):
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@torch.compile(dynamic=True)
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
return probs.div(q).argmax(dim=-1).view(-1)
if len(generators) != logits.shape[0]:
return compiled_random_sample(logits), logits_to_return
else:
......@@ -148,6 +139,16 @@ class TopKTopPSampler(nn.Module):
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@torch.compile(dynamic=True)
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
return probs.div(q).argmax(dim=-1).view(-1)
def apply_top_k_top_p(
logits: torch.Tensor,
k: torch.Tensor | None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment