Commit 158a99c3 authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

ONNX export fix, including one for gelu()


Signed-off-by: default avatarBoris Fomitchev <bfomitchev@nvidia.com>
parent 05620ee4
......@@ -158,6 +158,9 @@ def _add_network_size_args(parser):
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
'reasons.')
group.add_argument('--erf-gelu', action='store_true',
help='Python GeLU implementation equivalent to one in Torch. This option'
'should only be used to work around Torch bug exporting gelu() to ONNX in FP16')
return parser
......
......@@ -22,7 +22,7 @@ from megatron import mpu
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
......@@ -95,6 +95,9 @@ class BertLMHead(MegatronModule):
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
# make it override
if args.erf_gelu:
self.gelu = openai_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
......
......@@ -22,7 +22,7 @@ from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import openai_gelu
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
......@@ -52,6 +52,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
gelu = F.gelu
if args.openai_gelu:
gelu = openai_gelu
if args.erf_gelu:
gelu = erf_gelu
# Language model.
language_model = TransformerLanguageModel(
......
......@@ -48,8 +48,6 @@ def get_linear_layer(rows, columns, init_method):
layer.bias.zero_()
return layer
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
......@@ -57,6 +55,10 @@ def gelu_impl(x):
def openai_gelu(x):
return gelu_impl(x)
#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@torch.jit.script
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
......
......@@ -120,19 +120,23 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition, 0, init_method)
def forward(self, input_):
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
# Get the embeddings.
if self.num_embeddings_per_partition < self.num_embeddings:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.0
if self.num_embeddings_per_partition < self.num_embeddings:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel)
return output
......
......@@ -79,6 +79,10 @@ def _gather(input_):
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
......@@ -91,6 +95,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
......@@ -103,6 +111,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_):
return _split(input_)
......@@ -115,6 +127,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod
def forward(ctx, input_):
return _gather(input_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment