Unverified Commit e327e951 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[hotfix] gpt example titans bug #2493 (#2494)

parent d565a248
...@@ -12,11 +12,11 @@ TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE) ...@@ -12,11 +12,11 @@ TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
# if you do no want zero, just comment out this dictionary # if you do no want zero, just comment out this dictionary
zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()), zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
optimizer_config=dict(initial_scale=2**16)) optimizer_config=dict(initial_scale=2**5))
optimizer = dict( optimizer = dict(
type=HybridAdam, type=HybridAdam,
lr=0.00015, lr=0.000015,
weight_decay=1e-2, weight_decay=1e-2,
) )
......
import json
import os
from typing import Optional
import torch
from torch.utils.data import Dataset
from transformers import GPT2Tokenizer
from colossalai.registry import DATASETS
@DATASETS.register_module
class WebtextDataset(Dataset):
def __init__(self, path: Optional[str] = None, seq_len=1024) -> None:
super().__init__()
if path is not None:
root = os.path.dirname(path)
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
if os.path.isfile(encoded_data_cache_path):
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
if seq_len_ == seq_len:
self.data = data
self.attention_mask = attention_mask
return
raw_data = []
with open(path) as f:
for line in f.readlines():
raw_data.append(json.loads(line)['text'])
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.unk_token
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
self.data = encoded_data['input_ids']
self.attention_mask = encoded_data['attention_mask']
else:
self.data = torch.randint(0, 50257, (10240, seq_len))
self.attention_mask = torch.ones_like(self.data)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index]
export DATA=/data/scratch/gpt_data/small-gpt-dataset.json export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch DUMMY_DATA=--use_dummy_dataset
colossalai run --nproc_per_node=2 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch $DUMMY_DATA
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from dataset.webtext import WebtextDataset
from titans.model.gpt import GPTLMLoss from titans.model.gpt import GPTLMLoss
import colossalai import colossalai
...@@ -30,7 +31,7 @@ VOCAB_SIZE = 50257 ...@@ -30,7 +31,7 @@ VOCAB_SIZE = 50257
def main(): def main():
parser = colossalai.get_default_parser() parser = colossalai.get_default_parser()
parser.add_argument('--from_torch', default=False, action='store_true') parser.add_argument('--from_torch', default=False, action='store_true')
parser.add_argument('--use_dummy_dataset', default=True, action='store_true') parser.add_argument('--use_dummy_dataset', default=False, action='store_true')
args = parser.parse_args() args = parser.parse_args()
disable_existing_loggers() disable_existing_loggers()
if args.from_torch: if args.from_torch:
...@@ -39,52 +40,16 @@ def main(): ...@@ -39,52 +40,16 @@ def main():
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42) colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
logger = get_dist_logger() logger = get_dist_logger()
if not args.use_dummy_dataset: data_path = None if args.use_dummy_dataset else os.environ['DATA']
data_path = os.environ['DATA']
logger.info(f'Build data loader from path {data_path}', ranks=[0]) logger.info(f'Build data loader from path {data_path}', ranks=[0])
from dataset.webtext import WebtextDataset
train_ds = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LEN) train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN)
train_dataloader = utils.get_dataloader(train_ds, train_dataloader = utils.get_dataloader(train_ds,
seed=42, seed=42,
batch_size=gpc.config.BATCH_SIZE, batch_size=gpc.config.BATCH_SIZE,
pin_memory=True, pin_memory=True,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True)
else:
# build a dummy train_dataloader
logger.info('Build data loader using dummy data', ranks=[0])
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
# 10 iterations
input_ids, attn_mask = get_data(gpc.config.BATCH_SIZE * 10, gpc.config.SEQ_LEN, VOCAB_SIZE)
from torch.utils.data import DataLoader, Dataset
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
def cycle(loader):
while True:
for data in loader:
yield data
train_dataset = TextSamplerDataset(input_ids, gpc.config.SEQ_LEN)
train_dataloader = DataLoader(train_dataset, batch_size=gpc.config.BATCH_SIZE)
logger.info('Build model', ranks=[0]) logger.info('Build model', ranks=[0])
use_pipeline = is_using_pp() use_pipeline = is_using_pp()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment