Commit 9ee197d0 authored by アマデウス's avatar アマデウス Committed by Frank Lee
Browse files

moved env variables to global variables; (#215)

added branch context;
added vocab parallel layers;
moved split_batch from load_batch to tensor parallel embedding layers;
updated gpt model;
updated unit test cases;
fixed few collective communicator bugs
parent b82d60be
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d import torch
import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_tensor_2p5d
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
...@@ -29,8 +35,104 @@ class CrossEntropyLoss2p5D(_Loss): ...@@ -29,8 +35,104 @@ class CrossEntropyLoss2p5D(_Loss):
:param logits: Output logits of model :param logits: Output logits of model
:param targets: True targets from data :param targets: True targets from data
""" """
targets = split_tensor_2p5d(targets)
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean: if self.reduction_mean:
loss = loss.mean() loss = loss.mean()
loss = reduce_by_batch_2p5d.apply(loss, True) loss = reduce_by_batch_2p5d(loss, True)
return loss
class _VocabParallelCrossEntropy2p5D(torch.autograd.Function):
### Modified based on megatron.mpu.cross_entropy ###
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, targets):
# logits: [b/dq, h/q]
# loss: [b/dq]
# targets: [b/dq, h/q]
logits_max = torch.max(logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
# Subtract the maximum value.
logits = logits - logits_max.unsqueeze(dim=-1)
vocab_size = logits.size(-1)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
vocab_start = rank * (vocab_size)
vocab_end = (rank + 1) * (vocab_size) - 1
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(
start=0,
end=logits.size()[0],
)
predicted_logits = logits[arange_1d, masked_target]
predicted_logits[target_mask] = 0.
dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(dim=1)
dist.all_reduce(sum_exp_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target)
return loss
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(output_grad.unsqueeze(dim=-1))
return grad_input, None
@LOSSES.register_module
class VocabParallelCrossEntropyLoss2p5D(_Loss):
"""
Vocab parallel cross entropy loss for 2.5D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
self.reduction_mean = reduction
def forward(self, logits, targets):
"""Calculate loss between logits and targets
:param logits: Output logits of model
:param targets: True targets from data
"""
targets = split_tensor_2p5d(targets)
loss = _VocabParallelCrossEntropy2p5D.apply(logits, targets)
if self.reduction_mean:
loss = loss.mean()
loss = reduce_by_batch_2p5d(loss, True)
return loss return loss
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D import torch
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d import torch.distributed as dist
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@LOSSES.register_module @LOSSES.register_module
class CrossEntropyLoss3D(_Loss): class CrossEntropyLoss3D(_Loss):
""" """
Cross entropy loss for 3D parallelism Cross entropy loss for 3D parallelism
:param depth: depth for 3D parallelism
:type depth: int
:param reduction: whether to average the loss, defaults to True :param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
:param args: Args for loss function :param args: Args for loss function
:param kwargs: Kwargs for loss function :param kwargs: Kwargs for loss function
:type reduction: bool, optional
""" """
def __init__(self, reduction=True, *args, **kwargs): def __init__(self, reduction=True, *args, **kwargs):
super().__init__() super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
...@@ -32,8 +37,103 @@ class CrossEntropyLoss3D(_Loss): ...@@ -32,8 +37,103 @@ class CrossEntropyLoss3D(_Loss):
:param logits: Output logits of model :param logits: Output logits of model
:param targets: True targets from data :param targets: True targets from data
""" """
targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean: if self.reduction_mean:
loss = loss.mean() loss = loss.mean()
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode, True) loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
return loss
class _VocabParallelCrossEntropy3D(torch.autograd.Function):
# Adapted from megatron.mpu.cross_entropy
# loss[i] = -logits[i][targets] + log(sum(exp(logits[i])))
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, targets, output_parallel_mode):
# logits: [b/q^2, c/q]
# labels: [b/q^2]
# loss: [b/q^2]
logits_max = torch.max(logits, dim=-1)[0]
dist.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(output_parallel_mode))
# Subtract the maximum value.
logits = logits - logits_max.unsqueeze(dim=-1)
vocab_size_per_partition = logits.size()[-1]
rank = gpc.get_local_rank(output_parallel_mode)
vocab_start = rank * vocab_size_per_partition
vocab_end = (rank + 1) * vocab_size_per_partition - 1
# loss[i] = 0 if targets[i] < vocab_start or targets[i] > vocab_end
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device())
predicted_logits = logits[arange_1d, masked_target]
predicted_logits = predicted_logits.clone().contiguous().view_as(targets)
predicted_logits[target_mask] = 0.
dist.all_reduce(predicted_logits, group=gpc.get_group(output_parallel_mode))
# Loss = log(sum(exp(logits))) - predicted-logit.
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(dim=-1)
dist.all_reduce(sum_exp_logits, group=gpc.get_group(output_parallel_mode))
loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target)
return loss
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
# All the inputs have softmax as thier gradient.
input_grad = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = input_grad.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float())
input_grad.mul_(output_grad.unsqueeze(dim=-1))
return input_grad, None, None, None
@LOSSES.register_module
class VocabParallelCrossEntropyLoss3D(_Loss):
"""
Vocab parallel cross entropy loss for 2D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.reduction_mean = reduction
def forward(self, logits, targets):
"""Calculate loss between logits and targets
:param logits: Output logits of model
:param targets: True targets from data
"""
targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
loss = _VocabParallelCrossEntropy3D.apply(logits, targets, self.output_parallel_mode)
if self.reduction_mean:
loss = loss.mean()
loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
return loss return loss
...@@ -17,7 +17,7 @@ class Accuracy(nn.Module): ...@@ -17,7 +17,7 @@ class Accuracy(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']: if tensor_parallel not in _parallel_accuracy:
self.acc = calc_acc self.acc = calc_acc
else: else:
self.acc = _parallel_accuracy[tensor_parallel]() self.acc = _parallel_accuracy[tensor_parallel]()
......
import torch import torch
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d
from torch import nn from torch import nn
from ._utils import calc_acc from ._utils import calc_acc
...@@ -18,6 +18,7 @@ class Accuracy2D(nn.Module): ...@@ -18,6 +18,7 @@ class Accuracy2D(nn.Module):
:param targets: True labels from data :param targets: True labels from data
""" """
with torch.no_grad(): with torch.no_grad():
targets = split_tensor_2d(targets)
correct = calc_acc(logits, targets) correct = calc_acc(logits, targets)
correct = reduce_by_batch_2d.apply(correct) correct = reduce_by_batch_2d(correct)
return correct return correct
...@@ -18,6 +18,7 @@ class Accuracy2p5D(nn.Module): ...@@ -18,6 +18,7 @@ class Accuracy2p5D(nn.Module):
:param targets: True labels from data :param targets: True labels from data
""" """
with torch.no_grad(): with torch.no_grad():
targets = split_tensor_2p5d(targets)
correct = calc_acc(logits, targets) correct = calc_acc(logits, targets)
correct = reduce_by_batch_2p5d.apply(correct) correct = reduce_by_batch_2p5d(correct)
return correct return correct
import torch import torch
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from torch import nn from torch import nn
...@@ -22,6 +22,8 @@ class Accuracy3D(nn.Module): ...@@ -22,6 +22,8 @@ class Accuracy3D(nn.Module):
:param targets: True labels from data :param targets: True labels from data
""" """
with torch.no_grad(): with torch.no_grad():
targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
correct = calc_acc(logits, targets) correct = calc_acc(logits, targets)
correct = reduce_by_batch_3d.apply(correct, self.input_parallel_mode, self.weight_parallel_mode) correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode)
return correct return correct
...@@ -224,7 +224,7 @@ class LogTimingByEpochHook(LogByEpochHook): ...@@ -224,7 +224,7 @@ class LogTimingByEpochHook(LogByEpochHook):
super().__init__(logger=logger, interval=interval, priority=priority) super().__init__(logger=logger, interval=interval, priority=priority)
self._timer = timer self._timer = timer
self._log_eval = log_eval self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
# extra handling to avoid the unstable readings of the first # extra handling to avoid the unstable readings of the first
# few training steps to affect the history mean time # few training steps to affect the history mean time
...@@ -256,7 +256,7 @@ class LogTimingByEpochHook(LogByEpochHook): ...@@ -256,7 +256,7 @@ class LogTimingByEpochHook(LogByEpochHook):
""" """
if self._is_epoch_to_log(trainer) and self._is_rank_to_log: if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
msg = self._get_message('Train') msg = self._get_message('Train')
self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}, #steps/epoch = {trainer.steps_per_epoch}') self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}')
def after_test_epoch(self, trainer): def after_test_epoch(self, trainer):
"""Writes log after finishing a testing epoch. """Writes log after finishing a testing epoch.
......
...@@ -317,22 +317,27 @@ class ThroughputMetric(Metric): ...@@ -317,22 +317,27 @@ class ThroughputMetric(Metric):
:param epoch_only: epoch only :param epoch_only: epoch only
:type epoch_only: bool :type epoch_only: bool
""" """
def __init__(self, epoch_only: bool): def __init__(self, epoch_only: bool, ignored_steps: int = 0):
super().__init__(epoch_only=epoch_only) super().__init__(epoch_only=epoch_only)
self.ignored_steps = ignored_steps
self.cur_steps = 0
self.accumulated_num_samples = torch.zeros(1, device=get_current_device()) self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
self.accumulated_used_time = torch.zeros(1, device=get_current_device()) self.accumulated_used_time = torch.zeros(1, device=get_current_device())
self.last_step_num_samples = torch.zeros(1, device=get_current_device()) self.last_step_num_samples = torch.zeros(1, device=get_current_device())
self.last_step_used_time = torch.zeros(1, device=get_current_device()) self.last_step_used_time = torch.zeros(1, device=get_current_device())
def reset(self) -> None: def reset(self) -> None:
# self.cur_steps = 0
self.accumulated_num_samples.zero_() self.accumulated_num_samples.zero_()
self.accumulated_used_time.zero_() self.accumulated_used_time.zero_()
self.last_step_num_samples.zero_() self.last_step_num_samples.zero_()
self.last_step_used_time.zero_() self.last_step_used_time.zero_()
def update(self, num_samples, time) -> None: def update(self, num_samples, time) -> None:
self.cur_steps += 1
self.last_step_num_samples.fill_(num_samples) self.last_step_num_samples.fill_(num_samples)
self.last_step_used_time.fill_(time) self.last_step_used_time.fill_(time)
if self.cur_steps >= self.ignored_steps:
self.accumulated_num_samples += self.last_step_num_samples self.accumulated_num_samples += self.last_step_num_samples
self.accumulated_used_time += self.last_step_used_time self.accumulated_used_time += self.last_step_used_time
...@@ -360,13 +365,14 @@ class ThroughputHook(MetricHook): ...@@ -360,13 +365,14 @@ class ThroughputHook(MetricHook):
:param priority: priority of throughput hook, defaults to 10 :param priority: priority of throughput hook, defaults to 10
:type priority: int, optional :type priority: int, optional
""" """
def __init__(self, priority: int = 10): def __init__(self, ignored_steps: int = 0, priority: int = 10):
super().__init__(priority) super().__init__(priority)
self.ignored_steps = ignored_steps
def after_hook_is_attached(self, trainer): def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer) self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute: if self._is_stage_to_compute:
self.metric = ThroughputMetric(epoch_only=True) self.metric = ThroughputMetric(epoch_only=True, ignored_steps=self.ignored_steps)
# register the metric # register the metric
trainer.states['metrics']['train']['Throughput'] = self.metric trainer.states['metrics']['train']['Throughput'] = self.metric
......
from .activation_checkpoint import checkpoint from .activation_checkpoint import checkpoint
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, param_is_not_tensor_parallel_duplicate, is_using_ddp, is_using_pp, is_using_sequence, model_branch_context, multi_tensor_applier,
print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param) param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
sync_model_param)
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .data_sampler import DataParallelSampler, get_dataloader from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient from .gradient_accumulation import accumulate_gradient
...@@ -11,9 +12,9 @@ from .timer import MultiTimer, Timer ...@@ -11,9 +12,9 @@ from .timer import MultiTimer, Timer
__all__ = [ __all__ = [
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0', 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context', 'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'model_branch_context',
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', 'conditional_context', 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32',
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', 'copy_tensor_parallel_attributes', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize',
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', 'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier',
'get_dataloader', 'switch_virtual_pipeline_parallel_rank' 'accumulate_gradient', 'DataParallelSampler', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank'
] ]
...@@ -6,8 +6,6 @@ import socket ...@@ -6,8 +6,6 @@ import socket
import torch import torch
from torch._six import inf from torch._six import inf
import colossalai.context.parallel_mode
try: try:
import colossal_C import colossal_C
except: except:
...@@ -20,6 +18,7 @@ from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARA ...@@ -20,6 +18,7 @@ from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARA
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env from colossalai.global_variables import moe_env
from colossalai.global_variables import tensor_parallel_env as env
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
...@@ -62,8 +61,7 @@ def sync_model_param(model, parallel_mode): ...@@ -62,8 +61,7 @@ def sync_model_param(model, parallel_mode):
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
for param in model.parameters(): for param in model.parameters():
ranks = gpc.get_ranks_in_group(parallel_mode) ranks = gpc.get_ranks_in_group(parallel_mode)
dist.broadcast( dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
param, src=ranks[0], group=gpc.get_group(parallel_mode))
def is_dp_rank_0(): def is_dp_rank_0():
...@@ -99,6 +97,15 @@ def conditional_context(context_manager, enable=True): ...@@ -99,6 +97,15 @@ def conditional_context(context_manager, enable=True):
yield yield
class model_branch_context(object):
def __enter__(self):
self.env_status = env.save()
def __exit__(self, *exc_info):
env.load(**self.env_status)
def is_model_parallel_parameter(p): def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
...@@ -124,9 +131,10 @@ def _calc_lp(grads, norm_type): ...@@ -124,9 +131,10 @@ def _calc_lp(grads, norm_type):
norm = 0.0 norm = 0.0
for grad in grads: for grad in grads:
grad_norm = torch.norm(grad, norm_type) grad_norm = torch.norm(grad, norm_type)
norm += grad_norm ** norm_type norm += grad_norm**norm_type
return norm return norm
# ======== Gradient Clipping ========= # ======== Gradient Clipping =========
...@@ -183,7 +191,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -183,7 +191,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
moe_parallel_grads = [] # used to collect moe tensor parallel gradients moe_parallel_grads = [] # used to collect moe tensor parallel gradients
for p in params: for p in params:
if is_model_parallel_parameter(p): if is_model_parallel_parameter(p):
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type) reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
tensor_parallel_grads.append(p.grad.data / reductor) tensor_parallel_grads.append(p.grad.data / reductor)
elif is_moe_parallel_parameter(p): elif is_moe_parallel_parameter(p):
moe_parallel_grads.append(p.grad.data) moe_parallel_grads.append(p.grad.data)
...@@ -191,32 +199,24 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -191,32 +199,24 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
no_tensor_parallel_grads.append(p.grad.data) no_tensor_parallel_grads.append(p.grad.data)
if norm_type == 2.0: if norm_type == 2.0:
tensor_parallel_norm = _calc_l2_norm( tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
tensor_parallel_grads) ** norm_type no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
no_tensor_parallel_norm = _calc_l2_norm( moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type
no_tensor_parallel_grads) ** norm_type
moe_parallel_norm = _calc_l2_norm(
moe_parallel_grads) ** norm_type
else: else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_norm = _calc_lp( no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
no_tensor_parallel_grads, norm_type)
moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type) moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type)
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
dist.all_reduce(tensor_parallel_norm, dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR))
# Sum across all moe-tensor-parallel GPUs # Sum across all moe-tensor-parallel GPUs
if len(moe_parallel_grads) > 0: if len(moe_parallel_grads) > 0:
dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL)) dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL))
no_tensor_parallel_norm += moe_parallel_norm no_tensor_parallel_norm += moe_parallel_norm
total_norm = tensor_parallel_norm + no_tensor_parallel_norm total_norm = tensor_parallel_norm + no_tensor_parallel_norm
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
dist.all_reduce(total_norm, dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
op=dist.ReduceOp.SUM, total_norm = total_norm**(1.0 / norm_type)
group=gpc.get_group(ParallelMode.PIPELINE))
total_norm = total_norm ** (1.0 / norm_type)
if type(total_norm) == 'torch.cuda.FloatTensor': if type(total_norm) == 'torch.cuda.FloatTensor':
total_norm = total_norm.item() total_norm = total_norm.item()
...@@ -225,10 +225,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -225,10 +225,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if clip_coeff < 1.0: if clip_coeff < 1.0:
grads = [p.grad.detach() for p in params] grads = [p.grad.detach() for p in params]
dummy_overflow_buf = torch.cuda.IntTensor([0]) dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale, multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
dummy_overflow_buf,
[grads, grads],
clip_coeff)
return total_norm return total_norm
...@@ -254,12 +251,11 @@ def count_zeros_fp32(parameters): ...@@ -254,12 +251,11 @@ def count_zeros_fp32(parameters):
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
ops = [] ops = []
ops.append(dist.all_reduce(total_num_zeros, ops.append(
op=dist.ReduceOp.SUM, dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True))
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True))
if gpc.is_initialized(ParallelMode.PIPELINE): if gpc.is_initialized(ParallelMode.PIPELINE):
ops.append(dist.all_reduce(total_num_zeros, ops.append(
dist.all_reduce(total_num_zeros,
op=dist.ReduceOp.SUM, op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PIPELINE), group=gpc.get_group(ParallelMode.PIPELINE),
async_op=True)) async_op=True))
...@@ -279,9 +275,8 @@ def copy_tensor_parallel_attributes(src_tensor, dst_tensor): ...@@ -279,9 +275,8 @@ def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
def param_is_not_tensor_parallel_duplicate(param): def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, IS_TENSOR_PARALLEL) and return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank(
getattr(param, IS_TENSOR_PARALLEL)) or ( ParallelMode.TENSOR) == 0)
gpc.get_local_rank(ParallelMode.TENSOR) == 0)
@contextmanager @contextmanager
......
...@@ -3,12 +3,20 @@ from typing import Callable ...@@ -3,12 +3,20 @@ from typing import Callable
import torch import torch
from colossalai import nn as col_nn from colossalai import nn as col_nn
from colossalai.nn.layer.utils import CheckpointModule from colossalai.builder.pipeline import partition_uniform
from colossalai.registry import LAYERS, MODELS, LOSSES from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.layer.utils import CheckpointModule, divide
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.registry import LAYERS, LOSSES, MODELS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import dtype, nn from torch import dtype, nn
__all__ = ['GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt3'] __all__ = [
'GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt2_8B', 'gpt2_xl_pipeline',
'gpt2_8B_pipeline', 'gpt3', 'gpt3_pipeline'
]
@LAYERS.register_module @LAYERS.register_module
...@@ -18,7 +26,7 @@ class GPTEmbedding(nn.Module): ...@@ -18,7 +26,7 @@ class GPTEmbedding(nn.Module):
vocab_size: int, vocab_size: int,
max_position_embeddings: int, max_position_embeddings: int,
num_tokentypes: int = 0, num_tokentypes: int = 0,
padding_idx: int = 0, padding_idx: int = None,
dropout: float = 0., dropout: float = 0.,
dtype: dtype = None) -> None: dtype: dtype = None) -> None:
super().__init__() super().__init__()
...@@ -34,7 +42,7 @@ class GPTEmbedding(nn.Module): ...@@ -34,7 +42,7 @@ class GPTEmbedding(nn.Module):
def word_embedding_weight(self): def word_embedding_weight(self):
return self.word_embeddings.weight return self.word_embeddings.weight
def forward(self, input_ids, position_ids=None, tokentype_ids=None): def forward(self, input_ids, attention_mask=None, position_ids=None, tokentype_ids=None):
seq_length = input_ids.size(1) seq_length = input_ids.size(1)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0) position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0)
...@@ -42,7 +50,20 @@ class GPTEmbedding(nn.Module): ...@@ -42,7 +50,20 @@ class GPTEmbedding(nn.Module):
if self.tokentype_embeddings is not None and tokentype_ids is not None: if self.tokentype_embeddings is not None and tokentype_ids is not None:
x = x + self.tokentype_embeddings(tokentype_ids) x = x + self.tokentype_embeddings(tokentype_ids)
x = self.dropout(x) x = self.dropout(x)
return x
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# Adapted from huggingface
if attention_mask is not None:
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = col_nn.partition_batch(attention_mask)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
return x, attention_mask
@LAYERS.register_module @LAYERS.register_module
...@@ -53,20 +74,32 @@ class GPTSelfAttention(nn.Module): ...@@ -53,20 +74,32 @@ class GPTSelfAttention(nn.Module):
attention_dropout: float, attention_dropout: float,
dropout: float, dropout: float,
bias: bool = True, bias: bool = True,
fuse_scale_mask_softmax: bool = False,
dtype: dtype = None) -> None: dtype: dtype = None) -> None:
super().__init__() super().__init__()
self.fuse_scale_mask_softmax = fuse_scale_mask_softmax
self.attention_head_size = dim // num_heads self.attention_head_size = divide(dim, num_heads)
self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias) self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias)
if fuse_scale_mask_softmax:
from colossalai.kernel import FusedScaleMaskSoftmax
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
self.softmax = FusedScaleMaskSoftmax(input_in_fp16=True,
input_in_bf16=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=True,
mask_func=None,
softmax_in_fp32=True,
scale=math.sqrt(self.attention_head_size))
else:
self.softmax = nn.Softmax(dim=-1)
self.attention_dropout = col_nn.Dropout(attention_dropout) self.attention_dropout = col_nn.Dropout(attention_dropout)
self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True) self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True)
self.dropout = col_nn.Dropout(dropout) self.dropout = col_nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, attention_mask=None): def forward(self, x, attention_mask=None):
qkv = self.query_key_value(x) qkv = self.query_key_value(x)
all_head_size = qkv.shape[-1] // 3 all_head_size = qkv.shape[-1] // 3
num_attention_heads = all_head_size // self.attention_head_size num_attention_heads = divide(all_head_size, self.attention_head_size)
new_qkv_shape = qkv.shape[:-1] + \ new_qkv_shape = qkv.shape[:-1] + \
(num_attention_heads, 3 * self.attention_head_size) (num_attention_heads, 3 * self.attention_head_size)
qkv = qkv.view(new_qkv_shape) qkv = qkv.view(new_qkv_shape)
...@@ -74,17 +107,20 @@ class GPTSelfAttention(nn.Module): ...@@ -74,17 +107,20 @@ class GPTSelfAttention(nn.Module):
q, k, v = torch.chunk(qkv, 3, dim=-1) q, k, v = torch.chunk(qkv, 3, dim=-1)
x = torch.matmul(q, k.transpose(-1, -2)) x = torch.matmul(q, k.transpose(-1, -2))
x = x / math.sqrt(self.attention_head_size)
if self.fuse_scale_mask_softmax:
x = self.softmax(x, attention_mask)
else:
x = x / math.sqrt(self.attention_head_size)
# causal mask # causal mask
q_len, k_len = q.size(-2), k.size(-2) q_len, k_len = q.size(-2), k.size(-2)
causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8, causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
device=get_current_device())).view(1, 1, q_len, k_len).bool() device=get_current_device())).view(1, 1, q_len, k_len).bool()
x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device())) x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device()))
if attention_mask is not None: if attention_mask is not None:
x = x + attention_mask x = x + attention_mask
x = self.softmax(x) x = self.softmax(x)
x = self.attention_dropout(x) x = self.attention_dropout(x)
x = torch.matmul(x, v) x = torch.matmul(x, v)
...@@ -102,15 +138,16 @@ class GPTSelfAttention(nn.Module): ...@@ -102,15 +138,16 @@ class GPTSelfAttention(nn.Module):
class GPTMLP(nn.Module): class GPTMLP(nn.Module):
def __init__(self, def __init__(self,
dim: int, dim: int,
mlp_ratio: int, mlp_ratio: float,
activation: Callable, activation: Callable,
dropout: float, dropout: float,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True): bias: bool = True):
super().__init__() super().__init__()
self.dense_1 = col_nn.Linear(dim, mlp_ratio * dim, dtype=dtype, bias=bias) intermediate_dim = int(dim * mlp_ratio)
self.dense_1 = col_nn.Linear(dim, intermediate_dim, dtype=dtype, bias=bias)
self.activation = activation self.activation = activation
self.dense_2 = col_nn.Linear(mlp_ratio * dim, dim, dtype=dtype, bias=bias) self.dense_2 = col_nn.Linear(intermediate_dim, dim, dtype=dtype, bias=bias)
self.dropout = col_nn.Dropout(dropout) self.dropout = col_nn.Dropout(dropout)
def forward(self, x): def forward(self, x):
...@@ -126,27 +163,44 @@ class GPTBlock(CheckpointModule): ...@@ -126,27 +163,44 @@ class GPTBlock(CheckpointModule):
def __init__(self, def __init__(self,
dim: int, dim: int,
num_heads: int, num_heads: int,
mlp_ratio: int, mlp_ratio: float,
activation: Callable, activation: Callable,
attention_dropout: float = 0., attention_dropout: float = 0.,
dropout: float = 0., dropout: float = 0.,
layernorm_epsilon: float = 1e-5,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
apply_post_layernorm: bool = False,
fuse_scale_mask_softmax: bool = False,
checkpoint: bool = False): checkpoint: bool = False):
super().__init__(checkpoint=checkpoint) super().__init__(checkpoint)
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) self.apply_post_layernorm = apply_post_layernorm
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.attn = GPTSelfAttention(dim=dim, self.attn = GPTSelfAttention(dim=dim,
num_heads=num_heads, num_heads=num_heads,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
dropout=dropout, dropout=dropout,
bias=bias, bias=bias,
fuse_scale_mask_softmax=fuse_scale_mask_softmax,
dtype=dtype) dtype=dtype)
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.mlp = GPTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias) self.mlp = GPTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias)
def _forward(self, x, attention_mask=None): def _forward(self, x, attention_mask=None):
x = x + self.attn(self.norm1(x), attention_mask) if not self.apply_post_layernorm:
x = x + self.mlp(self.norm2(x)) residual = x
x = self.norm1(x)
if self.apply_post_layernorm:
residual = x
x = residual + self.attn(x, attention_mask)
if not self.apply_post_layernorm:
residual = x
x = self.norm2(x)
if self.apply_post_layernorm:
residual = x
x = residual + self.mlp(x)
return x, attention_mask return x, attention_mask
...@@ -161,6 +215,10 @@ class GPTLMHead(nn.Module): ...@@ -161,6 +215,10 @@ class GPTLMHead(nn.Module):
super().__init__() super().__init__()
self.dense = col_nn.Classifier(dim, vocab_size, word_embeeding_weight, bias=bias, dtype=dtype) self.dense = col_nn.Classifier(dim, vocab_size, word_embeeding_weight, bias=bias, dtype=dtype)
@property
def weight(self):
return self.dense.weight
def forward(self, x): def forward(self, x):
x = self.dense(x) x = self.dense(x)
return x return x
...@@ -187,18 +245,19 @@ class GPT(nn.Module): ...@@ -187,18 +245,19 @@ class GPT(nn.Module):
dim: int = 768, dim: int = 768,
num_heads: int = 12, num_heads: int = 12,
depth: int = 12, depth: int = 12,
mlp_ratio: int = 4, mlp_ratio: float = 4.0,
dropout: float = 0.1, dropout: float = 0.1,
embedding_dropout: float = 0.1, embedding_dropout: float = 0.1,
attention_dropout: float = 0.1, attention_dropout: float = 0.1,
layernorm_epsilon: float = 1e-5, layernorm_epsilon: float = 1e-5,
activation: Callable = nn.functional.gelu, activation: Callable = nn.functional.gelu,
checkpoint: bool = False, padding_idx: int = None,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
padding_idx: int = 0) -> None: apply_post_layernorm: bool = False,
fuse_scale_mask_softmax: bool = False,
checkpoint: bool = False) -> None:
super().__init__() super().__init__()
self.dtype = dtype
self.embed = GPTEmbedding(embedding_dim=dim, self.embed = GPTEmbedding(embedding_dim=dim,
vocab_size=vocab_size, vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
...@@ -213,8 +272,11 @@ class GPT(nn.Module): ...@@ -213,8 +272,11 @@ class GPT(nn.Module):
activation=activation, activation=activation,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
dropout=dropout, dropout=dropout,
layernorm_epsilon=layernorm_epsilon,
dtype=dtype, dtype=dtype,
bias=bias, bias=bias,
apply_post_layernorm=apply_post_layernorm,
fuse_scale_mask_softmax=fuse_scale_mask_softmax,
checkpoint=checkpoint, checkpoint=checkpoint,
) for _ in range(depth) ) for _ in range(depth)
]) ])
...@@ -224,26 +286,79 @@ class GPT(nn.Module): ...@@ -224,26 +286,79 @@ class GPT(nn.Module):
self.head = GPTLMHead(dim=dim, self.head = GPTLMHead(dim=dim,
vocab_size=vocab_size, vocab_size=vocab_size,
word_embeeding_weight=self.embed.word_embedding_weight, word_embeeding_weight=self.embed.word_embedding_weight,
bias=bias,
dtype=dtype) dtype=dtype)
def forward(self, input_ids, attention_mask=None): def forward(self, input_ids, attention_mask=None):
# We create a 3D attention mask from a 2D tensor mask. x, attention_mask = self.embed(input_ids, attention_mask)
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] for block in self.blocks:
# Adapted from huggingface x, attention_mask = block(x, attention_mask)
if attention_mask is not None:
batch_size = input_ids.shape[0] x = self.head(self.norm(x))
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) return x
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
class PipelineGPT(nn.Module):
def __init__(self,
vocab_size: int = 50304,
max_position_embeddings: int = 1024,
dim: int = 768,
num_heads: int = 12,
depth: int = 12,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
embedding_dropout: float = 0.1,
attention_dropout: float = 0.1,
layernorm_epsilon: float = 1e-5,
activation: Callable = nn.functional.gelu,
padding_idx: int = None,
dtype: dtype = None,
bias: bool = True,
apply_post_layernorm: bool = False,
fuse_scale_mask_softmax: bool = False,
checkpoint: bool = False,
first: bool = False,
last: bool = False):
super().__init__()
self.checkpoint = checkpoint
self.first = first
self.last = last
if first:
self.embed = GPTEmbedding(embedding_dim=dim,
vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings,
padding_idx=padding_idx,
dropout=embedding_dropout,
dtype=dtype)
self.blocks = nn.ModuleList([
GPTBlock(
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
activation=activation,
attention_dropout=attention_dropout,
dropout=dropout,
layernorm_epsilon=layernorm_epsilon,
dtype=dtype,
bias=bias,
apply_post_layernorm=apply_post_layernorm,
fuse_scale_mask_softmax=fuse_scale_mask_softmax,
checkpoint=checkpoint,
) for _ in range(depth)
])
if self.last:
self.norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.head = GPTLMHead(dim=dim, vocab_size=vocab_size, dtype=dtype)
x = self.embed(input_ids) def forward(self, x=None, input_ids=None, attention_mask=None):
if self.first:
x, attention_mask = self.embed(input_ids, attention_mask)
for block in self.blocks: for block in self.blocks:
x, attention_mask = block(x, attention_mask) x, attention_mask = block(x, attention_mask)
if self.last:
x = self.head(self.norm(x)) x = self.head(self.norm(x))
return x return x
...@@ -254,6 +369,33 @@ def _create_gpt_model(**model_kwargs): ...@@ -254,6 +369,33 @@ def _create_gpt_model(**model_kwargs):
return model return model
def _create_gpt_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, **model_kwargs):
logger = get_dist_logger()
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
rank = gpc.get_global_rank()
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
parts = partition_uniform(depth, pipeline_size,
num_chunks)[pipeline_rank] if layer_partitions is None else layer_partitions
models = []
for start, end in parts:
model_kwargs['first'] = start == 0
model_kwargs['last'] = end == depth
model_kwargs['depth'] = end - start
chunk = PipelineGPT(**model_kwargs).to(get_current_device())
if start == 0:
wrapper.register_parameter(chunk.embed.word_embedding_weight)
elif end == depth:
wrapper.register_parameter(chunk.head.weight)
models.append(chunk)
logger.info(f'==> Rank {rank} built layer {start}-{end} / total {depth}')
if len(models) == 1:
model = models[0]
else:
model = nn.ModuleList(models)
return model
@MODELS.register_module @MODELS.register_module
def gpt2_small(**kwargs): def gpt2_small(**kwargs):
model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs)
...@@ -262,23 +404,47 @@ def gpt2_small(**kwargs): ...@@ -262,23 +404,47 @@ def gpt2_small(**kwargs):
@MODELS.register_module @MODELS.register_module
def gpt2_medium(**kwargs): def gpt2_medium(**kwargs):
model_kwargs = dict(dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(dim=1024, depth=24, num_heads=8, **kwargs)
return _create_gpt_model(**model_kwargs) return _create_gpt_model(**model_kwargs)
@MODELS.register_module @MODELS.register_module
def gpt2_large(**kwargs): def gpt2_large(**kwargs):
model_kwargs = dict(dim=1280, depth=36, num_heads=20, **kwargs) model_kwargs = dict(dim=1536, depth=36, num_heads=12, **kwargs)
return _create_gpt_model(**model_kwargs) return _create_gpt_model(**model_kwargs)
@MODELS.register_module @MODELS.register_module
def gpt2_xl(**kwargs): def gpt2_xl(**kwargs):
model_kwargs = dict(dim=1600, depth=48, num_heads=25, **kwargs) model_kwargs = dict(dim=1600, depth=48, num_heads=16, **kwargs)
return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt2_8B(**kwargs):
model_kwargs = dict(dim=3072, depth=72, num_heads=24, **kwargs)
return _create_gpt_model(**model_kwargs) return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt2_xl_pipeline(**kwargs):
model_kwargs = dict(dim=1600, depth=48, num_heads=20, **kwargs)
return _create_gpt_pipeline_model(**model_kwargs)
@MODELS.register_module
def gpt2_8B_pipeline(**kwargs):
model_kwargs = dict(dim=3072, depth=72, num_heads=24, **kwargs)
return _create_gpt_pipeline_model(**model_kwargs)
@MODELS.register_module @MODELS.register_module
def gpt3(**kwargs): def gpt3(**kwargs):
model_kwargs = dict(dim=12288, max_position_embeddings=2048, depth=96, num_heads=96, **kwargs) model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs)
return _create_gpt_model(**model_kwargs) return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt3_pipeline(**kwargs):
model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs)
return _create_gpt_pipeline_model(**model_kwargs)
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.nn import Parameter
import time
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn import Linear1D_Col, Linear1D_Row from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import (Classifier1D, Embedding1D, Linear1D_Col, Linear1D_Row, VanillaClassifier,
VocabParallelClassifier1D, VocabParallelCrossEntropyLoss1D, VocabParallelEmbedding1D)
from colossalai.utils import get_current_device, print_rank_0 from colossalai.utils import get_current_device, print_rank_0
from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES, check_equal, IMG_SIZE from torch.nn import Parameter
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
def check_linear_col(): def check_linear_col():
...@@ -144,3 +146,351 @@ def check_linear_row(): ...@@ -144,3 +146,351 @@ def check_linear_row():
check_equal(B_grad, layer.bias.grad) check_equal(B_grad, layer.bias.grad)
print_rank_0('linear_row backward: pass') print_rank_0('linear_row backward: pass')
def check_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = C_master.clone()
check_equal(out, C)
print_rank_0('embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('embed backward: pass')
def check_vocab_parallel_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = C_master.clone()
check_equal(out, C)
print_rank_0('vocab parallel embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('vocab parallel embed backward: pass')
def check_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
env.parallel_input_1d = False
parallel_input_1d = env.parallel_input_1d
layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, bias=True)
layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, bias=True)
layer_master = layer_master.to(dtype).to(device)
W_master = layer_master.weight.data
dist.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=-1)[i]
layer.weight.data.copy_(W)
B_master = layer_master.bias.data
dist.broadcast(B_master, src=0)
B = B_master.clone()
layer.bias.data.copy_(B)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
dist.broadcast(A_master, src=0)
if parallel_input_1d:
A = torch.chunk(A_master, DEPTH, dim=-1)[i]
A = A.clone()
else:
A = A_master.clone()
A.requires_grad = True
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = C_master.clone()
check_equal(out, C)
print_rank_0('classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
dist.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
if parallel_input_1d:
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i]
check_equal(A_grad, A.grad)
W_grad = layer_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = layer_master.bias.grad
check_equal(B_grad, layer.bias.grad)
print_rank_0('classifier (no given weight) backward: pass')
def check_vocab_parallel_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = VocabParallelClassifier1D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer_master = layer_master.to(dtype).to(device)
W_master = layer_master.weight.data
dist.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=0)[i]
layer.weight.data.copy_(W)
B_master = layer_master.bias.data
dist.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[i]
layer.bias.data.copy_(B)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
dist.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
check_equal(out, C)
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
dist.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
check_equal(A_grad, A.grad)
W_grad = layer_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = layer_master.bias.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, layer.bias.grad)
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
def check_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[i]
embed.weight.data.copy_(weight)
env.parallel_input_1d = False
layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False)
layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = C_master.clone()
check_equal(out, C)
print_rank_0('classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
dist.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('classifier (given embed weight) backward: pass')
def check_vocab_parallel_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
embed.weight.data.copy_(weight)
env.parallel_input_1d = False
layer = VocabParallelClassifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False)
layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
check_equal(out, C)
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
dist.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
def check_vocab_parallel_loss():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
criterion = VocabParallelCrossEntropyLoss1D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, SEQ_LENGTH), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=-1)[i]
out = out.clone()
out.requires_grad = True
loss = criterion(out, target_master)
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('vocab parallel loss forward: pass')
loss.backward()
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i]
check_equal(out_grad, out.grad)
print_rank_0('vocab parallel loss backward: pass')
...@@ -9,6 +9,7 @@ SEQ_LENGTH = 8 ...@@ -9,6 +9,7 @@ SEQ_LENGTH = 8
IMG_SIZE = 16 IMG_SIZE = 16
HIDDEN_SIZE = 8 HIDDEN_SIZE = 8
NUM_CLASSES = 8 NUM_CLASSES = 8
VOCAB_SIZE = 16
def check_equal(A, B): def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True
...@@ -7,6 +7,7 @@ import pytest ...@@ -7,6 +7,7 @@ import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port from colossalai.utils import free_port
...@@ -24,6 +25,7 @@ CONFIG = dict( ...@@ -24,6 +25,7 @@ CONFIG = dict(
def check_layer(rank, world_size, port): def check_layer(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG, launch(config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
...@@ -33,6 +35,13 @@ def check_layer(rank, world_size, port): ...@@ -33,6 +35,13 @@ def check_layer(rank, world_size, port):
check_linear_col() check_linear_col()
check_linear_row() check_linear_row()
check_embed()
check_vocab_parallel_embed()
check_classifier_no_given_weight()
check_vocab_parallel_classifier_no_given_weight()
check_classifier_given_embed_weight()
check_vocab_parallel_classifier_given_embed_weight()
check_vocab_parallel_loss()
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -8,6 +8,9 @@ BATCH_SIZE = 8 ...@@ -8,6 +8,9 @@ BATCH_SIZE = 8
SEQ_LENGTH = 8 SEQ_LENGTH = 8
HIDDEN_SIZE = 8 HIDDEN_SIZE = 8
NUM_CLASSES = 8 NUM_CLASSES = 8
VOCAB_SIZE = 16
IMG_SIZE = 16
def check_equal(A, B): def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) == True assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
This diff is collapsed.
...@@ -5,8 +5,10 @@ TESSERACT_DEP = 2 ...@@ -5,8 +5,10 @@ TESSERACT_DEP = 2
BATCH_SIZE = 8 BATCH_SIZE = 8
SEQ_LENGTH = 8 SEQ_LENGTH = 8
HIDDEN_SIZE = 8 HIDDEN_SIZE = 8
NUM_CLASSES = 3 NUM_CLASSES = 8
VOCAB_SIZE = 16
IMG_SIZE = 16
def check_equal(A, B): def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True assert torch.allclose(A, B, rtol=1e-5, atol=1e-2)
\ No newline at end of file \ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment