Commit 158a99c3 authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

ONNX export fix, including one for gelu()


Signed-off-by: default avatarBoris Fomitchev <bfomitchev@nvidia.com>
parent 05620ee4
...@@ -158,6 +158,9 @@ def _add_network_size_args(parser): ...@@ -158,6 +158,9 @@ def _add_network_size_args(parser):
help='Use OpenAIs GeLU implementation. This option' help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility' 'should not be used unless for backward compatibility'
'reasons.') 'reasons.')
group.add_argument('--erf-gelu', action='store_true',
help='Python GeLU implementation equivalent to one in Torch. This option'
'should only be used to work around Torch bug exporting gelu() to ONNX in FP16')
return parser return parser
......
...@@ -22,7 +22,7 @@ from megatron import mpu ...@@ -22,7 +22,7 @@ from megatron import mpu
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
...@@ -95,6 +95,9 @@ class BertLMHead(MegatronModule): ...@@ -95,6 +95,9 @@ class BertLMHead(MegatronModule):
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
if args.openai_gelu: if args.openai_gelu:
self.gelu = openai_gelu self.gelu = openai_gelu
# make it override
if args.erf_gelu:
self.gelu = openai_gelu
def forward(self, hidden_states, word_embeddings_weight): def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
......
...@@ -22,7 +22,7 @@ from megatron import get_args ...@@ -22,7 +22,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import openai_gelu from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
...@@ -52,6 +52,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -52,6 +52,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
gelu = F.gelu gelu = F.gelu
if args.openai_gelu: if args.openai_gelu:
gelu = openai_gelu gelu = openai_gelu
if args.erf_gelu:
gelu = erf_gelu
# Language model. # Language model.
language_model = TransformerLanguageModel( language_model = TransformerLanguageModel(
......
...@@ -48,8 +48,6 @@ def get_linear_layer(rows, columns, init_method): ...@@ -48,8 +48,6 @@ def get_linear_layer(rows, columns, init_method):
layer.bias.zero_() layer.bias.zero_()
return layer return layer
@torch.jit.script
def gelu_impl(x): def gelu_impl(x):
"""OpenAI's gelu implementation.""" """OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
...@@ -57,6 +55,10 @@ def gelu_impl(x): ...@@ -57,6 +55,10 @@ def gelu_impl(x):
def openai_gelu(x): def openai_gelu(x):
return gelu_impl(x) return gelu_impl(x)
#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@torch.jit.script
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
def get_params_for_weight_decay_optimization(module): def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
......
...@@ -120,19 +120,23 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -120,19 +120,23 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition, 0, init_method) self.num_embeddings_per_partition, 0, init_method)
def forward(self, input_): def forward(self, input_):
# Build the mask. if self.num_embeddings_per_partition < self.num_embeddings:
input_mask = (input_ < self.vocab_start_index) | \ # Build the mask.
(input_ >= self.vocab_end_index) input_mask = (input_ < self.vocab_start_index) | \
# Mask the input. (input_ >= self.vocab_end_index)
masked_input = input_.clone() - self.vocab_start_index # Mask the input.
masked_input[input_mask] = 0 masked_input = input_.clone() - self.vocab_start_index
# Get the embeddings. masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight, output_parallel = F.embedding(masked_input, self.weight,
self.padding_idx, self.max_norm, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.norm_type, self.scale_grad_by_freq,
self.sparse) self.sparse)
# Mask the output embedding. # Mask the output embedding.
output_parallel[input_mask, :] = 0.0 if self.num_embeddings_per_partition < self.num_embeddings:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel) output = reduce_from_model_parallel_region(output_parallel)
return output return output
......
...@@ -79,6 +79,10 @@ def _gather(input_): ...@@ -79,6 +79,10 @@ def _gather(input_):
class _CopyToModelParallelRegion(torch.autograd.Function): class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region.""" """Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return input_ return input_
...@@ -91,6 +95,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function): ...@@ -91,6 +95,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function): class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region.""" """All-redcue the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _reduce(input_) return _reduce(input_)
...@@ -103,6 +111,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function): ...@@ -103,6 +111,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function): class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank.""" """Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _split(input_) return _split(input_)
...@@ -115,6 +127,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): ...@@ -115,6 +127,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class _GatherFromModelParallelRegion(torch.autograd.Function): class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" """Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _gather(input_) return _gather(input_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment