Commit b9614a3e authored by Baber's avatar Baber
Browse files

nit

parent 352127ae
...@@ -172,7 +172,10 @@ def get_dataset(pretrained, seq=None, **kwargs): ...@@ -172,7 +172,10 @@ def get_dataset(pretrained, seq=None, **kwargs):
def get_cw_dataset(**kwargs): def get_cw_dataset(**kwargs):
kwargs = kwargs.get("metadata", {}) kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in DEFAULT_SEQ_LENGTHS) df = (
get_dataset(pretrained, seq=seq)
for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
)
return { return {
"test": datasets.Dataset.from_list( "test": datasets.Dataset.from_list(
......
...@@ -159,7 +159,10 @@ def get_dataset(pretrained, max_seq_length=None, **kwargs): ...@@ -159,7 +159,10 @@ def get_dataset(pretrained, max_seq_length=None, **kwargs):
def fwe_download(**kwargs): def fwe_download(**kwargs):
kwargs = kwargs.get("metadata", {}) kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, max_seq_length=seq) for seq in DEFAULT_SEQ_LENGTHS) df = (
get_dataset(pretrained, max_seq_length=seq)
for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
)
return { return {
"test": datasets.Dataset.from_list( "test": datasets.Dataset.from_list(
......
...@@ -223,7 +223,7 @@ def get_qa_dataset(ds, **kwargs) -> dict[str, datasets.Dataset]: ...@@ -223,7 +223,7 @@ def get_qa_dataset(ds, **kwargs) -> dict[str, datasets.Dataset]:
qas, docs = read_hotpotqa() qas, docs = read_hotpotqa()
df = ( df = (
get_dataset(pretrained=pretrained, docs=docs, qas=qas, max_seq_length=seq) get_dataset(pretrained=pretrained, docs=docs, qas=qas, max_seq_length=seq)
for seq in DEFAULT_SEQ_LENGTHS for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
) )
return { return {
......
...@@ -239,7 +239,10 @@ def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]: ...@@ -239,7 +239,10 @@ def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]:
def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]: def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]:
kwargs = kwargs.get("metadata", {}) kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in DEFAULT_SEQ_LENGTHS) df = (
get_dataset(pretrained, seq=seq)
for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
)
return { return {
"test": datasets.Dataset.from_list( "test": datasets.Dataset.from_list(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment