Unverified Commit cde17e73 authored by humu789's avatar humu789 Committed by GitHub
Browse files

[Fix] Fix bug for issues #141 (#145)

* fix get_dataset error

* fix lint

* add datasets to requirements.txt

* update some msci
parent 289ffa3c
......@@ -5,7 +5,7 @@ from typing import List, Tuple
import fire
import torch
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import (LlamaDecoderLayer,
LlamaForCausalLM)
......@@ -109,9 +109,11 @@ def main(model: str,
assert calib_dataset in ['c4', 'ptb', 'wikitext2', 'pileval'], \
'Currently, only support `c4`, `ptb`, `wikitext2`, or `pileval`.'
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
model = AutoModel.from_pretrained(model)
model.use_cache = True
tokenizer = AutoTokenizer.from_pretrained(model,
use_fast=False,
trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True)
model.config.use_cache = True
print('Loading calibrate dataset ...')
calib_loader, _ = get_calib_loaders(calib_dataset,
......
......@@ -8,7 +8,18 @@ def set_seed(seed):
torch.random.manual_seed(seed)
def get_wikitext2(tokenizer, nsamples, seed, seqlen, model):
def get_wikitext2(tokenizer, nsamples, seed, seqlen):
"""Load Wikitext-2 train and test datasets and tokenize.
Args:
tokenizer: Tokenizer to encode text.
nsamples: Number of samples to take from train set.
seed: Random seed for sampling.
seqlen: Maximum sequence length.
Returns:
train_loader: List of sampled and tokenized training examples.
test_enc: Full tokenized Wikitext-2 test set.
"""
from datasets import load_dataset
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
......@@ -20,7 +31,7 @@ def get_wikitext2(tokenizer, nsamples, seed, seqlen, model):
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
......@@ -29,7 +40,18 @@ def get_wikitext2(tokenizer, nsamples, seed, seqlen, model):
return trainloader, testenc
def get_ptb(tokenizer, nsamples, seed, seqlen, model):
def get_ptb(tokenizer, nsamples, seed, seqlen):
"""Load PTB train and validation datasets and tokenize.
Args:
tokenizer: Tokenizer to encode text.
nsamples: Number of samples to take from train set.
seed: Random seed for sampling.
seqlen: Maximum sequence length.
Returns:
train_loader: List of sampled and tokenized training examples.
test_enc: Full tokenized PTB validation set.
"""
from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
valdata = load_dataset('ptb_text_only',
......@@ -44,7 +66,7 @@ def get_ptb(tokenizer, nsamples, seed, seqlen, model):
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
......@@ -53,7 +75,18 @@ def get_ptb(tokenizer, nsamples, seed, seqlen, model):
return trainloader, testenc
def get_c4(tokenizer, nsamples, seed, seqlen, model):
def get_c4(tokenizer, nsamples, seed, seqlen):
"""Load C4 train and validation datasets and tokenize.
Args:
tokenizer: Tokenizer to encode text.
nsamples: Number of samples to take from train set.
seed: Random seed for sampling.
seqlen: Maximum sequence length.
Returns:
train_loader: List of sampled and tokenized training examples.
test_enc: Full tokenized PTB validation set.
"""
from datasets import load_dataset
traindata = load_dataset(
'allenai/c4',
......@@ -77,7 +110,7 @@ def get_c4(tokenizer, nsamples, seed, seqlen, model):
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
......@@ -93,7 +126,7 @@ def get_c4(tokenizer, nsamples, seed, seqlen, model):
tmp = tokenizer(valdata[i]['text'], return_tensors='pt')
if tmp.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
i = random.randint(0, tmp.input_ids.shape[1] - seqlen)
j = i + seqlen
valenc.append(tmp.input_ids[:, i:j])
valenc = torch.hstack(valenc)
......@@ -108,7 +141,18 @@ def get_c4(tokenizer, nsamples, seed, seqlen, model):
return trainloader, valenc
def get_ptb_new(tokenizer, nsamples, seed, seqlen, model):
def get_ptb_new(tokenizer, nsamples, seed, seqlen):
"""Load PTB New train and validation datasets and tokenize.
Args:
tokenizer: Tokenizer to encode text.
nsamples: Number of samples to take from train set.
seed: Random seed for sampling.
seqlen: Maximum sequence length.
Returns:
train_loader: List of sampled and tokenized training examples.
test_enc: Full tokenized PTB validation set.
"""
from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
......@@ -120,7 +164,7 @@ def get_ptb_new(tokenizer, nsamples, seed, seqlen, model):
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
......@@ -129,7 +173,18 @@ def get_ptb_new(tokenizer, nsamples, seed, seqlen, model):
return trainloader, testenc
def get_c4_new(tokenizer, nsamples, seed, seqlen, model):
def get_c4_new(tokenizer, nsamples, seed, seqlen):
"""Load C4 New train and validation datasets and tokenize.
Args:
tokenizer: Tokenizer to encode text.
nsamples: Number of samples to take from train set.
seed: Random seed for sampling.
seqlen: Maximum sequence length.
Returns:
train_loader: List of sampled and tokenized training examples.
test_enc: Full tokenized PTB validation set.
"""
from datasets import load_dataset
traindata = load_dataset(
'allenai/c4',
......@@ -151,7 +206,7 @@ def get_c4_new(tokenizer, nsamples, seed, seqlen, model):
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
......@@ -172,6 +227,17 @@ def get_c4_new(tokenizer, nsamples, seed, seqlen, model):
def get_pileval(tokenizer, nsamples, seed, seqlen=512):
"""Load pileval train dataset and tokenize.
Args:
tokenizer: Tokenizer to encode text.
nsamples: Number of samples to take from train set.
seed: Random seed for sampling.
seqlen: Maximum sequence length.
Returns:
train_loader: List of sampled and tokenized training examples.
test_enc: Full tokenized PTB validation set.
"""
from datasets import load_dataset
from datasets.builder import DatasetGenerationError
try:
......@@ -210,23 +276,29 @@ def get_pileval(tokenizer, nsamples, seed, seqlen=512):
], None
def get_calib_loaders(name,
tokenizer,
nsamples=128,
seed=0,
seqlen=2048,
model=''):
def get_calib_loaders(name, tokenizer, nsamples=128, seed=0, seqlen=2048):
"""Get calibration data loaders for a dataset.
Args:
name: Dataset name ('wikitext2', 'ptb', 'c4', etc).
tokenizer: Tokenizer to encode text.
nsamples: Number of samples to take from train set.
seed: Random seed for sampling.
seqlen: Maximum sequence length.
Returns:
train_loader: List of sampled and tokenized training examples.
test_data: Full tokenized validation set.
"""
if 'wikitext2' in name:
return get_wikitext2(tokenizer, nsamples, seed, seqlen, model)
return get_wikitext2(tokenizer, nsamples, seed, seqlen)
if 'ptb' in name:
if 'new' in name:
return get_ptb_new(tokenizer, nsamples, seed, seqlen, model)
return get_ptb(tokenizer, nsamples, seed, seqlen, model)
return get_ptb_new(tokenizer, nsamples, seed, seqlen)
return get_ptb(tokenizer, nsamples, seed, seqlen)
if 'c4' in name:
if 'new' in name:
return get_c4_new(tokenizer, nsamples, seed, seqlen, model)
return get_c4(tokenizer, nsamples, seed, seqlen, model)
return get_c4_new(tokenizer, nsamples, seed, seqlen)
return get_c4(tokenizer, nsamples, seed, seqlen)
if 'pileval' in name:
return get_pileval(tokenizer, nsamples, seed, seqlen)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment