Unverified Commit 63e76e89 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

[Bugfix] add temperature=0 to logprobs and seed args to API models (#2149)

* add temperature for log probs

* add seed

* nit

* add new args to test

* added warning for api chat models
parent b70af4f5
......@@ -160,6 +160,7 @@ class TemplateAPI(TemplateLM):
*,
generate: bool = True,
gen_kwargs: Optional[dict] = None,
seed: int = 1234,
**kwargs,
) -> dict:
"""This method is responsible for creating the json payload that will be sent to the API."""
......@@ -334,6 +335,7 @@ class TemplateAPI(TemplateLM):
self.create_message(messages),
generate=generate,
gen_kwargs=gen_kwargs,
seed=self._seed,
**kwargs,
),
headers=self.header,
......@@ -367,6 +369,7 @@ class TemplateAPI(TemplateLM):
self.create_message(messages),
generate=generate,
gen_kwargs=gen_kwargs,
seed=self._seed,
**kwargs,
)
cache_method = "generate_until" if generate else "loglikelihood"
......
......@@ -24,6 +24,7 @@ class LocalCompletionsAPI(TemplateAPI):
messages: Union[List[List[int]], List[dict], List[str], str],
generate=False,
gen_kwargs: Optional[dict] = None,
seed: int = 1234,
**kwargs,
) -> dict:
if generate:
......@@ -37,14 +38,17 @@ class LocalCompletionsAPI(TemplateAPI):
"max_tokens": max_tokens,
"temperature": temperature,
"stop": stop,
"seed": seed,
**gen_kwargs,
}
else:
return {
"model": self.model,
"prompt": messages,
"temperature": 0,
"max_tokens": 1,
"logprobs": 1,
"seed": seed,
"echo": True,
}
......@@ -96,6 +100,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
tokenized_requests=False,
**kwargs,
):
eval_logger.warning(
"chat-completions endpoint requires the `--apply_chat_template` flag."
)
super().__init__(
base_url=base_url,
tokenizer_backend=tokenizer_backend,
......@@ -109,7 +116,12 @@ class LocalChatCompletion(LocalCompletionsAPI):
self._batch_size = 1
def _create_payload(
self, messages: List[Dict], generate=False, gen_kwargs: dict = None, **kwargs
self,
messages: List[Dict],
generate=False,
gen_kwargs: dict = None,
seed=1234,
**kwargs,
) -> dict:
gen_kwargs.pop("do_sample", False)
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
......@@ -123,6 +135,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
"max_tokens": max_tokens,
"temperature": temperature,
"stop": stop[:4],
"seed": seed,
**gen_kwargs,
}
......
......@@ -28,6 +28,7 @@ def test_create_payload_generate(api):
"temperature": 0.7,
"until": ["The End"],
"do_sample": True,
"seed": 1234,
}
payload = api._create_payload(messages, generate=True, gen_kwargs=gen_kwargs)
......@@ -37,6 +38,7 @@ def test_create_payload_generate(api):
"max_tokens": 100,
"temperature": 0.7,
"stop": ["The End"],
"seed": 1234,
}
......@@ -50,6 +52,8 @@ def test_create_payload_loglikelihood(api):
"max_tokens": 1,
"logprobs": 1,
"echo": True,
"temperature": 0,
"seed": 1234,
}
......@@ -66,6 +70,7 @@ def test_create_payload_loglikelihood(api):
"max_tokens": 100,
"temperature": 0.7,
"stop": ["<|endoftext|>"],
"seed": 1234,
},
),
(
......@@ -78,6 +83,7 @@ def test_create_payload_loglikelihood(api):
"max_tokens": 256,
"temperature": 0,
"stop": ["<|endoftext|>"],
"seed": 1234,
},
),
],
......@@ -116,6 +122,8 @@ def test_model_generate_call_usage(
"max_tokens": 1,
"logprobs": 1,
"echo": True,
"seed": 1234,
"temperature": 0,
},
),
],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment