Unverified Commit 081deb8b authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

manage default (greedy) gen_kwargs in vllm (#1341)

* manage default (greedy) gen_kwargs in vllm better

* mirror HF `do_sample`

* just need to set temp=0 for greedy
parent 969b48bf
......@@ -170,14 +170,8 @@ class VLLM(LM):
stop: Optional[List[str]] = None,
**kwargs,
):
if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if generate:
# hf defaults
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
kwargs["spaces_between_special_tokens"] = kwargs.get(
"spaces_between_special_tokens", False
)
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
sampling_params = SamplingParams(
......@@ -438,3 +432,16 @@ class VLLM(LM):
break
return continuation_logprobs, is_greedy
@staticmethod
def modify_gen_kwargs(kwargs: dict) -> dict:
# sampling_params
do_sample = kwargs.pop("do_sample", False)
if do_sample is not True:
kwargs["temperature"] = 0.0
# hf defaults
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
kwargs["spaces_between_special_tokens"] = kwargs.get(
"spaces_between_special_tokens", False
)
return kwargs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment