Unverified Commit a3432f18 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[BugFix][Spec Decode] Use float64 for uniform_probs (#23803)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 67cee40d
...@@ -138,7 +138,7 @@ def main(): ...@@ -138,7 +138,7 @@ def main():
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
if not args.custom_mm_prompts: if not args.custom_mm_prompts:
outputs = llm.generate( outputs = llm.generate(
TokensPrompt(prompt_token_ids=prompt_ids), [TokensPrompt(prompt_token_ids=x) for x in prompt_ids],
sampling_params=sampling_params, sampling_params=sampling_params,
) )
else: else:
......
...@@ -365,9 +365,14 @@ def generate_uniform_probs( ...@@ -365,9 +365,14 @@ def generate_uniform_probs(
A tensor of shape `(num_tokens, )` containing uniform A tensor of shape `(num_tokens, )` containing uniform
random values in the range [0, 1). random values in the range [0, 1).
""" """
# NOTE(woosuk): We deliberately use float64 instead of float32 here
# because when using float32, there's a non-negligible chance that
# uniform_prob is sampled to be exact 0.0 as reported in
# https://github.com/pytorch/pytorch/issues/16706. Using float64
# mitigates the issue.
uniform_probs = torch.rand( uniform_probs = torch.rand(
(num_tokens, ), (num_tokens, ),
dtype=torch.float32, dtype=torch.float64,
device=device, device=device,
) )
start_idx = 0 start_idx = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment