"megatron/git@developer.sourcefind.cn:wuxk1/megatron-lm.git" did not exist on "fd33e9303732416c47c95e6941147839912700bb"
Commit e1354f9d authored by liangjing's avatar liangjing
Browse files

update

parents
Pipeline #1025 failed with stages
in 0 seconds
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""BERT model."""
import torch
from megatron import get_args
from megatron.core import tensor_parallel
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model import LayerNorm
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Convert attention mask to binary:
extended_attention_mask = (extended_attention_mask < 0.5)
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
Arguments:
config: TransformerConfig object
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
parallel_output: whether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, config, parallel_output):
super().__init__(config=config)
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, config.init_method, gather_params_on_init=args.zero_stage == 3)
setattr(self.dense.weight, 'sequence_parallel', config.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel)
self.layernorm = LayerNorm(hidden_size,
eps=config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
elif args.onnx_safe:
self.gelu = erf_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states,
word_embeddings_weight,
self.parallel_output,
bias=self.bias)
return output
def post_language_model_processing(lm_output, pooled_output,
lm_head, binary_head,
lm_labels,
logit_weights,
fp16_lm_cross_entropy):
# Output.
lm_logits = lm_head(
lm_output, logit_weights)
binary_logits = None
if binary_head is not None:
binary_logits = binary_head(pooled_output)
if lm_labels is None:
# [s b h] => [b s h]
return lm_logits.transpose(0,1).contiguous(), binary_logits
else:
# [b s] => [s b]
lm_labels = lm_labels.transpose(0,1).contiguous()
# lm_logits : [s, b, h] and lm_labels: [s, b]
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
# [s, b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss, binary_logits
class BertModel(MegatronModule):
"""Bert Language model."""
def __init__(self,
config,
num_tokentypes=2,
add_binary_head=True,
parallel_output=True,
pre_process=True,
post_process=True,
return_moe_loss=False):
super().__init__(config=config)
args = get_args()
# TODO this option is not yet implemented in BERT
assert args.untie_embeddings_and_output_weights is False
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.return_moe_loss = return_moe_loss
self.return_embeddings = args.output_bert_embeddings
if self.return_embeddings:
assert self.post_process and self.add_binary_head
self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
pre_process=self.pre_process,
post_process=self.post_process,
num_experts=args.num_experts,
)
self.initialize_word_embeddings()
if self.post_process:
self.lm_head = BertLMHead(self.shared_embedding_or_output_weight().size(0), config.hidden_size,
config, parallel_output)
self._lm_head_key = 'lm_head'
self.binary_head = None
if self.add_binary_head:
self.binary_head = get_linear_layer(config.hidden_size, 2,
config.init_method,
args.zero_stage == 3)
self._binary_head_key = 'binary_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
if self.post_process and self.add_binary_head:
lm_output, pooled_output, moe_losses = lm_output
# Return pooled output (e.g., when computing Bert embeddings).
if self.return_embeddings:
# Sum attention mask.
embeddings = torch.transpose(lm_output, 0, 1)
masks = torch.sum(attention_mask, dim=1)
# Collect masked embeddings.
output = torch.zeros(
size=(embeddings.shape[0], embeddings.shape[2]),
dtype=torch.float32,
device=torch.cuda.current_device())
for i, (embedding, mask) in enumerate(zip(embeddings, masks)):
output[i, :] = torch.mean(embedding[1: mask - 1], dim=0)
return output
else:
pooled_output = None
if self.post_process:
if not self.add_binary_head:
lm_output, moe_losses = lm_output
lm_output = post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head,
lm_labels,
self.shared_embedding_or_output_weight(),
self.fp16_lm_cross_entropy)
return *lm_output, moe_losses if self.return_moe_loss else lm_output
else:
return lm_output
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process:
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process and self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(prefix=prefix, keep_vars=keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self.post_process:
self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict)
if self.post_process and self.add_binary_head:
self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict)
# Load word_embeddings.
if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
import os
import torch
import sys
from megatron import get_args, print_rank_0, get_tokenizer
from megatron.core import mpu
from megatron.checkpointing import fix_query_key_value_ordering
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import get_checkpoint_name
from megatron.model.bert_model import bert_position_ids
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def get_model_provider(only_query_model=False, only_context_model=False,
biencoder_shared_query_context_model=False):
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building Bienoder model ...')
model = biencoder_model_provider(only_query_model=only_query_model,
only_context_model = only_context_model,
biencoder_shared_query_context_model = \
biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
return model
return model_provider
def biencoder_model_provider(only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False,
pre_process=True,
post_process=True):
"""Build the model."""
assert mpu.get_tensor_model_parallel_world_size() == 1 and \
mpu.get_pipeline_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building BiEncoderModel...')
# simpler to just keep using 2 tokentypes since
# the LM we initialize with has 2 tokentypes
model = BiEncoderModel(
num_tokentypes=2,
parallel_output=False,
only_query_model=only_query_model,
only_context_model=only_context_model,
biencoder_shared_query_context_model=\
biencoder_shared_query_context_model,
pre_process=pre_process,
post_process=post_process)
return model
class BiEncoderModel(MegatronModule):
"""Bert-based module for Biencoder model."""
def __init__(self,
num_tokentypes=1,
parallel_output=True,
only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False,
pre_process=True,
post_process=True):
super(BiEncoderModel, self).__init__()
args = get_args()
bert_kwargs = dict(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output,
pre_process=pre_process,
post_process=post_process)
self.biencoder_shared_query_context_model = \
biencoder_shared_query_context_model
assert not (only_context_model and only_query_model)
self.use_context_model = not only_query_model
self.use_query_model = not only_context_model
self.biencoder_projection_dim = args.biencoder_projection_dim
if self.biencoder_shared_query_context_model:
self.model = PretrainedBertModel(**bert_kwargs)
self._model_key = 'shared_model'
self.query_model, self.context_model = self.model, self.model
else:
if self.use_query_model:
# this model embeds (pseudo-)queries - Embed_input in the paper
self.query_model = PretrainedBertModel(**bert_kwargs)
self._query_key = 'query_model'
if self.use_context_model:
# this model embeds evidence blocks - Embed_doc in the paper
self.context_model = PretrainedBertModel(**bert_kwargs)
self._context_key = 'context_model'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
# this is just a placeholder and will be needed when model
# parallelism will be used
# self.language_model.set_input_tensor(input_tensor)
return
def forward(self, query_tokens, query_attention_mask, query_types,
context_tokens, context_attention_mask, context_types):
"""Run a forward pass for each of the models and
return the respective embeddings."""
if self.use_query_model:
query_logits = self.embed_text(self.query_model,
query_tokens,
query_attention_mask,
query_types)
else:
raise ValueError("Cannot embed query without the query model.")
if self.use_context_model:
context_logits = self.embed_text(self.context_model,
context_tokens,
context_attention_mask,
context_types)
else:
raise ValueError("Cannot embed block without the block model.")
return query_logits, context_logits
@staticmethod
def embed_text(model, tokens, attention_mask, token_types):
"""Embed a batch of tokens using the model"""
logits = model(tokens,
attention_mask,
token_types)
return logits
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.biencoder_shared_query_context_model:
state_dict_[self._model_key] = \
self.model.state_dict_for_save_checkpoint(
prefix=prefix, keep_vars=keep_vars)
else:
if self.use_query_model:
state_dict_[self._query_key] = \
self.query_model.state_dict_for_save_checkpoint(
prefix=prefix, keep_vars=keep_vars)
if self.use_context_model:
state_dict_[self._context_key] = \
self.context_model.state_dict_for_save_checkpoint(
prefix=prefix, keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
if self.biencoder_shared_query_context_model:
print_rank_0("Loading shared query-context model")
self.model.load_state_dict(state_dict[self._model_key], \
strict=strict)
else:
if self.use_query_model:
print_rank_0("Loading query model")
self.query_model.load_state_dict( \
state_dict[self._query_key], strict=strict)
if self.use_context_model:
print_rank_0("Loading context model")
self.context_model.load_state_dict( \
state_dict[self._context_key], strict=strict)
def init_state_dict_from_bert(self):
"""Initialize the state from a pretrained BERT model
on iteration zero of ICT pretraining"""
args = get_args()
if args.bert_load is None:
print_rank_0("bert-load argument is None")
return
tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
if not os.path.isfile(tracker_filename):
raise FileNotFoundError("Could not find BERT checkpoint")
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading BERT checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException:
print_rank_0('could not load the BERT checkpoint')
sys.exit()
checkpoint_version = state_dict.get('checkpoint_version', 0)
# load the LM state dict into each model
model_dict = state_dict['model']['language_model']
if self.biencoder_shared_query_context_model:
self.model.language_model.load_state_dict(model_dict)
fix_query_key_value_ordering(self.model, checkpoint_version)
else:
if self.use_query_model:
self.query_model.language_model.load_state_dict(model_dict)
# give each model the same ict_head to begin with as well
if self.biencoder_projection_dim > 0:
query_proj_state_dict = \
self.state_dict_for_save_checkpoint()\
[self._query_key]['projection_enc']
fix_query_key_value_ordering(self.query_model, checkpoint_version)
if self.use_context_model:
self.context_model.language_model.load_state_dict(model_dict)
if self.query_model is not None and \
self.biencoder_projection_dim > 0:
self.context_model.projection_enc.load_state_dict\
(query_proj_state_dict)
fix_query_key_value_ordering(self.context_model, checkpoint_version)
class PretrainedBertModel(MegatronModule):
"""BERT-based encoder for queries or contexts used for
learned information retrieval."""
def __init__(self, num_tokentypes=2,
parallel_output=True, pre_process=True, post_process=True):
super(PretrainedBertModel, self).__init__()
args = get_args()
tokenizer = get_tokenizer()
self.pad_id = tokenizer.pad
self.biencoder_projection_dim = args.biencoder_projection_dim
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers)
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
if args.biencoder_projection_dim > 0:
self.projection_enc = get_linear_layer(args.hidden_size,
args.biencoder_projection_dim,
init_method,
gather_params_on_init=args.zero_stage == 3)
self._projection_enc_key = 'projection_enc'
def forward(self, input_ids, attention_mask, tokentype_ids=None):
extended_attention_mask = attention_mask.unsqueeze(1)
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# This mask will be used in average-pooling and max-pooling
pool_mask = (input_ids == self.pad_id).unsqueeze(2)
# Taking the representation of the [CLS] token of BERT
pooled_output = lm_output[0, :, :]
# Converting to float16 dtype
pooled_output = pooled_output.to(lm_output.dtype)
# Output.
if self.biencoder_projection_dim:
pooled_output = self.projection_enc(pooled_output)
return pooled_output
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
prefix=prefix, keep_vars=keep_vars)
if self.biencoder_projection_dim > 0:
state_dict_[self._projection_enc_key] = \
self.projection_enc.state_dict(prefix=prefix,
keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
print_rank_0("loading pretrained weights")
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self.biencoder_projection_dim > 0:
print_rank_0("loading projection head weights")
self.projection_enc.load_state_dict(
state_dict[self._projection_enc_key], strict=strict)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Classification model."""
import torch
from megatron import get_args, print_rank_last
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
class Classification(MegatronModule):
def __init__(self,
config,
num_classes,
num_tokentypes=2,
pre_process=True,
post_process=True):
super().__init__(config=config, share_embeddings_and_output_weights=False)
args = get_args()
self.num_classes = num_classes
self.pre_process = pre_process
self.post_process = post_process
self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
pre_process=self.pre_process,
post_process=self.post_process)
# Multi-choice head.
if self.post_process:
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes,
init_method,
gather_params_on_init=args.zero_stage == 3)
self._classification_head_key = 'classification_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
if self.post_process:
_, pooled_output = lm_output[0], lm_output[1]
classification_output = self.classification_dropout(pooled_output)
classification_logits = self.classification_head(classification_output)
# Reshape back to separate choices.
classification_logits = classification_logits.view(-1, self.num_classes)
return classification_logits
return lm_output
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process:
state_dict_[self._classification_head_key] \
= self.classification_head.state_dict(prefix=prefix, keep_vars=keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self.post_process:
if self._classification_head_key in state_dict:
self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict)
else:
print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._classification_head_key))
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC
from abc import abstractmethod
import math
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_args
from megatron.core import mpu
from .module import MegatronModule
from deepspeed.accelerator import get_accelerator
class MemoryBuffer:
def __init__(self, numel, numel_padded, dtype):
self.numel = numel
self.numel_padded = numel_padded
self.dtype = dtype
self.data = torch.zeros(self.numel_padded,
dtype=self.dtype,
device=get_accelerator().current_device_name(),
requires_grad=False)
def zero(self):
"""Reset the buffer to zero."""
self.data.zero_()
def get(self, shape, start_index):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index = start_index + shape.numel()
assert end_index <= self.numel, \
'requested tensor is out of the buffer range.'
buffer_tensor = self.data[start_index:end_index]
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
class DistributedDataParallelBase(MegatronModule, ABC):
"""Abstract class for DDP."""
def __init__(self, module):
super(DistributedDataParallelBase, self).__init__()
# Keep a pointer to the model.
self.module = module
@abstractmethod
def allreduce_gradients(self):
pass
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def state_dict(self, prefix='', keep_vars=False):
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
return self.module.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
class DistributedDataParallel(DistributedDataParallelBase):
"""DDP with contiguous buffers options to storre and accumulate gradients.
This class:
- has the potential to reduce memory fragmentation.
- provides the option to do the gradient accumulation
in a type other than the params type (for example fp32)
Arguments:
module: input model.
accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
and the gradient all-reduce all in in float32. If this option is
true, we require `use_contiguous_buffers` to be true too.
use_contiguous_buffers: if true, use a contiguous buffer to store the
gradients.
"""
def __init__(self, module,
accumulate_allreduce_grads_in_fp32,
use_contiguous_buffers):
super(DistributedDataParallel, self).__init__(module)
self.accumulate_allreduce_grads_in_fp32 \
= accumulate_allreduce_grads_in_fp32
self.use_contiguous_buffers = use_contiguous_buffers
# If we are using fp32-accumulate-allreduce explicitly
# this means we need main grads in a continous buffer.
if self.accumulate_allreduce_grads_in_fp32:
assert self.use_contiguous_buffers
# ===================================
# Rest of this part applies only to
# the case we use continuous buffers.
# ===================================
self._grad_buffers = None
self._grad_buffer_param_index_map = None
if self.use_contiguous_buffers:
self._grad_buffers = {}
self._grad_buffer_param_index_map = {}
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Simple function to define buffer type.
def _get_buffer_type(param):
return torch.float if \
self.accumulate_allreduce_grads_in_fp32 else param.dtype
# First calculate total number of elements per type.
type_num_elements = {}
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
+ param.data.nelement()
# Allocate the buffer.
for dtype, num_elements in type_num_elements.items():
# If using distributed optimizer, pad memory buffer to be
# multiple of data_parallel_world_size. (This padding is done
# due to a constraint with the reduce_scatter op, which requires
# all tensors have equal size. See: optimizer.py.)
num_elements_padded = data_parallel_world_size * \
int(math.ceil(num_elements / data_parallel_world_size))
# Allocate grad buffer.
self._grad_buffers[dtype] = MemoryBuffer(num_elements,
num_elements_padded,
dtype)
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype])
if dtype not in self._grad_buffer_param_index_map:
self._grad_buffer_param_index_map[dtype] = {}
self._grad_buffer_param_index_map[dtype][param] = (
type_num_elements[dtype],
type_num_elements[dtype] + param.data.nelement(),
)
# Backward hook.
# Accumalation function for the gradients. We need
# to store them so they don't go out of scope.
self.grad_accs = []
# Loop over all the parameters in the model.
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator functtion.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc)
def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def param_hook(*unused):
# Add the gradient to the buffer.
if param.grad is not None:
# The gradient function of linear layers is fused with GEMMs
param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory.
param.grad = None
return param_hook
def zero_grad_buffer(self):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
assert self._grad_buffers is not None, 'buffers are not initialized.'
for _, buffer_ in self._grad_buffers.items():
buffer_.zero()
def broadcast_params(self):
for param in self.module.parameters():
torch.distributed.broadcast(param.data,
src=mpu.get_data_parallel_src_rank(),
group=mpu.get_data_parallel_group())
def allreduce_gradients(self):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
if self._grad_buffers is not None:
for _, buffer_ in self._grad_buffers.items():
buffer_.data /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce(
buffer_.data, group=mpu.get_data_parallel_group())
else:
# Otherwise, bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self.module.parameters():
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce(
coalesced, group=mpu.get_data_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
retro_encoder = 3
retro_decoder = 4
retro_decoder_with_retriever = 5
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
# For backward compatibility with old model checkpoints
from megatron.core.enums import ModelType
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment