Unverified Commit 081deb8b authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

manage default (greedy) gen_kwargs in vllm (#1341)

* manage default (greedy) gen_kwargs in vllm better

* mirror HF `do_sample`

* just need to set temp=0 for greedy
parent 969b48bf
...@@ -170,14 +170,8 @@ class VLLM(LM): ...@@ -170,14 +170,8 @@ class VLLM(LM):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs, **kwargs,
): ):
if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if generate: if generate:
# hf defaults kwargs = self.modify_gen_kwargs(kwargs)
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
kwargs["spaces_between_special_tokens"] = kwargs.get(
"spaces_between_special_tokens", False
)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs) sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else: else:
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -438,3 +432,16 @@ class VLLM(LM): ...@@ -438,3 +432,16 @@ class VLLM(LM):
break break
return continuation_logprobs, is_greedy return continuation_logprobs, is_greedy
@staticmethod
def modify_gen_kwargs(kwargs: dict) -> dict:
# sampling_params
do_sample = kwargs.pop("do_sample", False)
if do_sample is not True:
kwargs["temperature"] = 0.0
# hf defaults
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
kwargs["spaces_between_special_tokens"] = kwargs.get(
"spaces_between_special_tokens", False
)
return kwargs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment