Commit 3ea5491e authored by Mohammad's avatar Mohammad
Browse files

added faster L2 grad clipping and new torch gelu

parent 99410264
...@@ -122,6 +122,10 @@ def _add_network_size_args(parser): ...@@ -122,6 +122,10 @@ def _add_network_size_args(parser):
action='store_true', action='store_true',
help='If set, use original BERT residula connection ' help='If set, use original BERT residula connection '
'ordering.') 'ordering.')
group.add_argument('--openai-gelu', action='store_true',
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
'reasons.')
return parser return parser
......
...@@ -18,16 +18,15 @@ ...@@ -18,16 +18,15 @@
import torch import torch
from megatron import get_args from megatron import get_args
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .transformer import LayerNorm
from .utils import gelu
from .utils import get_linear_layer
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def bert_attention_mask_func(attention_scores, attention_mask): def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
...@@ -82,6 +81,8 @@ class BertLMHead(MegatronModule): ...@@ -82,6 +81,8 @@ class BertLMHead(MegatronModule):
super(BertLMHead, self).__init__() super(BertLMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True self.bias.model_parallel = True
self.bias.partition_dim = 0 self.bias.partition_dim = 0
...@@ -90,10 +91,13 @@ class BertLMHead(MegatronModule): ...@@ -90,10 +91,13 @@ class BertLMHead(MegatronModule):
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
def forward(self, hidden_states, word_embeddings_weight): def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states) hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states) hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states, output = parallel_lm_logits(hidden_states,
word_embeddings_weight, word_embeddings_weight,
......
...@@ -21,9 +21,8 @@ import torch.nn.functional as F ...@@ -21,9 +21,8 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import gelu from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
...@@ -47,7 +46,13 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -47,7 +46,13 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method, scaled_init_method): init_method, scaled_init_method):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args()
# Use torch gelu unless otherwise forced.
gelu = F.gelu
if args.openai_gelu:
gelu = openai_gelu
# Language model. # Language model.
language_model = TransformerLanguageModel( language_model = TransformerLanguageModel(
attention_mask_func=attention_mask_func, attention_mask_func=attention_mask_func,
......
...@@ -54,9 +54,7 @@ def gelu_impl(x): ...@@ -54,9 +54,7 @@ def gelu_impl(x):
"""OpenAI's gelu implementation.""" """OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x))) (1.0 + 0.044715 * x * x)))
def openai_gelu(x):
def gelu(x):
return gelu_impl(x) return gelu_impl(x)
......
...@@ -21,10 +21,47 @@ ...@@ -21,10 +21,47 @@
import torch import torch
from torch._six import inf from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_rank
def l2_grad_clipper(parameters, max_norm):
"""Efficient L2 norm gradient clipping."""
overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda')
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
mp_rank_is_zero = (get_model_parallel_rank() == 0)
parameters = list(filter(lambda p: (p.grad is not None) and
(p.model_parallel or mp_rank_is_zero),
parameters))
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
overflow_buf,
[parameters],
False # no per-parameter norm
)
# Sum across all model parallel GPUs.
norm_2 = norm * norm
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
total_norm = norm_2.item() ** 0.5
clip_coef = max_norm / (total_norm + 1e-6)
grads = [p.grad for p in parameters]
if clip_coef < 1:
multi_tensor_applier(
amp_C.multi_tensor_scale,
overflow_buf,
[grads, grads],
clip_coef)
return total_norm
def clip_grad_norm(parameters, max_norm, norm_type=2): def clip_grad_norm(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters. """Clips gradient norm of an iterable of parameters.
...@@ -55,6 +92,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -55,6 +92,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group()) group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
elif norm_type == 2:
total_norm = l2_grad_clipper(parameters, max_norm)
else: else:
total_norm = 0 total_norm = 0
for p in parameters: for p in parameters:
...@@ -67,8 +111,8 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -67,8 +111,8 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group()) group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item() ** (1. / norm_type) total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6) clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1: if clip_coef < 1:
for p in parameters: for p in parameters:
p.grad.data.mul_(clip_coef) p.grad.data.mul_(clip_coef)
return total_norm return total_norm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment