Commit e58cc441 authored by jiaruifang's avatar jiaruifang
Browse files

polish code and fix dataloader bugs

parent a4b75b78
import json import json
import os import os
from typing import Optional
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -11,26 +12,29 @@ from colossalai.registry import DATASETS ...@@ -11,26 +12,29 @@ from colossalai.registry import DATASETS
@DATASETS.register_module @DATASETS.register_module
class WebtextDataset(Dataset): class WebtextDataset(Dataset):
def __init__(self, path, seq_len=1024) -> None: def __init__(self, path: Optional[str] = None, seq_len=1024) -> None:
super().__init__() super().__init__()
root = os.path.dirname(path) if path is not None:
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') root = os.path.dirname(path)
if os.path.isfile(encoded_data_cache_path): encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path) if os.path.isfile(encoded_data_cache_path):
if seq_len_ == seq_len: seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
self.data = data if seq_len_ == seq_len:
self.attention_mask = attention_mask self.data = data
return self.attention_mask = attention_mask
raw_data = [] return
with open(path) as f: raw_data = []
for line in f.readlines(): with open(path) as f:
raw_data.append(json.loads(line)['text']) for line in f.readlines():
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') raw_data.append(json.loads(line)['text'])
tokenizer.pad_token = tokenizer.unk_token tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') tokenizer.pad_token = tokenizer.unk_token
self.data = encoded_data['input_ids'] encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
self.attention_mask = encoded_data['attention_mask'] self.data = encoded_data['input_ids']
torch.save((seq_len, self.data, self.attention_mask), encoded_data_cache_path) self.attention_mask = encoded_data['attention_mask']
else:
self.data = torch.randint(0, 50257, (10240, seq_len))
self.attention_mask = torch.ones_like(self.data)
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
......
export DATA=/data/scratch/gpt_data/small-gpt-dataset.json export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch DUMMY_DATA=--use_dummy_dataset
colossalai run --nproc_per_node=2 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch $DUMMY_DATA
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from dataset.webtext import WebtextDataset
from titans.model.gpt import GPTLMLoss from titans.model.gpt import GPTLMLoss
import colossalai import colossalai
...@@ -39,52 +40,16 @@ def main(): ...@@ -39,52 +40,16 @@ def main():
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42) colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
logger = get_dist_logger() logger = get_dist_logger()
if not args.use_dummy_dataset: data_path = None if args.use_dummy_dataset else os.environ['DATA']
data_path = os.environ['DATA'] logger.info(f'Build data loader from path {data_path}', ranks=[0])
logger.info(f'Build data loader from path {data_path}', ranks=[0])
from dataset.webtext import WebtextDataset
train_ds = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LEN)
train_dataloader = utils.get_dataloader(train_ds,
seed=42,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
shuffle=True,
drop_last=True)
else:
# build a dummy train_dataloader
logger.info('Build data loader using dummy data', ranks=[0])
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
# 10 iterations
input_ids, attn_mask = get_data(gpc.config.BATCH_SIZE * 10, gpc.config.SEQ_LEN, VOCAB_SIZE)
from torch.utils.data import DataLoader, Dataset
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
def cycle(loader):
while True:
for data in loader:
yield data
train_dataset = TextSamplerDataset(input_ids, gpc.config.SEQ_LEN) train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN)
train_dataloader = DataLoader(train_dataset, batch_size=gpc.config.BATCH_SIZE) train_dataloader = utils.get_dataloader(train_ds,
seed=42,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
shuffle=True,
drop_last=True)
logger.info('Build model', ranks=[0]) logger.info('Build model', ranks=[0])
use_pipeline = is_using_pp() use_pipeline = is_using_pp()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment