Commit b9614a3e authored by Baber's avatar Baber
Browse files

nit

parent 352127ae
......@@ -172,7 +172,10 @@ def get_dataset(pretrained, seq=None, **kwargs):
def get_cw_dataset(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in DEFAULT_SEQ_LENGTHS)
df = (
get_dataset(pretrained, seq=seq)
for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
)
return {
"test": datasets.Dataset.from_list(
......
......@@ -159,7 +159,10 @@ def get_dataset(pretrained, max_seq_length=None, **kwargs):
def fwe_download(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, max_seq_length=seq) for seq in DEFAULT_SEQ_LENGTHS)
df = (
get_dataset(pretrained, max_seq_length=seq)
for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
)
return {
"test": datasets.Dataset.from_list(
......
......@@ -223,7 +223,7 @@ def get_qa_dataset(ds, **kwargs) -> dict[str, datasets.Dataset]:
qas, docs = read_hotpotqa()
df = (
get_dataset(pretrained=pretrained, docs=docs, qas=qas, max_seq_length=seq)
for seq in DEFAULT_SEQ_LENGTHS
for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
)
return {
......
......@@ -239,7 +239,10 @@ def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]:
def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]:
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in DEFAULT_SEQ_LENGTHS)
df = (
get_dataset(pretrained, seq=seq)
for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
)
return {
"test": datasets.Dataset.from_list(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment