Commit af4e0622 authored by Casper Hansen's avatar Casper Hansen
Browse files

Implement fallback calibration dataset (wikitext).

parent ed347704
import torch
from datasets import load_dataset
from datasets.builder import DatasetGenerationError
def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512):
if data == "pileval":
dataset = load_dataset("json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train")
try:
dataset = load_dataset("json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train")
except DatasetGenerationError:
print('The Pile URL is down, using wikitext-103-raw-v1 instead')
dataset = load_dataset('wikitext', "wikitext-103-raw-v1", split="train")
else:
raise NotImplementedError
dataset = dataset.shuffle(seed=42)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment