Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b7f1b050
Commit
b7f1b050
authored
Apr 14, 2020
by
Neel Kant
Browse files
Lint whole repo
parent
c99fa80c
Changes
63
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
201 additions
and
199 deletions
+201
-199
megatron/deprecated_data_utils/tokenization_gpt2.py
megatron/deprecated_data_utils/tokenization_gpt2.py
+32
-17
megatron/deprecated_data_utils/wordpiece.py
megatron/deprecated_data_utils/wordpiece.py
+9
-8
megatron/fp16/fp16.py
megatron/fp16/fp16.py
+96
-87
megatron/fp16/fp16util.py
megatron/fp16/fp16util.py
+9
-6
megatron/fp16/loss_scaler.py
megatron/fp16/loss_scaler.py
+21
-16
megatron/global_vars.py
megatron/global_vars.py
+1
-1
megatron/learning_rates.py
megatron/learning_rates.py
+0
-5
megatron/model/bert_model.py
megatron/model/bert_model.py
+1
-6
megatron/model/classification.py
megatron/model/classification.py
+0
-3
megatron/model/distributed.py
megatron/model/distributed.py
+2
-3
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+1
-4
megatron/model/language_model.py
megatron/model/language_model.py
+4
-11
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+0
-3
megatron/model/transformer.py
megatron/model/transformer.py
+14
-24
megatron/model/utils.py
megatron/model/utils.py
+2
-0
megatron/module.py
megatron/module.py
+0
-2
megatron/mpu/cross_entropy.py
megatron/mpu/cross_entropy.py
+0
-1
megatron/mpu/layers.py
megatron/mpu/layers.py
+5
-2
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+3
-0
megatron/mpu/random.py
megatron/mpu/random.py
+1
-0
No files found.
megatron/deprecated_data_utils/tokenization_gpt2.py
View file @
b7f1b050
...
@@ -27,7 +27,8 @@ try:
...
@@ -27,7 +27,8 @@ try:
from
functools
import
lru_cache
from
functools
import
lru_cache
except
ImportError
:
except
ImportError
:
# Just a dummy decorator to get the checks to run on python2
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def
lru_cache
():
def
lru_cache
():
return
lambda
func
:
func
return
lambda
func
:
func
...
@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json'
...
@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME
=
'merges.txt'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
@
lru_cache
()
@
lru_cache
()
def
bytes_to_unicode
():
def
bytes_to_unicode
():
"""
"""
...
@@ -60,17 +62,19 @@ def bytes_to_unicode():
...
@@ -60,17 +62,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
"""
_chr
=
unichr
if
sys
.
version_info
[
0
]
==
2
else
chr
_chr
=
unichr
if
sys
.
version_info
[
0
]
==
2
else
chr
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
\
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
cs
=
bs
[:]
cs
=
bs
[:]
n
=
0
n
=
0
for
b
in
range
(
2
**
8
):
for
b
in
range
(
2
**
8
):
if
b
not
in
bs
:
if
b
not
in
bs
:
bs
.
append
(
b
)
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
n
+=
1
cs
=
[
_chr
(
n
)
for
n
in
cs
]
cs
=
[
_chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
return
dict
(
zip
(
bs
,
cs
))
def
get_pairs
(
word
):
def
get_pairs
(
word
):
"""Return set of symbol pairs in a word.
"""Return set of symbol pairs in a word.
...
@@ -83,6 +87,7 @@ def get_pairs(word):
...
@@ -83,6 +87,7 @@ def get_pairs(word):
prev_char
=
char
prev_char
=
char
return
pairs
return
pairs
class
GPT2Tokenizer
(
object
):
class
GPT2Tokenizer
(
object
):
"""
"""
GPT-2 BPE tokenizer. Peculiarities:
GPT-2 BPE tokenizer. Peculiarities:
...
@@ -138,23 +143,31 @@ class GPT2Tokenizer(object):
...
@@ -138,23 +143,31 @@ class GPT2Tokenizer(object):
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
special_tokens
=
None
,
max_len
=
None
):
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
special_tokens
=
None
,
max_len
=
None
):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
errors
=
errors
# how to handle errors in decoding
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
bpe_data
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
bpe_data
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
bpe_merges
=
[
tuple
(
merge
.
split
())
for
merge
in
bpe_data
]
bpe_merges
=
[
tuple
(
merge
.
split
())
for
merge
in
bpe_data
]
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
cache
=
{}
self
.
cache
=
{}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
# Should haved added re.IGNORECASE so BPE merges can happen for
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
# capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self
.
special_tokens
=
{}
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
special_tokens_decoder
=
{}
...
@@ -172,8 +185,9 @@ class GPT2Tokenizer(object):
...
@@ -172,8 +185,9 @@ class GPT2Tokenizer(object):
self
.
special_tokens
=
{}
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
special_tokens_decoder
=
{}
return
return
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
logger
.
info
(
"Special tokens {}"
.
format
(
self
.
special_tokens
))
logger
.
info
(
"Special tokens {}"
.
format
(
self
.
special_tokens
))
def
bpe
(
self
,
token
):
def
bpe
(
self
,
token
):
...
@@ -186,7 +200,7 @@ class GPT2Tokenizer(object):
...
@@ -186,7 +200,7 @@ class GPT2Tokenizer(object):
return
token
return
token
while
True
:
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
if
bigram
not
in
self
.
bpe_ranks
:
break
break
first
,
second
=
bigram
first
,
second
=
bigram
...
@@ -197,12 +211,12 @@ class GPT2Tokenizer(object):
...
@@ -197,12 +211,12 @@ class GPT2Tokenizer(object):
j
=
word
.
index
(
first
,
i
)
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
i
=
j
except
:
except
BaseException
:
new_word
.
extend
(
word
[
i
:])
new_word
.
extend
(
word
[
i
:])
break
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
new_word
.
append
(
first
+
second
)
i
+=
2
i
+=
2
else
:
else
:
new_word
.
append
(
word
[
i
])
new_word
.
append
(
word
[
i
])
...
@@ -245,7 +259,8 @@ class GPT2Tokenizer(object):
...
@@ -245,7 +259,8 @@ class GPT2Tokenizer(object):
logger
.
warning
(
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum "
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
)
)
return
ids
return
ids
...
...
megatron/deprecated_data_utils/wordpiece.py
View file @
b7f1b050
...
@@ -123,7 +123,8 @@ class BertTokenizer(object):
...
@@ -123,7 +123,8 @@ class BertTokenizer(object):
logger
.
warning
(
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum "
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
" sequence through BERT will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
)
)
return
ids
return
ids
...
...
megatron/fp16/fp16.py
View file @
b7f1b050
...
@@ -28,6 +28,7 @@ from megatron.module import MegatronModule
...
@@ -28,6 +28,7 @@ from megatron.module import MegatronModule
FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
def
conversion_helper
(
val
,
conversion
):
def
conversion_helper
(
val
,
conversion
):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if
not
isinstance
(
val
,
(
tuple
,
list
)):
if
not
isinstance
(
val
,
(
tuple
,
list
)):
...
@@ -37,6 +38,7 @@ def conversion_helper(val, conversion):
...
@@ -37,6 +38,7 @@ def conversion_helper(val, conversion):
rtn
=
tuple
(
rtn
)
rtn
=
tuple
(
rtn
)
return
rtn
return
rtn
def
fp32_to_fp16
(
val
):
def
fp32_to_fp16
(
val
):
"""Convert fp32 `val` to fp16"""
"""Convert fp32 `val` to fp16"""
def
half_conversion
(
val
):
def
half_conversion
(
val
):
...
@@ -48,6 +50,7 @@ def fp32_to_fp16(val):
...
@@ -48,6 +50,7 @@ def fp32_to_fp16(val):
return
val
return
val
return
conversion_helper
(
val
,
half_conversion
)
return
conversion_helper
(
val
,
half_conversion
)
def
fp16_to_fp32
(
val
):
def
fp16_to_fp32
(
val
):
"""Convert fp16 `val` to fp32"""
"""Convert fp16 `val` to fp32"""
def
float_conversion
(
val
):
def
float_conversion
(
val
):
...
@@ -59,6 +62,7 @@ def fp16_to_fp32(val):
...
@@ -59,6 +62,7 @@ def fp16_to_fp32(val):
return
val
return
val
return
conversion_helper
(
val
,
float_conversion
)
return
conversion_helper
(
val
,
float_conversion
)
class
FP16_Module
(
MegatronModule
):
class
FP16_Module
(
MegatronModule
):
def
__init__
(
self
,
module
):
def
__init__
(
self
,
module
):
super
(
FP16_Module
,
self
).
__init__
()
super
(
FP16_Module
,
self
).
__init__
()
...
@@ -79,6 +83,8 @@ class FP16_Module(MegatronModule):
...
@@ -79,6 +83,8 @@ class FP16_Module(MegatronModule):
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
# TODO: Update overflow check + downscale to use Carl's fused kernel.
# TODO: Update overflow check + downscale to use Carl's fused kernel.
class
FP16_Optimizer
(
object
):
class
FP16_Optimizer
(
object
):
"""
"""
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
...
@@ -305,7 +311,8 @@ class FP16_Optimizer(object):
...
@@ -305,7 +311,8 @@ class FP16_Optimizer(object):
master_params_to_model_params
(
fp32_from_fp16_group
,
fp16_group
)
master_params_to_model_params
(
fp32_from_fp16_group
,
fp16_group
)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
# that does the overflow check, gradient copy + downscale, and fp32
# allreduce in a different stream.
def
_model_grads_to_master_grads
(
self
):
def
_model_grads_to_master_grads
(
self
):
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
model_grads_to_master_grads
(
fp16_group
,
fp32_from_fp16_group
)
model_grads_to_master_grads
(
fp16_group
,
fp32_from_fp16_group
)
...
@@ -315,7 +322,7 @@ class FP16_Optimizer(object):
...
@@ -315,7 +322,7 @@ class FP16_Optimizer(object):
for
group
in
self
.
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
for
param
in
group
[
'params'
]:
for
param
in
group
[
'params'
]:
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
param
.
grad
.
data
.
mul_
(
1.
/
self
.
loss_scale
)
param
.
grad
.
data
.
mul_
(
1.
/
self
.
loss_scale
)
def
clip_master_grads
(
self
,
max_norm
,
norm_type
=
2
):
def
clip_master_grads
(
self
,
max_norm
,
norm_type
=
2
):
"""
"""
...
@@ -400,7 +407,8 @@ class FP16_Optimizer(object):
...
@@ -400,7 +407,8 @@ class FP16_Optimizer(object):
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
# are guaranteed to exist, so we can just copy_() from the saved master params.
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_fp16_groups
,
state_dict
[
'fp32_from_fp16'
]):
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_fp16_groups
,
state_dict
[
'fp32_from_fp16'
]):
for
current
,
saved
in
zip
(
current_group
,
saved_group
):
for
current
,
saved
in
zip
(
current_group
,
saved_group
):
current
.
data
.
copy_
(
saved
.
data
)
current
.
data
.
copy_
(
saved
.
data
)
...
@@ -570,7 +578,8 @@ class FP16_Optimizer(object):
...
@@ -570,7 +578,8 @@ class FP16_Optimizer(object):
"""
"""
if
self
.
dynamic_loss_scale
:
if
self
.
dynamic_loss_scale
:
self
.
_check_overflow
()
self
.
_check_overflow
()
if
self
.
overflow
:
return
if
self
.
overflow
:
return
self
.
_model_grads_to_master_grads
()
self
.
_model_grads_to_master_grads
()
self
.
_downscale_master
()
self
.
_downscale_master
()
...
@@ -607,8 +616,8 @@ class FP16_Optimizer(object):
...
@@ -607,8 +616,8 @@ class FP16_Optimizer(object):
master_grads_data
.
append
(
master_grads_this_group
)
master_grads_data
.
append
(
master_grads_this_group
)
return
master_grads_data
return
master_grads_data
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def
_get_loss_scale
(
self
):
def
_get_loss_scale
(
self
):
return
self
.
loss_scaler
.
loss_scale
return
self
.
loss_scaler
.
loss_scale
...
...
megatron/fp16/fp16util.py
View file @
b7f1b050
...
@@ -102,6 +102,7 @@ class FP16Model(nn.Module):
...
@@ -102,6 +102,7 @@ class FP16Model(nn.Module):
def
backwards_debug_hook
(
grad
):
def
backwards_debug_hook
(
grad
):
raise
RuntimeError
(
"master_params recieved a gradient in the backward pass!"
)
raise
RuntimeError
(
"master_params recieved a gradient in the backward pass!"
)
def
prep_param_lists
(
model
,
flat_master
=
False
):
def
prep_param_lists
(
model
,
flat_master
=
False
):
"""
"""
Creates a list of FP32 master parameters for a given model, as in
Creates a list of FP32 master parameters for a given model, as in
...
@@ -131,7 +132,7 @@ def prep_param_lists(model, flat_master=False):
...
@@ -131,7 +132,7 @@ def prep_param_lists(model, flat_master=False):
# flatten_dense_tensors returns a contiguous flat array.
# flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html
# http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params
=
_flatten_dense_tensors
([
param
.
data
for
param
in
model_params
]).
float
()
master_params
=
_flatten_dense_tensors
([
param
.
data
for
param
in
model_params
]).
float
()
except
:
except
BaseException
:
print
(
"Error in prep_param_lists: model may contain a mixture of parameters "
print
(
"Error in prep_param_lists: model may contain a mixture of parameters "
"of different types. Use flat_master=False, or use F16_Optimizer."
)
"of different types. Use flat_master=False, or use F16_Optimizer."
)
raise
raise
...
@@ -188,17 +189,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
...
@@ -188,17 +189,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
# Backward compatibility fixes
# Backward compatibility fixes
def
to_python_float
(
t
):
def
to_python_float
(
t
):
if
hasattr
(
t
,
'item'
):
if
hasattr
(
t
,
'item'
):
return
t
.
item
()
return
t
.
item
()
else
:
else
:
return
t
[
0
]
return
t
[
0
]
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
clip_grad_norm
=
mpu
.
clip_grad_norm
clip_grad_norm
=
mpu
.
clip_grad_norm
#elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
#
elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# clip_grad_norm = torch.nn.utils.clip_grad_norm
# clip_grad_norm = torch.nn.utils.clip_grad_norm
#else:
#
else:
# clip_grad_norm = torch.nn.utils.clip_grad_norm_
# clip_grad_norm = torch.nn.utils.clip_grad_norm_
megatron/fp16/loss_scaler.py
View file @
b7f1b050
...
@@ -17,12 +17,15 @@ import torch
...
@@ -17,12 +17,15 @@ import torch
from
megatron
import
mpu
from
megatron
import
mpu
# item() is a recent addition, so this helps with backward compatibility.
# item() is a recent addition, so this helps with backward compatibility.
def
to_python_float
(
t
):
def
to_python_float
(
t
):
if
hasattr
(
t
,
'item'
):
if
hasattr
(
t
,
'item'
):
return
t
.
item
()
return
t
.
item
()
else
:
else
:
return
t
[
0
]
return
t
[
0
]
class
LossScaler
:
class
LossScaler
:
"""
"""
Class that manages a static loss scale. This class is intended to interact with
Class that manages a static loss scale. This class is intended to interact with
...
@@ -57,9 +60,10 @@ class LossScaler:
...
@@ -57,9 +60,10 @@ class LossScaler:
return
tuple
(
self
.
loss_scale
*
g
for
g
in
grad_in
)
return
tuple
(
self
.
loss_scale
*
g
for
g
in
grad_in
)
def
backward
(
self
,
loss
,
retain_graph
=
False
):
def
backward
(
self
,
loss
,
retain_graph
=
False
):
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
class
DynamicLossScaler
:
class
DynamicLossScaler
:
"""
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
...
@@ -122,8 +126,8 @@ class DynamicLossScaler:
...
@@ -122,8 +126,8 @@ class DynamicLossScaler:
overflow
=
overflow_gpu
[
0
].
item
()
overflow
=
overflow_gpu
[
0
].
item
()
return
bool
(
overflow
)
return
bool
(
overflow
)
# `x` is a torch.Tensor
# `x` is a torch.Tensor
def
_has_inf_or_nan
(
x
):
def
_has_inf_or_nan
(
x
):
try
:
try
:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
...
@@ -158,7 +162,7 @@ class DynamicLossScaler:
...
@@ -158,7 +162,7 @@ class DynamicLossScaler:
if
overflow
:
if
overflow
:
# self.cur_scale /= self.scale_factor
# self.cur_scale /= self.scale_factor
if
self
.
delayed_shift
==
1
or
self
.
cur_hysteresis
==
1
:
if
self
.
delayed_shift
==
1
or
self
.
cur_hysteresis
==
1
:
self
.
cur_scale
=
max
(
self
.
cur_scale
/
self
.
scale_factor
,
self
.
min_scale
)
self
.
cur_scale
=
max
(
self
.
cur_scale
/
self
.
scale_factor
,
self
.
min_scale
)
else
:
else
:
self
.
cur_hysteresis
-=
1
self
.
cur_hysteresis
-=
1
self
.
last_overflow_iter
=
self
.
cur_iter
self
.
last_overflow_iter
=
self
.
cur_iter
...
@@ -179,9 +183,10 @@ class DynamicLossScaler:
...
@@ -179,9 +183,10 @@ class DynamicLossScaler:
return
tuple
(
self
.
loss_scale
*
g
for
g
in
grad_in
)
return
tuple
(
self
.
loss_scale
*
g
for
g
in
grad_in
)
def
backward
(
self
,
loss
,
retain_graph
=
False
):
def
backward
(
self
,
loss
,
retain_graph
=
False
):
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
=
loss
*
self
.
loss_scale
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
scaled_loss
.
backward
(
retain_graph
=
retain_graph
)
##############################################################
##############################################################
# Example usage below here -- assuming it's in a separate file
# Example usage below here -- assuming it's in a separate file
##############################################################
##############################################################
...
...
megatron/global_vars.py
View file @
b7f1b050
...
@@ -124,7 +124,7 @@ def _set_adlr_autoresume(args):
...
@@ -124,7 +124,7 @@ def _set_adlr_autoresume(args):
sys
.
path
.
append
(
os
.
environ
.
get
(
'SUBMIT_SCRIPTS'
,
'.'
))
sys
.
path
.
append
(
os
.
environ
.
get
(
'SUBMIT_SCRIPTS'
,
'.'
))
try
:
try
:
from
userlib.auto_resume
import
AutoResume
from
userlib.auto_resume
import
AutoResume
except
:
except
BaseException
:
print
(
'ADLR autoresume is not available, exiting ...'
)
print
(
'ADLR autoresume is not available, exiting ...'
)
sys
.
exit
()
sys
.
exit
()
...
...
megatron/learning_rates.py
View file @
b7f1b050
...
@@ -48,7 +48,6 @@ class AnnealingLR(object):
...
@@ -48,7 +48,6 @@ class AnnealingLR(object):
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
def
get_lr
(
self
):
def
get_lr
(
self
):
"""Learning rate decay functions from:
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
...
@@ -71,7 +70,6 @@ class AnnealingLR(object):
...
@@ -71,7 +70,6 @@ class AnnealingLR(object):
lr
=
self
.
start_lr
lr
=
self
.
start_lr
return
max
(
lr
,
self
.
min_lr
)
return
max
(
lr
,
self
.
min_lr
)
def
step
(
self
,
step_num
=
None
):
def
step
(
self
,
step_num
=
None
):
"""Set lr for all parameters groups."""
"""Set lr for all parameters groups."""
if
step_num
is
None
:
if
step_num
is
None
:
...
@@ -81,7 +79,6 @@ class AnnealingLR(object):
...
@@ -81,7 +79,6 @@ class AnnealingLR(object):
for
group
in
self
.
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
group
[
'lr'
]
=
new_lr
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{
state_dict
=
{
'start_lr'
:
self
.
start_lr
,
'start_lr'
:
self
.
start_lr
,
...
@@ -93,7 +90,6 @@ class AnnealingLR(object):
...
@@ -93,7 +90,6 @@ class AnnealingLR(object):
}
}
return
state_dict
return
state_dict
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
"""Auxiliary function for checking the values in the checkpoint and
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
setting them."""
...
@@ -108,7 +104,6 @@ class AnnealingLR(object):
...
@@ -108,7 +104,6 @@ class AnnealingLR(object):
name
))
name
))
return
sd_value
return
sd_value
def
load_state_dict
(
self
,
sd
):
def
load_state_dict
(
self
,
sd
):
self
.
start_lr
=
self
.
_check_and_set
(
self
.
start_lr
,
sd
[
'start_lr'
],
self
.
start_lr
=
self
.
_check_and_set
(
self
.
start_lr
,
sd
[
'start_lr'
],
...
...
megatron/model/bert_model.py
View file @
b7f1b050
...
@@ -66,7 +66,6 @@ def bert_position_ids(token_ids):
...
@@ -66,7 +66,6 @@ def bert_position_ids(token_ids):
return
position_ids
return
position_ids
class
BertLMHead
(
MegatronModule
):
class
BertLMHead
(
MegatronModule
):
"""Masked LM head for Bert
"""Masked LM head for Bert
...
@@ -77,6 +76,7 @@ class BertLMHead(MegatronModule):
...
@@ -77,6 +76,7 @@ class BertLMHead(MegatronModule):
layernorm_epsilon: tolerance for layer norm divisions
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
parallel_output: wether output logits being distributed or not.
"""
"""
def
__init__
(
self
,
mpu_vocab_size
,
hidden_size
,
init_method
,
def
__init__
(
self
,
mpu_vocab_size
,
hidden_size
,
init_method
,
layernorm_epsilon
,
parallel_output
):
layernorm_epsilon
,
parallel_output
):
...
@@ -91,7 +91,6 @@ class BertLMHead(MegatronModule):
...
@@ -91,7 +91,6 @@ class BertLMHead(MegatronModule):
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
def
forward
(
self
,
hidden_states
,
word_embeddings_weight
):
def
forward
(
self
,
hidden_states
,
word_embeddings_weight
):
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
gelu
(
hidden_states
)
hidden_states
=
gelu
(
hidden_states
)
...
@@ -103,7 +102,6 @@ class BertLMHead(MegatronModule):
...
@@ -103,7 +102,6 @@ class BertLMHead(MegatronModule):
return
output
return
output
class
BertModel
(
MegatronModule
):
class
BertModel
(
MegatronModule
):
"""Bert Language model."""
"""Bert Language model."""
...
@@ -136,7 +134,6 @@ class BertModel(MegatronModule):
...
@@ -136,7 +134,6 @@ class BertModel(MegatronModule):
init_method
)
init_method
)
self
.
_binary_head_key
=
'binary_head'
self
.
_binary_head_key
=
'binary_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
extended_attention_mask
=
bert_extended_attention_mask
(
...
@@ -166,7 +163,6 @@ class BertModel(MegatronModule):
...
@@ -166,7 +163,6 @@ class BertModel(MegatronModule):
return
lm_logits
,
None
return
lm_logits
,
None
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
"""For easy load when model is combined with other heads,
...
@@ -184,7 +180,6 @@ class BertModel(MegatronModule):
...
@@ -184,7 +180,6 @@ class BertModel(MegatronModule):
=
self
.
binary_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
=
self
.
binary_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Customized load."""
...
...
megatron/model/classification.py
View file @
b7f1b050
...
@@ -53,7 +53,6 @@ class Classification(MegatronModule):
...
@@ -53,7 +53,6 @@ class Classification(MegatronModule):
init_method
)
init_method
)
self
.
_classification_head_key
=
'classification_head'
self
.
_classification_head_key
=
'classification_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
):
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
):
extended_attention_mask
=
bert_extended_attention_mask
(
extended_attention_mask
=
bert_extended_attention_mask
(
...
@@ -74,7 +73,6 @@ class Classification(MegatronModule):
...
@@ -74,7 +73,6 @@ class Classification(MegatronModule):
return
classification_logits
return
classification_logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
"""For easy load when model is combined with other heads,
...
@@ -89,7 +87,6 @@ class Classification(MegatronModule):
...
@@ -89,7 +87,6 @@ class Classification(MegatronModule):
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Customized load."""
...
...
megatron/model/distributed.py
View file @
b7f1b050
...
@@ -71,8 +71,8 @@ class DistributedDataParallel(MegatronModule):
...
@@ -71,8 +71,8 @@ class DistributedDataParallel(MegatronModule):
def
allreduce_hook
(
*
unused
):
def
allreduce_hook
(
*
unused
):
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
# handle = param.register_hook(allreduce_hook)
# handle = param.register_hook(allreduce_hook)
#self.hooks.append(allreduce_hook)
#
self.hooks.append(allreduce_hook)
#self.hook_handles.append(handle)
#
self.hook_handles.append(handle)
self
.
allreduce_params
=
allreduce_params
self
.
allreduce_params
=
allreduce_params
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
...
@@ -114,4 +114,3 @@ class DistributedDataParallel(MegatronModule):
...
@@ -114,4 +114,3 @@ class DistributedDataParallel(MegatronModule):
super(DistributedDataParallel, self).train(mode)
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
self.module.train(mode)
'''
'''
megatron/model/gpt2_model.py
View file @
b7f1b050
...
@@ -49,7 +49,6 @@ class GPT2Model(MegatronModule):
...
@@ -49,7 +49,6 @@ class GPT2Model(MegatronModule):
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
))
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
forward_method_parallel_output
=
None
):
...
@@ -79,7 +78,6 @@ class GPT2Model(MegatronModule):
...
@@ -79,7 +78,6 @@ class GPT2Model(MegatronModule):
return
output
return
output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
...
@@ -89,7 +87,6 @@ class GPT2Model(MegatronModule):
...
@@ -89,7 +87,6 @@ class GPT2Model(MegatronModule):
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Customized load."""
...
...
megatron/model/language_model.py
View file @
b7f1b050
...
@@ -62,7 +62,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
...
@@ -62,7 +62,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
return
language_model
,
language_model_key
return
language_model
,
language_model_key
class
Pooler
(
MegatronModule
):
class
Pooler
(
MegatronModule
):
"""Pooler layer.
"""Pooler layer.
...
@@ -74,11 +73,11 @@ class Pooler(MegatronModule):
...
@@ -74,11 +73,11 @@ class Pooler(MegatronModule):
init_method: weight initialization method for the linear layer.
init_method: weight initialization method for the linear layer.
bias is set to zero.
bias is set to zero.
"""
"""
def
__init__
(
self
,
hidden_size
,
init_method
):
def
__init__
(
self
,
hidden_size
,
init_method
):
super
(
Pooler
,
self
).
__init__
()
super
(
Pooler
,
self
).
__init__
()
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
# hidden_states: [b, s, h]
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
# sequence_index: index of the token to pool.
...
@@ -101,6 +100,7 @@ class Embedding(MegatronModule):
...
@@ -101,6 +100,7 @@ class Embedding(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
will ignore this embedding
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
hidden_size
,
hidden_size
,
vocab_size
,
vocab_size
,
...
@@ -142,7 +142,6 @@ class Embedding(MegatronModule):
...
@@ -142,7 +142,6 @@ class Embedding(MegatronModule):
# Embeddings dropout
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
def
add_tokentype_embeddings
(
self
,
num_tokentypes
):
def
add_tokentype_embeddings
(
self
,
num_tokentypes
):
"""Add token-type embedding. This function is provided so we can add
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
token-type embeddings in case the pretrained model does not have it.
...
@@ -159,7 +158,6 @@ class Embedding(MegatronModule):
...
@@ -159,7 +158,6 @@ class Embedding(MegatronModule):
# Initialize the token-type embeddings.
# Initialize the token-type embeddings.
self
.
init_method
(
self
.
tokentype_embeddings
.
weight
)
self
.
init_method
(
self
.
tokentype_embeddings
.
weight
)
def
forward
(
self
,
input_ids
,
position_ids
,
tokentype_ids
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
,
tokentype_ids
=
None
):
# Embeddings.
# Embeddings.
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
...
@@ -176,7 +174,6 @@ class Embedding(MegatronModule):
...
@@ -176,7 +174,6 @@ class Embedding(MegatronModule):
return
embeddings
return
embeddings
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load."""
"""For easy load."""
...
@@ -194,7 +191,6 @@ class Embedding(MegatronModule):
...
@@ -194,7 +191,6 @@ class Embedding(MegatronModule):
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Customized load."""
...
@@ -241,7 +237,6 @@ class Embedding(MegatronModule):
...
@@ -241,7 +237,6 @@ class Embedding(MegatronModule):
'checkpoint but could not find it'
,
flush
=
True
)
'checkpoint but could not find it'
,
flush
=
True
)
class
TransformerLanguageModel
(
MegatronModule
):
class
TransformerLanguageModel
(
MegatronModule
):
"""Transformer language model.
"""Transformer language model.
...
@@ -260,6 +255,7 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -260,6 +255,7 @@ class TransformerLanguageModel(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
will ignore this embedding
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
attention_mask_func
,
attention_mask_func
,
mlp_activation_func
,
mlp_activation_func
,
...
@@ -295,7 +291,6 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -295,7 +291,6 @@ class TransformerLanguageModel(MegatronModule):
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
self
.
_pooler_key
=
'pooler'
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
pooling_sequence_index
=
0
):
...
@@ -317,7 +312,6 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -317,7 +312,6 @@ class TransformerLanguageModel(MegatronModule):
return
transformer_output
return
transformer_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load."""
"""For easy load."""
...
@@ -336,7 +330,6 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -336,7 +330,6 @@ class TransformerLanguageModel(MegatronModule):
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Customized load."""
...
...
megatron/model/multiple_choice.py
View file @
b7f1b050
...
@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule):
...
@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule):
init_method
)
init_method
)
self
.
_multichoice_head_key
=
'multichoice_head'
self
.
_multichoice_head_key
=
'multichoice_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
):
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
# [batch, choices, sequence] --> [batch * choices, sequence] -->
...
@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule):
...
@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule):
return
multichoice_logits
return
multichoice_logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
"""For easy load when model is combined with other heads,
...
@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule):
...
@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule):
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Customized load."""
...
...
megatron/model/transformer.py
View file @
b7f1b050
...
@@ -46,6 +46,7 @@ from megatron.module import MegatronModule
...
@@ -46,6 +46,7 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask)
unmaksed-attention-scores, attention-mask)
"""
"""
class
ParallelMLP
(
MegatronModule
):
class
ParallelMLP
(
MegatronModule
):
"""MLP.
"""MLP.
...
@@ -63,7 +64,7 @@ class ParallelMLP(MegatronModule):
...
@@ -63,7 +64,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
# Project to 4h.
self
.
dense_h_to_4h
=
mpu
.
ColumnParallelLinear
(
self
.
dense_h_to_4h
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
4
*
args
.
hidden_size
,
4
*
args
.
hidden_size
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
)
init_method
=
init_method
)
...
@@ -71,14 +72,13 @@ class ParallelMLP(MegatronModule):
...
@@ -71,14 +72,13 @@ class ParallelMLP(MegatronModule):
# Project back to h.
# Project back to h.
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
4
*
args
.
hidden_size
,
4
*
args
.
hidden_size
,
args
.
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
)
init_method
=
output_layer_init_method
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
# [b, s, 4hp]
# [b, s, 4hp]
...
@@ -91,13 +91,13 @@ class ParallelMLP(MegatronModule):
...
@@ -91,13 +91,13 @@ class ParallelMLP(MegatronModule):
return
output
return
output
class
ParallelSelfAttention
(
MegatronModule
):
class
ParallelSelfAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
and returns output of the same size.
"""
"""
def
__init__
(
self
,
attention_mask_func
,
init_method
,
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
):
output_layer_init_method
,
layer_number
):
super
(
ParallelSelfAttention
,
self
).
__init__
()
super
(
ParallelSelfAttention
,
self
).
__init__
()
...
@@ -123,7 +123,7 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -123,7 +123,7 @@ class ParallelSelfAttention(MegatronModule):
# Strided linear layer.
# Strided linear layer.
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
3
*
args
.
hidden_size
,
3
*
args
.
hidden_size
,
stride
=
3
,
stride
=
3
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
)
init_method
=
init_method
)
...
@@ -141,7 +141,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -141,7 +141,6 @@ class ParallelSelfAttention(MegatronModule):
init_method
=
output_layer_init_method
)
init_method
=
output_layer_init_method
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
def
_transpose_for_scores
(
self
,
tensor
):
def
_transpose_for_scores
(
self
,
tensor
):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
size [b, np, s, hn].
...
@@ -152,7 +151,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -152,7 +151,6 @@ class ParallelSelfAttention(MegatronModule):
tensor
=
tensor
.
view
(
*
new_tensor_shape
)
tensor
=
tensor
.
view
(
*
new_tensor_shape
)
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
def
_get_query_key_value
(
self
,
hidden_states
):
def
_get_query_key_value
(
self
,
hidden_states
):
"""Get query, key, and value and transpose to
"""Get query, key, and value and transpose to
get size [b, np, s, hn].
get size [b, np, s, hn].
...
@@ -170,7 +168,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -170,7 +168,6 @@ class ParallelSelfAttention(MegatronModule):
return
query_layer
,
key_layer
,
value_layer
return
query_layer
,
key_layer
,
value_layer
def
_get_unmasked_attention_scores
(
self
,
query_layer
,
key_layer
):
def
_get_unmasked_attention_scores
(
self
,
query_layer
,
key_layer
):
"""Unmasked attention scores with size [b, np, s, s]."""
"""Unmasked attention scores with size [b, np, s, s]."""
coeff
=
1
coeff
=
1
...
@@ -179,9 +176,8 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -179,9 +176,8 @@ class ParallelSelfAttention(MegatronModule):
norm_factor
=
math
.
sqrt
(
coeff
*
norm_factor
=
math
.
sqrt
(
coeff
*
math
.
sqrt
(
self
.
hidden_size_per_attention_head
))
math
.
sqrt
(
self
.
hidden_size_per_attention_head
))
# Raw attention scores. [b, np, s, s]
# Raw attention scores. [b, np, s, s]
return
torch
.
matmul
(
query_layer
/
norm_factor
,
return
torch
.
matmul
(
query_layer
/
norm_factor
,
key_layer
.
transpose
(
-
1
,
-
2
)
/
norm_factor
)
key_layer
.
transpose
(
-
1
,
-
2
)
/
norm_factor
)
def
_get_attention_probs
(
self
,
attention_scores
):
def
_get_attention_probs
(
self
,
attention_scores
):
"""Attention probabilies with dropout. The output has
"""Attention probabilies with dropout. The output has
...
@@ -198,7 +194,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -198,7 +194,6 @@ class ParallelSelfAttention(MegatronModule):
return
attention_probs
return
attention_probs
def
_get_attended_context
(
self
,
attention_probs
,
value_layer
):
def
_get_attended_context
(
self
,
attention_probs
,
value_layer
):
"""Final attended tesnor and transposed back to [b, s, hp]."""
"""Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer.
# Context layer.
...
@@ -213,7 +208,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -213,7 +208,6 @@ class ParallelSelfAttention(MegatronModule):
return
context_layer
return
context_layer
def
_get_output
(
self
,
context_layer
):
def
_get_output
(
self
,
context_layer
):
"""Output layer with dropout."""
"""Output layer with dropout."""
# Output. [b, s, h]
# Output. [b, s, h]
...
@@ -222,7 +216,6 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -222,7 +216,6 @@ class ParallelSelfAttention(MegatronModule):
return
output
return
output
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
):
# hidden_states: [b, s, h]
# hidden_states: [b, s, h]
...
@@ -254,7 +247,7 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -254,7 +247,7 @@ class ParallelSelfAttention(MegatronModule):
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
attention_mask
=
attention_mask
[
attention_mask
=
attention_mask
[
...,
...,
attention_scores
.
size
(
3
)
-
1
,
attention_scores
.
size
(
3
)
-
1
,
:
attention_scores
.
size
(
3
)].
unsqueeze
(
2
)
:
attention_scores
.
size
(
3
)].
unsqueeze
(
2
)
else
:
else
:
attention_mask
=
attention_mask
[
attention_mask
=
attention_mask
[
...
@@ -283,13 +276,13 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -283,13 +276,13 @@ class ParallelSelfAttention(MegatronModule):
return
output
return
output
class
ParallelTransformerLayer
(
MegatronModule
):
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
output of the same size.
"""
"""
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
,
layer_number
):
init_method
,
output_layer_init_method
,
layer_number
):
args
=
get_args
()
args
=
get_args
()
...
@@ -319,7 +312,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -319,7 +312,6 @@ class ParallelTransformerLayer(MegatronModule):
self
.
mlp
=
ParallelMLP
(
mlp_activation_func
,
init_method
,
self
.
mlp
=
ParallelMLP
(
mlp_activation_func
,
init_method
,
output_layer_init_method
)
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
):
# hidden_states: [b, s, h]
# hidden_states: [b, s, h]
...
@@ -375,14 +367,13 @@ class ParallelTransformer(MegatronModule):
...
@@ -375,14 +367,13 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
# Transformer layers.
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
get_layer
(
i
+
1
)
for
i
in
range
(
args
.
num_layers
)])
[
get_layer
(
i
+
1
)
for
i
in
range
(
args
.
num_layers
)])
# Final layer norm before output.
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
):
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
):
"""Forward method with activation checkpointing."""
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom
(
start
,
end
):
...
@@ -398,13 +389,12 @@ class ParallelTransformer(MegatronModule):
...
@@ -398,13 +389,12 @@ class ParallelTransformer(MegatronModule):
num_layers
=
len
(
self
.
layers
)
num_layers
=
len
(
self
.
layers
)
while
l
<
num_layers
:
while
l
<
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
hidden_states
,
attention_mask
)
hidden_states
,
attention_mask
)
l
+=
self
.
checkpoint_num_layers
l
+=
self
.
checkpoint_num_layers
return
hidden_states
return
hidden_states
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
):
...
...
megatron/model/utils.py
View file @
b7f1b050
...
@@ -33,6 +33,7 @@ def init_method_normal(sigma):
...
@@ -33,6 +33,7 @@ def init_method_normal(sigma):
def
scaled_init_method_normal
(
sigma
,
num_layers
):
def
scaled_init_method_normal
(
sigma
,
num_layers
):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std
=
sigma
/
math
.
sqrt
(
2.0
*
num_layers
)
std
=
sigma
/
math
.
sqrt
(
2.0
*
num_layers
)
def
init_
(
tensor
):
def
init_
(
tensor
):
return
torch
.
nn
.
init
.
normal_
(
tensor
,
mean
=
0.0
,
std
=
std
)
return
torch
.
nn
.
init
.
normal_
(
tensor
,
mean
=
0.0
,
std
=
std
)
...
@@ -54,6 +55,7 @@ def gelu_impl(x):
...
@@ -54,6 +55,7 @@ def gelu_impl(x):
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
0.7978845608028654
*
x
*
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
0.7978845608028654
*
x
*
(
1.0
+
0.044715
*
x
*
x
)))
(
1.0
+
0.044715
*
x
*
x
)))
def
gelu
(
x
):
def
gelu
(
x
):
return
gelu_impl
(
x
)
return
gelu_impl
(
x
)
...
...
megatron/module.py
View file @
b7f1b050
...
@@ -21,11 +21,9 @@ import torch
...
@@ -21,11 +21,9 @@ import torch
class
MegatronModule
(
torch
.
nn
.
Module
):
class
MegatronModule
(
torch
.
nn
.
Module
):
"""Megatron specific extentions of torch Module."""
"""Megatron specific extentions of torch Module."""
def
__init__
(
self
):
def
__init__
(
self
):
super
(
MegatronModule
,
self
).
__init__
()
super
(
MegatronModule
,
self
).
__init__
()
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""Use this function to override the state dict for
"""Use this function to override the state dict for
...
...
megatron/mpu/cross_entropy.py
View file @
b7f1b050
...
@@ -72,7 +72,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -72,7 +72,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
group
=
get_model_parallel_group
())
# Loss = log(sum(exp(logits))) - predicted-logit.
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
...
...
megatron/mpu/layers.py
View file @
b7f1b050
...
@@ -89,6 +89,7 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -89,6 +89,7 @@ class VocabParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
init_method: method to initialize weights.
"""
"""
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
init_method
=
init
.
xavier_normal_
):
init_method
=
init
.
xavier_normal_
):
super
(
VocabParallelEmbedding
,
self
).
__init__
()
super
(
VocabParallelEmbedding
,
self
).
__init__
()
...
@@ -147,6 +148,7 @@ class ParallelEmbedding(torch.nn.Module):
...
@@ -147,6 +148,7 @@ class ParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
init_method: method to initialize weights.
"""
"""
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
init_method
=
init
.
xavier_normal_
,
init_method
=
init
.
xavier_normal_
,
keep_master_weight_for_test
=
False
):
keep_master_weight_for_test
=
False
):
...
@@ -205,6 +207,7 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -205,6 +207,7 @@ class ColumnParallelLinear(torch.nn.Module):
set to False. It returns the master weights
set to False. It returns the master weights
used for initialization.
used for initialization.
"""
"""
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
gather_output
=
True
,
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
gather_output
=
True
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
keep_master_weight_for_test
=
False
):
keep_master_weight_for_test
=
False
):
...
@@ -279,6 +282,7 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -279,6 +282,7 @@ class RowParallelLinear(torch.nn.Module):
set to False. It returns the master weights
set to False. It returns the master weights
used for initialization.
used for initialization.
"""
"""
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
input_is_parallel
=
False
,
input_is_parallel
=
False
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
...
@@ -327,4 +331,3 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -327,4 +331,3 @@ class RowParallelLinear(torch.nn.Module):
else
:
else
:
output
=
output_
output
=
output_
return
output
return
output
megatron/mpu/mappings.py
View file @
b7f1b050
...
@@ -131,11 +131,14 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
...
@@ -131,11 +131,14 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
def
copy_to_model_parallel_region
(
input_
):
def
copy_to_model_parallel_region
(
input_
):
return
_CopyToModelParallelRegion
.
apply
(
input_
)
return
_CopyToModelParallelRegion
.
apply
(
input_
)
def
reduce_from_model_parallel_region
(
input_
):
def
reduce_from_model_parallel_region
(
input_
):
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_model_parallel_region
(
input_
):
def
scatter_to_model_parallel_region
(
input_
):
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
def
gather_from_model_parallel_region
(
input_
):
def
gather_from_model_parallel_region
(
input_
):
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
megatron/mpu/random.py
View file @
b7f1b050
...
@@ -73,6 +73,7 @@ class CudaRNGStatesTracker:
...
@@ -73,6 +73,7 @@ class CudaRNGStatesTracker:
rng state, we can perform operations and return to our starting
rng state, we can perform operations and return to our starting
cuda state.
cuda state.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
# Map from a string name to the cuda rng state.
# Map from a string name to the cuda rng state.
self
.
states_
=
{}
self
.
states_
=
{}
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment