Unverified Commit 63e76e89 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

[Bugfix] add temperature=0 to logprobs and seed args to API models (#2149)

* add temperature for log probs

* add seed

* nit

* add new args to test

* added warning for api chat models
parent b70af4f5
...@@ -160,6 +160,7 @@ class TemplateAPI(TemplateLM): ...@@ -160,6 +160,7 @@ class TemplateAPI(TemplateLM):
*, *,
generate: bool = True, generate: bool = True,
gen_kwargs: Optional[dict] = None, gen_kwargs: Optional[dict] = None,
seed: int = 1234,
**kwargs, **kwargs,
) -> dict: ) -> dict:
"""This method is responsible for creating the json payload that will be sent to the API.""" """This method is responsible for creating the json payload that will be sent to the API."""
...@@ -334,6 +335,7 @@ class TemplateAPI(TemplateLM): ...@@ -334,6 +335,7 @@ class TemplateAPI(TemplateLM):
self.create_message(messages), self.create_message(messages),
generate=generate, generate=generate,
gen_kwargs=gen_kwargs, gen_kwargs=gen_kwargs,
seed=self._seed,
**kwargs, **kwargs,
), ),
headers=self.header, headers=self.header,
...@@ -367,6 +369,7 @@ class TemplateAPI(TemplateLM): ...@@ -367,6 +369,7 @@ class TemplateAPI(TemplateLM):
self.create_message(messages), self.create_message(messages),
generate=generate, generate=generate,
gen_kwargs=gen_kwargs, gen_kwargs=gen_kwargs,
seed=self._seed,
**kwargs, **kwargs,
) )
cache_method = "generate_until" if generate else "loglikelihood" cache_method = "generate_until" if generate else "loglikelihood"
......
...@@ -24,6 +24,7 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -24,6 +24,7 @@ class LocalCompletionsAPI(TemplateAPI):
messages: Union[List[List[int]], List[dict], List[str], str], messages: Union[List[List[int]], List[dict], List[str], str],
generate=False, generate=False,
gen_kwargs: Optional[dict] = None, gen_kwargs: Optional[dict] = None,
seed: int = 1234,
**kwargs, **kwargs,
) -> dict: ) -> dict:
if generate: if generate:
...@@ -37,14 +38,17 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -37,14 +38,17 @@ class LocalCompletionsAPI(TemplateAPI):
"max_tokens": max_tokens, "max_tokens": max_tokens,
"temperature": temperature, "temperature": temperature,
"stop": stop, "stop": stop,
"seed": seed,
**gen_kwargs, **gen_kwargs,
} }
else: else:
return { return {
"model": self.model, "model": self.model,
"prompt": messages, "prompt": messages,
"temperature": 0,
"max_tokens": 1, "max_tokens": 1,
"logprobs": 1, "logprobs": 1,
"seed": seed,
"echo": True, "echo": True,
} }
...@@ -96,6 +100,9 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -96,6 +100,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
tokenized_requests=False, tokenized_requests=False,
**kwargs, **kwargs,
): ):
eval_logger.warning(
"chat-completions endpoint requires the `--apply_chat_template` flag."
)
super().__init__( super().__init__(
base_url=base_url, base_url=base_url,
tokenizer_backend=tokenizer_backend, tokenizer_backend=tokenizer_backend,
...@@ -109,7 +116,12 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -109,7 +116,12 @@ class LocalChatCompletion(LocalCompletionsAPI):
self._batch_size = 1 self._batch_size = 1
def _create_payload( def _create_payload(
self, messages: List[Dict], generate=False, gen_kwargs: dict = None, **kwargs self,
messages: List[Dict],
generate=False,
gen_kwargs: dict = None,
seed=1234,
**kwargs,
) -> dict: ) -> dict:
gen_kwargs.pop("do_sample", False) gen_kwargs.pop("do_sample", False)
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
...@@ -123,6 +135,7 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -123,6 +135,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
"max_tokens": max_tokens, "max_tokens": max_tokens,
"temperature": temperature, "temperature": temperature,
"stop": stop[:4], "stop": stop[:4],
"seed": seed,
**gen_kwargs, **gen_kwargs,
} }
......
...@@ -28,6 +28,7 @@ def test_create_payload_generate(api): ...@@ -28,6 +28,7 @@ def test_create_payload_generate(api):
"temperature": 0.7, "temperature": 0.7,
"until": ["The End"], "until": ["The End"],
"do_sample": True, "do_sample": True,
"seed": 1234,
} }
payload = api._create_payload(messages, generate=True, gen_kwargs=gen_kwargs) payload = api._create_payload(messages, generate=True, gen_kwargs=gen_kwargs)
...@@ -37,6 +38,7 @@ def test_create_payload_generate(api): ...@@ -37,6 +38,7 @@ def test_create_payload_generate(api):
"max_tokens": 100, "max_tokens": 100,
"temperature": 0.7, "temperature": 0.7,
"stop": ["The End"], "stop": ["The End"],
"seed": 1234,
} }
...@@ -50,6 +52,8 @@ def test_create_payload_loglikelihood(api): ...@@ -50,6 +52,8 @@ def test_create_payload_loglikelihood(api):
"max_tokens": 1, "max_tokens": 1,
"logprobs": 1, "logprobs": 1,
"echo": True, "echo": True,
"temperature": 0,
"seed": 1234,
} }
...@@ -66,6 +70,7 @@ def test_create_payload_loglikelihood(api): ...@@ -66,6 +70,7 @@ def test_create_payload_loglikelihood(api):
"max_tokens": 100, "max_tokens": 100,
"temperature": 0.7, "temperature": 0.7,
"stop": ["<|endoftext|>"], "stop": ["<|endoftext|>"],
"seed": 1234,
}, },
), ),
( (
...@@ -78,6 +83,7 @@ def test_create_payload_loglikelihood(api): ...@@ -78,6 +83,7 @@ def test_create_payload_loglikelihood(api):
"max_tokens": 256, "max_tokens": 256,
"temperature": 0, "temperature": 0,
"stop": ["<|endoftext|>"], "stop": ["<|endoftext|>"],
"seed": 1234,
}, },
), ),
], ],
...@@ -116,6 +122,8 @@ def test_model_generate_call_usage( ...@@ -116,6 +122,8 @@ def test_model_generate_call_usage(
"max_tokens": 1, "max_tokens": 1,
"logprobs": 1, "logprobs": 1,
"echo": True, "echo": True,
"seed": 1234,
"temperature": 0,
}, },
), ),
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment