Unverified Commit 75ac1f47 authored by Amine Elhattami's avatar Amine Elhattami Committed by GitHub
Browse files

Fixed generation args issue affection OpenAI completion model (#1458)



* Fixed generation args issue affection openai completion model

* Fixed hf unit test; removed pop attributes in OpenAi completion.

* fix format

* fix format

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 2683fbbb
...@@ -4,6 +4,7 @@ import logging ...@@ -4,6 +4,7 @@ import logging
import random import random
import re import re
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import Any, List, Literal, Tuple, Union from typing import Any, List, Literal, Tuple, Union
...@@ -1064,7 +1065,7 @@ class ConfigurableTask(Task): ...@@ -1064,7 +1065,7 @@ class ConfigurableTask(Task):
return request_list return request_list
elif self.OUTPUT_TYPE == "generate_until": elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, self.config.generation_kwargs) arguments = (ctx, deepcopy(self.config.generation_kwargs))
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
......
...@@ -261,14 +261,13 @@ class OpenaiCompletionsLM(TemplateLM): ...@@ -261,14 +261,13 @@ class OpenaiCompletionsLM(TemplateLM):
list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size)) list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size))
): ):
inps = [] inps = []
self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks) self._max_gen_toks = request_args.get("max_gen_toks", self.max_gen_toks)
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp) inps.append(inp)
until = request_args.pop("until", ["<|endoftext|>"]) until = request_args.get("until", ["<|endoftext|>"])
request_args.pop("do_sample", None)
request_args["temperature"] = request_args.get("temperature", 0) request_args["temperature"] = request_args.get("temperature", 0)
response = oa_completion( response = oa_completion(
...@@ -278,7 +277,11 @@ class OpenaiCompletionsLM(TemplateLM): ...@@ -278,7 +277,11 @@ class OpenaiCompletionsLM(TemplateLM):
max_tokens=self.max_gen_toks, max_tokens=self.max_gen_toks,
stop=until, stop=until,
seed=self.seed, seed=self.seed,
**request_args, **{
k: v
for k, v in request_args.items()
if k not in ["do_sample", "max_gen_toks"]
},
) )
for resp, (context, args_) in zip(response.choices, chunk): for resp, (context, args_) in zip(response.choices, chunk):
s = getattr(resp, "text") s = getattr(resp, "text")
......
...@@ -22,8 +22,8 @@ class Test_HFLM: ...@@ -22,8 +22,8 @@ class Test_HFLM:
multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1) multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
MULTIPLE_CH: list[Instance] = multiple_choice_task.instances MULTIPLE_CH: list[Instance] = multiple_choice_task.instances
generate_until_task = task_list["gsm8k"] # type: ignore generate_until_task = task_list["gsm8k"] # type: ignore
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until_task._config.generation_kwargs["max_gen_toks"] = 10 generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until: list[Instance] = generate_until_task.instances generate_until: list[Instance] = generate_until_task.instances
rolling_task = task_list["wikitext"] # type: ignore rolling_task = task_list["wikitext"] # type: ignore
rolling_task.build_all_requests(limit=10, rank=0, world_size=1) rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment