Unverified Commit e5b9f9a0 authored by アマデウス's avatar アマデウス Committed by GitHub
Browse files

added gpt model & benchmark (#95)

parent 01a80cd8
...@@ -11,12 +11,7 @@ from colossalai.core import global_context as gpc ...@@ -11,12 +11,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, CrossEntropyLoss from colossalai.nn import Accuracy, CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer from colossalai.trainer import Trainer, hooks
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook,
LogMetricByEpochHook,
LogMetricByStepHook,
LogTimingByEpochHook, LossHook,
LRSchedulerHook, ThroughputHook)
from colossalai.utils import MultiTimer, get_dataloader from colossalai.utils import MultiTimer, get_dataloader
from model_zoo.vit import vit_lite_depth7_patch4_32 from model_zoo.vit import vit_lite_depth7_patch4_32
from torchvision import transforms from torchvision import transforms
...@@ -100,22 +95,22 @@ def train_cifar(): ...@@ -100,22 +95,22 @@ def train_cifar():
trainer = Trainer(engine=engine, logger=logger, timer=timer) trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0]) logger.info("Trainer is built", ranks=[0])
hooks = [ hook_list = [
LogMetricByEpochHook(logger=logger), hooks.LogMetricByEpochHook(logger=logger),
LogMetricByStepHook(), hooks.LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger), # hooks.LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger), # hooks.LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy()), hooks.AccuracyHook(accuracy_func=Accuracy()),
LossHook(), hooks.LossHook(),
ThroughputHook(), hooks.ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False) hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
] ]
logger.info("Train start", ranks=[0]) logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader, trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS, epochs=gpc.config.NUM_EPOCHS,
hooks=hooks, hooks=hook_list,
display_progress=True, display_progress=True,
test_interval=1) test_interval=1)
......
from colossalai.amp import AMP_TYPE
VOCAB_SIZE = 50304
SEQ_LENGTH = 1024
TOTAL_BATCH_SIZE = 256
LEARNING_RATE = 0.00015
WEIGHT_DECAY = 1e-2
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 60
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 2
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"
from colossalai.amp import AMP_TYPE
VOCAB_SIZE = 50304
SEQ_LENGTH = 1024
TOTAL_BATCH_SIZE = 256
LEARNING_RATE = 0.00015
WEIGHT_DECAY = 1e-2
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_MODE = '2d'
NUM_EPOCHS = 60
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 1
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"
from colossalai.amp import AMP_TYPE
VOCAB_SIZE = 50304
SEQ_LENGTH = 1024
TOTAL_BATCH_SIZE = 256
LEARNING_RATE = 0.00015
WEIGHT_DECAY = 1e-2
TENSOR_PARALLEL_SIZE = 4
DEPTH = 1
TENSOR_PARALLEL_MODE = '2.5d'
NUM_EPOCHS = 60
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 1
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"
from colossalai.amp import AMP_TYPE
VOCAB_SIZE = 50304
SEQ_LENGTH = 1024
TOTAL_BATCH_SIZE = 256
LEARNING_RATE = 0.00015
WEIGHT_DECAY = 1e-2
TENSOR_PARALLEL_SIZE = 8
TENSOR_PARALLEL_MODE = '3d'
NUM_EPOCHS = 60
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 1
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"
from colossalai.amp import AMP_TYPE
VOCAB_SIZE = 50304
SEQ_LENGTH = 1024
TOTAL_BATCH_SIZE = 256
LEARNING_RATE = 0.00015
WEIGHT_DECAY = 1e-2
TENSOR_PARALLEL_SIZE = 1
TENSOR_PARALLEL_MODE = None
NUM_EPOCHS = 60
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 1
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"
import json
import os
import torch
from colossalai.registry import DATASETS
from torch.utils.data import Dataset
from transformers import GPT2Tokenizer
@DATASETS.register_module
class WebtextDataset(Dataset):
def __init__(self, path, seq_len=1024) -> None:
super().__init__()
root = os.path.dirname(path)
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
if os.path.isfile(encoded_data_cache_path):
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
if seq_len_ == seq_len:
self.data = data
self.attention_mask = attention_mask
return
raw_data = []
with open(path) as f:
for line in f.readlines():
raw_data.append(json.loads(line)['text'])
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.unk_token
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
self.data = encoded_data['input_ids']
self.attention_mask = encoded_data['attention_mask']
torch.save((seq_len, self.data, self.attention_mask), encoded_data_cache_path)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return (self.data[index], self.attention_mask[index]), self.data[index]
import contextlib
import os
import colossalai
import torch
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule)
from colossalai.logging import get_dist_logger
from colossalai.nn import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from colossalai.zero import zero3_model_context
from model_zoo.gpt import GPTLMLoss, gpt2_small, gpt2_medium, gpt2_large, gpt2_xl
from data import WebtextDataset
def train_gpt():
args = colossalai.get_default_parser().parse_args()
# standard launch
# colossalai.launch(config=args.config,
# rank=args.rank,
# world_size=args.world_size,
# local_rank=args.local_rank,
# host=args.host,
# port=args.port)
# launch from torchrun
colossalai.launch_from_torch(config=args.config)
logger = get_dist_logger()
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
train_dataset = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LENGTH)
train_dataloader = get_dataloader(train_dataset,
seed=42,
batch_size=gpc.config.BATCH_SIZE // gpc.data_parallel_size,
pin_memory=True,
shuffle=True,
drop_last=True)
logger.info(f'Loaded {len(train_dataset)}/{len(train_dataloader)} samples/batches', ranks=[0])
# zero3 under test
# use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
# cm = zero3_model_context() if use_zero3 else contextlib.nullcontext()
# with cm:
# model = gpc.config.model.pop('type')(**gpc.config.model)
model = gpt2_medium(vocab_size=gpc.config.VOCAB_SIZE,
max_position_embeddings=gpc.config.SEQ_LENGTH,
checkpoint=True)
criterion = GPTLMLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00015, weight_decay=1e-2)
steps_per_epoch = len(train_dataloader) // gpc.config.gradient_accumulation
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS * steps_per_epoch,
warmup_steps=gpc.config.WARMUP_EPOCHS * steps_per_epoch,
eta_min=1e-5)
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
lr_scheduler=lr_scheduler)
# pipeline under test
# num_model_chunks = getattr(gpc.config.model, 'num_chunks', 1)
# if num_model_chunks > 1:
# logger.info('Build InterleavedPipelineSchedule', ranks=[0])
# schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, num_model_chunks)
# else:
# logger.info('Build PipelineSchedule', ranks=[0])
# schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES)
timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)
hook_list = [
hooks.LogMetricByEpochHook(logger=logger),
hooks.LogMetricByStepHook(),
hooks.LossHook(),
hooks.ThroughputHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
# hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
# hooks.LogMemoryByEpochHook(logger),
# hooks.LogTimingByEpochHook(timer, logger),
# hooks.SaveCheckpointHook(checkpoint_dir='./ckpt')
]
logger.info("Training start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader, epochs=gpc.config.NUM_EPOCHS, hooks=hook_list, display_progress=True)
if __name__ == '__main__':
train_gpt()
...@@ -14,9 +14,7 @@ from colossalai.core import global_context as gpc ...@@ -14,9 +14,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, CrossEntropyLoss from colossalai.nn import Accuracy, CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer from colossalai.trainer import Trainer, hooks
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook,
LogTimingByEpochHook, LossHook, LRSchedulerHook, ThroughputHook)
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer
from model_zoo.vit import vit_small_patch16_224 from model_zoo.vit import vit_small_patch16_224
from nvidia.dali import types from nvidia.dali import types
...@@ -185,22 +183,22 @@ def train_imagenet(): ...@@ -185,22 +183,22 @@ def train_imagenet():
trainer = Trainer(engine=engine, logger=logger, timer=timer) trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0]) logger.info("Trainer is built", ranks=[0])
hooks = [ hook_list = [
LogMetricByEpochHook(logger=logger), hooks.LogMetricByEpochHook(logger=logger),
LogMetricByStepHook(), hooks.LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger), # hooks.LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger), # hooks.LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy()), hooks.AccuracyHook(accuracy_func=Accuracy()),
LossHook(), hooks.LossHook(),
ThroughputHook(), hooks.ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
] ]
logger.info("Train start", ranks=[0]) logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader, trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS, epochs=gpc.config.NUM_EPOCHS,
hooks=hooks, hooks=hook_list,
display_progress=True, display_progress=True,
test_interval=1) test_interval=1)
......
...@@ -14,9 +14,7 @@ from colossalai.core import global_context as gpc ...@@ -14,9 +14,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, CrossEntropyLoss from colossalai.nn import Accuracy, CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer from colossalai.trainer import Trainer, hooks
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook,
LogTimingByEpochHook, LossHook, LRSchedulerHook, ThroughputHook)
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer
from model_zoo.vit import vit_small_patch16_224 from model_zoo.vit import vit_small_patch16_224
from nvidia.dali import types from nvidia.dali import types
...@@ -185,22 +183,22 @@ def train_imagenet(): ...@@ -185,22 +183,22 @@ def train_imagenet():
trainer = Trainer(engine=engine, logger=logger, timer=timer) trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0]) logger.info("Trainer is built", ranks=[0])
hooks = [ hook_list = [
LogMetricByEpochHook(logger=logger), hooks.LogMetricByEpochHook(logger=logger),
LogMetricByStepHook(), hooks.LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger), # hooks.LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger), # hooks.LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy()), hooks.AccuracyHook(accuracy_func=Accuracy()),
LossHook(), hooks.LossHook(),
ThroughputHook(), hooks.ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
] ]
logger.info("Train start", ranks=[0]) logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader, trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS, epochs=gpc.config.NUM_EPOCHS,
hooks=hooks, hooks=hook_list,
display_progress=True, display_progress=True,
test_interval=1) test_interval=1)
......
from .gpt import *
\ No newline at end of file
import math
from typing import Callable
import torch
from colossalai import nn as col_nn
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.registry import LAYERS, MODELS, LOSSES
from colossalai.utils import get_current_device
from torch import dtype, nn
__all__ = ['GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt3']
@LAYERS.register_module
class GPTEmbedding(nn.Module):
def __init__(self,
embedding_dim: int,
vocab_size: int,
max_position_embeddings: int,
num_tokentypes: int = 0,
padding_idx: int = 0,
dropout: float = 0.,
dtype: dtype = None) -> None:
super().__init__()
self.word_embeddings = col_nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype)
self.position_embeddings = col_nn.Embedding(max_position_embeddings, embedding_dim, dtype=dtype)
if num_tokentypes > 0:
self.tokentype_embeddings = col_nn.Embedding(num_tokentypes, embedding_dim, dtype=dtype)
else:
self.tokentype_embeddings = None
self.dropout = col_nn.Dropout(dropout)
@property
def word_embedding_weight(self):
return self.word_embeddings.weight
def forward(self, input_ids, position_ids=None, tokentype_ids=None):
seq_length = input_ids.size(1)
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0)
x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids)
if self.tokentype_embeddings is not None and tokentype_ids is not None:
x = x + self.tokentype_embeddings(tokentype_ids)
x = self.dropout(x)
return x
@LAYERS.register_module
class GPTSelfAttention(nn.Module):
def __init__(self,
dim: int,
num_heads: int,
attention_dropout: float,
dropout: float,
bias: bool = True,
dtype: dtype = None) -> None:
super().__init__()
self.attention_head_size = dim // num_heads
self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias)
self.attention_dropout = col_nn.Dropout(attention_dropout)
self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True)
self.dropout = col_nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, attention_mask=None):
qkv = self.query_key_value(x)
all_head_size = qkv.shape[-1] // 3
num_attention_heads = all_head_size // self.attention_head_size
new_qkv_shape = qkv.shape[:-1] + \
(num_attention_heads, 3 * self.attention_head_size)
qkv = qkv.view(new_qkv_shape)
qkv = qkv.permute((0, 2, 1, 3))
q, k, v = torch.chunk(qkv, 3, dim=-1)
x = torch.matmul(q, k.transpose(-1, -2))
x = x / math.sqrt(self.attention_head_size)
# causal mask
q_len, k_len = q.size(-2), k.size(-2)
causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
device=get_current_device())).view(1, 1, q_len, k_len).bool()
x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device()))
if attention_mask is not None:
x = x + attention_mask
x = self.softmax(x)
x = self.attention_dropout(x)
x = torch.matmul(x, v)
x = x.transpose(1, 2)
new_context_layer_shape = x.size()[:-2] + (all_head_size, )
x = x.reshape(new_context_layer_shape)
x = self.dense(x)
x = self.dropout(x)
return x
@LAYERS.register_module
class GPTMLP(nn.Module):
def __init__(self,
dim: int,
mlp_ratio: int,
activation: Callable,
dropout: float,
dtype: dtype = None,
bias: bool = True):
super().__init__()
self.dense_1 = col_nn.Linear(dim, mlp_ratio * dim, dtype=dtype, bias=bias)
self.activation = activation
self.dense_2 = col_nn.Linear(mlp_ratio * dim, dim, dtype=dtype, bias=bias)
self.dropout = col_nn.Dropout(dropout)
def forward(self, x):
x = self.dense_1(x)
x = self.activation(x)
x = self.dense_2(x)
x = self.dropout(x)
return x
@LAYERS.register_module
class GPTBlock(CheckpointModule):
def __init__(self,
dim: int,
num_heads: int,
mlp_ratio: int,
activation: Callable,
attention_dropout: float = 0.,
dropout: float = 0.,
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False):
super().__init__()
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
self.attn = GPTSelfAttention(dim=dim,
num_heads=num_heads,
attention_dropout=attention_dropout,
dropout=dropout,
bias=bias,
dtype=dtype)
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
self.mlp = GPTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias)
def _forward(self, x, attention_mask=None):
x = x + self.attn(self.norm1(x), attention_mask)
x = x + self.mlp(self.norm2(x))
return x, attention_mask
@LAYERS.register_module
class GPTLMHead(nn.Module):
def __init__(self,
dim: int,
vocab_size: int,
word_embeeding_weight: nn.Parameter = None,
bias: bool = False,
dtype: dtype = None) -> None:
super().__init__()
self.dense = col_nn.Classifier(dim, vocab_size, word_embeeding_weight, bias=bias, dtype=dtype)
def forward(self, x):
x = self.dense(x)
return x
@LOSSES.register_module
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss = col_nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
@MODELS.register_module
class GPT(nn.Module):
def __init__(self,
vocab_size: int = 50304,
max_position_embeddings: int = 1024,
dim: int = 768,
num_heads: int = 12,
depth: int = 12,
mlp_ratio: int = 4,
dropout: float = 0.1,
embedding_dropout: float = 0.1,
attention_dropout: float = 0.1,
layernorm_epsilon: float = 1e-5,
activation: Callable = nn.functional.gelu,
checkpoint: bool = False,
dtype: dtype = None,
bias: bool = True,
padding_idx: int = 0) -> None:
super().__init__()
self.dtype = dtype
self.embed = GPTEmbedding(embedding_dim=dim,
vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings,
padding_idx=padding_idx,
dropout=embedding_dropout,
dtype=dtype)
self.blocks = nn.ModuleList([
GPTBlock(
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
activation=activation,
attention_dropout=attention_dropout,
dropout=dropout,
dtype=dtype,
bias=bias,
checkpoint=checkpoint,
) for _ in range(depth)
])
self.norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.head = GPTLMHead(dim=dim,
vocab_size=vocab_size,
word_embeeding_weight=self.embed.word_embedding_weight,
bias=bias,
dtype=dtype)
def forward(self, input_ids, attention_mask=None):
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# Adapted from huggingface
if attention_mask is not None:
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
x = self.embed(input_ids)
for block in self.blocks:
x, attention_mask = block(x, attention_mask)
x = self.head(self.norm(x))
return x
def _create_gpt_model(**model_kwargs):
model = GPT(**model_kwargs)
return model
@MODELS.register_module
def gpt2_small(**kwargs):
model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs)
return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt2_medium(**kwargs):
model_kwargs = dict(dim=1024, depth=24, num_heads=16, **kwargs)
return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt2_large(**kwargs):
model_kwargs = dict(dim=1280, depth=36, num_heads=20, **kwargs)
return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt2_xl(**kwargs):
model_kwargs = dict(dim=1600, depth=48, num_heads=25, **kwargs)
return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt3(**kwargs):
model_kwargs = dict(dim=12288, max_position_embeddings=2048, depth=96, num_heads=96, **kwargs)
return _create_gpt_model(**model_kwargs)
...@@ -89,7 +89,7 @@ class ViTEmbedding(nn.Module): ...@@ -89,7 +89,7 @@ class ViTEmbedding(nn.Module):
@LAYERS.register_module @LAYERS.register_module
class ViTSelfAttention(CheckpointModule): class ViTSelfAttention(nn.Module):
def __init__(self, def __init__(self,
dim: int, dim: int,
num_heads: int, num_heads: int,
...@@ -97,9 +97,8 @@ class ViTSelfAttention(CheckpointModule): ...@@ -97,9 +97,8 @@ class ViTSelfAttention(CheckpointModule):
dropout: float, dropout: float,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: dtype = None,
checkpoint: bool = False,
init_method: str = 'torch'): init_method: str = 'torch'):
super().__init__(checkpoint) super().__init__()
self.attention_head_size = dim // num_heads self.attention_head_size = dim // num_heads
self.query_key_value = col_nn.Linear(dim, self.query_key_value = col_nn.Linear(dim,
3 * dim, 3 * dim,
...@@ -111,7 +110,7 @@ class ViTSelfAttention(CheckpointModule): ...@@ -111,7 +110,7 @@ class ViTSelfAttention(CheckpointModule):
self.dropout = col_nn.Dropout(dropout) self.dropout = col_nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1) self.softmax = nn.Softmax(dim=-1)
def _forward(self, x): def forward(self, x):
qkv = self.query_key_value(x) qkv = self.query_key_value(x)
all_head_size = qkv.shape[-1] // 3 all_head_size = qkv.shape[-1] // 3
num_attention_heads = all_head_size // self.attention_head_size num_attention_heads = all_head_size // self.attention_head_size
...@@ -138,7 +137,7 @@ class ViTSelfAttention(CheckpointModule): ...@@ -138,7 +137,7 @@ class ViTSelfAttention(CheckpointModule):
@LAYERS.register_module @LAYERS.register_module
class ViTMLP(CheckpointModule): class ViTMLP(nn.Module):
def __init__(self, def __init__(self,
dim: int, dim: int,
mlp_ratio: int, mlp_ratio: int,
...@@ -146,9 +145,8 @@ class ViTMLP(CheckpointModule): ...@@ -146,9 +145,8 @@ class ViTMLP(CheckpointModule):
dropout: float, dropout: float,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
checkpoint: bool = False,
init_method: str = 'torch'): init_method: str = 'torch'):
super().__init__(checkpoint) super().__init__()
self.dense_1 = col_nn.Linear(dim, self.dense_1 = col_nn.Linear(dim,
mlp_ratio * dim, mlp_ratio * dim,
dtype=dtype, dtype=dtype,
...@@ -163,7 +161,7 @@ class ViTMLP(CheckpointModule): ...@@ -163,7 +161,7 @@ class ViTMLP(CheckpointModule):
**_init_rules[init_method]['transformer']) **_init_rules[init_method]['transformer'])
self.dropout_2 = col_nn.Dropout(dropout) self.dropout_2 = col_nn.Dropout(dropout)
def _forward(self, x): def forward(self, x):
x = self.dense_1(x) x = self.dense_1(x)
x = self.activation(x) x = self.activation(x)
x = self.dropout_1(x) x = self.dropout_1(x)
...@@ -192,22 +190,22 @@ class ViTHead(nn.Module): ...@@ -192,22 +190,22 @@ class ViTHead(nn.Module):
self.representation = None self.representation = None
representation_size = dim representation_size = dim
self.linear = col_nn.Classifier(representation_size, self.dense = col_nn.Classifier(representation_size,
num_classes, num_classes,
dtype=dtype, dtype=dtype,
bias=bias, bias=bias,
**_init_rules[init_method]['head']) **_init_rules[init_method]['head'])
def forward(self, x): def forward(self, x):
x = x[:, 0] x = x[:, 0]
if self.representation is not None: if self.representation is not None:
x = self.representation(x) x = self.representation(x)
x = self.linear(x) x = self.dense(x)
return x return x
@LAYERS.register_module @LAYERS.register_module
class ViTBlock(nn.Module): class ViTBlock(CheckpointModule):
def __init__(self, def __init__(self,
dim: int, dim: int,
num_heads: int, num_heads: int,
...@@ -216,32 +214,31 @@ class ViTBlock(nn.Module): ...@@ -216,32 +214,31 @@ class ViTBlock(nn.Module):
attention_dropout: float = 0., attention_dropout: float = 0.,
dropout: float = 0., dropout: float = 0.,
drop_path: float = 0., drop_path: float = 0.,
layernorm_epsilon: float = 1e-6,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
checkpoint: bool = False, checkpoint: bool = False,
init_method: str = 'torch'): init_method: str = 'torch'):
super().__init__() super().__init__(checkpoint)
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.attn = ViTSelfAttention(dim=dim, self.attn = ViTSelfAttention(dim=dim,
num_heads=num_heads, num_heads=num_heads,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
dropout=dropout, dropout=dropout,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
checkpoint=checkpoint,
init_method=init_method) init_method=init_method)
self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.mlp = ViTMLP(dim=dim, self.mlp = ViTMLP(dim=dim,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
activation=activation, activation=activation,
dropout=dropout, dropout=dropout,
dtype=dtype, dtype=dtype,
bias=bias, bias=bias,
checkpoint=checkpoint,
init_method=init_method) init_method=init_method)
def forward(self, x): def _forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path(self.mlp(self.norm2(x)))
return x return x
...@@ -261,6 +258,7 @@ class VisionTransformer(nn.Module): ...@@ -261,6 +258,7 @@ class VisionTransformer(nn.Module):
attention_dropout: float = 0., attention_dropout: float = 0.,
dropout: float = 0.1, dropout: float = 0.1,
drop_path: float = 0., drop_path: float = 0.,
layernorm_epsilon: float = 1e-6,
activation: Callable = nn.functional.gelu, activation: Callable = nn.functional.gelu,
representation_size: int = None, representation_size: int = None,
dtype: dtype = None, dtype: dtype = None,
...@@ -295,7 +293,7 @@ class VisionTransformer(nn.Module): ...@@ -295,7 +293,7 @@ class VisionTransformer(nn.Module):
) for i in range(depth) ) for i in range(depth)
] ]
norm = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
head = ViTHead(dim=dim, head = ViTHead(dim=dim,
num_classes=num_classes, num_classes=num_classes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment