Unverified Commit 38c8d02f authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

modified default gen_kwargs to work better with CLI; changed prompt_logprobs=1 (#1345)

parent 081deb8b
......@@ -175,7 +175,7 @@ class VLLM(LM):
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
sampling_params = SamplingParams(
temperature=0, prompt_logprobs=2, max_tokens=1
temperature=0, prompt_logprobs=1, max_tokens=1
)
if self.data_parallel_size > 1:
requests = [list(x) for x in divide(requests, self.data_parallel_size)]
......@@ -436,8 +436,8 @@ class VLLM(LM):
@staticmethod
def modify_gen_kwargs(kwargs: dict) -> dict:
# sampling_params
do_sample = kwargs.pop("do_sample", False)
if do_sample is not True:
do_sample = kwargs.pop("do_sample", None)
if do_sample is False or "temperature" not in kwargs:
kwargs["temperature"] = 0.0
# hf defaults
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment