"llm/vscode:/vscode.git/clone" did not exist on "a4564232a480159d97376579d08c8b36743b60d1"
Unverified Commit 38c8d02f authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

modified default gen_kwargs to work better with CLI; changed prompt_logprobs=1 (#1345)

parent 081deb8b
...@@ -175,7 +175,7 @@ class VLLM(LM): ...@@ -175,7 +175,7 @@ class VLLM(LM):
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs) sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else: else:
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, prompt_logprobs=2, max_tokens=1 temperature=0, prompt_logprobs=1, max_tokens=1
) )
if self.data_parallel_size > 1: if self.data_parallel_size > 1:
requests = [list(x) for x in divide(requests, self.data_parallel_size)] requests = [list(x) for x in divide(requests, self.data_parallel_size)]
...@@ -436,8 +436,8 @@ class VLLM(LM): ...@@ -436,8 +436,8 @@ class VLLM(LM):
@staticmethod @staticmethod
def modify_gen_kwargs(kwargs: dict) -> dict: def modify_gen_kwargs(kwargs: dict) -> dict:
# sampling_params # sampling_params
do_sample = kwargs.pop("do_sample", False) do_sample = kwargs.pop("do_sample", None)
if do_sample is not True: if do_sample is False or "temperature" not in kwargs:
kwargs["temperature"] = 0.0 kwargs["temperature"] = 0.0
# hf defaults # hf defaults
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False) kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment