Unverified Commit 6a1c19ed authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

nits + fix siqa (#1216)

* fix group

* siqa: default.yml -> default.yaml

* max_gen_toks -> self.max_gen_toks

* add ids to task tests

* fix siqa

* fix gen_kwargs for openai-chat
parent f2853995
......@@ -254,6 +254,7 @@ class OpenaiCompletionsLM(LM):
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = []
self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks)
for context, _ in chunk:
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
......@@ -441,8 +442,7 @@ class OpenaiChatCompletionsLM(LM):
gen_kwargs = all_gen_kwargs[0]
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict):
if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if "until" in kwargs.keys():
......@@ -453,6 +453,8 @@ class OpenaiChatCompletionsLM(LM):
raise ValueError(
f"Expected repr(kwargs['until']) to be of type Union[str, list] but got {until}"
)
kwargs["stop"] = until
kwargs["max_tokens"] = kwargs.pop("max_gen_toks", self.max_gen_toks)
else:
raise ValueError(
f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}"
......
......@@ -6,8 +6,11 @@ training_split: train
validation_split: validation
doc_to_text: "Q: {{context}} {{question}}\nA:"
target_delimiter: " "
doc_to_choice: ["{{answerA}}", "{{answerB}}", "{{answerC}}"]
doc_to_target: "{{label}}"
doc_to_choice:
- "{{answerA}}"
- "{{answerB}}"
- "{{answerC}}"
doc_to_target: "{{ (label|int) - 1 }}"
metric_list:
- metric: acc
aggregation: mean
......
......@@ -859,7 +859,8 @@ class Collator:
def __len__(self):
return self.size
def group(self, arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
@staticmethod
def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
"""
Groups elements of an iterable based on a provided function.
......@@ -875,8 +876,13 @@ class Collator:
for ob in arr:
try:
hashable_dict = tuple(
(key, tuple(value) if isinstance(value, list) else value)
for key, value in sorted(ob[1][1].items())
(
key,
tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
)
for key, value in sorted(fn(ob).items())
)
res[hashable_dict].append(ob)
except TypeError:
......@@ -885,7 +891,8 @@ class Collator:
return res
return res.values()
def get_chunks(self, iter, n: int = 0, fn=None):
@staticmethod
def get_chunks(_iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
......@@ -913,9 +920,9 @@ class Collator:
```
"""
arr = []
for i, x in enumerate(iter):
for i, x in enumerate(_iter):
arr.append(x)
if len(arr) == (fn(i, iter) if fn else n):
if len(arr) == (fn(i, _iter) if fn else n):
yield arr
arr = []
......
......@@ -30,7 +30,7 @@ def limit() -> int:
# Tests
@pytest.mark.parametrize("task_class", task_class())
@pytest.mark.parametrize("task_class", task_class(), ids=lambda x: f"{x.config.task}")
class TestNewTasks:
def test_download(self, task_class: ConfigurableTask):
task_class.download()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment