Commit deffcb6a authored by Mohammad's avatar Mohammad
Browse files

arguments in the model refactored

parent 601b19b7
...@@ -108,6 +108,10 @@ def _add_network_size_args(parser): ...@@ -108,6 +108,10 @@ def _add_network_size_args(parser):
'This is added for computational efficieny reasons.') 'This is added for computational efficieny reasons.')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5, group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='Layer norm epsilon.') help='Layer norm epsilon.')
group.add_argument('--apply-residual-connection-post-layernorm',
action='store_true',
help='If set, use original BERT residula connection '
'ordering.')
return parser return parser
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import torch import torch
from megatron import get_args
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
...@@ -106,60 +107,33 @@ class BertLMHead(MegatronModule): ...@@ -106,60 +107,33 @@ class BertLMHead(MegatronModule):
class BertModel(MegatronModule): class BertModel(MegatronModule):
"""Bert Language model.""" """Bert Language model."""
def __init__(self, def __init__(self, num_tokentypes=2, add_binary_head=True,
num_layers, parallel_output=True):
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
add_binary_head=False,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(BertModel, self).__init__() super(BertModel, self).__init__()
args = get_args()
self.add_binary_head = add_binary_head self.add_binary_head = add_binary_head
self.parallel_output = parallel_output self.parallel_output = parallel_output
init_method = init_method_normal(init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers, attention_mask_func=bert_attention_mask_func,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head, add_pooler=self.add_binary_head,
attention_mask_func=bert_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(init_method_std, scaled_init_method=scaled_init_method)
num_layers),
residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
self.lm_head = BertLMHead( self.lm_head = BertLMHead(
self.language_model.embedding.word_embeddings.weight.size(0), self.language_model.embedding.word_embeddings.weight.size(0),
hidden_size, init_method, layernorm_epsilon, parallel_output) args.hidden_size, init_method, args.layernorm_epsilon,
parallel_output)
self._lm_head_key = 'lm_head' self._lm_head_key = 'lm_head'
if self.add_binary_head: if self.add_binary_head:
self.binary_head = get_linear_layer(hidden_size, 2, init_method) self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import torch import torch
from megatron import get_args
from megatron.model.bert_model import bert_attention_mask_func from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids from megatron.model.bert_model import bert_position_ids
...@@ -30,54 +31,24 @@ from megatron import print_rank_0 ...@@ -30,54 +31,24 @@ from megatron import print_rank_0
class Classification(MegatronModule): class Classification(MegatronModule):
def __init__(self, def __init__(self, num_classes, num_tokentypes=2):
num_classes,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=2,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(Classification, self).__init__() super(Classification, self).__init__()
args = get_args()
self.num_classes = num_classes self.num_classes = num_classes
init_method = init_method_normal(init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers, attention_mask_func=bert_attention_mask_func,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
attention_mask_func=bert_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
num_layers), args.num_layers))
residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
# Multi-choice head. # Multi-choice head.
self.classification_dropout = torch.nn.Dropout(output_dropout_prob) self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(hidden_size, self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes, self.num_classes,
init_method) init_method)
self._classification_head_key = 'classification_head' self._classification_head_key = 'classification_head'
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import torch import torch
from megatron import get_args
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
...@@ -34,49 +35,19 @@ def gpt2_attention_mask_func(attention_scores, ltor_mask): ...@@ -34,49 +35,19 @@ def gpt2_attention_mask_func(attention_scores, ltor_mask):
class GPT2Model(MegatronModule): class GPT2Model(MegatronModule):
"""GPT-2 Language model.""" """GPT-2 Language model."""
def __init__(self, def __init__(self, num_tokentypes=0, parallel_output=True):
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(GPT2Model, self).__init__() super(GPT2Model, self).__init__()
args = get_args()
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers, attention_mask_func=gpt2_attention_mask_func,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
attention_mask_func=gpt2_attention_mask_func, init_method=init_method_normal(args.init_method_std),
checkpoint_activations=checkpoint_activations, scaled_init_method=scaled_init_method_normal(args.init_method_std,
checkpoint_num_layers=checkpoint_num_layers, args.num_layers))
layernorm_epsilon=layernorm_epsilon,
init_method=init_method_normal(init_method_std),
scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers),
residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
......
...@@ -18,13 +18,13 @@ ...@@ -18,13 +18,13 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from .transformer import TransformerHyperparameters from megatron.model.utils import gelu
from .utils import gelu from megatron.model.utils import get_linear_layer
from .utils import get_linear_layer
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...@@ -40,52 +40,20 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -40,52 +40,20 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
else:
return mpu.gather_from_model_parallel_region(logits_parallel) return mpu.gather_from_model_parallel_region(logits_parallel)
def get_language_model(num_layers, def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
vocab_size, init_method, scaled_init_method):
hidden_size, """Build language model and return along with the key to save."""
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
num_tokentypes,
attention_mask_func,
add_pooler,
checkpoint_activations,
checkpoint_num_layers,
layernorm_epsilon,
init_method,
scaled_init_method,
residual_connection_post_layernorm,
apply_query_key_layer_scaling,
attention_softmax_in_fp32):
# Transformer hyperparameters.
transformer_hparams = TransformerHyperparameters(
hidden_size=hidden_size,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
mlp_activation_func=gelu,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method,
output_layer_init_method=scaled_init_method,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
apply_residual_connection_post_layernorm=residual_connection_post_layernorm,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
# Language model. # Language model.
language_model = TransformerLanguageModel( language_model = TransformerLanguageModel(
transformer_hparams=transformer_hparams,
attention_mask_func=attention_mask_func, attention_mask_func=attention_mask_func,
vocab_size=vocab_size, mlp_activation_func=gelu,
max_sequence_length=max_sequence_length, init_method=init_method,
embedding_dropout_prob=embedding_dropout_prob, output_layer_init_method=scaled_init_method,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=add_pooler) add_pooler=add_pooler)
# key used for checkpoints. # key used for checkpoints.
...@@ -293,33 +261,33 @@ class TransformerLanguageModel(MegatronModule): ...@@ -293,33 +261,33 @@ class TransformerLanguageModel(MegatronModule):
will ignore this embedding will ignore this embedding
""" """
def __init__(self, def __init__(self,
transformer_hparams,
attention_mask_func, attention_mask_func,
vocab_size, mlp_activation_func,
max_sequence_length, init_method,
embedding_dropout_prob, output_layer_init_method,
num_tokentypes=0, num_tokentypes=0,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModel, self).__init__() super(TransformerLanguageModel, self).__init__()
args = get_args()
self.hidden_size = transformer_hparams['hidden_size'] self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = transformer_hparams['init_method'] self.init_method = init_method
self.add_pooler = add_pooler self.add_pooler = add_pooler
# Embeddings # Embeddings
self.embedding = Embedding(self.hidden_size, self.embedding = Embedding(self.hidden_size,
vocab_size, args.padded_vocab_size,
max_sequence_length, args.max_position_embeddings,
embedding_dropout_prob, args.hidden_dropout,
self.init_method, self.init_method,
self.num_tokentypes) self.num_tokentypes)
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Transformer # Transformer
self.transformer = ParallelTransformer( self.transformer = ParallelTransformer(
transformer_hparams, attention_mask_func, mlp_activation_func,
attention_mask_func) self.init_method, output_layer_init_method)
self._transformer_key = 'transformer' self._transformer_key = 'transformer'
# Pooler # Pooler
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import torch import torch
from megatron import get_args
from megatron.model.bert_model import bert_attention_mask_func from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids from megatron.model.bert_model import bert_position_ids
...@@ -30,52 +31,24 @@ from megatron import print_rank_0 ...@@ -30,52 +31,24 @@ from megatron import print_rank_0
class MultipleChoice(MegatronModule): class MultipleChoice(MegatronModule):
def __init__(self, def __init__(self, num_tokentypes=2):
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=2,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(MultipleChoice, self).__init__() super(MultipleChoice, self).__init__()
args = get_args()
init_method = init_method_normal(init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers, attention_mask_func=bert_attention_mask_func,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
attention_mask_func=bert_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
num_layers), args.num_layers))
residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
# Multi-choice head. # Multi-choice head.
self.multichoice_dropout = torch.nn.Dropout(output_dropout_prob) self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
self.multichoice_head = get_linear_layer(hidden_size, 1, init_method) self.multichoice_head = get_linear_layer(args.hidden_size, 1,
init_method)
self._multichoice_head_key = 'multichoice_head' self._multichoice_head_key = 'multichoice_head'
......
...@@ -20,6 +20,7 @@ import math ...@@ -20,6 +20,7 @@ import math
import torch import torch
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
...@@ -45,85 +46,6 @@ from megatron.module import MegatronModule ...@@ -45,85 +46,6 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask) unmaksed-attention-scores, attention-mask)
""" """
class TransformerHyperparameters:
"""Hyperparameters used to build and run the transformer.
Arguments:
hidden_size: hidden size (h)
num_layers: number of layers (l)
num_attention_heads: number of attention heads (n)
attention_dropout_prob: dropout probability for the attention
probabiliies
output_dropout_prob: dropout probability for the output
layers (attention output and mlp output)
mlp_activation_func: activation function for the mlp layer
layernorm_epsilon: tolerance parameters used for layer norm
dividions
init_method: init method used for all weights except layer
norm and output weights
output_layer_init_method: init method for output weights (
attention output and mlp output)
checkpoint_activations: flag to use activation checkpointing
checkpoint_num_layers: number of layers use in each chunk of
activation checkpointing
apply_residual_connection_post_layernorm: Take the post layer-norm
values for resudual connecton. BERT: True, GPT-2: False
"""
def __init__(self,
hidden_size=None,
num_layers=None,
num_attention_heads=None,
attention_dropout_prob=None,
output_dropout_prob=None,
mlp_activation_func=None,
layernorm_epsilon=None,
init_method=None,
output_layer_init_method=None,
checkpoint_activations=None,
checkpoint_num_layers=None,
apply_residual_connection_post_layernorm=None,
apply_query_key_layer_scaling=None,
attention_softmax_in_fp32=None):
self.params_dict = {}
self.params_dict['hidden_size'] = hidden_size
self.params_dict['num_layers'] = num_layers
self.params_dict['num_attention_heads'] = num_attention_heads
self.params_dict['attention_dropout_prob'] = attention_dropout_prob
self.params_dict['output_dropout_prob'] = output_dropout_prob
self.params_dict['mlp_activation_func'] = mlp_activation_func
self.params_dict['layernorm_epsilon'] = layernorm_epsilon
self.params_dict['init_method'] = init_method
self.params_dict['output_layer_init_method'] = output_layer_init_method
self.params_dict['checkpoint_activations'] = checkpoint_activations
self.params_dict['checkpoint_num_layers'] = checkpoint_num_layers
self.params_dict['apply_residual_connection_post_layernorm'] \
= apply_residual_connection_post_layernorm
self.params_dict['apply_query_key_layer_scaling'] \
= apply_query_key_layer_scaling
self.params_dict['attention_softmax_in_fp32'] \
= attention_softmax_in_fp32
def __getitem__(self, key):
"""Custom retrieval with error checks."""
try:
value = self.params_dict[key]
except KeyError:
raise Exception(
'could not find {} in transformer hyperparameters'.format(key))
except Exception as e:
print('unexpected error in transformer hyperparameters:', e)
raise Exception()
else:
assert value is not None, \
'parameter value for {} is not set in transformer '\
'hyperparameters'.format(key)
return value
raise Exception('should not be here')
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
"""MLP. """MLP.
...@@ -133,26 +55,28 @@ class ParallelMLP(MegatronModule): ...@@ -133,26 +55,28 @@ class ParallelMLP(MegatronModule):
applied. applied.
""" """
def __init__(self, hyperparameters): def __init__(self, mlp_activation_func, init_method,
output_layer_init_method):
super(ParallelMLP, self).__init__() super(ParallelMLP, self).__init__()
args = get_args()
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear( self.dense_h_to_4h = mpu.ColumnParallelLinear(
hyperparameters['hidden_size'], args.hidden_size,
4*hyperparameters['hidden_size'], 4*args.hidden_size,
gather_output=False, gather_output=False,
init_method=hyperparameters['init_method']) init_method=init_method)
self.activation_func = hyperparameters['mlp_activation_func'] self.activation_func = mlp_activation_func
# Project back to h. # Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear( self.dense_4h_to_h = mpu.RowParallelLinear(
4*hyperparameters['hidden_size'], 4*args.hidden_size,
hyperparameters['hidden_size'], args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=hyperparameters['output_layer_init_method']) init_method=output_layer_init_method)
self.dropout = torch.nn.Dropout(hyperparameters['output_dropout_prob']) self.dropout = torch.nn.Dropout(args.hidden_dropout)
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -174,51 +98,47 @@ class ParallelSelfAttention(MegatronModule): ...@@ -174,51 +98,47 @@ class ParallelSelfAttention(MegatronModule):
Self-attention layer takes input with size [b, s, h] Self-attention layer takes input with size [b, s, h]
and returns output of the same size. and returns output of the same size.
""" """
def __init__(self, attention_mask_func, init_method,
def __init__(self, hyperparameters, attention_mask_func, layer_number): output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__() super(ParallelSelfAttention, self).__init__()
args = get_args()
self.attention_mask_func = attention_mask_func self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling \ self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
= hyperparameters['apply_query_key_layer_scaling'] self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
self.attention_softmax_in_fp32 \
= hyperparameters['attention_softmax_in_fp32']
if self.apply_query_key_layer_scaling: if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size() world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide( self.hidden_size_per_partition = mpu.divide(args.hidden_size,
hyperparameters['hidden_size'], world_size) world_size)
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = mpu.divide(
hyperparameters['hidden_size'], args.hidden_size, args.num_attention_heads)
hyperparameters['num_attention_heads'])
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = mpu.divide(
hyperparameters['num_attention_heads'], world_size) args.num_attention_heads, world_size)
# Strided linear layer. # Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear( self.query_key_value = mpu.ColumnParallelLinear(
hyperparameters['hidden_size'], args.hidden_size,
3*hyperparameters['hidden_size'], 3*args.hidden_size,
stride=3, stride=3,
gather_output=False, gather_output=False,
init_method=hyperparameters['init_method']) init_method=init_method)
# Dropout. Note that for a single iteration, this layer will generate # Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but # different outputs on different number of parallel partitions but
# on average it should not be partition dependent. # on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout( self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
hyperparameters['attention_dropout_prob'])
# Output. # Output.
self.dense = mpu.RowParallelLinear( self.dense = mpu.RowParallelLinear(
hyperparameters['hidden_size'], args.hidden_size,
hyperparameters['hidden_size'], args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=hyperparameters['output_layer_init_method']) init_method=output_layer_init_method)
self.output_dropout = torch.nn.Dropout( self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
hyperparameters['output_dropout_prob'])
def _transpose_for_scores(self, tensor): def _transpose_for_scores(self, tensor):
...@@ -369,30 +289,34 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -369,30 +289,34 @@ class ParallelTransformerLayer(MegatronModule):
Transformore layer takes input with size [b, s, h] and returns an Transformore layer takes input with size [b, s, h] and returns an
output of the same size. output of the same size.
""" """
def __init__(self, hyperparameters, attention_mask_func, layer_number): def __init__(self, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number):
args = get_args()
super(ParallelTransformerLayer, self).__init__() super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number self.layer_number = layer_number
self.apply_residual_connection_post_layernorm \ self.apply_residual_connection_post_layernorm \
= hyperparameters['apply_residual_connection_post_layernorm'] = args.apply_residual_connection_post_layernorm
# Layernorm on the input data. # Layernorm on the input data.
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
hyperparameters['hidden_size'], args.hidden_size,
eps=hyperparameters['layernorm_epsilon']) eps=args.layernorm_epsilon)
# Self attention. # Self attention.
self.attention = ParallelSelfAttention( self.attention = ParallelSelfAttention(attention_mask_func, init_method,
hyperparameters, attention_mask_func, layer_number) output_layer_init_method,
layer_number)
# Layernorm on the input data. # Layernorm on the input data.
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
hyperparameters['hidden_size'], args.hidden_size,
eps=hyperparameters['layernorm_epsilon']) eps=args.layernorm_epsilon)
# MLP # MLP
self.mlp = ParallelMLP(hyperparameters) self.mlp = ParallelMLP(mlp_activation_func, init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
...@@ -434,25 +358,28 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -434,25 +358,28 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule): class ParallelTransformer(MegatronModule):
"""Transformer class.""" """Transformer class."""
def __init__(self, hyperparameters, attention_mask_func): def __init__(self, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args()
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.checkpoint_activations = hyperparameters['checkpoint_activations'] self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = hyperparameters['checkpoint_num_layers'] self.checkpoint_num_layers = args.checkpoint_num_layers
def get_layer(layer_number): def get_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
hyperparameters, attention_mask_func, layer_number) attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number)
# Transformer layers. # Transformer layers.
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[get_layer(i+1) for i in range(hyperparameters['num_layers'])]) [get_layer(i+1) for i in range(args.num_layers)])
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
hyperparameters['hidden_size'], args.hidden_size,
eps=hyperparameters['layernorm_epsilon']) eps=args.layernorm_epsilon)
def _checkpointed_forward(self, hidden_states, attention_mask): def _checkpointed_forward(self, hidden_states, attention_mask):
......
...@@ -36,22 +36,9 @@ def model_provider(): ...@@ -36,22 +36,9 @@ def model_provider():
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
model = BertModel( model = BertModel(
num_layers=args.num_layers,
vocab_size=args.padded_vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
add_binary_head=True,
layernorm_epsilon=args.layernorm_epsilon,
num_tokentypes=2, num_tokentypes=2,
parallel_output=True, add_binary_head=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling, parallel_output=True)
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
return model return model
......
...@@ -37,20 +37,7 @@ def model_provider(): ...@@ -37,20 +37,7 @@ def model_provider():
args = get_args() args = get_args()
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_layers=args.num_layers, model = GPT2Model(num_tokentypes=0, parallel_output=True)
vocab_size=args.padded_vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
layernorm_epsilon=args.layernorm_epsilon,
parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
return model return model
......
...@@ -46,17 +46,7 @@ def glue_classification(num_classes, Dataset, ...@@ -46,17 +46,7 @@ def glue_classification(num_classes, Dataset,
print_rank_0('building classification model for {} ...'.format( print_rank_0('building classification model for {} ...'.format(
args.task)) args.task))
return Classification( return Classification(num_classes=num_classes, num_tokentypes=2)
num_classes=num_classes,
num_layers=args.num_layers,
vocab_size=args.padded_vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations)
def metrics_func_provider(): def metrics_func_provider():
......
...@@ -39,20 +39,10 @@ def train_valid_datasets_provider(): ...@@ -39,20 +39,10 @@ def train_valid_datasets_provider():
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building multichoice model for RACE ...') print_rank_0('building multichoice model for RACE ...')
return MultipleChoice( return MultipleChoice(num_tokentypes=2)
num_layers=args.num_layers,
vocab_size=args.padded_vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations)
def metrics_func_provider(): def metrics_func_provider():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment