Commit a1aa84cc authored by Casper's avatar Casper
Browse files

Merge remote-tracking branch 'upstream/main' into refactor-models

parents 2a243105 f3a90b77
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from datasets.builder import DatasetGenerationError
def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512): def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512):
if data == "pileval": if data == "pileval":
try: dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
dataset = load_dataset("json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train")
except DatasetGenerationError:
print('The Pile URL is down, using wikitext-103-raw-v1 instead')
dataset = load_dataset('wikitext', "wikitext-103-raw-v1", split="train")
else: else:
raise NotImplementedError raise NotImplementedError
dataset = dataset.shuffle(seed=42) dataset = dataset.shuffle(seed=42)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment