Commit a4cb4153 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'scale_qkt_exp_lr' into 'master'

Scale Q*K (query times key) by 1/layer-number and add exponential decay option

See merge request ADLR/megatron-lm!27
parents 4ad5717e 6c2d0337
......@@ -70,6 +70,12 @@ def add_fp16_config_args(parser):
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode')
group.add_argument('--apply-query-key-layer-scaling', action='store_true',
help='Scale Q * K^T by 1 / layer-number. If this flag '
'is set, then it will automatically set '
'attention-softmax-in-fp32 to true')
group.add_argument('--attention-softmax-in-fp32', action='store_true',
help='Run attention masking and softmax in fp32.')
group.add_argument('--fp32-embedding', action='store_true',
help='embedding in fp32')
group.add_argument('--fp32-layernorm', action='store_true',
......
......@@ -24,7 +24,7 @@ from megatron.utils import print_rank_0
class AnnealingLR(_LRScheduler):
"""Anneals the learning rate"""
DECAY_STYLES = ['linear', 'cosine', 'constant', 'None']
DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
def __init__(self, optimizer, start_lr, warmup_iter, num_iters,
decay_style=None, last_iter=-1, min_lr=0.0,
......@@ -57,6 +57,9 @@ class AnnealingLR(_LRScheduler):
lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter)
elif self.decay_style == self.DECAY_STYLES[1]:
lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
# exp(-0.693) = 1/2
lr = self.start_lr * math.exp(-0.693 * (num_iters_ - self.warmup_iter) / self.end_iter)
else:
lr = self.start_lr
return max(lr, self.min_lr)
......
......@@ -119,7 +119,9 @@ class BertModel(MegatronModule):
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True):
parallel_output=True,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(BertModel, self).__init__()
......@@ -145,7 +147,9 @@ class BertModel(MegatronModule):
init_method=init_method,
scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers),
residual_connection_post_layernorm=False)
residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
self.lm_head = BertLMHead(
self.language_model.embedding.word_embeddings.weight.size(0),
......
......@@ -48,7 +48,9 @@ class GPT2Model(MegatronModule):
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True):
parallel_output=True,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(GPT2Model, self).__init__()
......@@ -72,7 +74,9 @@ class GPT2Model(MegatronModule):
init_method=init_method_normal(init_method_std),
scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers),
residual_connection_post_layernorm=False)
residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
def forward(self, input_ids, position_ids, attention_mask,
......
......@@ -60,7 +60,9 @@ def get_language_model(num_layers,
layernorm_epsilon,
init_method,
scaled_init_method,
residual_connection_post_layernorm):
residual_connection_post_layernorm,
apply_query_key_layer_scaling,
attention_softmax_in_fp32):
# Transformer hyperparameters.
transformer_hparams = TransformerHyperparameters(
hidden_size=hidden_size,
......@@ -74,7 +76,9 @@ def get_language_model(num_layers,
output_layer_init_method=scaled_init_method,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
apply_residual_connection_post_layernorm=residual_connection_post_layernorm)
apply_residual_connection_post_layernorm=residual_connection_post_layernorm,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
# Language model.
language_model = TransformerLanguageModel(
transformer_hparams=transformer_hparams,
......
......@@ -82,7 +82,9 @@ class TransformerHyperparameters:
output_layer_init_method=None,
checkpoint_activations=None,
checkpoint_num_layers=None,
apply_residual_connection_post_layernorm=None):
apply_residual_connection_post_layernorm=None,
apply_query_key_layer_scaling=None,
attention_softmax_in_fp32=None):
self.params_dict = {}
self.params_dict['hidden_size'] = hidden_size
self.params_dict['num_layers'] = num_layers
......@@ -97,6 +99,10 @@ class TransformerHyperparameters:
self.params_dict['checkpoint_num_layers'] = checkpoint_num_layers
self.params_dict['apply_residual_connection_post_layernorm'] \
= apply_residual_connection_post_layernorm
self.params_dict['apply_query_key_layer_scaling'] \
= apply_query_key_layer_scaling
self.params_dict['attention_softmax_in_fp32'] \
= attention_softmax_in_fp32
def __getitem__(self, key):
......@@ -169,10 +175,17 @@ class ParallelSelfAttention(MegatronModule):
and returns output of the same size.
"""
def __init__(self, hyperparameters, attention_mask_func):
def __init__(self, hyperparameters, attention_mask_func, layer_number):
super(ParallelSelfAttention, self).__init__()
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling \
= hyperparameters['apply_query_key_layer_scaling']
self.attention_softmax_in_fp32 \
= hyperparameters['attention_softmax_in_fp32']
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
......@@ -239,7 +252,11 @@ class ParallelSelfAttention(MegatronModule):
def _get_unmasked_attention_scores(self, query_layer, key_layer):
"""Unmasked attention scores with size [b, np, s, s]."""
norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head))
coeff = 1
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
norm_factor = math.sqrt(coeff *
math.sqrt(self.hidden_size_per_attention_head))
# Raw attention scores. [b, np, s, s]
return torch.matmul(query_layer/norm_factor,
key_layer.transpose(-1, -2)/norm_factor)
......@@ -250,6 +267,8 @@ class ParallelSelfAttention(MegatronModule):
the size [b, np, s, s].
"""
# Attention probabilities. [b, np, s, s]
if self.apply_query_key_layer_scaling:
attention_scores = attention_scores * self.layer_number
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......@@ -304,6 +323,10 @@ class ParallelSelfAttention(MegatronModule):
attention_scores = self._get_unmasked_attention_scores(
query_layer, key_layer)
# fp32 conversion.
if self.attention_softmax_in_fp32:
attention_scores = attention_scores.float()
# Apply attention mask. [b, np, s, s]
if get_key_value:
with torch.no_grad():
......@@ -323,6 +346,10 @@ class ParallelSelfAttention(MegatronModule):
# Attention probabilities. [b, np, s, s]
attention_probs = self._get_attention_probs(attention_scores)
# fp16 conversion
if self.attention_softmax_in_fp32:
attention_probs = attention_probs.half()
# Context layer. [b, s, hp]
context_layer = self._get_attended_context(attention_probs, value_layer)
......@@ -342,7 +369,7 @@ class ParallelTransformerLayer(MegatronModule):
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, hyperparameters, attention_mask_func):
def __init__(self, hyperparameters, attention_mask_func, layer_number):
super(ParallelTransformerLayer, self).__init__()
......@@ -356,8 +383,7 @@ class ParallelTransformerLayer(MegatronModule):
# Self attention.
self.attention = ParallelSelfAttention(
hyperparameters,
attention_mask_func)
hyperparameters, attention_mask_func, layer_number)
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(
......@@ -414,14 +440,13 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_activations = hyperparameters['checkpoint_activations']
self.checkpoint_num_layers = hyperparameters['checkpoint_num_layers']
def get_layer():
def get_layer(layer_number):
return ParallelTransformerLayer(
hyperparameters,
attention_mask_func)
hyperparameters, attention_mask_func, layer_number)
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer() for _ in range(hyperparameters['num_layers'])])
[get_layer(i+1) for i in range(hyperparameters['num_layers'])])
# Final layer norm before output.
self.final_layernorm = LayerNorm(
......
......@@ -47,7 +47,9 @@ def model_provider(args):
add_binary_head=True,
layernorm_epsilon=args.layernorm_epsilon,
num_tokentypes=args.tokentype_size,
parallel_output=True)
parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
return model
......
......@@ -46,7 +46,9 @@ def model_provider(args):
add_binary_head=True,
layernorm_epsilon=args.layernorm_epsilon,
num_tokentypes=args.tokentype_size,
parallel_output=True)
parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
return model
......
......@@ -43,7 +43,9 @@ def model_provider(args):
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
layernorm_epsilon=args.layernorm_epsilon,
parallel_output=True)
parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment