Unverified Commit 060e8761 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

OpenAI ChatCompletions: switch `max_tokens` (#2443)

* switch `max_tokens` for `max_completion_tokens`. OpenAI ChatCompletions

* remove stop, temp=1 for o1

* add chat assertion

* HF_DATASETS_TRUST_REMOTE_CODE = True for task tests

* move warning
parent 4155ec7f
......@@ -126,6 +126,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
seed=1234,
**kwargs,
) -> dict:
assert (
type(messages) is not str
), "chat-completions require the --apply_chat_template flag."
gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens")
......@@ -217,6 +220,10 @@ class OpenAIChatCompletion(LocalChatCompletion):
tokenized_requests=False,
**kwargs,
):
if "o1" in kwargs.get("model", ""):
eval_logger.warning(
"o1 models do not support `stop` and only support temperature=1"
)
super().__init__(
base_url=base_url,
tokenizer_backend=tokenizer_backend,
......@@ -238,3 +245,37 @@ class OpenAIChatCompletion(LocalChatCompletion):
raise NotImplementedError(
"Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation."
)
def _create_payload(
self,
messages: List[Dict],
generate=False,
gen_kwargs: dict = None,
seed=1234,
**kwargs,
) -> dict:
assert (
type(messages) is not str
), "chat-completions require the --apply_chat_template flag."
gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens")
else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["<|endoftext|>"])
if not isinstance(stop, (list, tuple)):
stop = [stop]
output = {
"messages": messages,
"model": self.model,
"max_completion_tokens": max_tokens,
"temperature": temperature,
"stop": stop[:4],
"seed": seed,
**gen_kwargs,
}
if "o1" in self.model:
output.pop("stop")
output["temperature"] = 1
return output
import os
from itertools import islice
import datasets
import pytest
import lm_eval.tasks as tasks
......@@ -10,6 +11,7 @@ from lm_eval.evaluator_utils import get_task_list
from .utils import new_tasks
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
os.environ["TOKENIZERS_PARALLELISM"] = "false"
task_manager = tasks.TaskManager()
# Default Task
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment