Unverified Commit 75ac1f47 authored by Amine Elhattami's avatar Amine Elhattami Committed by GitHub
Browse files

Fixed generation args issue affection OpenAI completion model (#1458)



* Fixed generation args issue affection openai completion model

* Fixed hf unit test; removed pop attributes in OpenAi completion.

* fix format

* fix format

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 2683fbbb
......@@ -4,6 +4,7 @@ import logging
import random
import re
from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, List, Literal, Tuple, Union
......@@ -1064,7 +1065,7 @@ class ConfigurableTask(Task):
return request_list
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, self.config.generation_kwargs)
arguments = (ctx, deepcopy(self.config.generation_kwargs))
return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
......
......@@ -261,14 +261,13 @@ class OpenaiCompletionsLM(TemplateLM):
list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size))
):
inps = []
self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks)
self._max_gen_toks = request_args.get("max_gen_toks", self.max_gen_toks)
for context, _ in chunk:
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
until = request_args.pop("until", ["<|endoftext|>"])
request_args.pop("do_sample", None)
until = request_args.get("until", ["<|endoftext|>"])
request_args["temperature"] = request_args.get("temperature", 0)
response = oa_completion(
......@@ -278,7 +277,11 @@ class OpenaiCompletionsLM(TemplateLM):
max_tokens=self.max_gen_toks,
stop=until,
seed=self.seed,
**request_args,
**{
k: v
for k, v in request_args.items()
if k not in ["do_sample", "max_gen_toks"]
},
)
for resp, (context, args_) in zip(response.choices, chunk):
s = getattr(resp, "text")
......
......@@ -22,8 +22,8 @@ class Test_HFLM:
multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
MULTIPLE_CH: list[Instance] = multiple_choice_task.instances
generate_until_task = task_list["gsm8k"] # type: ignore
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until: list[Instance] = generate_until_task.instances
rolling_task = task_list["wikitext"] # type: ignore
rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment