Commit 61f50c45 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'onnx-erf' into 'master'

ONNX export fix, including one for gelu()

See merge request ADLR/megatron-lm!101
parents 4ac91436 35bea728
...@@ -158,6 +158,8 @@ def _add_network_size_args(parser): ...@@ -158,6 +158,8 @@ def _add_network_size_args(parser):
help='Use OpenAIs GeLU implementation. This option' help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility' 'should not be used unless for backward compatibility'
'reasons.') 'reasons.')
group.add_argument('--onnx-safe', action='store_true',
help='Use workarounds for known problems with Torch ONNX exporter')
return parser return parser
......
...@@ -22,7 +22,7 @@ from megatron import mpu ...@@ -22,7 +22,7 @@ from megatron import mpu
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
...@@ -95,6 +95,8 @@ class BertLMHead(MegatronModule): ...@@ -95,6 +95,8 @@ class BertLMHead(MegatronModule):
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
if args.openai_gelu: if args.openai_gelu:
self.gelu = openai_gelu self.gelu = openai_gelu
elif args.onnx_safe:
self.gelu = erf_gelu
def forward(self, hidden_states, word_embeddings_weight): def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
......
...@@ -22,7 +22,7 @@ from megatron import get_args ...@@ -22,7 +22,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import openai_gelu from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
...@@ -52,6 +52,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -52,6 +52,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
gelu = F.gelu gelu = F.gelu
if args.openai_gelu: if args.openai_gelu:
gelu = openai_gelu gelu = openai_gelu
elif args.onnx_safe:
gelu = erf_gelu
# Language model. # Language model.
language_model = TransformerLanguageModel( language_model = TransformerLanguageModel(
......
...@@ -48,7 +48,6 @@ def get_linear_layer(rows, columns, init_method): ...@@ -48,7 +48,6 @@ def get_linear_layer(rows, columns, init_method):
layer.bias.zero_() layer.bias.zero_()
return layer return layer
@torch.jit.script @torch.jit.script
def gelu_impl(x): def gelu_impl(x):
"""OpenAI's gelu implementation.""" """OpenAI's gelu implementation."""
...@@ -57,6 +56,10 @@ def gelu_impl(x): ...@@ -57,6 +56,10 @@ def gelu_impl(x):
def openai_gelu(x): def openai_gelu(x):
return gelu_impl(x) return gelu_impl(x)
#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@torch.jit.script
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
def get_params_for_weight_decay_optimization(module): def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
......
...@@ -110,11 +110,12 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -110,11 +110,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq = False self.scale_grad_by_freq = False
self.sparse = False self.sparse = False
self._weight = None self._weight = None
self.model_parallel_size = get_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension. # Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \ self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size( VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_model_parallel_rank(), self.num_embeddings, get_model_parallel_rank(),
get_model_parallel_world_size()) self.model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \ self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index self.vocab_start_index
...@@ -127,19 +128,23 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -127,19 +128,23 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition, 0, init_method) self.num_embeddings_per_partition, 0, init_method)
def forward(self, input_): def forward(self, input_):
# Build the mask. if self.model_parallel_size > 1:
input_mask = (input_ < self.vocab_start_index) | \ # Build the mask.
(input_ >= self.vocab_end_index) input_mask = (input_ < self.vocab_start_index) | \
# Mask the input. (input_ >= self.vocab_end_index)
masked_input = input_.clone() - self.vocab_start_index # Mask the input.
masked_input[input_mask] = 0 masked_input = input_.clone() - self.vocab_start_index
# Get the embeddings. masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight, output_parallel = F.embedding(masked_input, self.weight,
self.padding_idx, self.max_norm, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.norm_type, self.scale_grad_by_freq,
self.sparse) self.sparse)
# Mask the output embedding. # Mask the output embedding.
output_parallel[input_mask, :] = 0.0 if self.model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel) output = reduce_from_model_parallel_region(output_parallel)
return output return output
......
...@@ -79,6 +79,10 @@ def _gather(input_): ...@@ -79,6 +79,10 @@ def _gather(input_):
class _CopyToModelParallelRegion(torch.autograd.Function): class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region.""" """Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return input_ return input_
...@@ -91,6 +95,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function): ...@@ -91,6 +95,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function): class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region.""" """All-redcue the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _reduce(input_) return _reduce(input_)
...@@ -103,6 +111,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function): ...@@ -103,6 +111,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function): class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank.""" """Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _split(input_) return _split(input_)
...@@ -115,6 +127,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): ...@@ -115,6 +127,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class _GatherFromModelParallelRegion(torch.autograd.Function): class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" """Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _gather(input_) return _gather(input_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment