Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
158a99c3
Commit
158a99c3
authored
Jul 23, 2020
by
Boris Fomitchev
Browse files
ONNX export fix, including one for gelu()
Signed-off-by:
Boris Fomitchev
<
bfomitchev@nvidia.com
>
parent
05620ee4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
42 additions
and
12 deletions
+42
-12
megatron/arguments.py
megatron/arguments.py
+3
-0
megatron/model/bert_model.py
megatron/model/bert_model.py
+4
-1
megatron/model/language_model.py
megatron/model/language_model.py
+3
-1
megatron/model/utils.py
megatron/model/utils.py
+4
-2
megatron/mpu/layers.py
megatron/mpu/layers.py
+12
-8
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+16
-0
No files found.
megatron/arguments.py
View file @
158a99c3
...
@@ -158,6 +158,9 @@ def _add_network_size_args(parser):
...
@@ -158,6 +158,9 @@ def _add_network_size_args(parser):
help
=
'Use OpenAIs GeLU implementation. This option'
help
=
'Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
'should not be used unless for backward compatibility'
'reasons.'
)
'reasons.'
)
group
.
add_argument
(
'--erf-gelu'
,
action
=
'store_true'
,
help
=
'Python GeLU implementation equivalent to one in Torch. This option'
'should only be used to work around Torch bug exporting gelu() to ONNX in FP16'
)
return
parser
return
parser
...
...
megatron/model/bert_model.py
View file @
158a99c3
...
@@ -22,7 +22,7 @@ from megatron import mpu
...
@@ -22,7 +22,7 @@ from megatron import mpu
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
from
megatron.model.transformer
import
LayerNorm
from
megatron.model.transformer
import
LayerNorm
from
megatron.model.utils
import
openai_gelu
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
...
@@ -95,6 +95,9 @@ class BertLMHead(MegatronModule):
...
@@ -95,6 +95,9 @@ class BertLMHead(MegatronModule):
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
if
args
.
openai_gelu
:
self
.
gelu
=
openai_gelu
self
.
gelu
=
openai_gelu
# make it override
if
args
.
erf_gelu
:
self
.
gelu
=
openai_gelu
def
forward
(
self
,
hidden_states
,
word_embeddings_weight
):
def
forward
(
self
,
hidden_states
,
word_embeddings_weight
):
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
dense
(
hidden_states
)
...
...
megatron/model/language_model.py
View file @
158a99c3
...
@@ -22,7 +22,7 @@ from megatron import get_args
...
@@ -22,7 +22,7 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
openai_gelu
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
...
@@ -52,6 +52,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
...
@@ -52,6 +52,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
gelu
=
F
.
gelu
gelu
=
F
.
gelu
if
args
.
openai_gelu
:
if
args
.
openai_gelu
:
gelu
=
openai_gelu
gelu
=
openai_gelu
if
args
.
erf_gelu
:
gelu
=
erf_gelu
# Language model.
# Language model.
language_model
=
TransformerLanguageModel
(
language_model
=
TransformerLanguageModel
(
...
...
megatron/model/utils.py
View file @
158a99c3
...
@@ -48,8 +48,6 @@ def get_linear_layer(rows, columns, init_method):
...
@@ -48,8 +48,6 @@ def get_linear_layer(rows, columns, init_method):
layer
.
bias
.
zero_
()
layer
.
bias
.
zero_
()
return
layer
return
layer
@
torch
.
jit
.
script
def
gelu_impl
(
x
):
def
gelu_impl
(
x
):
"""OpenAI's gelu implementation."""
"""OpenAI's gelu implementation."""
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
0.7978845608028654
*
x
*
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
0.7978845608028654
*
x
*
...
@@ -57,6 +55,10 @@ def gelu_impl(x):
...
@@ -57,6 +55,10 @@ def gelu_impl(x):
def
openai_gelu
(
x
):
def
openai_gelu
(
x
):
return
gelu_impl
(
x
)
return
gelu_impl
(
x
)
#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@
torch
.
jit
.
script
def
erf_gelu
(
x
):
return
x
*
0.5
*
(
torch
.
erf
(
x
/
1.41421
).
to
(
dtype
=
x
.
dtype
)
+
torch
.
ones_like
(
x
).
to
(
dtype
=
x
.
dtype
))
def
get_params_for_weight_decay_optimization
(
module
):
def
get_params_for_weight_decay_optimization
(
module
):
"""Divide params into with-weight-decay and without-weight-decay groups.
"""Divide params into with-weight-decay and without-weight-decay groups.
...
...
megatron/mpu/layers.py
View file @
158a99c3
...
@@ -120,19 +120,23 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -120,19 +120,23 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
num_embeddings_per_partition
,
0
,
init_method
)
self
.
num_embeddings_per_partition
,
0
,
init_method
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
# Build the mask.
if
self
.
num_embeddings_per_partition
<
self
.
num_embeddings
:
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
\
# Build the mask.
(
input_
>=
self
.
vocab_end_index
)
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
\
# Mask the input.
(
input_
>=
self
.
vocab_end_index
)
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
# Mask the input.
masked_input
[
input_mask
]
=
0
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
# Get the embeddings.
masked_input
[
input_mask
]
=
0
else
:
masked_input
=
input_
# Get the embeddings.
output_parallel
=
F
.
embedding
(
masked_input
,
self
.
weight
,
output_parallel
=
F
.
embedding
(
masked_input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
self
.
sparse
)
# Mask the output embedding.
# Mask the output embedding.
output_parallel
[
input_mask
,
:]
=
0.0
if
self
.
num_embeddings_per_partition
<
self
.
num_embeddings
:
output_parallel
[
input_mask
,
:]
=
0.0
# Reduce across all the model parallel GPUs.
# Reduce across all the model parallel GPUs.
output
=
reduce_from_model_parallel_region
(
output_parallel
)
output
=
reduce_from_model_parallel_region
(
output_parallel
)
return
output
return
output
...
...
megatron/mpu/mappings.py
View file @
158a99c3
...
@@ -79,6 +79,10 @@ def _gather(input_):
...
@@ -79,6 +79,10 @@ def _gather(input_):
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Pass the input to the model parallel region."""
"""Pass the input to the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
input_
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
input_
return
input_
...
@@ -91,6 +95,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
...
@@ -91,6 +95,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class
_ReduceFromModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_ReduceFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""All-redcue the input from the model parallel region."""
"""All-redcue the input from the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_reduce
(
input_
)
return
_reduce
(
input_
)
...
@@ -103,6 +111,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
...
@@ -103,6 +111,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class
_ScatterToModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_ScatterToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
"""Split the input and keep only the corresponding chuck to the rank."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_split
(
input_
)
return
_split
(
input_
)
...
@@ -115,6 +127,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
...
@@ -115,6 +127,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatinate."""
"""Gather the input from model parallel region and concatinate."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_gather
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_gather
(
input_
)
return
_gather
(
input_
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment