Commit a4cb4153 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'scale_qkt_exp_lr' into 'master'

Scale Q*K (query times key) by 1/layer-number and add exponential decay option

See merge request ADLR/megatron-lm!27
parents 4ad5717e 6c2d0337
...@@ -70,6 +70,12 @@ def add_fp16_config_args(parser): ...@@ -70,6 +70,12 @@ def add_fp16_config_args(parser):
group.add_argument('--fp16', action='store_true', group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode') help='Run model in fp16 mode')
group.add_argument('--apply-query-key-layer-scaling', action='store_true',
help='Scale Q * K^T by 1 / layer-number. If this flag '
'is set, then it will automatically set '
'attention-softmax-in-fp32 to true')
group.add_argument('--attention-softmax-in-fp32', action='store_true',
help='Run attention masking and softmax in fp32.')
group.add_argument('--fp32-embedding', action='store_true', group.add_argument('--fp32-embedding', action='store_true',
help='embedding in fp32') help='embedding in fp32')
group.add_argument('--fp32-layernorm', action='store_true', group.add_argument('--fp32-layernorm', action='store_true',
......
...@@ -24,7 +24,7 @@ from megatron.utils import print_rank_0 ...@@ -24,7 +24,7 @@ from megatron.utils import print_rank_0
class AnnealingLR(_LRScheduler): class AnnealingLR(_LRScheduler):
"""Anneals the learning rate""" """Anneals the learning rate"""
DECAY_STYLES = ['linear', 'cosine', 'constant', 'None'] DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
def __init__(self, optimizer, start_lr, warmup_iter, num_iters, def __init__(self, optimizer, start_lr, warmup_iter, num_iters,
decay_style=None, last_iter=-1, min_lr=0.0, decay_style=None, last_iter=-1, min_lr=0.0,
...@@ -57,6 +57,9 @@ class AnnealingLR(_LRScheduler): ...@@ -57,6 +57,9 @@ class AnnealingLR(_LRScheduler):
lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter) lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter)
elif self.decay_style == self.DECAY_STYLES[1]: elif self.decay_style == self.DECAY_STYLES[1]:
lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1) lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
# exp(-0.693) = 1/2
lr = self.start_lr * math.exp(-0.693 * (num_iters_ - self.warmup_iter) / self.end_iter)
else: else:
lr = self.start_lr lr = self.start_lr
return max(lr, self.min_lr) return max(lr, self.min_lr)
......
...@@ -119,7 +119,9 @@ class BertModel(MegatronModule): ...@@ -119,7 +119,9 @@ class BertModel(MegatronModule):
layernorm_epsilon=1.0e-5, layernorm_epsilon=1.0e-5,
init_method_std=0.02, init_method_std=0.02,
num_tokentypes=0, num_tokentypes=0,
parallel_output=True): parallel_output=True,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(BertModel, self).__init__() super(BertModel, self).__init__()
...@@ -145,7 +147,9 @@ class BertModel(MegatronModule): ...@@ -145,7 +147,9 @@ class BertModel(MegatronModule):
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(init_method_std, scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers), num_layers),
residual_connection_post_layernorm=False) residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
self.lm_head = BertLMHead( self.lm_head = BertLMHead(
self.language_model.embedding.word_embeddings.weight.size(0), self.language_model.embedding.word_embeddings.weight.size(0),
......
...@@ -48,7 +48,9 @@ class GPT2Model(MegatronModule): ...@@ -48,7 +48,9 @@ class GPT2Model(MegatronModule):
layernorm_epsilon=1.0e-5, layernorm_epsilon=1.0e-5,
init_method_std=0.02, init_method_std=0.02,
num_tokentypes=0, num_tokentypes=0,
parallel_output=True): parallel_output=True,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(GPT2Model, self).__init__() super(GPT2Model, self).__init__()
...@@ -72,7 +74,9 @@ class GPT2Model(MegatronModule): ...@@ -72,7 +74,9 @@ class GPT2Model(MegatronModule):
init_method=init_method_normal(init_method_std), init_method=init_method_normal(init_method_std),
scaled_init_method=scaled_init_method_normal(init_method_std, scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers), num_layers),
residual_connection_post_layernorm=False) residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
......
...@@ -60,7 +60,9 @@ def get_language_model(num_layers, ...@@ -60,7 +60,9 @@ def get_language_model(num_layers,
layernorm_epsilon, layernorm_epsilon,
init_method, init_method,
scaled_init_method, scaled_init_method,
residual_connection_post_layernorm): residual_connection_post_layernorm,
apply_query_key_layer_scaling,
attention_softmax_in_fp32):
# Transformer hyperparameters. # Transformer hyperparameters.
transformer_hparams = TransformerHyperparameters( transformer_hparams = TransformerHyperparameters(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -74,7 +76,9 @@ def get_language_model(num_layers, ...@@ -74,7 +76,9 @@ def get_language_model(num_layers,
output_layer_init_method=scaled_init_method, output_layer_init_method=scaled_init_method,
checkpoint_activations=checkpoint_activations, checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers, checkpoint_num_layers=checkpoint_num_layers,
apply_residual_connection_post_layernorm=residual_connection_post_layernorm) apply_residual_connection_post_layernorm=residual_connection_post_layernorm,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
# Language model. # Language model.
language_model = TransformerLanguageModel( language_model = TransformerLanguageModel(
transformer_hparams=transformer_hparams, transformer_hparams=transformer_hparams,
......
...@@ -82,7 +82,9 @@ class TransformerHyperparameters: ...@@ -82,7 +82,9 @@ class TransformerHyperparameters:
output_layer_init_method=None, output_layer_init_method=None,
checkpoint_activations=None, checkpoint_activations=None,
checkpoint_num_layers=None, checkpoint_num_layers=None,
apply_residual_connection_post_layernorm=None): apply_residual_connection_post_layernorm=None,
apply_query_key_layer_scaling=None,
attention_softmax_in_fp32=None):
self.params_dict = {} self.params_dict = {}
self.params_dict['hidden_size'] = hidden_size self.params_dict['hidden_size'] = hidden_size
self.params_dict['num_layers'] = num_layers self.params_dict['num_layers'] = num_layers
...@@ -97,6 +99,10 @@ class TransformerHyperparameters: ...@@ -97,6 +99,10 @@ class TransformerHyperparameters:
self.params_dict['checkpoint_num_layers'] = checkpoint_num_layers self.params_dict['checkpoint_num_layers'] = checkpoint_num_layers
self.params_dict['apply_residual_connection_post_layernorm'] \ self.params_dict['apply_residual_connection_post_layernorm'] \
= apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.params_dict['apply_query_key_layer_scaling'] \
= apply_query_key_layer_scaling
self.params_dict['attention_softmax_in_fp32'] \
= attention_softmax_in_fp32
def __getitem__(self, key): def __getitem__(self, key):
...@@ -169,10 +175,17 @@ class ParallelSelfAttention(MegatronModule): ...@@ -169,10 +175,17 @@ class ParallelSelfAttention(MegatronModule):
and returns output of the same size. and returns output of the same size.
""" """
def __init__(self, hyperparameters, attention_mask_func): def __init__(self, hyperparameters, attention_mask_func, layer_number):
super(ParallelSelfAttention, self).__init__() super(ParallelSelfAttention, self).__init__()
self.attention_mask_func = attention_mask_func self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling \
= hyperparameters['apply_query_key_layer_scaling']
self.attention_softmax_in_fp32 \
= hyperparameters['attention_softmax_in_fp32']
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size() world_size = mpu.get_model_parallel_world_size()
...@@ -239,7 +252,11 @@ class ParallelSelfAttention(MegatronModule): ...@@ -239,7 +252,11 @@ class ParallelSelfAttention(MegatronModule):
def _get_unmasked_attention_scores(self, query_layer, key_layer): def _get_unmasked_attention_scores(self, query_layer, key_layer):
"""Unmasked attention scores with size [b, np, s, s].""" """Unmasked attention scores with size [b, np, s, s]."""
norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head)) coeff = 1
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
norm_factor = math.sqrt(coeff *
math.sqrt(self.hidden_size_per_attention_head))
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
return torch.matmul(query_layer/norm_factor, return torch.matmul(query_layer/norm_factor,
key_layer.transpose(-1, -2)/norm_factor) key_layer.transpose(-1, -2)/norm_factor)
...@@ -250,6 +267,8 @@ class ParallelSelfAttention(MegatronModule): ...@@ -250,6 +267,8 @@ class ParallelSelfAttention(MegatronModule):
the size [b, np, s, s]. the size [b, np, s, s].
""" """
# Attention probabilities. [b, np, s, s] # Attention probabilities. [b, np, s, s]
if self.apply_query_key_layer_scaling:
attention_scores = attention_scores * self.layer_number
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
...@@ -304,6 +323,10 @@ class ParallelSelfAttention(MegatronModule): ...@@ -304,6 +323,10 @@ class ParallelSelfAttention(MegatronModule):
attention_scores = self._get_unmasked_attention_scores( attention_scores = self._get_unmasked_attention_scores(
query_layer, key_layer) query_layer, key_layer)
# fp32 conversion.
if self.attention_softmax_in_fp32:
attention_scores = attention_scores.float()
# Apply attention mask. [b, np, s, s] # Apply attention mask. [b, np, s, s]
if get_key_value: if get_key_value:
with torch.no_grad(): with torch.no_grad():
...@@ -323,6 +346,10 @@ class ParallelSelfAttention(MegatronModule): ...@@ -323,6 +346,10 @@ class ParallelSelfAttention(MegatronModule):
# Attention probabilities. [b, np, s, s] # Attention probabilities. [b, np, s, s]
attention_probs = self._get_attention_probs(attention_scores) attention_probs = self._get_attention_probs(attention_scores)
# fp16 conversion
if self.attention_softmax_in_fp32:
attention_probs = attention_probs.half()
# Context layer. [b, s, hp] # Context layer. [b, s, hp]
context_layer = self._get_attended_context(attention_probs, value_layer) context_layer = self._get_attended_context(attention_probs, value_layer)
...@@ -342,7 +369,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -342,7 +369,7 @@ class ParallelTransformerLayer(MegatronModule):
Transformore layer takes input with size [b, s, h] and returns an Transformore layer takes input with size [b, s, h] and returns an
output of the same size. output of the same size.
""" """
def __init__(self, hyperparameters, attention_mask_func): def __init__(self, hyperparameters, attention_mask_func, layer_number):
super(ParallelTransformerLayer, self).__init__() super(ParallelTransformerLayer, self).__init__()
...@@ -356,8 +383,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -356,8 +383,7 @@ class ParallelTransformerLayer(MegatronModule):
# Self attention. # Self attention.
self.attention = ParallelSelfAttention( self.attention = ParallelSelfAttention(
hyperparameters, hyperparameters, attention_mask_func, layer_number)
attention_mask_func)
# Layernorm on the input data. # Layernorm on the input data.
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
...@@ -414,14 +440,13 @@ class ParallelTransformer(MegatronModule): ...@@ -414,14 +440,13 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_activations = hyperparameters['checkpoint_activations'] self.checkpoint_activations = hyperparameters['checkpoint_activations']
self.checkpoint_num_layers = hyperparameters['checkpoint_num_layers'] self.checkpoint_num_layers = hyperparameters['checkpoint_num_layers']
def get_layer(): def get_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
hyperparameters, hyperparameters, attention_mask_func, layer_number)
attention_mask_func)
# Transformer layers. # Transformer layers.
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[get_layer() for _ in range(hyperparameters['num_layers'])]) [get_layer(i+1) for i in range(hyperparameters['num_layers'])])
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
......
...@@ -47,7 +47,9 @@ def model_provider(args): ...@@ -47,7 +47,9 @@ def model_provider(args):
add_binary_head=True, add_binary_head=True,
layernorm_epsilon=args.layernorm_epsilon, layernorm_epsilon=args.layernorm_epsilon,
num_tokentypes=args.tokentype_size, num_tokentypes=args.tokentype_size,
parallel_output=True) parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
return model return model
......
...@@ -46,7 +46,9 @@ def model_provider(args): ...@@ -46,7 +46,9 @@ def model_provider(args):
add_binary_head=True, add_binary_head=True,
layernorm_epsilon=args.layernorm_epsilon, layernorm_epsilon=args.layernorm_epsilon,
num_tokentypes=args.tokentype_size, num_tokentypes=args.tokentype_size,
parallel_output=True) parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
return model return model
......
...@@ -43,7 +43,9 @@ def model_provider(args): ...@@ -43,7 +43,9 @@ def model_provider(args):
checkpoint_activations=args.checkpoint_activations, checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers, checkpoint_num_layers=args.checkpoint_num_layers,
layernorm_epsilon=args.layernorm_epsilon, layernorm_epsilon=args.layernorm_epsilon,
parallel_output=True) parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment