Commit d832a218 authored by Casper's avatar Casper
Browse files

Move shuffling up for datasets loaded with load_dataset

parent a9cef34b
from typing import List, Union
import torch import torch
import logging import logging
from typing import List, Union
from datasets import load_dataset from datasets import load_dataset
def get_calib_dataset(data: Union[str, List[str]] = "pileval", def get_calib_dataset(data: Union[str, List[str]] = "pileval",
...@@ -11,14 +11,16 @@ def get_calib_dataset(data: Union[str, List[str]] = "pileval", ...@@ -11,14 +11,16 @@ def get_calib_dataset(data: Union[str, List[str]] = "pileval",
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else: else:
dataset = load_dataset(data, split=split) dataset = load_dataset(data, split=split)
dataset = dataset.shuffle(seed=42)
elif isinstance(data, list): elif isinstance(data, list):
dataset = [{text_column: text} for text in data] dataset = [{text_column: text} for text in data]
else: else:
raise NotImplementedError( raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list" "Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element.") "that is preprocessed with one sample of text per element.")
dataset = dataset.shuffle(seed=42)
samples = [] samples = []
n_run = 0 n_run = 0
for data in dataset: for data in dataset:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment