Unverified Commit 6a1c19ed authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

nits + fix siqa (#1216)

* fix group

* siqa: default.yml -> default.yaml

* max_gen_toks -> self.max_gen_toks

* add ids to task tests

* fix siqa

* fix gen_kwargs for openai-chat
parent f2853995
...@@ -254,6 +254,7 @@ class OpenaiCompletionsLM(LM): ...@@ -254,6 +254,7 @@ class OpenaiCompletionsLM(LM):
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)) list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
): ):
inps = [] inps = []
self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks)
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
...@@ -441,8 +442,7 @@ class OpenaiChatCompletionsLM(LM): ...@@ -441,8 +442,7 @@ class OpenaiChatCompletionsLM(LM):
gen_kwargs = all_gen_kwargs[0] gen_kwargs = all_gen_kwargs[0]
until = None until = None
if isinstance(gen_kwargs, dict): if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "do_sample" in kwargs.keys(): if "do_sample" in kwargs.keys():
kwargs.pop("do_sample") kwargs.pop("do_sample")
if "until" in kwargs.keys(): if "until" in kwargs.keys():
...@@ -453,6 +453,8 @@ class OpenaiChatCompletionsLM(LM): ...@@ -453,6 +453,8 @@ class OpenaiChatCompletionsLM(LM):
raise ValueError( raise ValueError(
f"Expected repr(kwargs['until']) to be of type Union[str, list] but got {until}" f"Expected repr(kwargs['until']) to be of type Union[str, list] but got {until}"
) )
kwargs["stop"] = until
kwargs["max_tokens"] = kwargs.pop("max_gen_toks", self.max_gen_toks)
else: else:
raise ValueError( raise ValueError(
f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}" f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}"
......
...@@ -6,8 +6,11 @@ training_split: train ...@@ -6,8 +6,11 @@ training_split: train
validation_split: validation validation_split: validation
doc_to_text: "Q: {{context}} {{question}}\nA:" doc_to_text: "Q: {{context}} {{question}}\nA:"
target_delimiter: " " target_delimiter: " "
doc_to_choice: ["{{answerA}}", "{{answerB}}", "{{answerC}}"] doc_to_choice:
doc_to_target: "{{label}}" - "{{answerA}}"
- "{{answerB}}"
- "{{answerC}}"
doc_to_target: "{{ (label|int) - 1 }}"
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
...@@ -859,7 +859,8 @@ class Collator: ...@@ -859,7 +859,8 @@ class Collator:
def __len__(self): def __len__(self):
return self.size return self.size
def group(self, arr: Iterable, fn: Callable, values: bool = False) -> Iterable: @staticmethod
def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
""" """
Groups elements of an iterable based on a provided function. Groups elements of an iterable based on a provided function.
...@@ -875,8 +876,13 @@ class Collator: ...@@ -875,8 +876,13 @@ class Collator:
for ob in arr: for ob in arr:
try: try:
hashable_dict = tuple( hashable_dict = tuple(
(key, tuple(value) if isinstance(value, list) else value) (
for key, value in sorted(ob[1][1].items()) key,
tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
)
for key, value in sorted(fn(ob).items())
) )
res[hashable_dict].append(ob) res[hashable_dict].append(ob)
except TypeError: except TypeError:
...@@ -885,7 +891,8 @@ class Collator: ...@@ -885,7 +891,8 @@ class Collator:
return res return res
return res.values() return res.values()
def get_chunks(self, iter, n: int = 0, fn=None): @staticmethod
def get_chunks(_iter, n: int = 0, fn=None):
""" """
Divides an iterable into chunks of specified size or based on a given function. Divides an iterable into chunks of specified size or based on a given function.
Useful for batching Useful for batching
...@@ -913,9 +920,9 @@ class Collator: ...@@ -913,9 +920,9 @@ class Collator:
``` ```
""" """
arr = [] arr = []
for i, x in enumerate(iter): for i, x in enumerate(_iter):
arr.append(x) arr.append(x)
if len(arr) == (fn(i, iter) if fn else n): if len(arr) == (fn(i, _iter) if fn else n):
yield arr yield arr
arr = [] arr = []
......
...@@ -30,7 +30,7 @@ def limit() -> int: ...@@ -30,7 +30,7 @@ def limit() -> int:
# Tests # Tests
@pytest.mark.parametrize("task_class", task_class()) @pytest.mark.parametrize("task_class", task_class(), ids=lambda x: f"{x.config.task}")
class TestNewTasks: class TestNewTasks:
def test_download(self, task_class: ConfigurableTask): def test_download(self, task_class: ConfigurableTask):
task_class.download() task_class.download()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment