Unverified Commit 1fa020c5 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[V1][BugFix] Fix Generator construction in greedy + seed case (#10097)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent e7b84c39
...@@ -146,7 +146,7 @@ class GPUModelRunner: ...@@ -146,7 +146,7 @@ class GPUModelRunner:
for req_data in scheduler_output.scheduled_new_reqs: for req_data in scheduler_output.scheduled_new_reqs:
req_id = req_data.req_id req_id = req_data.req_id
sampling_params = req_data.sampling_params sampling_params = req_data.sampling_params
if sampling_params.seed is not None: if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device) generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed) generator.manual_seed(sampling_params.seed)
else: else:
...@@ -382,7 +382,8 @@ class GPUModelRunner: ...@@ -382,7 +382,8 @@ class GPUModelRunner:
# Rewind the generator state as if the token was not sampled. # Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i) generator = self.input_batch.generators.get(i)
if generator is not None: if generator is not None:
generator.set_offset(generator.get_offset() - 1) # This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
if sampler_output.logprob_token_ids is None: if sampler_output.logprob_token_ids is None:
logprob_token_ids = None logprob_token_ids = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment