Commit 61f50c45 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'onnx-erf' into 'master'

ONNX export fix, including one for gelu()

See merge request ADLR/megatron-lm!101
parents 4ac91436 35bea728
......@@ -158,6 +158,8 @@ def _add_network_size_args(parser):
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
'reasons.')
group.add_argument('--onnx-safe', action='store_true',
help='Use workarounds for known problems with Torch ONNX exporter')
return parser
......
......@@ -22,7 +22,7 @@ from megatron import mpu
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
......@@ -95,6 +95,8 @@ class BertLMHead(MegatronModule):
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
elif args.onnx_safe:
self.gelu = erf_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
......
......@@ -22,7 +22,7 @@ from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import openai_gelu
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
......@@ -52,6 +52,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
gelu = F.gelu
if args.openai_gelu:
gelu = openai_gelu
elif args.onnx_safe:
gelu = erf_gelu
# Language model.
language_model = TransformerLanguageModel(
......
......@@ -48,7 +48,6 @@ def get_linear_layer(rows, columns, init_method):
layer.bias.zero_()
return layer
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
......@@ -57,6 +56,10 @@ def gelu_impl(x):
def openai_gelu(x):
return gelu_impl(x)
#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@torch.jit.script
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
......
......@@ -110,11 +110,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.model_parallel_size = get_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_model_parallel_rank(),
get_model_parallel_world_size())
self.model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
......@@ -127,19 +128,23 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition, 0, init_method)
def forward(self, input_):
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
# Get the embeddings.
if self.model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.0
if self.model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel)
return output
......
......@@ -79,6 +79,10 @@ def _gather(input_):
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
......@@ -91,6 +95,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
......@@ -103,6 +111,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_):
return _split(input_)
......@@ -115,6 +127,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod
def forward(ctx, input_):
return _gather(input_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment