Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5e56e563
Commit
5e56e563
authored
Apr 28, 2020
by
Neel Kant
Browse files
Merge master into realm-mlm
parents
6c0a5bd8
569b3dab
Changes
107
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
380 additions
and
315 deletions
+380
-315
megatron/deprecated_data_utils/tokenization.py
megatron/deprecated_data_utils/tokenization.py
+105
-71
megatron/deprecated_data_utils/tokenization_gpt2.py
megatron/deprecated_data_utils/tokenization_gpt2.py
+32
-17
megatron/deprecated_data_utils/wordpiece.py
megatron/deprecated_data_utils/wordpiece.py
+9
-8
megatron/fp16/__init__.py
megatron/fp16/__init__.py
+1
-1
megatron/fp16/fp16.py
megatron/fp16/fp16.py
+105
-91
megatron/fp16/fp16util.py
megatron/fp16/fp16util.py
+20
-8
megatron/fp16/loss_scaler.py
megatron/fp16/loss_scaler.py
+38
-19
megatron/global_vars.py
megatron/global_vars.py
+10
-6
megatron/initialize.py
megatron/initialize.py
+5
-3
megatron/learning_rates.py
megatron/learning_rates.py
+1
-6
megatron/model/__init__.py
megatron/model/__init__.py
+1
-1
megatron/model/bert_model.py
megatron/model/bert_model.py
+15
-17
megatron/model/classification.py
megatron/model/classification.py
+1
-4
megatron/model/distributed.py
megatron/model/distributed.py
+3
-8
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+2
-6
megatron/model/language_model.py
megatron/model/language_model.py
+12
-14
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+1
-4
megatron/model/transformer.py
megatron/model/transformer.py
+15
-25
megatron/model/utils.py
megatron/model/utils.py
+3
-3
megatron/module.py
megatron/module.py
+1
-3
No files found.
megatron/deprecated_data_utils/tokenization.py
View file @
5e56e563
This diff is collapsed.
Click to expand it.
megatron/deprecated_data_utils/tokenization_gpt2.py
View file @
5e56e563
...
@@ -27,7 +27,8 @@ try:
...
@@ -27,7 +27,8 @@ try:
from
functools
import
lru_cache
from
functools
import
lru_cache
except
ImportError
:
except
ImportError
:
# Just a dummy decorator to get the checks to run on python2
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def
lru_cache
():
def
lru_cache
():
return
lambda
func
:
func
return
lambda
func
:
func
...
@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json'
...
@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME
=
'merges.txt'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
@
lru_cache
()
@
lru_cache
()
def
bytes_to_unicode
():
def
bytes_to_unicode
():
"""
"""
...
@@ -60,17 +62,19 @@ def bytes_to_unicode():
...
@@ -60,17 +62,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
"""
_chr
=
unichr
if
sys
.
version_info
[
0
]
==
2
else
chr
_chr
=
unichr
if
sys
.
version_info
[
0
]
==
2
else
chr
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
\
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
cs
=
bs
[:]
cs
=
bs
[:]
n
=
0
n
=
0
for
b
in
range
(
2
**
8
):
for
b
in
range
(
2
**
8
):
if
b
not
in
bs
:
if
b
not
in
bs
:
bs
.
append
(
b
)
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
n
+=
1
cs
=
[
_chr
(
n
)
for
n
in
cs
]
cs
=
[
_chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
return
dict
(
zip
(
bs
,
cs
))
def
get_pairs
(
word
):
def
get_pairs
(
word
):
"""Return set of symbol pairs in a word.
"""Return set of symbol pairs in a word.
...
@@ -83,6 +87,7 @@ def get_pairs(word):
...
@@ -83,6 +87,7 @@ def get_pairs(word):
prev_char
=
char
prev_char
=
char
return
pairs
return
pairs
class
GPT2Tokenizer
(
object
):
class
GPT2Tokenizer
(
object
):
"""
"""
GPT-2 BPE tokenizer. Peculiarities:
GPT-2 BPE tokenizer. Peculiarities:
...
@@ -138,23 +143,31 @@ class GPT2Tokenizer(object):
...
@@ -138,23 +143,31 @@ class GPT2Tokenizer(object):
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
special_tokens
=
None
,
max_len
=
None
):
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
special_tokens
=
None
,
max_len
=
None
):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
errors
=
errors
# how to handle errors in decoding
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
bpe_data
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
bpe_data
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
bpe_merges
=
[
tuple
(
merge
.
split
())
for
merge
in
bpe_data
]
bpe_merges
=
[
tuple
(
merge
.
split
())
for
merge
in
bpe_data
]
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
cache
=
{}
self
.
cache
=
{}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
# Should haved added re.IGNORECASE so BPE merges can happen for
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
# capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self
.
special_tokens
=
{}
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
special_tokens_decoder
=
{}
...
@@ -172,8 +185,9 @@ class GPT2Tokenizer(object):
...
@@ -172,8 +185,9 @@ class GPT2Tokenizer(object):
self
.
special_tokens
=
{}
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
special_tokens_decoder
=
{}
return
return
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
logger
.
info
(
"Special tokens {}"
.
format
(
self
.
special_tokens
))
logger
.
info
(
"Special tokens {}"
.
format
(
self
.
special_tokens
))
def
bpe
(
self
,
token
):
def
bpe
(
self
,
token
):
...
@@ -186,7 +200,7 @@ class GPT2Tokenizer(object):
...
@@ -186,7 +200,7 @@ class GPT2Tokenizer(object):
return
token
return
token
while
True
:
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
if
bigram
not
in
self
.
bpe_ranks
:
break
break
first
,
second
=
bigram
first
,
second
=
bigram
...
@@ -197,12 +211,12 @@ class GPT2Tokenizer(object):
...
@@ -197,12 +211,12 @@ class GPT2Tokenizer(object):
j
=
word
.
index
(
first
,
i
)
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
i
=
j
except
:
except
BaseException
:
new_word
.
extend
(
word
[
i
:])
new_word
.
extend
(
word
[
i
:])
break
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
new_word
.
append
(
first
+
second
)
i
+=
2
i
+=
2
else
:
else
:
new_word
.
append
(
word
[
i
])
new_word
.
append
(
word
[
i
])
...
@@ -245,7 +259,8 @@ class GPT2Tokenizer(object):
...
@@ -245,7 +259,8 @@ class GPT2Tokenizer(object):
logger
.
warning
(
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum "
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
)
)
return
ids
return
ids
...
...
megatron/deprecated_data_utils/wordpiece.py
View file @
5e56e563
...
@@ -123,7 +123,8 @@ class BertTokenizer(object):
...
@@ -123,7 +123,8 @@ class BertTokenizer(object):
logger
.
warning
(
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum "
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
" sequence through BERT will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
)
)
return
ids
return
ids
...
...
megatron/fp16/__init__.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
megatron/fp16/fp16.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -22,12 +22,15 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
...
@@ -22,12 +22,15 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from
.loss_scaler
import
DynamicLossScaler
,
LossScaler
from
.loss_scaler
import
DynamicLossScaler
,
LossScaler
from
.fp16util
import
model_grads_to_master_grads
,
master_params_to_model_params
,
clip_grad_norm
from
.fp16util
import
model_grads_to_master_grads
,
master_params_to_model_params
,
clip_grad_norm
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
def
conversion_helper
(
val
,
conversion
):
def
conversion_helper
(
val
,
conversion
):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if
not
isinstance
(
val
,
(
tuple
,
list
)):
if
not
isinstance
(
val
,
(
tuple
,
list
)):
...
@@ -37,6 +40,7 @@ def conversion_helper(val, conversion):
...
@@ -37,6 +40,7 @@ def conversion_helper(val, conversion):
rtn
=
tuple
(
rtn
)
rtn
=
tuple
(
rtn
)
return
rtn
return
rtn
def
fp32_to_fp16
(
val
):
def
fp32_to_fp16
(
val
):
"""Convert fp32 `val` to fp16"""
"""Convert fp32 `val` to fp16"""
def
half_conversion
(
val
):
def
half_conversion
(
val
):
...
@@ -48,6 +52,7 @@ def fp32_to_fp16(val):
...
@@ -48,6 +52,7 @@ def fp32_to_fp16(val):
return
val
return
val
return
conversion_helper
(
val
,
half_conversion
)
return
conversion_helper
(
val
,
half_conversion
)
def
fp16_to_fp32
(
val
):
def
fp16_to_fp32
(
val
):
"""Convert fp16 `val` to fp32"""
"""Convert fp16 `val` to fp32"""
def
float_conversion
(
val
):
def
float_conversion
(
val
):
...
@@ -59,6 +64,7 @@ def fp16_to_fp32(val):
...
@@ -59,6 +64,7 @@ def fp16_to_fp32(val):
return
val
return
val
return
conversion_helper
(
val
,
float_conversion
)
return
conversion_helper
(
val
,
float_conversion
)
class
FP16_Module
(
MegatronModule
):
class
FP16_Module
(
MegatronModule
):
def
__init__
(
self
,
module
):
def
__init__
(
self
,
module
):
super
(
FP16_Module
,
self
).
__init__
()
super
(
FP16_Module
,
self
).
__init__
()
...
@@ -79,6 +85,8 @@ class FP16_Module(MegatronModule):
...
@@ -79,6 +85,8 @@ class FP16_Module(MegatronModule):
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
# TODO: Update overflow check + downscale to use Carl's fused kernel.
# TODO: Update overflow check + downscale to use Carl's fused kernel.
class
FP16_Optimizer
(
object
):
class
FP16_Optimizer
(
object
):
"""
"""
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
...
@@ -305,7 +313,8 @@ class FP16_Optimizer(object):
...
@@ -305,7 +313,8 @@ class FP16_Optimizer(object):
master_params_to_model_params
(
fp32_from_fp16_group
,
fp16_group
)
master_params_to_model_params
(
fp32_from_fp16_group
,
fp16_group
)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
# that does the overflow check, gradient copy + downscale, and fp32
# allreduce in a different stream.
def
_model_grads_to_master_grads
(
self
):
def
_model_grads_to_master_grads
(
self
):
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
model_grads_to_master_grads
(
fp16_group
,
fp32_from_fp16_group
)
model_grads_to_master_grads
(
fp16_group
,
fp32_from_fp16_group
)
...
@@ -313,9 +322,12 @@ class FP16_Optimizer(object):
...
@@ -313,9 +322,12 @@ class FP16_Optimizer(object):
def
_downscale_master
(
self
):
def
_downscale_master
(
self
):
if
self
.
loss_scale
!=
1.0
:
if
self
.
loss_scale
!=
1.0
:
for
group
in
self
.
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
for
param
in
group
[
'params'
]:
grads
=
[
p
.
grad
for
p
in
group
[
'params'
]
if
p
.
grad
is
not
None
]
if
param
.
grad
is
not
None
:
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
param
.
grad
.
data
.
mul_
(
1.
/
self
.
loss_scale
)
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
_overflow_buf
,
[
grads
,
grads
],
1.
/
self
.
loss_scale
)
def
clip_master_grads
(
self
,
max_norm
,
norm_type
=
2
):
def
clip_master_grads
(
self
,
max_norm
,
norm_type
=
2
):
"""
"""
...
@@ -400,7 +412,8 @@ class FP16_Optimizer(object):
...
@@ -400,7 +412,8 @@ class FP16_Optimizer(object):
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
# are guaranteed to exist, so we can just copy_() from the saved master params.
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_fp16_groups
,
state_dict
[
'fp32_from_fp16'
]):
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_fp16_groups
,
state_dict
[
'fp32_from_fp16'
]):
for
current
,
saved
in
zip
(
current_group
,
saved_group
):
for
current
,
saved
in
zip
(
current_group
,
saved_group
):
current
.
data
.
copy_
(
saved
.
data
)
current
.
data
.
copy_
(
saved
.
data
)
...
@@ -570,7 +583,8 @@ class FP16_Optimizer(object):
...
@@ -570,7 +583,8 @@ class FP16_Optimizer(object):
"""
"""
if
self
.
dynamic_loss_scale
:
if
self
.
dynamic_loss_scale
:
self
.
_check_overflow
()
self
.
_check_overflow
()
if
self
.
overflow
:
return
if
self
.
overflow
:
return
self
.
_model_grads_to_master_grads
()
self
.
_model_grads_to_master_grads
()
self
.
_downscale_master
()
self
.
_downscale_master
()
...
@@ -607,8 +621,8 @@ class FP16_Optimizer(object):
...
@@ -607,8 +621,8 @@ class FP16_Optimizer(object):
master_grads_data
.
append
(
master_grads_this_group
)
master_grads_data
.
append
(
master_grads_this_group
)
return
master_grads_data
return
master_grads_data
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def
_get_loss_scale
(
self
):
def
_get_loss_scale
(
self
):
return
self
.
loss_scaler
.
loss_scale
return
self
.
loss_scaler
.
loss_scale
...
...
megatron/fp16/fp16util.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -18,6 +18,9 @@ import torch.nn as nn
...
@@ -18,6 +18,9 @@ import torch.nn as nn
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
megatron
import
mpu
from
megatron
import
mpu
...
@@ -102,6 +105,7 @@ class FP16Model(nn.Module):
...
@@ -102,6 +105,7 @@ class FP16Model(nn.Module):
def
backwards_debug_hook
(
grad
):
def
backwards_debug_hook
(
grad
):
raise
RuntimeError
(
"master_params recieved a gradient in the backward pass!"
)
raise
RuntimeError
(
"master_params recieved a gradient in the backward pass!"
)
def
prep_param_lists
(
model
,
flat_master
=
False
):
def
prep_param_lists
(
model
,
flat_master
=
False
):
"""
"""
Creates a list of FP32 master parameters for a given model, as in
Creates a list of FP32 master parameters for a given model, as in
...
@@ -131,7 +135,7 @@ def prep_param_lists(model, flat_master=False):
...
@@ -131,7 +135,7 @@ def prep_param_lists(model, flat_master=False):
# flatten_dense_tensors returns a contiguous flat array.
# flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html
# http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params
=
_flatten_dense_tensors
([
param
.
data
for
param
in
model_params
]).
float
()
master_params
=
_flatten_dense_tensors
([
param
.
data
for
param
in
model_params
]).
float
()
except
:
except
BaseException
:
print
(
"Error in prep_param_lists: model may contain a mixture of parameters "
print
(
"Error in prep_param_lists: model may contain a mixture of parameters "
"of different types. Use flat_master=False, or use F16_Optimizer."
)
"of different types. Use flat_master=False, or use F16_Optimizer."
)
raise
raise
...
@@ -165,9 +169,15 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False):
...
@@ -165,9 +169,15 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False):
if
model
.
grad
is
not
None
:
if
model
.
grad
is
not
None
:
if
master
.
grad
is
None
:
if
master
.
grad
is
None
:
master
.
grad
=
Variable
(
master
.
data
.
new
(
*
master
.
data
.
size
()))
master
.
grad
=
Variable
(
master
.
data
.
new
(
*
master
.
data
.
size
()))
master
.
grad
.
data
.
copy_
(
model
.
grad
.
data
)
else
:
else
:
master
.
grad
=
None
master
.
grad
=
None
model_grads
=
[
p
.
grad
for
p
in
model_params
if
p
.
grad
is
not
None
]
master_grads
=
[
p
.
grad
for
p
in
master_params
if
p
.
grad
is
not
None
]
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
_overflow_buf
,
[
model_grads
,
master_grads
],
1.0
)
def
master_params_to_model_params
(
model_params
,
master_params
,
flat_master
=
False
):
def
master_params_to_model_params
(
model_params
,
master_params
,
flat_master
=
False
):
...
@@ -188,17 +198,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
...
@@ -188,17 +198,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
# Backward compatibility fixes
# Backward compatibility fixes
def
to_python_float
(
t
):
def
to_python_float
(
t
):
if
hasattr
(
t
,
'item'
):
if
hasattr
(
t
,
'item'
):
return
t
.
item
()
return
t
.
item
()
else
:
else
:
return
t
[
0
]
return
t
[
0
]
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
clip_grad_norm
=
mpu
.
clip_grad_norm
clip_grad_norm
=
mpu
.
clip_grad_norm
#elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
#
elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# clip_grad_norm = torch.nn.utils.clip_grad_norm
# clip_grad_norm = torch.nn.utils.clip_grad_norm
#else:
#
else:
# clip_grad_norm = torch.nn.utils.clip_grad_norm_
# clip_grad_norm = torch.nn.utils.clip_grad_norm_
megatron/fp16/loss_scaler.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,15 +14,22 @@
...
@@ -14,15 +14,22 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
megatron
import
mpu
from
megatron
import
mpu
# item() is a recent addition, so this helps with backward compatibility.
# item() is a recent addition, so this helps with backward compatibility.
def
to_python_float
(
t
):
def
to_python_float
(
t
):
if
hasattr
(
t
,
'item'
):
if
hasattr
(
t
,
'item'
):
return
t
.
item
()
return
t
.
item
()
else
:
else
:
return
t
[
0
]
return
t
[
0
]
class
LossScaler
:
class
LossScaler
:
"""
"""
Class that manages a static loss scale. This class is intended to interact with
Class that manages a static loss scale. This class is intended to interact with
...
@@ -54,12 +61,18 @@ class LossScaler:
...
@@ -54,12 +61,18 @@ class LossScaler:
return
self
.
cur_scale
return
self
.
cur_scale
def
scale_gradient
(
self
,
module
,
grad_in
,
grad_out
):
def
scale_gradient
(
self
,
module
,
grad_in
,
grad_out
):
return
tuple
(
self
.
loss_scale
*
g
for
g
in
grad_in
)
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
_overflow_buf
,
[
grad_in
,
grad_in
],
self
.
loss_scale
)
return
grad_in
def
backward
(
self
,
loss
,
retain_graph
=
False
):
def
backward
(
self
,
loss
,
retain_graph
=
False
):
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
class
DynamicLossScaler
:
class
DynamicLossScaler
:
"""
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
...
@@ -122,8 +135,8 @@ class DynamicLossScaler:
...
@@ -122,8 +135,8 @@ class DynamicLossScaler:
overflow
=
overflow_gpu
[
0
].
item
()
overflow
=
overflow_gpu
[
0
].
item
()
return
bool
(
overflow
)
return
bool
(
overflow
)
# `x` is a torch.Tensor
# `x` is a torch.Tensor
def
_has_inf_or_nan
(
x
):
def
_has_inf_or_nan
(
x
):
try
:
try
:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
...
@@ -158,7 +171,7 @@ class DynamicLossScaler:
...
@@ -158,7 +171,7 @@ class DynamicLossScaler:
if
overflow
:
if
overflow
:
# self.cur_scale /= self.scale_factor
# self.cur_scale /= self.scale_factor
if
self
.
delayed_shift
==
1
or
self
.
cur_hysteresis
==
1
:
if
self
.
delayed_shift
==
1
or
self
.
cur_hysteresis
==
1
:
self
.
cur_scale
=
max
(
self
.
cur_scale
/
self
.
scale_factor
,
self
.
min_scale
)
self
.
cur_scale
=
max
(
self
.
cur_scale
/
self
.
scale_factor
,
self
.
min_scale
)
else
:
else
:
self
.
cur_hysteresis
-=
1
self
.
cur_hysteresis
-=
1
self
.
last_overflow_iter
=
self
.
cur_iter
self
.
last_overflow_iter
=
self
.
cur_iter
...
@@ -176,12 +189,18 @@ class DynamicLossScaler:
...
@@ -176,12 +189,18 @@ class DynamicLossScaler:
return
self
.
cur_scale
return
self
.
cur_scale
def
scale_gradient
(
self
,
module
,
grad_in
,
grad_out
):
def
scale_gradient
(
self
,
module
,
grad_in
,
grad_out
):
return
tuple
(
self
.
loss_scale
*
g
for
g
in
grad_in
)
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
_overflow_buf
,
[
grad_in
,
grad_in
],
self
.
loss_scale
)
return
grad_in
def
backward
(
self
,
loss
,
retain_graph
=
False
):
def
backward
(
self
,
loss
,
retain_graph
=
False
):
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
##############################################################
##############################################################
# Example usage below here -- assuming it's in a separate file
# Example usage below here -- assuming it's in a separate file
##############################################################
##############################################################
...
...
megatron/global_vars.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -61,22 +61,26 @@ def get_timers():
...
@@ -61,22 +61,26 @@ def get_timers():
return
_GLOBAL_TIMERS
return
_GLOBAL_TIMERS
def
set_global_variables
(
extra_args_provider
=
None
,
args_defaults
=
{}):
def
set_global_variables
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args
=
_parse_args
(
extra_args_provider
=
extra_args_provider
,
args
=
_parse_args
(
extra_args_provider
=
extra_args_provider
,
defaults
=
args_defaults
)
defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
_
=
_build_tokenizer
(
args
)
_
=
_build_tokenizer
(
args
)
_set_tensorboard_writer
(
args
)
_set_tensorboard_writer
(
args
)
_set_adlr_autoresume
(
args
)
_set_adlr_autoresume
(
args
)
_set_timers
()
_set_timers
()
def
_parse_args
(
extra_args_provider
=
None
,
defaults
=
{}):
def
_parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
"""Parse entire arguments."""
"""Parse entire arguments."""
global
_GLOBAL_ARGS
global
_GLOBAL_ARGS
_ensure_var_is_not_initialized
(
_GLOBAL_ARGS
,
'args'
)
_ensure_var_is_not_initialized
(
_GLOBAL_ARGS
,
'args'
)
_GLOBAL_ARGS
=
parse_args
(
extra_args_provider
=
extra_args_provider
,
_GLOBAL_ARGS
=
parse_args
(
extra_args_provider
=
extra_args_provider
,
defaults
=
defaults
)
defaults
=
defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
return
_GLOBAL_ARGS
return
_GLOBAL_ARGS
...
@@ -124,7 +128,7 @@ def _set_adlr_autoresume(args):
...
@@ -124,7 +128,7 @@ def _set_adlr_autoresume(args):
sys
.
path
.
append
(
os
.
environ
.
get
(
'SUBMIT_SCRIPTS'
,
'.'
))
sys
.
path
.
append
(
os
.
environ
.
get
(
'SUBMIT_SCRIPTS'
,
'.'
))
try
:
try
:
from
userlib.auto_resume
import
AutoResume
from
userlib.auto_resume
import
AutoResume
except
:
except
BaseException
:
print
(
'ADLR autoresume is not available, exiting ...'
)
print
(
'ADLR autoresume is not available, exiting ...'
)
sys
.
exit
()
sys
.
exit
()
...
...
megatron/initialize.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -28,7 +28,8 @@ from megatron import mpu
...
@@ -28,7 +28,8 @@ from megatron import mpu
from
megatron.global_vars
import
set_global_variables
from
megatron.global_vars
import
set_global_variables
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{}):
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
):
"""Set global variables, initialize distributed, and
"""Set global variables, initialize distributed, and
set autoresume and random seeds."""
set autoresume and random seeds."""
# Make sure cuda is available.
# Make sure cuda is available.
...
@@ -37,7 +38,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}):
...
@@ -37,7 +38,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}):
# Parse args, build tokenizer, and set adlr-autoresume,
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
# tensorboard-writer, and timers.
set_global_variables
(
extra_args_provider
=
extra_args_provider
,
set_global_variables
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
args_defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
# Pytorch distributed.
# Pytorch distributed.
_initialize_distributed
()
_initialize_distributed
()
...
...
megatron/learning_rates.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -48,7 +48,6 @@ class AnnealingLR(object):
...
@@ -48,7 +48,6 @@ class AnnealingLR(object):
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
def
get_lr
(
self
):
def
get_lr
(
self
):
"""Learning rate decay functions from:
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
...
@@ -71,7 +70,6 @@ class AnnealingLR(object):
...
@@ -71,7 +70,6 @@ class AnnealingLR(object):
lr
=
self
.
start_lr
lr
=
self
.
start_lr
return
max
(
lr
,
self
.
min_lr
)
return
max
(
lr
,
self
.
min_lr
)
def
step
(
self
,
step_num
=
None
):
def
step
(
self
,
step_num
=
None
):
"""Set lr for all parameters groups."""
"""Set lr for all parameters groups."""
if
step_num
is
None
:
if
step_num
is
None
:
...
@@ -81,7 +79,6 @@ class AnnealingLR(object):
...
@@ -81,7 +79,6 @@ class AnnealingLR(object):
for
group
in
self
.
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
group
[
'lr'
]
=
new_lr
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{
state_dict
=
{
'start_lr'
:
self
.
start_lr
,
'start_lr'
:
self
.
start_lr
,
...
@@ -93,7 +90,6 @@ class AnnealingLR(object):
...
@@ -93,7 +90,6 @@ class AnnealingLR(object):
}
}
return
state_dict
return
state_dict
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
"""Auxiliary function for checking the values in the checkpoint and
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
setting them."""
...
@@ -108,7 +104,6 @@ class AnnealingLR(object):
...
@@ -108,7 +104,6 @@ class AnnealingLR(object):
name
))
name
))
return
sd_value
return
sd_value
def
load_state_dict
(
self
,
sd
):
def
load_state_dict
(
self
,
sd
):
self
.
start_lr
=
self
.
_check_and_set
(
self
.
start_lr
,
sd
[
'start_lr'
],
self
.
start_lr
=
self
.
_check_and_set
(
self
.
start_lr
,
sd
[
'start_lr'
],
...
...
megatron/model/__init__.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
megatron/model/bert_model.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -22,16 +22,15 @@ import torch
...
@@ -22,16 +22,15 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model.transformer
import
LayerNorm
from
megatron.model.utils
import
openai_gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
from
.language_model
import
parallel_lm_logits
from
.language_model
import
get_language_model
from
.transformer
import
LayerNorm
from
.utils
import
gelu
from
.utils
import
get_linear_layer
from
.utils
import
init_method_normal
from
.utils
import
scaled_init_method_normal
def
bert_attention_mask_func
(
attention_scores
,
attention_mask
):
def
bert_attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
=
attention_scores
+
attention_mask
attention_scores
=
attention_scores
+
attention_mask
...
@@ -70,7 +69,6 @@ def bert_position_ids(token_ids):
...
@@ -70,7 +69,6 @@ def bert_position_ids(token_ids):
return
position_ids
return
position_ids
class
BertLMHead
(
MegatronModule
):
class
BertLMHead
(
MegatronModule
):
"""Masked LM head for Bert
"""Masked LM head for Bert
...
@@ -81,11 +79,14 @@ class BertLMHead(MegatronModule):
...
@@ -81,11 +79,14 @@ class BertLMHead(MegatronModule):
layernorm_epsilon: tolerance for layer norm divisions
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not.
parallel_output: whether output logits being distributed or not.
"""
"""
def
__init__
(
self
,
mpu_vocab_size
,
hidden_size
,
init_method
,
def
__init__
(
self
,
mpu_vocab_size
,
hidden_size
,
init_method
,
layernorm_epsilon
,
parallel_output
):
layernorm_epsilon
,
parallel_output
):
super
(
BertLMHead
,
self
).
__init__
()
super
(
BertLMHead
,
self
).
__init__
()
args
=
get_args
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
self
.
bias
.
model_parallel
=
True
self
.
bias
.
model_parallel
=
True
self
.
bias
.
partition_dim
=
0
self
.
bias
.
partition_dim
=
0
...
@@ -94,11 +95,13 @@ class BertLMHead(MegatronModule):
...
@@ -94,11 +95,13 @@ class BertLMHead(MegatronModule):
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
self
.
gelu
=
openai_gelu
def
forward
(
self
,
hidden_states
,
word_embeddings_weight
):
def
forward
(
self
,
hidden_states
,
word_embeddings_weight
):
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
gelu
(
hidden_states
)
hidden_states
=
self
.
gelu
(
hidden_states
)
hidden_states
=
self
.
layernorm
(
hidden_states
)
hidden_states
=
self
.
layernorm
(
hidden_states
)
output
=
parallel_lm_logits
(
hidden_states
,
output
=
parallel_lm_logits
(
hidden_states
,
word_embeddings_weight
,
word_embeddings_weight
,
...
@@ -107,7 +110,6 @@ class BertLMHead(MegatronModule):
...
@@ -107,7 +110,6 @@ class BertLMHead(MegatronModule):
return
output
return
output
class
BertModel
(
MegatronModule
):
class
BertModel
(
MegatronModule
):
"""Bert Language model."""
"""Bert Language model."""
...
@@ -184,7 +186,6 @@ class BertModel(MegatronModule):
...
@@ -184,7 +186,6 @@ class BertModel(MegatronModule):
return
lm_logits
,
None
return
lm_logits
,
None
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
"""For easy load when model is combined with other heads,
...
@@ -206,7 +207,6 @@ class BertModel(MegatronModule):
...
@@ -206,7 +207,6 @@ class BertModel(MegatronModule):
=
self
.
ict_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
=
self
.
ict_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Customized load."""
...
@@ -224,8 +224,6 @@ class BertModel(MegatronModule):
...
@@ -224,8 +224,6 @@ class BertModel(MegatronModule):
class
REALMBertModel
(
MegatronModule
):
class
REALMBertModel
(
MegatronModule
):
# TODO: load BertModel checkpoint
def
__init__
(
self
,
retriever
):
def
__init__
(
self
,
retriever
):
super
(
REALMBertModel
,
self
).
__init__
()
super
(
REALMBertModel
,
self
).
__init__
()
bert_args
=
dict
(
bert_args
=
dict
(
...
...
megatron/model/classification.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -53,7 +53,6 @@ class Classification(MegatronModule):
...
@@ -53,7 +53,6 @@ class Classification(MegatronModule):
init_method
)
init_method
)
self
.
_classification_head_key
=
'classification_head'
self
.
_classification_head_key
=
'classification_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
):
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
):
extended_attention_mask
=
bert_extended_attention_mask
(
extended_attention_mask
=
bert_extended_attention_mask
(
...
@@ -74,7 +73,6 @@ class Classification(MegatronModule):
...
@@ -74,7 +73,6 @@ class Classification(MegatronModule):
return
classification_logits
return
classification_logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
"""For easy load when model is combined with other heads,
...
@@ -89,7 +87,6 @@ class Classification(MegatronModule):
...
@@ -89,7 +87,6 @@ class Classification(MegatronModule):
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Customized load."""
...
...
megatron/model/distributed.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -31,10 +31,6 @@ class DistributedDataParallel(MegatronModule):
...
@@ -31,10 +31,6 @@ class DistributedDataParallel(MegatronModule):
self
.
module
=
module
self
.
module
=
module
self
.
data_parallel_group
=
mpu
.
get_data_parallel_group
()
self
.
data_parallel_group
=
mpu
.
get_data_parallel_group
()
src_rank
=
mpu
.
get_model_parallel_rank
()
for
p
in
self
.
module
.
parameters
():
if
torch
.
is_tensor
(
p
):
dist
.
broadcast
(
p
,
src_rank
,
group
=
self
.
data_parallel_group
)
def
allreduce_params
(
reduce_after
=
True
,
no_scale
=
False
,
fp32_allreduce
=
False
):
def
allreduce_params
(
reduce_after
=
True
,
no_scale
=
False
,
fp32_allreduce
=
False
):
if
(
self
.
needs_reduction
):
if
(
self
.
needs_reduction
):
...
@@ -71,8 +67,8 @@ class DistributedDataParallel(MegatronModule):
...
@@ -71,8 +67,8 @@ class DistributedDataParallel(MegatronModule):
def
allreduce_hook
(
*
unused
):
def
allreduce_hook
(
*
unused
):
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
# handle = param.register_hook(allreduce_hook)
# handle = param.register_hook(allreduce_hook)
#self.hooks.append(allreduce_hook)
#
self.hooks.append(allreduce_hook)
#self.hook_handles.append(handle)
#
self.hook_handles.append(handle)
self
.
allreduce_params
=
allreduce_params
self
.
allreduce_params
=
allreduce_params
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
...
@@ -114,4 +110,3 @@ class DistributedDataParallel(MegatronModule):
...
@@ -114,4 +110,3 @@ class DistributedDataParallel(MegatronModule):
super(DistributedDataParallel, self).train(mode)
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
self.module.train(mode)
'''
'''
megatron/model/gpt2_model.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal
...
@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal
def
gpt2_attention_mask_func
(
attention_scores
,
ltor_mask
):
def
gpt2_attention_mask_func
(
attention_scores
,
ltor_mask
):
attention_scores
=
torch
.
mul
(
attention_scores
,
ltor_mask
)
-
\
attention_scores
.
masked_fill_
(
ltor_mask
,
-
10000.0
)
10000.0
*
(
1.0
-
ltor_mask
)
return
attention_scores
return
attention_scores
...
@@ -49,7 +48,6 @@ class GPT2Model(MegatronModule):
...
@@ -49,7 +48,6 @@ class GPT2Model(MegatronModule):
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
))
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
forward_method_parallel_output
=
None
):
...
@@ -79,7 +77,6 @@ class GPT2Model(MegatronModule):
...
@@ -79,7 +77,6 @@ class GPT2Model(MegatronModule):
return
output
return
output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
...
@@ -89,7 +86,6 @@ class GPT2Model(MegatronModule):
...
@@ -89,7 +86,6 @@ class GPT2Model(MegatronModule):
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Customized load."""
...
...
megatron/model/language_model.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -21,9 +21,8 @@ import torch.nn.functional as F
...
@@ -21,9 +21,8 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
gelu
from
megatron.model.utils
import
openai_
gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
...
@@ -47,6 +46,12 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -47,6 +46,12 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
def
get_language_model
(
attention_mask_func
,
num_tokentypes
,
add_pooler
,
init_method
,
scaled_init_method
,
max_pos_embeds
=
None
):
init_method
,
scaled_init_method
,
max_pos_embeds
=
None
):
"""Build language model and return along with the key to save."""
"""Build language model and return along with the key to save."""
args
=
get_args
()
# Use torch gelu unless otherwise forced.
gelu
=
F
.
gelu
if
args
.
openai_gelu
:
gelu
=
openai_gelu
# Language model.
# Language model.
language_model
=
TransformerLanguageModel
(
language_model
=
TransformerLanguageModel
(
...
@@ -63,7 +68,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
...
@@ -63,7 +68,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
return
language_model
,
language_model_key
return
language_model
,
language_model_key
class
Pooler
(
MegatronModule
):
class
Pooler
(
MegatronModule
):
"""Pooler layer.
"""Pooler layer.
...
@@ -75,11 +79,11 @@ class Pooler(MegatronModule):
...
@@ -75,11 +79,11 @@ class Pooler(MegatronModule):
init_method: weight initialization method for the linear layer.
init_method: weight initialization method for the linear layer.
bias is set to zero.
bias is set to zero.
"""
"""
def
__init__
(
self
,
hidden_size
,
init_method
):
def
__init__
(
self
,
hidden_size
,
init_method
):
super
(
Pooler
,
self
).
__init__
()
super
(
Pooler
,
self
).
__init__
()
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
# hidden_states: [b, s, h]
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
# sequence_index: index of the token to pool.
...
@@ -102,6 +106,7 @@ class Embedding(MegatronModule):
...
@@ -102,6 +106,7 @@ class Embedding(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
will ignore this embedding
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
hidden_size
,
hidden_size
,
vocab_size
,
vocab_size
,
...
@@ -143,7 +148,6 @@ class Embedding(MegatronModule):
...
@@ -143,7 +148,6 @@ class Embedding(MegatronModule):
# Embeddings dropout
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
def
add_tokentype_embeddings
(
self
,
num_tokentypes
):
def
add_tokentype_embeddings
(
self
,
num_tokentypes
):
"""Add token-type embedding. This function is provided so we can add
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
token-type embeddings in case the pretrained model does not have it.
...
@@ -160,7 +164,6 @@ class Embedding(MegatronModule):
...
@@ -160,7 +164,6 @@ class Embedding(MegatronModule):
# Initialize the token-type embeddings.
# Initialize the token-type embeddings.
self
.
init_method
(
self
.
tokentype_embeddings
.
weight
)
self
.
init_method
(
self
.
tokentype_embeddings
.
weight
)
def
forward
(
self
,
input_ids
,
position_ids
,
tokentype_ids
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
,
tokentype_ids
=
None
):
# Embeddings.
# Embeddings.
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
...
@@ -177,7 +180,6 @@ class Embedding(MegatronModule):
...
@@ -177,7 +180,6 @@ class Embedding(MegatronModule):
return
embeddings
return
embeddings
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load."""
"""For easy load."""
...
@@ -195,7 +197,6 @@ class Embedding(MegatronModule):
...
@@ -195,7 +197,6 @@ class Embedding(MegatronModule):
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Customized load."""
...
@@ -242,7 +243,6 @@ class Embedding(MegatronModule):
...
@@ -242,7 +243,6 @@ class Embedding(MegatronModule):
'checkpoint but could not find it'
,
flush
=
True
)
'checkpoint but could not find it'
,
flush
=
True
)
class
TransformerLanguageModel
(
MegatronModule
):
class
TransformerLanguageModel
(
MegatronModule
):
"""Transformer language model.
"""Transformer language model.
...
@@ -261,6 +261,7 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -261,6 +261,7 @@ class TransformerLanguageModel(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
will ignore this embedding
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
attention_mask_func
,
attention_mask_func
,
mlp_activation_func
,
mlp_activation_func
,
...
@@ -298,7 +299,6 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -298,7 +299,6 @@ class TransformerLanguageModel(MegatronModule):
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
self
.
_pooler_key
=
'pooler'
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
pooling_sequence_index
=
0
):
...
@@ -320,7 +320,6 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -320,7 +320,6 @@ class TransformerLanguageModel(MegatronModule):
return
transformer_output
return
transformer_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load."""
"""For easy load."""
...
@@ -339,7 +338,6 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -339,7 +338,6 @@ class TransformerLanguageModel(MegatronModule):
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Customized load."""
...
...
megatron/model/multiple_choice.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule):
...
@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule):
init_method
)
init_method
)
self
.
_multichoice_head_key
=
'multichoice_head'
self
.
_multichoice_head_key
=
'multichoice_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
):
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
# [batch, choices, sequence] --> [batch * choices, sequence] -->
...
@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule):
...
@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule):
return
multichoice_logits
return
multichoice_logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
"""For easy load when model is combined with other heads,
...
@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule):
...
@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule):
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Customized load."""
...
...
megatron/model/transformer.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -46,6 +46,7 @@ from megatron.module import MegatronModule
...
@@ -46,6 +46,7 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask)
unmaksed-attention-scores, attention-mask)
"""
"""
class
ParallelMLP
(
MegatronModule
):
class
ParallelMLP
(
MegatronModule
):
"""MLP.
"""MLP.
...
@@ -63,7 +64,7 @@ class ParallelMLP(MegatronModule):
...
@@ -63,7 +64,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
# Project to 4h.
self
.
dense_h_to_4h
=
mpu
.
ColumnParallelLinear
(
self
.
dense_h_to_4h
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
4
*
args
.
hidden_size
,
4
*
args
.
hidden_size
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
)
init_method
=
init_method
)
...
@@ -71,14 +72,13 @@ class ParallelMLP(MegatronModule):
...
@@ -71,14 +72,13 @@ class ParallelMLP(MegatronModule):
# Project back to h.
# Project back to h.
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
4
*
args
.
hidden_size
,
4
*
args
.
hidden_size
,
args
.
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
)
init_method
=
output_layer_init_method
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
# [b, s, 4hp]
# [b, s, 4hp]
...
@@ -91,13 +91,13 @@ class ParallelMLP(MegatronModule):
...
@@ -91,13 +91,13 @@ class ParallelMLP(MegatronModule):
return
output
return
output
class
ParallelSelfAttention
(
MegatronModule
):
class
ParallelSelfAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
and returns output of the same size.
"""
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
):
output_layer_init_method
,
layer_number
):
super
(
ParallelSelfAttention
,
self
).
__init__
()
super
(
ParallelSelfAttention
,
self
).
__init__
()
...
@@ -123,7 +123,7 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -123,7 +123,7 @@ class ParallelSelfAttention(MegatronModule):
# Strided linear layer.
# Strided linear layer.
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
3
*
args
.
hidden_size
,
3
*
args
.
hidden_size
,
stride
=
3
,
stride
=
3
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
)
init_method
=
init_method
)
...
@@ -141,7 +141,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -141,7 +141,6 @@ class ParallelSelfAttention(MegatronModule):
init_method
=
output_layer_init_method
)
init_method
=
output_layer_init_method
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
def
_transpose_for_scores
(
self
,
tensor
):
def
_transpose_for_scores
(
self
,
tensor
):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
size [b, np, s, hn].
...
@@ -152,7 +151,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -152,7 +151,6 @@ class ParallelSelfAttention(MegatronModule):
tensor
=
tensor
.
view
(
*
new_tensor_shape
)
tensor
=
tensor
.
view
(
*
new_tensor_shape
)
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
def
_get_query_key_value
(
self
,
hidden_states
):
def
_get_query_key_value
(
self
,
hidden_states
):
"""Get query, key, and value and transpose to
"""Get query, key, and value and transpose to
get size [b, np, s, hn].
get size [b, np, s, hn].
...
@@ -170,7 +168,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -170,7 +168,6 @@ class ParallelSelfAttention(MegatronModule):
return
query_layer
,
key_layer
,
value_layer
return
query_layer
,
key_layer
,
value_layer
def
_get_unmasked_attention_scores
(
self
,
query_layer
,
key_layer
):
def
_get_unmasked_attention_scores
(
self
,
query_layer
,
key_layer
):
"""Unmasked attention scores with size [b, np, s, s]."""
"""Unmasked attention scores with size [b, np, s, s]."""
coeff
=
1
coeff
=
1
...
@@ -179,9 +176,8 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -179,9 +176,8 @@ class ParallelSelfAttention(MegatronModule):
norm_factor
=
math
.
sqrt
(
coeff
*
norm_factor
=
math
.
sqrt
(
coeff
*
math
.
sqrt
(
self
.
hidden_size_per_attention_head
))
math
.
sqrt
(
self
.
hidden_size_per_attention_head
))
# Raw attention scores. [b, np, s, s]
# Raw attention scores. [b, np, s, s]
return
torch
.
matmul
(
query_layer
/
norm_factor
,
return
torch
.
matmul
(
query_layer
/
norm_factor
,
key_layer
.
transpose
(
-
1
,
-
2
)
/
norm_factor
)
key_layer
.
transpose
(
-
1
,
-
2
)
/
norm_factor
)
def
_get_attention_probs
(
self
,
attention_scores
):
def
_get_attention_probs
(
self
,
attention_scores
):
"""Attention probabilies with dropout. The output has
"""Attention probabilies with dropout. The output has
...
@@ -198,7 +194,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -198,7 +194,6 @@ class ParallelSelfAttention(MegatronModule):
return
attention_probs
return
attention_probs
def
_get_attended_context
(
self
,
attention_probs
,
value_layer
):
def
_get_attended_context
(
self
,
attention_probs
,
value_layer
):
"""Final attended tesnor and transposed back to [b, s, hp]."""
"""Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer.
# Context layer.
...
@@ -213,7 +208,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -213,7 +208,6 @@ class ParallelSelfAttention(MegatronModule):
return
context_layer
return
context_layer
def
_get_output
(
self
,
context_layer
):
def
_get_output
(
self
,
context_layer
):
"""Output layer with dropout."""
"""Output layer with dropout."""
# Output. [b, s, h]
# Output. [b, s, h]
...
@@ -222,7 +216,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -222,7 +216,6 @@ class ParallelSelfAttention(MegatronModule):
return
output
return
output
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
):
# hidden_states: [b, s, h]
# hidden_states: [b, s, h]
...
@@ -254,7 +247,7 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -254,7 +247,7 @@ class ParallelSelfAttention(MegatronModule):
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
attention_mask
=
attention_mask
[
attention_mask
=
attention_mask
[
...,
...,
attention_scores
.
size
(
3
)
-
1
,
attention_scores
.
size
(
3
)
-
1
,
:
attention_scores
.
size
(
3
)].
unsqueeze
(
2
)
:
attention_scores
.
size
(
3
)].
unsqueeze
(
2
)
else
:
else
:
attention_mask
=
attention_mask
[
attention_mask
=
attention_mask
[
...
@@ -283,13 +276,13 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -283,13 +276,13 @@ class ParallelSelfAttention(MegatronModule):
return
output
return
output
class
ParallelTransformerLayer
(
MegatronModule
):
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
output of the same size.
"""
"""
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
,
layer_number
):
init_method
,
output_layer_init_method
,
layer_number
):
args
=
get_args
()
args
=
get_args
()
...
@@ -319,7 +312,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -319,7 +312,6 @@ class ParallelTransformerLayer(MegatronModule):
self
.
mlp
=
ParallelMLP
(
mlp_activation_func
,
init_method
,
self
.
mlp
=
ParallelMLP
(
mlp_activation_func
,
init_method
,
output_layer_init_method
)
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
):
# hidden_states: [b, s, h]
# hidden_states: [b, s, h]
...
@@ -375,14 +367,13 @@ class ParallelTransformer(MegatronModule):
...
@@ -375,14 +367,13 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
# Transformer layers.
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
get_layer
(
i
+
1
)
for
i
in
range
(
args
.
num_layers
)])
[
get_layer
(
i
+
1
)
for
i
in
range
(
args
.
num_layers
)])
# Final layer norm before output.
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
):
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
):
"""Forward method with activation checkpointing."""
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom
(
start
,
end
):
...
@@ -398,13 +389,12 @@ class ParallelTransformer(MegatronModule):
...
@@ -398,13 +389,12 @@ class ParallelTransformer(MegatronModule):
num_layers
=
len
(
self
.
layers
)
num_layers
=
len
(
self
.
layers
)
while
l
<
num_layers
:
while
l
<
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
hidden_states
,
attention_mask
)
hidden_states
,
attention_mask
)
l
+=
self
.
checkpoint_num_layers
l
+=
self
.
checkpoint_num_layers
return
hidden_states
return
hidden_states
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
):
...
...
megatron/model/utils.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -33,6 +33,7 @@ def init_method_normal(sigma):
...
@@ -33,6 +33,7 @@ def init_method_normal(sigma):
def
scaled_init_method_normal
(
sigma
,
num_layers
):
def
scaled_init_method_normal
(
sigma
,
num_layers
):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std
=
sigma
/
math
.
sqrt
(
2.0
*
num_layers
)
std
=
sigma
/
math
.
sqrt
(
2.0
*
num_layers
)
def
init_
(
tensor
):
def
init_
(
tensor
):
return
torch
.
nn
.
init
.
normal_
(
tensor
,
mean
=
0.0
,
std
=
std
)
return
torch
.
nn
.
init
.
normal_
(
tensor
,
mean
=
0.0
,
std
=
std
)
...
@@ -53,8 +54,7 @@ def gelu_impl(x):
...
@@ -53,8 +54,7 @@ def gelu_impl(x):
"""OpenAI's gelu implementation."""
"""OpenAI's gelu implementation."""
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
0.7978845608028654
*
x
*
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
0.7978845608028654
*
x
*
(
1.0
+
0.044715
*
x
*
x
)))
(
1.0
+
0.044715
*
x
*
x
)))
def
openai_gelu
(
x
):
def
gelu
(
x
):
return
gelu_impl
(
x
)
return
gelu_impl
(
x
)
...
...
megatron/module.py
View file @
5e56e563
# coding=utf-8
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -21,11 +21,9 @@ import torch
...
@@ -21,11 +21,9 @@ import torch
class
MegatronModule
(
torch
.
nn
.
Module
):
class
MegatronModule
(
torch
.
nn
.
Module
):
"""Megatron specific extentions of torch Module."""
"""Megatron specific extentions of torch Module."""
def
__init__
(
self
):
def
__init__
(
self
):
super
(
MegatronModule
,
self
).
__init__
()
super
(
MegatronModule
,
self
).
__init__
()
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""Use this function to override the state dict for
"""Use this function to override the state dict for
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment