Commit d832a218 authored by Casper's avatar Casper
Browse files

Move shuffling up for datasets loaded with load_dataset

parent a9cef34b
from typing import List, Union
import torch
import logging
from typing import List, Union
from datasets import load_dataset
def get_calib_dataset(data: Union[str, List[str]] = "pileval",
......@@ -11,6 +11,9 @@ def get_calib_dataset(data: Union[str, List[str]] = "pileval",
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else:
dataset = load_dataset(data, split=split)
dataset = dataset.shuffle(seed=42)
elif isinstance(data, list):
dataset = [{text_column: text} for text in data]
else:
......@@ -18,7 +21,6 @@ def get_calib_dataset(data: Union[str, List[str]] = "pileval",
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element.")
dataset = dataset.shuffle(seed=42)
samples = []
n_run = 0
for data in dataset:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment