Commit 9b174da8 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Merge branch 'pipeline_parallel_main' into 'main'

Pipeline parallelism and inter-layer model parallelism implementation

See merge request ADLR/megatron-lm!159
parents 3aacd955 6e83649f
...@@ -112,7 +112,7 @@ def main(): ...@@ -112,7 +112,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
args.rank = 0 args.rank = 0
args.make_vocab_size_divisible_by = 128 args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1 args.tensor_model_parallel_size = 1
if args.dataset_impl == "infer": if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
......
...@@ -26,6 +26,7 @@ from apex.multi_tensor_apply import multi_tensor_applier ...@@ -26,6 +26,7 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron import mpu
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
...@@ -71,7 +72,12 @@ class FP16_Module(MegatronModule): ...@@ -71,7 +72,12 @@ class FP16_Module(MegatronModule):
self.add_module('module', module.half()) self.add_module('module', module.half())
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) if mpu.is_pipeline_first_stage():
inputs = fp32_to_fp16(inputs)
outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage():
outputs = fp16_to_fp32(outputs)
return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars) return self.module.state_dict(destination, prefix, keep_vars)
...@@ -214,7 +220,7 @@ class FP16_Optimizer(object): ...@@ -214,7 +220,7 @@ class FP16_Optimizer(object):
master_param = param.detach().clone().float() master_param = param.detach().clone().float()
master_param.requires_grad = True master_param.requires_grad = True
# Copythe model parallel flag. # Copythe model parallel flag.
master_param.model_parallel = param.model_parallel master_param.tensor_model_parallel = param.tensor_model_parallel
param_group['params'][i] = master_param param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param) fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param. # Reset existing state dict key to the new master param.
...@@ -512,7 +518,8 @@ class FP16_Optimizer(object): ...@@ -512,7 +518,8 @@ class FP16_Optimizer(object):
return retval return retval
def backward(self, loss, update_master_grads=True, retain_graph=False): def backward(self, output_tensor, update_master_grads=True, retain_graph=False,
output_tensor_grad=None):
""" """
:attr:`backward` performs the following conceptual steps: :attr:`backward` performs the following conceptual steps:
...@@ -570,7 +577,8 @@ class FP16_Optimizer(object): ...@@ -570,7 +577,8 @@ class FP16_Optimizer(object):
# a loss scale that works. After you find a loss scale that works, do a final dummy # a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid # backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency. # discarding the iteration, but probably wouldn't improve overall efficiency.
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) self.loss_scaler.backward(output_tensor, retain_graph=retain_graph,
output_tensor_grad=output_tensor_grad)
if update_master_grads: if update_master_grads:
self.update_master_grads() self.update_master_grads()
......
...@@ -68,9 +68,17 @@ class LossScaler: ...@@ -68,9 +68,17 @@ class LossScaler:
self.loss_scale) self.loss_scale)
return grad_in return grad_in
def backward(self, loss, retain_graph=False): def backward(self, output_tensor, retain_graph=False, output_tensor_grad=None):
scaled_loss = loss * self.loss_scale # If output_tensor_grad is None, this is the last stage, and
scaled_loss.backward(retain_graph=retain_graph) # output_tensor is actually the loss and needs to be scaled.
# Otherwise, output_tensor does not need to be scaled again since
# output_tensor_grad is already scaled.
if output_tensor_grad is None:
scaled_output_tensor = output_tensor * self.loss_scale
else:
scaled_output_tensor = output_tensor
torch.autograd.backward(scaled_output_tensor, grad_tensors=output_tensor_grad,
retain_graph=retain_graph)
class DynamicLossScaler: class DynamicLossScaler:
...@@ -196,9 +204,17 @@ class DynamicLossScaler: ...@@ -196,9 +204,17 @@ class DynamicLossScaler:
self.loss_scale) self.loss_scale)
return grad_in return grad_in
def backward(self, loss, retain_graph=False): def backward(self, output_tensor, retain_graph=False, output_tensor_grad=None):
scaled_loss = loss * self.loss_scale # If output_tensor_grad is None, this is the last stage, and
scaled_loss.backward(retain_graph=retain_graph) # output_tensor is actually the loss and needs to be scaled.
# Otherwise, output_tensor does not need to be scaled again since
# output_tensor_grad is already scaled.
if output_tensor_grad is None:
scaled_output_tensor = output_tensor * self.loss_scale
else:
scaled_output_tensor = output_tensor
torch.autograd.backward(scaled_output_tensor, grad_tensors=output_tensor_grad,
retain_graph=retain_graph)
############################################################## ##############################################################
......
...@@ -80,7 +80,7 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -80,7 +80,7 @@ __global__ void scaled_masked_softmax_warp_forward(
const input_t *src, const input_t *src,
const uint8_t *mask, const uint8_t *mask,
const acc_t scale, const acc_t scale,
int batch_size, int micro_batch_size,
int stride, int stride,
int element_count, int element_count,
int pad_batches) int pad_batches)
...@@ -102,9 +102,9 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -102,9 +102,9 @@ __global__ void scaled_masked_softmax_warp_forward(
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
} }
// batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = batch_size - first_batch; int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH; local_batches = WARP_BATCH;
...@@ -184,7 +184,7 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -184,7 +184,7 @@ __global__ void scaled_masked_softmax_warp_backward(
input_t *grad, input_t *grad,
const input_t *output, const input_t *output,
acc_t scale, acc_t scale,
int batch_size, int micro_batch_size,
int stride, int stride,
int element_count) int element_count)
{ {
...@@ -199,9 +199,9 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -199,9 +199,9 @@ __global__ void scaled_masked_softmax_warp_backward(
// gridDim/blockIdx = (seq_len, attn_heads, batches) // gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = batch_size - first_batch; int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH; local_batches = WARP_BATCH;
......
...@@ -79,7 +79,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -79,7 +79,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst, output_t *dst,
const input_t *src, const input_t *src,
const acc_t scale, const acc_t scale,
int batch_size, int micro_batch_size,
int stride, int stride,
int element_count) int element_count)
{ {
...@@ -94,9 +94,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -94,9 +94,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE; int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE;
// batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = batch_size - first_batch; int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH; local_batches = WARP_BATCH;
...@@ -173,7 +173,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -173,7 +173,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
input_t *grad, input_t *grad,
const input_t *output, const input_t *output,
acc_t scale, acc_t scale,
int batch_size, int micro_batch_size,
int stride, int stride,
int element_count) int element_count)
{ {
...@@ -187,9 +187,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -187,9 +187,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
// batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = batch_size - first_batch; int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH; local_batches = WARP_BATCH;
......
...@@ -23,8 +23,10 @@ import torch ...@@ -23,8 +23,10 @@ import torch
from megatron.tokenizer import build_tokenizer from megatron.tokenizer import build_tokenizer
from .arguments import parse_args from .arguments import parse_args
from .microbatches import build_num_microbatches_calculator
_GLOBAL_ARGS = None _GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None _GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None _GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None _GLOBAL_ADLR_AUTORESUME = None
...@@ -37,6 +39,19 @@ def get_args(): ...@@ -37,6 +39,19 @@ def get_args():
return _GLOBAL_ARGS return _GLOBAL_ARGS
def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def get_current_global_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def update_num_microbatches(consumed_samples, consistency_check=True):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
consistency_check)
def get_tokenizer(): def get_tokenizer():
"""Return tokenizer.""" """Return tokenizer."""
_ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer') _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
...@@ -67,6 +82,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={}, ...@@ -67,6 +82,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
args = _parse_args(extra_args_provider=extra_args_provider, args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults, defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args) ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args)
_ = _build_tokenizer(args) _ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
...@@ -84,6 +100,16 @@ def _parse_args(extra_args_provider=None, defaults={}, ...@@ -84,6 +100,16 @@ def _parse_args(extra_args_provider=None, defaults={},
return _GLOBAL_ARGS return _GLOBAL_ARGS
def _build_num_microbatches_calculator(args):
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
'num microbatches calculator')
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
args)
def _build_tokenizer(args): def _build_tokenizer(args):
"""Initialize tokenizer.""" """Initialize tokenizer."""
global _GLOBAL_TOKENIZER global _GLOBAL_TOKENIZER
...@@ -105,7 +131,7 @@ def _set_tensorboard_writer(args): ...@@ -105,7 +131,7 @@ def _set_tensorboard_writer(args):
'tensorboard writer') 'tensorboard writer')
if hasattr(args, 'tensorboard_dir') and \ if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == 0: args.tensorboard_dir and args.rank == (args.world_size -1):
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
print('> setting tensorboard ...') print('> setting tensorboard ...')
...@@ -216,7 +242,7 @@ class Timers: ...@@ -216,7 +242,7 @@ class Timers:
assert normalizer > 0.0 assert normalizer > 0.0
for name in names: for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '_time', value, iteration) writer.add_scalar(name + '-time', value, iteration)
def log(self, names, normalizer=1.0, reset=True): def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers.""" """Log a group of timers."""
...@@ -227,7 +253,8 @@ class Timers: ...@@ -227,7 +253,8 @@ class Timers:
reset=reset) * 1000.0 / normalizer reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time) string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(string, flush=True) print(string, flush=True)
else: else:
print(string, flush=True) print(string, flush=True)
...@@ -26,7 +26,7 @@ from megatron import get_args ...@@ -26,7 +26,7 @@ from megatron import get_args
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
from megatron.mpu import set_model_parallel_rank, set_model_parallel_world_size from megatron.mpu import set_tensor_model_parallel_rank, set_tensor_model_parallel_world_size
def initialize_megatron(extra_args_provider=None, args_defaults={}, def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False): ignore_unknown_args=False, allow_no_cuda=False):
...@@ -65,9 +65,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -65,9 +65,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args.use_cpu_initialization=True args.use_cpu_initialization=True
# delayed initialization of DDP-related stuff # delayed initialization of DDP-related stuff
# We only set basic DDP globals # We only set basic DDP globals
set_model_parallel_world_size(args.model_parallel_size) set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager to call when it has DDP initialized # and return function for external DDP manager to call when it has DDP initialized
set_model_parallel_rank(args.rank) set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init return finish_mpu_init
else: else:
# Megatron's MPU is the master. Complete initialization right away. # Megatron's MPU is the master. Complete initialization right away.
...@@ -79,8 +79,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -79,8 +79,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Autoresume. # Autoresume.
_init_autoresume() _init_autoresume()
# Write arguments to tensorboard.
_write_args_to_tensorboard()
# No continuation function # No continuation function
return None return None
...@@ -121,12 +119,14 @@ def _initialize_distributed(): ...@@ -121,12 +119,14 @@ def _initialize_distributed():
world_size=args.world_size, rank=args.rank, world_size=args.world_size, rank=args.rank,
init_method=init_method) init_method=init_method)
# Set the model-parallel / data-parallel communicators. # Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
if device_count > 0: if device_count > 0:
if mpu.model_parallel_is_initialized(): if mpu.model_parallel_is_initialized():
print('model parallel is already initialized') print('model parallel is already initialized')
else: else:
mpu.initialize_model_parallel(args.model_parallel_size) mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
def _init_autoresume(): def _init_autoresume():
...@@ -138,9 +138,11 @@ def _init_autoresume(): ...@@ -138,9 +138,11 @@ def _init_autoresume():
torch.distributed.barrier() torch.distributed.barrier()
def _set_random_seed(seed): def _set_random_seed(seed_):
"""Set random seed for reproducability.""" """Set random seed for reproducability."""
if seed is not None and seed > 0: if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds.
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
...@@ -150,13 +152,14 @@ def _set_random_seed(seed): ...@@ -150,13 +152,14 @@ def _set_random_seed(seed):
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
def _write_args_to_tensorboard(): def write_args_to_tensorboard():
"""Write arguments to tensorboard.""" """Write arguments to tensorboard."""
args = get_args() args = get_args()
writer = get_tensorboard_writer() writer = get_tensorboard_writer()
if writer: if writer:
for arg in vars(args): for arg in vars(args):
writer.add_text(arg, str(getattr(args, arg))) writer.add_text(arg, str(getattr(args, arg)),
global_step=args.iteration)
def _initialize_mem_buffs(): def _initialize_mem_buffs():
......
...@@ -23,8 +23,7 @@ class AnnealingLR(object): ...@@ -23,8 +23,7 @@ class AnnealingLR(object):
"""Anneals the learning rate.""" """Anneals the learning rate."""
def __init__(self, optimizer, max_lr, min_lr, def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, warmup_steps, decay_steps, decay_style,
decay_style, num_steps,
use_checkpoint_lr_scheduler=True, use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False): override_lr_scheduler=False):
...@@ -37,7 +36,7 @@ class AnnealingLR(object): ...@@ -37,7 +36,7 @@ class AnnealingLR(object):
assert self.max_lr >= self.min_lr assert self.max_lr >= self.min_lr
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.num_steps = num_steps self.num_steps = 0
self.decay_steps = decay_steps self.decay_steps = decay_steps
assert self.decay_steps > 0 assert self.decay_steps > 0
assert self.warmup_steps < self.decay_steps assert self.warmup_steps < self.decay_steps
...@@ -51,7 +50,7 @@ class AnnealingLR(object): ...@@ -51,7 +50,7 @@ class AnnealingLR(object):
'use-checkpoint are set.' 'use-checkpoint are set.'
# Set the learning rate # Set the learning rate
self.step(step_num=self.num_steps) self.step(0)
print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
...@@ -92,11 +91,9 @@ class AnnealingLR(object): ...@@ -92,11 +91,9 @@ class AnnealingLR(object):
return self.min_lr + coeff * delta_lr return self.min_lr + coeff * delta_lr
def step(self, increment=1, step_num=None): def step(self, increment):
"""Set lr for all parameters groups.""" """Set lr for all parameters groups."""
if step_num is None: self.num_steps += increment
step_num = self.num_steps + increment
self.num_steps = step_num
new_lr = self.get_lr() new_lr = self.get_lr()
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr
...@@ -122,8 +119,9 @@ class AnnealingLR(object): ...@@ -122,8 +119,9 @@ class AnnealingLR(object):
return cls_value return cls_value
if not self.use_checkpoint_lr_scheduler: if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, 'AnnealingLR: class input value' \ assert cls_value == sd_value, \
'and checkpoint values for {} do not match'.format(name) f'AnnealingLR: class input value {cls_value} and checkpoint' \
f'value {sd_value} for {name} do not match'
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value, print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name)) name))
return sd_value return sd_value
...@@ -160,7 +158,7 @@ class AnnealingLR(object): ...@@ -160,7 +158,7 @@ class AnnealingLR(object):
'decay style') 'decay style')
if 'num_iters' in sd: if 'num_iters' in sd:
self.num_steps = sd['num_iters'] num_steps = sd['num_iters']
else: else:
self.num_steps = sd['num_steps'] num_steps = sd['num_steps']
self.step(step_num=self.num_steps) self.step(increment=num_steps)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron number of micro-batches calculators."""
from abc import ABC
from abc import abstractmethod
def build_num_microbatches_calculator(args):
# Constant num micro-batches.
if args.rampup_batch_size is None:
num_microbatches_calculator = ConstantNumMicroBatches(
args.global_batch_size, args.micro_batch_size,
args.data_parallel_size)
if args.rank == 0:
print('setting number of micro-batches to constant {}'.format(
num_microbatches_calculator.get()), flush=True)
else:
assert len(args.rampup_batch_size) == 3, 'expected the following ' \
'format: --rampup-batch-size <start batch size> ' \
'<batch size incerement> <ramp-up samples>'
start_batch_size = int(args.rampup_batch_size[0])
batch_size_increment = int(args.rampup_batch_size[1])
ramup_samples = int(args.rampup_batch_size[2])
if args.rank == 0:
print('will use batch size rampup starting from global batch '
'size {} to global batch size {} with batch size increments '
'{} over {} samples.'.format(start_batch_size,
args.global_batch_size,
batch_size_increment,
ramup_samples), flush=True)
num_microbatches_calculator = RampupBatchsizeNumMicroBatches(
start_batch_size, batch_size_increment, ramup_samples,
args.global_batch_size, args.micro_batch_size,
args.data_parallel_size)
return num_microbatches_calculator
class NumMicroBatchesCalculator(ABC):
def __init__(self):
self.num_micro_batches = None
self.current_global_batch_size = None
def get(self):
return self.num_micro_batches
def get_current_global_batch_size(self):
return self.current_global_batch_size
@abstractmethod
def update(self, consumed_samples, consistency_check):
pass
class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
micro_batch_times_data_parallel = micro_batch_size * \
data_parallel_size
assert global_batch_size % micro_batch_times_data_parallel == 0, \
'global batch size ({}) is not divisible by micro batch size ({})' \
' times data parallel size ({})'.format(global_batch_size,
micro_batch_size,
data_parallel_size)
self.num_micro_batches = global_batch_size // \
micro_batch_times_data_parallel
assert self.num_micro_batches >= 1
self.current_global_batch_size = global_batch_size
def update(self, consumed_samples, consistency_check):
pass
class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, start_batch_size, batch_size_increment, ramup_samples,
global_batch_size, micro_batch_size, data_parallel_size):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
self.data_parallel_size
assert self.micro_batch_times_data_parallel_size > 0
assert start_batch_size > 0
self.start_batch_size = start_batch_size
assert global_batch_size > 0
self.global_batch_size = global_batch_size
diff_batch_size = self.global_batch_size - self.start_batch_size
assert diff_batch_size >= 0
assert batch_size_increment > 0
self.batch_size_increment = batch_size_increment
assert diff_batch_size % batch_size_increment == 0, 'expected ' \
'global batch size interval ({}) to be divisible by global batch ' \
'size increment ({})'.format(diff_batch_size, batch_size_increment)
num_increments = diff_batch_size // self.batch_size_increment
self.ramup_samples = ramup_samples
assert self.ramup_samples >= 0
self.rampup_samples_per_increment = self.ramup_samples / num_increments
# Initialize number of microbatches.
self.update(0, False)
def update(self, consumed_samples, consistency_check):
if consumed_samples > self.ramup_samples:
self.current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
self.current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
assert self.current_global_batch_size <= self.global_batch_size
if consistency_check:
assert self.current_global_batch_size % \
self.micro_batch_times_data_parallel_size == 0, 'current global ' \
'batch size ({}) is not divisible by micro-batch-size ({}) times' \
'data parallel size ({})'.format(self.current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size)
self.num_micro_batches = self.current_global_batch_size // \
self.micro_batch_times_data_parallel_size
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
from .distributed import * from .distributed import *
from .bert_model import BertModel from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from .realm_model import ICTBertModel from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from .utils import get_params_for_weight_decay_optimization from .utils import get_params_for_weight_decay_optimization
from .language_model import get_language_model from .language_model import get_language_model
...@@ -26,7 +26,7 @@ from megatron.model.utils import openai_gelu, erf_gelu ...@@ -26,7 +26,7 @@ from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule from megatron.module import MegatronModule, PipelinedMegatronModule
def bert_attention_mask_func(attention_scores, attention_mask): def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0) attention_scores.masked_fill_(attention_mask, -10000.0)
...@@ -77,7 +77,7 @@ class BertLMHead(MegatronModule): ...@@ -77,7 +77,7 @@ class BertLMHead(MegatronModule):
args = get_args() args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True self.bias.tensor_model_parallel = True
self.bias.partition_dim = 0 self.bias.partition_dim = 0
self.bias.stride = 1 self.bias.stride = 1
self.parallel_output = parallel_output self.parallel_output = parallel_output
...@@ -101,17 +101,43 @@ class BertLMHead(MegatronModule): ...@@ -101,17 +101,43 @@ class BertLMHead(MegatronModule):
return output return output
class BertModel(MegatronModule): def post_language_model_processing(lm_output, pooled_output,
lm_head, binary_head,
lm_labels,
logit_weights,
fp16_lm_cross_entropy):
# Output.
lm_logits = lm_head(
lm_output, logit_weights)
binary_logits = None
if binary_head is not None:
binary_logits = binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits
else:
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss, binary_logits
class BertModelBase(PipelinedMegatronModule):
"""Bert Language model.""" """Bert Language model."""
def __init__(self, num_tokentypes=2, add_binary_head=True, def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True): parallel_output=True):
super(BertModel, self).__init__() super(BertModelBase, self).__init__()
args = get_args() args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head self.add_binary_head = add_binary_head
self.parallel_output = parallel_output self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std, scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers) args.num_layers)
...@@ -123,52 +149,45 @@ class BertModel(MegatronModule): ...@@ -123,52 +149,45 @@ class BertModel(MegatronModule):
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
self.lm_head = BertLMHead( self.initialize_word_embeddings(init_method_normal)
self.language_model.embedding.word_embeddings.weight.size(0), if mpu.is_pipeline_last_stage():
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) self.lm_head = BertLMHead(
self._lm_head_key = 'lm_head' self.word_embeddings_weight().size(0),
if self.add_binary_head: args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
self.binary_head = get_linear_layer(args.hidden_size, 2, self._lm_head_key = 'lm_head'
init_method) self.binary_head = None
self._binary_head_key = 'binary_head' if self.add_binary_head:
self.binary_head = get_linear_layer(args.hidden_size, 2,
def forward(self, input_ids, attention_mask, init_method)
self._binary_head_key = 'binary_head'
def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None): tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
kwargs = {}
if self.add_binary_head: if mpu.is_pipeline_first_stage():
lm_output, pooled_output = self.language_model( input_ids = bert_model_input
input_ids, position_ids = bert_position_ids(input_ids)
position_ids, args = [input_ids, position_ids, extended_attention_mask]
extended_attention_mask, kwargs['tokentype_ids'] = tokentype_ids
tokentype_ids=tokentype_ids)
else: else:
lm_output = self.language_model( args = [bert_model_input, extended_attention_mask]
input_ids, lm_output = self.language_model(*args, **kwargs)
position_ids, if mpu.is_pipeline_last_stage() and self.add_binary_head:
extended_attention_mask, lm_output, pooled_output = lm_output
tokentype_ids=tokentype_ids)
# Output.
lm_logits = self.lm_head(
lm_output, self.language_model.embedding.word_embeddings.weight)
binary_logits = None
if self.add_binary_head:
binary_logits = self.binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits
else: else:
if self.fp16_lm_cross_entropy: pooled_output = None
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) if mpu.is_pipeline_last_stage():
else: return post_language_model_processing(lm_output, pooled_output,
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), self.lm_head, self.binary_head,
lm_labels) lm_labels,
return lm_loss, binary_logits self.word_embeddings_weight(),
self.fp16_lm_cross_entropy)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
...@@ -180,12 +199,17 @@ class BertModel(MegatronModule): ...@@ -180,12 +199,17 @@ class BertModel(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._lm_head_key] \ if mpu.is_pipeline_last_stage():
= self.lm_head.state_dict_for_save_checkpoint( state_dict_[self._lm_head_key] \
destination, prefix, keep_vars) = self.lm_head.state_dict_for_save_checkpoint(
if self.add_binary_head: destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage() and self.add_binary_head:
state_dict_[self._binary_head_key] \ state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars) = self.binary_head.state_dict(destination, prefix, keep_vars)
# Save word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
...@@ -193,8 +217,74 @@ class BertModel(MegatronModule): ...@@ -193,8 +217,74 @@ class BertModel(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict( if mpu.is_pipeline_last_stage():
state_dict[self._lm_head_key], strict=strict) self.lm_head.load_state_dict(
if self.add_binary_head: state_dict[self._lm_head_key], strict=strict)
if mpu.is_pipeline_last_stage() and self.add_binary_head:
self.binary_head.load_state_dict( self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict) state_dict[self._binary_head_key], strict=strict)
# Load word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
class BertModel(BertModelBase):
def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True):
super(BertModel, self).__init__(
num_tokentypes=num_tokentypes,
add_binary_head=add_binary_head,
parallel_output=parallel_output)
def forward(self, input_ids, attention_mask,
tokentype_ids=None, lm_labels=None):
return super(BertModel, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids,
lm_labels=lm_labels)
class BertModelFirstStage(BertModelBase):
def __init__(self, num_tokentypes=2):
super(BertModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(BertModelFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class BertModelIntermediateStage(BertModelBase):
def __init__(self, num_tokentypes=2):
super(BertModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(BertModelIntermediateStage, self).forward(
hidden_state,
attention_mask)
class BertModelLastStage(BertModelBase):
def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True):
super(BertModelLastStage, self).__init__(
num_tokentypes=num_tokentypes,
add_binary_head=add_binary_head,
parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask,
lm_labels=None):
return super(BertModelLastStage, self).forward(
hidden_state,
attention_mask,
lm_labels=lm_labels)
...@@ -18,18 +18,19 @@ ...@@ -18,18 +18,19 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule from megatron.module import PipelinedMegatronModule
class Classification(MegatronModule): class ClassificationBase(PipelinedMegatronModule):
def __init__(self, num_classes, num_tokentypes=2): def __init__(self, num_classes, num_tokentypes=2):
super(Classification, self).__init__() super(ClassificationBase, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
self.num_classes = num_classes self.num_classes = num_classes
...@@ -50,24 +51,30 @@ class Classification(MegatronModule): ...@@ -50,24 +51,30 @@ class Classification(MegatronModule):
init_method) init_method)
self._classification_head_key = 'classification_head' self._classification_head_key = 'classification_head'
def forward(self, input_ids, attention_mask, tokentype_ids): def forward(self, model_input, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids, kwargs = {}
position_ids, if mpu.is_pipeline_first_stage():
extended_attention_mask, input_ids = model_input
tokentype_ids=tokentype_ids) position_ids = bert_position_ids(input_ids)
# Output. args = [input_ids, position_ids, extended_attention_mask]
classification_output = self.classification_dropout(pooled_output) kwargs['tokentype_ids'] = tokentype_ids
classification_logits = self.classification_head(classification_output) else:
args = [model_input, extended_attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage():
_, pooled_output = lm_output
classification_output = self.classification_dropout(pooled_output)
classification_logits = self.classification_head(classification_output)
# Reshape back to separate choices. # Reshape back to separate choices.
classification_logits = classification_logits.view(-1, self.num_classes) classification_logits = classification_logits.view(-1, self.num_classes)
return classification_logits return classification_logits
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -95,3 +102,55 @@ class Classification(MegatronModule): ...@@ -95,3 +102,55 @@ class Classification(MegatronModule):
print_rank_0('***WARNING*** could not find {} in the checkpoint, ' print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format( 'initializing to random'.format(
self._classification_head_key)) self._classification_head_key))
class Classification(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(Classification, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(Classification, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class ClassificationFirstStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationFirstStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(ClassificationFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class ClassificationIntermediateStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationIntermediateStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(ClassificationIntermediateStage, self).forward(
hidden_state,
attention_mask)
class ClassificationLastStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationLastStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(ClassificationLastStage, self).forward(
hidden_state,
attention_mask)
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import PipelinedMegatronModule
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
from .language_model import get_language_model from .language_model import get_language_model
...@@ -32,11 +32,40 @@ def gpt2_attention_mask_func(attention_scores, ltor_mask): ...@@ -32,11 +32,40 @@ def gpt2_attention_mask_func(attention_scores, ltor_mask):
return attention_scores return attention_scores
class GPT2Model(MegatronModule): def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output,
forward_method_parallel_output,
fp16_lm_cross_entropy):
if get_key_value:
lm_output, presents = lm_output
# Output.
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits(
lm_output,
logit_weights,
parallel_output)
if get_key_value:
output = [output, presents]
if labels is None:
return output
else:
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels)
else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
return loss
class GPT2ModelBase(PipelinedMegatronModule):
"""GPT-2 Language model.""" """GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2Model, self).__init__() super(GPT2ModelBase, self).__init__()
args = get_args() args = get_args()
self.parallel_output = parallel_output self.parallel_output = parallel_output
...@@ -50,43 +79,31 @@ class GPT2Model(MegatronModule): ...@@ -50,43 +79,31 @@ class GPT2Model(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
def forward(self, input_ids, position_ids, attention_mask, labels=None, self.initialize_word_embeddings(init_method_normal)
def forward(self, gpt2_model_input, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
# Language model. kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value}
lm_output = self.language_model(input_ids, if mpu.is_pipeline_first_stage():
position_ids, (input_ids, position_ids) = gpt2_model_input
attention_mask, args = [input_ids, position_ids, attention_mask]
tokentype_ids=tokentype_ids, kwargs['tokentype_ids'] = tokentype_ids
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value:
lm_output, presents = lm_output
# Output.
parallel_output = self.parallel_output
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits(
lm_output,
self.language_model.embedding.word_embeddings.weight,
parallel_output)
if get_key_value:
output = [output, presents]
if labels is None:
return output
else: else:
if self.fp16_lm_cross_entropy: args = [gpt2_model_input, attention_mask]
assert output.dtype == torch.half lm_output = self.language_model(*args, **kwargs)
loss = mpu.vocab_parallel_cross_entropy(output, labels)
else: if mpu.is_pipeline_last_stage():
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return post_language_model_processing(
return loss lm_output, labels,
self.word_embeddings_weight(),
get_key_value,
self.parallel_output,
forward_method_parallel_output,
self.fp16_lm_cross_entropy)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -95,11 +112,89 @@ class GPT2Model(MegatronModule): ...@@ -95,11 +112,89 @@ class GPT2Model(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
# Save word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
# Load word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict: if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key] state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict) self.language_model.load_state_dict(state_dict, strict=strict)
class GPT2Model(GPT2ModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2Model, self).__init__(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
return super(GPT2Model, self).forward(
(input_ids, position_ids),
attention_mask,
labels=labels,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
class GPT2ModelFirstStage(GPT2ModelBase):
def __init__(self, num_tokentypes=0):
super(GPT2ModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
return super(GPT2ModelFirstStage, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
class GPT2ModelIntermediateStage(GPT2ModelBase):
def __init__(self, num_tokentypes=0):
super(GPT2ModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask,
layer_past=None, get_key_value=False):
return super(GPT2ModelIntermediateStage, self).forward(
hidden_state,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
class GPT2ModelLastStage(GPT2ModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2ModelLastStage, self).__init__(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask, labels=None,
layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
return super(GPT2ModelLastStage, self).forward(
hidden_state,
attention_mask,
labels=labels,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
...@@ -29,7 +29,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -29,7 +29,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None): bias=None):
"""LM logits using word embedding weights.""" """LM logits using word embedding weights."""
# Parallel logits. # Parallel logits.
input_parallel = mpu.copy_to_model_parallel_region(input_) input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
if bias is None: if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight) logits_parallel = F.linear(input_parallel, word_embeddings_weight)
...@@ -39,7 +39,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -39,7 +39,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
return mpu.gather_from_model_parallel_region(logits_parallel) return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
...@@ -54,12 +54,24 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -54,12 +54,24 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
# Language model. # Language model.
language_model = TransformerLanguageModel( args = [attention_mask_func, init_method, scaled_init_method]
attention_mask_func=attention_mask_func, kwargs = {}
init_method=init_method, cls = None
output_layer_init_method=scaled_init_method, if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
num_tokentypes=num_tokentypes, cls = TransformerLanguageModel
add_pooler=add_pooler) kwargs['num_tokentypes'] = num_tokentypes
kwargs['add_pooler'] = add_pooler
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelFirstStage
kwargs['num_tokentypes'] = num_tokentypes
elif not mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelLastStage
kwargs['add_pooler'] = add_pooler
else:
cls = TransformerLanguageModelIntermediateStage
# Language model.
language_model = cls(*args, **kwargs)
# key used for checkpoints. # key used for checkpoints.
language_model_key = 'language_model' language_model_key = 'language_model'
...@@ -118,9 +130,12 @@ class Embedding(MegatronModule): ...@@ -118,9 +130,12 @@ class Embedding(MegatronModule):
self.init_method = init_method self.init_method = init_method
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
args = get_args()
# Word embeddings (parallel). # Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding( self.word_embeddings = mpu.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method) vocab_size, self.hidden_size,
init_method=self.init_method)
self._word_embeddings_key = 'word_embeddings' self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial). # Position embedding (serial).
...@@ -160,6 +175,7 @@ class Embedding(MegatronModule): ...@@ -160,6 +175,7 @@ class Embedding(MegatronModule):
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
self.hidden_size) self.hidden_size)
# Initialize the token-type embeddings. # Initialize the token-type embeddings.
args = get_args()
self.init_method(self.tokentype_embeddings.weight) self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None): def forward(self, input_ids, position_ids, tokentype_ids=None):
...@@ -241,7 +257,7 @@ class Embedding(MegatronModule): ...@@ -241,7 +257,7 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True) 'checkpoint but could not find it', flush=True)
class TransformerLanguageModel(MegatronModule): class TransformerLanguageModelBase(MegatronModule):
"""Transformer language model. """Transformer language model.
Arguments: Arguments:
...@@ -266,7 +282,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -266,7 +282,7 @@ class TransformerLanguageModel(MegatronModule):
output_layer_init_method, output_layer_init_method,
num_tokentypes=0, num_tokentypes=0,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModel, self).__init__() super(TransformerLanguageModelBase, self).__init__()
args = get_args() args = get_args()
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
...@@ -274,41 +290,47 @@ class TransformerLanguageModel(MegatronModule): ...@@ -274,41 +290,47 @@ class TransformerLanguageModel(MegatronModule):
self.init_method = init_method self.init_method = init_method
self.add_pooler = add_pooler self.add_pooler = add_pooler
# Embeddings # Embeddings.
self.embedding = Embedding(self.hidden_size, if mpu.is_pipeline_first_stage():
args.padded_vocab_size, self.embedding = Embedding(self.hidden_size,
args.max_position_embeddings, args.padded_vocab_size,
args.hidden_dropout, args.max_position_embeddings,
self.init_method, args.hidden_dropout,
self.num_tokentypes) self.init_method,
self._embedding_key = 'embedding' self.num_tokentypes)
self._embedding_key = 'embedding'
# Transformer # Transformer.
self.transformer = ParallelTransformer( self.transformer = ParallelTransformer(
attention_mask_func, self.init_method, attention_mask_func, self.init_method,
output_layer_init_method) output_layer_init_method)
self._transformer_key = 'transformer' self._transformer_key = 'transformer'
# Pooler # Pooler.
if self.add_pooler: if mpu.is_pipeline_last_stage() and self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler' self._pooler_key = 'pooler'
def forward(self, input_ids, position_ids, attention_mask, def forward(self, language_model_input, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0): pooling_sequence_index=0):
# Embeddings. # Embeddings.
embedding_output = self.embedding(input_ids, position_ids, if mpu.is_pipeline_first_stage():
tokentype_ids=tokentype_ids) (input_ids, position_ids) = language_model_input
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids)
transformer_input = embedding_output
else:
transformer_input = language_model_input
# Transformer. # Transformer.
transformer_output = self.transformer(embedding_output, transformer_output = self.transformer(transformer_input,
attention_mask, attention_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value) get_key_value=get_key_value)
if self.add_pooler: if mpu.is_pipeline_last_stage() and self.add_pooler:
pooled_output = self.pooler(transformer_output, pooled_output = self.pooler(transformer_output,
pooling_sequence_index) pooling_sequence_index)
return transformer_output, pooled_output return transformer_output, pooled_output
...@@ -320,13 +342,14 @@ class TransformerLanguageModel(MegatronModule): ...@@ -320,13 +342,14 @@ class TransformerLanguageModel(MegatronModule):
"""For easy load.""" """For easy load."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._embedding_key] \ if mpu.is_pipeline_first_stage():
= self.embedding.state_dict_for_save_checkpoint( state_dict_[self._embedding_key] \
destination, prefix, keep_vars) = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._transformer_key] \ state_dict_[self._transformer_key] \
= self.transformer.state_dict_for_save_checkpoint( = self.transformer.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if self.add_pooler: if mpu.is_pipeline_last_stage() and self.add_pooler:
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint( = self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
...@@ -337,15 +360,16 @@ class TransformerLanguageModel(MegatronModule): ...@@ -337,15 +360,16 @@ class TransformerLanguageModel(MegatronModule):
"""Customized load.""" """Customized load."""
# Embedding. # Embedding.
if self._embedding_key in state_dict: if mpu.is_pipeline_first_stage():
state_dict_ = state_dict[self._embedding_key] if self._embedding_key in state_dict:
else: state_dict_ = state_dict[self._embedding_key]
# for backward compatibility. else:
state_dict_ = {} # for backward compatibility.
for key in state_dict.keys(): state_dict_ = {}
if '_embeddings' in key: for key in state_dict.keys():
state_dict_[key] = state_dict[key] if '_embeddings' in key:
self.embedding.load_state_dict(state_dict_, strict=strict) state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Transformer. # Transformer.
if self._transformer_key in state_dict: if self._transformer_key in state_dict:
...@@ -359,8 +383,118 @@ class TransformerLanguageModel(MegatronModule): ...@@ -359,8 +383,118 @@ class TransformerLanguageModel(MegatronModule):
self.transformer.load_state_dict(state_dict_, strict=strict) self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler. # Pooler.
if self.add_pooler: if mpu.is_pipeline_last_stage() and self.add_pooler:
assert 'pooler' in state_dict, \ assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict) strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase):
"""Transformer language model (see TransformerLanguageModelBase
for description of arguments).
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=0,
add_pooler=False):
super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=num_tokentypes,
add_pooler=add_pooler)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
return super(TransformerLanguageModel, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
)
class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
"""Transformer language model, first stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
return super(TransformerLanguageModelFirstStage, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value
)
class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
"""Transformer language model, intermediate stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method):
super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False):
return super(TransformerLanguageModelIntermediateStage, self).forward(
hidden_states,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value
)
class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
"""Transformer language model, final stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False,
pooling_sequence_index=0):
return super(TransformerLanguageModelLastStage, self).forward(
hidden_states,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
)
...@@ -18,18 +18,19 @@ ...@@ -18,18 +18,19 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule from megatron.module import PipelinedMegatronModule
class MultipleChoice(MegatronModule): class MultipleChoiceBase(PipelinedMegatronModule):
def __init__(self, num_tokentypes=2): def __init__(self, num_tokentypes=2):
super(MultipleChoice, self).__init__() super(MultipleChoiceBase, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
...@@ -48,38 +49,44 @@ class MultipleChoice(MegatronModule): ...@@ -48,38 +49,44 @@ class MultipleChoice(MegatronModule):
init_method) init_method)
self._multichoice_head_key = 'multichoice_head' self._multichoice_head_key = 'multichoice_head'
def forward(self, input_ids, attention_mask, tokentype_ids): def forward(self, model_input, attention_mask, tokentype_ids=None):
# [batch, choices, sequence] --> [batch * choices, sequence] --> # [batch, choices, sequence] --> [batch * choices, sequence] -->
# transformer --> [batch, choices] --> softmax # transformer --> [batch, choices] --> softmax
# Ensure the shape is [batch-size, choices, sequence] # Ensure the shape is [batch-size, choices, sequence]
assert len(input_ids.shape) == 3
assert len(attention_mask.shape) == 3 assert len(attention_mask.shape) == 3
assert len(tokentype_ids.shape) == 3 num_choices = attention_mask.shape[1]
# Reshape and treat choice dimension the same as batch. # Reshape and treat choice dimension the same as batch.
num_choices = input_ids.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1))
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) attention_mask = attention_mask.view(-1, attention_mask.size(-1))
tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output. kwargs = {}
multichoice_output = self.multichoice_dropout(pooled_output) if mpu.is_pipeline_first_stage():
multichoice_logits = self.multichoice_head(multichoice_output) input_ids = model_input
# Do the same as attention_mask for input_ids, tokentype_ids
assert len(input_ids.shape) == 3
assert len(tokentype_ids.shape) == 3
input_ids = input_ids.view(-1, input_ids.size(-1))
tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
position_ids = bert_position_ids(input_ids)
args = [input_ids, position_ids, extended_attention_mask]
kwargs['tokentype_ids'] = tokentype_ids
else:
args = [model_input, extended_attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage():
_, pooled_output = lm_output
multichoice_output = self.multichoice_dropout(pooled_output)
multichoice_logits = self.multichoice_head(multichoice_output)
# Reshape back to separate choices. # Reshape back to separate choices.
multichoice_logits = multichoice_logits.view(-1, num_choices) multichoice_logits = multichoice_logits.view(-1, num_choices)
return multichoice_logits return multichoice_logits
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -107,3 +114,54 @@ class MultipleChoice(MegatronModule): ...@@ -107,3 +114,54 @@ class MultipleChoice(MegatronModule):
print_rank_0('***WARNING*** could not find {} in the checkpoint, ' print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format( 'initializing to random'.format(
self._multichoice_head_key)) self._multichoice_head_key))
class MultipleChoice(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoice, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(MultipleChoice, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class MultipleChoiceFirstStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(MultipleChoiceFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class MultipleChoiceIntermediateStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(MultipleChoiceIntermediateStage, self).forward(
hidden_state,
attention_mask)
class MultipleChoiceLastStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceLastStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(MultipleChoiceLastStage, self).forward(
hidden_state,
attention_mask)
...@@ -18,8 +18,7 @@ def general_ict_model_provider(only_query_model=False, only_block_model=False): ...@@ -18,8 +18,7 @@ def general_ict_model_provider(only_query_model=False, only_block_model=False):
args = get_args() args = get_args()
assert args.ict_head_size is not None, \ assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel" "Need to specify --ict-head-size to provide an ICTBertModel"
assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \
assert args.model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT" "Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...') print_rank_0('building ICTBertModel...')
......
...@@ -130,7 +130,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -130,7 +130,7 @@ class ParallelSelfAttention(MegatronModule):
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size() world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size, self.hidden_size_per_partition = mpu.divide(args.hidden_size,
world_size) world_size)
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = mpu.divide(
...@@ -504,46 +504,28 @@ class ParallelTransformer(MegatronModule): ...@@ -504,46 +504,28 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_activations = args.checkpoint_activations self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers self.checkpoint_num_layers = args.checkpoint_num_layers
# Number of layers: # Number of layers.
self.num_layers = args.num_layers assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
self.num_unique_layers = args.num_unique_layers 'num_layers must be divisible by pipeline_model_parallel_size'
if self.num_unique_layers is None: self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
self.num_unique_layers = self.num_layers
assert self.num_layers % self.num_unique_layers == 0, \
'number of layers should be divisible by number of unique layers'
self.param_sharing_style = args.param_sharing_style
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
attention_mask_func, init_method, attention_mask_func, init_method,
output_layer_init_method, layer_number) output_layer_init_method, layer_number)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1) for i in range(self.num_unique_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
# Print layer ordering.
if self.num_layers != self.num_unique_layers:
if torch.distributed.get_rank() == 0:
print('> will be using the following layer ordering:')
for i in range(self.num_layers):
print(' layer id: {:3d} --> unique layer id: '
'{:3d}'.format(i, self._get_layer_index(i)),
flush=True)
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
def _get_layer_index(self, layer_number): if mpu.is_pipeline_last_stage():
if self.param_sharing_style == 'grouped': # Final layer norm before output.
return layer_number % self.num_unique_layers self.final_layernorm = LayerNorm(
if self.param_sharing_style == 'spaced': args.hidden_size,
return layer_number // (self.num_layers // self.num_unique_layers) eps=args.layernorm_epsilon)
assert False, 'should not be here'
def _get_layer(self, layer_number): def _get_layer(self, layer_number):
return self.layers[self._get_layer_index(layer_number)] return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask): def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
...@@ -570,7 +552,7 @@ class ParallelTransformer(MegatronModule): ...@@ -570,7 +552,7 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False):
# Checks # Checks.
if layer_past is not None: if layer_past is not None:
assert get_key_value, \ assert get_key_value, \
'for not None values in layer_past, ' \ 'for not None values in layer_past, ' \
...@@ -580,8 +562,9 @@ class ParallelTransformer(MegatronModule): ...@@ -580,8 +562,9 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with ' \ 'get_key_value does not work with ' \
'activation checkpointing' 'activation checkpointing'
# data format change to avoid explicit tranposes : [b s h] --> [s b h] if mpu.is_pipeline_first_stage():
hidden_states = hidden_states.transpose(0, 1).contiguous() # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
...@@ -602,11 +585,13 @@ class ParallelTransformer(MegatronModule): ...@@ -602,11 +585,13 @@ class ParallelTransformer(MegatronModule):
hidden_states, present = hidden_states hidden_states, present = hidden_states
presents.append(present) presents.append(present)
# reverting data format change [s b h] --> [b s h]
hidden_states = hidden_states.transpose(0, 1).contiguous()
# Final layer norm. # Final layer norm.
output = self.final_layernorm(hidden_states) if mpu.is_pipeline_last_stage():
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
if get_key_value: if get_key_value:
output = [output, presents] output = [output, presents]
......
...@@ -17,9 +17,12 @@ ...@@ -17,9 +17,12 @@
import torch import torch
from megatron import get_args
from megatron import mpu
class MegatronModule(torch.nn.Module): class MegatronModule(torch.nn.Module):
"""Megatron specific extentions of torch Module.""" """Megatron specific extensions of torch Module."""
def __init__(self): def __init__(self):
super(MegatronModule, self).__init__() super(MegatronModule, self).__init__()
...@@ -29,3 +32,54 @@ class MegatronModule(torch.nn.Module): ...@@ -29,3 +32,54 @@ class MegatronModule(torch.nn.Module):
"""Use this function to override the state dict for """Use this function to override the state dict for
saving checkpoints.""" saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars) return self.state_dict(destination, prefix, keep_vars)
class PipelinedMegatronModule(MegatronModule):
"""Pipelining specific extensions of MegatronModule."""
def __init__(self, share_word_embeddings=True):
super(PipelinedMegatronModule, self).__init__()
args = get_args()
self.share_word_embeddings = share_word_embeddings
def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage():
return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage():
if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last stage, '
'but share_word_embeddings is false')
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
def initialize_word_embeddings(self, init_method_normal):
args = get_args()
if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false')
# Parameters are shared between the word embeddings layer, and the heads at
# the end of the model. In a pipelined setup with more than one stage, the
# initial embedding layer and the head are on different workers, so we do
# the following:
# 1. Create a second copy of word_embeddings on the last stage, with initial
# parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that the
# two copies of word_embeddings start off with the same parameter values.
# 3. In the training loop, before an all-reduce between the grads of the two
# word_embeddings layers to ensure that every applied weight update is the
# same on both stages.
if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage():
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings
# weights to 0 here, then copy first stage's weights using all_reduce
# below.
self.word_embeddings = mpu.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0)
# Ensure that first and last stages have the same initial parameter values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
...@@ -26,10 +26,18 @@ from .initialize import destroy_model_parallel ...@@ -26,10 +26,18 @@ from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
from .initialize import get_data_parallel_world_size from .initialize import get_data_parallel_world_size
from .initialize import get_embedding_group
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank, set_model_parallel_rank from .initialize import get_tensor_model_parallel_group
from .initialize import get_model_parallel_src_rank from .initialize import get_pipeline_model_parallel_group
from .initialize import get_model_parallel_world_size, set_model_parallel_world_size from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
...@@ -38,10 +46,10 @@ from .layers import ColumnParallelLinear ...@@ -38,10 +46,10 @@ from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
from .mappings import copy_to_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_model_parallel_region from .mappings import scatter_to_tensor_model_parallel_region
from .random import checkpoint from .random import checkpoint
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
......
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
import torch import torch
from .initialize import get_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
from .utils import VocabUtility from .utils import VocabUtility
...@@ -31,15 +31,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -31,15 +31,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max, torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group()) group=get_tensor_model_parallel_group())
# Subtract the maximum value. # Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies # Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1] partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank()
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range( vocab_start_index, vocab_end_index = get_vocab_range(
partition_vocab_size, rank, world_size) partition_vocab_size, rank, world_size)
...@@ -62,7 +62,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -62,7 +62,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# All reduce is needed to get the chunks from other GPUs. # All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits, torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group()) group=get_tensor_model_parallel_group())
# Sum of exponential of logits along vocab dimension across all GPUs. # Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits exp_logits = vocab_parallel_logits
...@@ -70,7 +70,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -70,7 +70,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
sum_exp_logits = exp_logits.sum(dim=-1) sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits, torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group()) group=get_tensor_model_parallel_group())
# Loss = log(sum(exp(logits))) - predicted-logit. # Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits loss = torch.log(sum_exp_logits) - predicted_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment