Unverified Commit d01b9214 authored by Alex Chi Z's avatar Alex Chi Z Committed by GitHub
Browse files

fix sampling_seed handling when deterministic is enabled (#11096)


Signed-off-by: default avatarAlex Chi <iskyzh@gmail.com>
parent c70e58e8
...@@ -142,6 +142,9 @@ class SamplingParams: ...@@ -142,6 +142,9 @@ class SamplingParams:
f"logit_bias must has keys in [0, {vocab_size - 1}], got " f"logit_bias must has keys in [0, {vocab_size - 1}], got "
f"{token_id}." f"{token_id}."
) )
if self.sampling_seed is None:
raise ValueError("sampling_seed should not be None")
grammars = [ grammars = [
self.json_schema, self.json_schema,
self.regex, self.regex,
......
...@@ -96,12 +96,15 @@ def send_single( ...@@ -96,12 +96,15 @@ def send_single(
"max_new_tokens": args.max_new_tokens, "max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty, "frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty, "presence_penalty": args.presence_penalty,
"sampling_seed": args.sampling_seed,
}, },
"return_logprob": args.return_logprob, "return_logprob": args.return_logprob,
"stream": args.stream, "stream": args.stream,
} }
if args.sampling_seed is not None:
# sglang server cannot parse None value for sampling_seed
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
if profile: if profile:
run_profile( run_profile(
base_url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage base_url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
...@@ -145,12 +148,14 @@ def send_mixed(args, batch_size: int): ...@@ -145,12 +148,14 @@ def send_mixed(args, batch_size: int):
"max_new_tokens": args.max_new_tokens, "max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty, "frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty, "presence_penalty": args.presence_penalty,
"sampling_seed": args.sampling_seed,
}, },
"return_logprob": args.return_logprob, "return_logprob": args.return_logprob,
"stream": args.stream, "stream": args.stream,
} }
if args.sampling_seed is not None:
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
response = requests.post( response = requests.post(
f"http://{args.host}:{args.port}/generate", f"http://{args.host}:{args.port}/generate",
json=json_data, json=json_data,
...@@ -192,12 +197,14 @@ def send_prefix(args, batch_size: int, prompts: List[str]): ...@@ -192,12 +197,14 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
"max_new_tokens": args.max_new_tokens, "max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty, "frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty, "presence_penalty": args.presence_penalty,
"sampling_seed": args.sampling_seed,
}, },
"return_logprob": args.return_logprob, "return_logprob": args.return_logprob,
"stream": args.stream, "stream": args.stream,
} }
if args.sampling_seed is not None:
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
response = requests.post( response = requests.post(
f"http://{args.host}:{args.port}/generate", f"http://{args.host}:{args.port}/generate",
json=json_data, json=json_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment