Commit fe794c5a authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Replacing --erf-gelu option with explicit --onnx-safe option


Signed-off-by: default avatarBoris Fomitchev <bfomitchev@nvidia.com>
parent 7917774a
...@@ -158,9 +158,8 @@ def _add_network_size_args(parser): ...@@ -158,9 +158,8 @@ def _add_network_size_args(parser):
help='Use OpenAIs GeLU implementation. This option' help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility' 'should not be used unless for backward compatibility'
'reasons.') 'reasons.')
group.add_argument('--erf-gelu', action='store_true', group.add_argument('--onnx-safe', action='store_true',
help='Python GeLU implementation equivalent to one in Torch. This option' help='Use workarounds for known problems with Torch ONNX exporter')
'should only be used to work around Torch bug exporting gelu() to ONNX in FP16')
return parser return parser
......
...@@ -95,8 +95,7 @@ class BertLMHead(MegatronModule): ...@@ -95,8 +95,7 @@ class BertLMHead(MegatronModule):
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
if args.openai_gelu: if args.openai_gelu:
self.gelu = openai_gelu self.gelu = openai_gelu
# make it override elif args.onnx_safe:
if args.erf_gelu:
self.gelu = erf_gelu self.gelu = erf_gelu
def forward(self, hidden_states, word_embeddings_weight): def forward(self, hidden_states, word_embeddings_weight):
......
...@@ -52,7 +52,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -52,7 +52,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
gelu = F.gelu gelu = F.gelu
if args.openai_gelu: if args.openai_gelu:
gelu = openai_gelu gelu = openai_gelu
if args.erf_gelu: elif args.onnx_safe:
gelu = erf_gelu gelu = erf_gelu
# Language model. # Language model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment