Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b7f1b050
Commit
b7f1b050
authored
Apr 14, 2020
by
Neel Kant
Browse files
Lint whole repo
parent
c99fa80c
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
456 additions
and
453 deletions
+456
-453
megatron/mpu/tests/commons.py
megatron/mpu/tests/commons.py
+1
-0
megatron/mpu/tests/test_cross_entropy.py
megatron/mpu/tests/test_cross_entropy.py
+8
-10
megatron/mpu/tests/test_data.py
megatron/mpu/tests/test_data.py
+5
-9
megatron/mpu/tests/test_initialize.py
megatron/mpu/tests/test_initialize.py
+4
-7
megatron/mpu/tests/test_layers.py
megatron/mpu/tests/test_layers.py
+21
-20
megatron/mpu/tests/test_random.py
megatron/mpu/tests/test_random.py
+4
-7
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+10
-10
megatron/tokenizer/bert_tokenization.py
megatron/tokenizer/bert_tokenization.py
+325
-325
megatron/tokenizer/gpt2_tokenization.py
megatron/tokenizer/gpt2_tokenization.py
+32
-17
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+4
-3
megatron/training.py
megatron/training.py
+6
-6
megatron/utils.py
megatron/utils.py
+3
-3
tasks/data_utils.py
tasks/data_utils.py
+7
-7
tasks/ensemble_classifier.py
tasks/ensemble_classifier.py
+17
-13
tasks/finetune_utils.py
tasks/finetune_utils.py
+4
-4
tasks/glue/data.py
tasks/glue/data.py
+0
-3
tasks/glue/finetune.py
tasks/glue/finetune.py
+2
-3
tasks/glue/mnli.py
tasks/glue/mnli.py
+0
-1
tasks/glue/qqp.py
tasks/glue/qqp.py
+0
-1
tasks/main.py
tasks/main.py
+3
-4
No files found.
megatron/mpu/tests/commons.py
View file @
b7f1b050
...
...
@@ -26,6 +26,7 @@ class IdentityLayer(torch.nn.Module):
def
__init__
(
self
,
size
,
scale
=
1.0
):
super
(
IdentityLayer
,
self
).
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
scale
*
torch
.
randn
(
size
))
def
forward
(
self
):
return
self
.
weight
...
...
megatron/mpu/tests/test_cross_entropy.py
View file @
b7f1b050
...
...
@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
set_random_seed
from
commons
import
IdentityLayer
from
commons
import
print_separator
from
commons
import
initialize_distributed
from
mpu.cross_entropy
import
vocab_parallel_cross_entropy
import
mpu
import
torch.nn.functional
as
F
import
torch
import
random
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
torch.nn.functional
as
F
import
mpu
from
mpu.cross_entropy
import
vocab_parallel_cross_entropy
from
commons
import
initialize_distributed
from
commons
import
print_separator
from
commons
import
IdentityLayer
from
commons
import
set_random_seed
def
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
...
...
megatron/mpu/tests/test_data.py
View file @
b7f1b050
...
...
@@ -13,18 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
print_separator
from
commons
import
initialize_distributed
from
mpu
import
data
as
data_utils
import
mpu
import
torch
import
functools
import
operator
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
mpu
from
mpu
import
data
as
data_utils
from
commons
import
initialize_distributed
from
commons
import
print_separator
def
test_boradcast_data
(
model_parallel_size
):
...
...
@@ -88,5 +86,3 @@ if __name__ == '__main__':
print_separator
(
'test test boradcast data'
)
test_boradcast_data
(
model_parallel_size
)
model_parallel_size
*=
2
megatron/mpu/tests/test_initialize.py
View file @
b7f1b050
...
...
@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
print_separator
from
commons
import
initialize_distributed
import
mpu
import
torch
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
mpu
from
commons
import
initialize_distributed
from
commons
import
print_separator
def
test_initialize_model_parallel
(
model_parallel_size
):
...
...
@@ -46,7 +44,6 @@ def test_initialize_model_parallel(model_parallel_size):
assert
rank
==
mpu
.
get_model_parallel_rank
()
check
(
mpu
.
get_model_parallel_group
(),
world_size
,
rank
)
# Data parallel.
world_size
=
torch
.
distributed
.
get_world_size
()
//
model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
//
model_parallel_size
...
...
megatron/mpu/tests/test_layers.py
View file @
b7f1b050
...
...
@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
mpu
import
layers
from
commons
import
set_random_seed
from
commons
import
print_separator
from
commons
import
initialize_distributed
import
mpu
from
torch.nn.parameter
import
Parameter
import
torch.nn.init
as
init
import
torch
import
random
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
torch.nn.init
as
init
from
torch.nn.parameter
import
Parameter
import
mpu
from
commons
import
initialize_distributed
from
commons
import
print_separator
from
commons
import
set_random_seed
from
mpu
import
layers
def
test_parallel_embedding
(
model_parallel_size
):
...
...
@@ -45,7 +43,7 @@ def test_parallel_embedding(model_parallel_size):
set_random_seed
(
123
)
input_data
=
torch
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
).
cuda
()
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
seq_length
,
hidden_size
]).
cuda
()
set_random_seed
(
seed
)
...
...
@@ -57,7 +55,7 @@ def test_parallel_embedding(model_parallel_size):
set_random_seed
(
seed
)
embedding_parallel
=
layers
.
ParallelEmbedding
(
vocab_size
,
hidden_size
,
init_method
=
init
.
normal_
).
cuda
()
vocab_size
,
hidden_size
,
init_method
=
init
.
normal_
).
cuda
()
output
=
embedding_parallel
(
input_data
)
loss_parallel
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
loss_parallel
.
backward
()
...
...
@@ -176,10 +174,11 @@ def test_initialize_affine_weight(model_parallel_size):
class
IdentityLayer2D
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
m
,
n
):
def
__init__
(
self
,
m
,
n
):
super
(
IdentityLayer2D
,
self
).
__init__
()
self
.
weight
=
Parameter
(
torch
.
Tensor
(
m
,
n
))
torch
.
nn
.
init
.
xavier_normal_
(
self
.
weight
)
def
forward
(
self
):
return
self
.
weight
...
...
@@ -317,10 +316,11 @@ def test_row_parallel_linear(model_parallel_size):
class
IdentityLayer3D
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
m
,
n
,
k
):
def
__init__
(
self
,
m
,
n
,
k
):
super
(
IdentityLayer3D
,
self
).
__init__
()
self
.
weight
=
Parameter
(
torch
.
Tensor
(
m
,
n
,
k
))
torch
.
nn
.
init
.
xavier_normal_
(
self
.
weight
)
def
forward
(
self
):
return
self
.
weight
...
...
@@ -335,14 +335,14 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
set_random_seed
(
seed
)
num_att_heads
=
num_att_heads_per_partition
*
\
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
get_world_size
()
hidden_size
=
hidden_size_per_att_head
*
num_att_heads
# Network
identity_layer
=
IdentityLayer3D
(
batch_size
,
sequence_length
,
hidden_size
).
cuda
()
attention_layer
=
mpu
.
BertParallelSelfAttention
(
hidden_size
,
num_att_heads
,
dropout_prob
).
cuda
()
dropout_prob
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
sequence_length
,
hidden_size
]).
cuda
()
attention_mask
=
torch
.
randn
([
batch_size
,
1
,
1
,
sequence_length
]).
cuda
()
# Forward
...
...
@@ -366,17 +366,17 @@ def test_parallel_self_attention(model_parallel_size):
num_att_heads_per_partition
=
3
hidden_size_per_att_head
=
7
dropout_prob
=
0.0
# has to be zero
dropout_prob
=
0.0
# has to be zero
batch_size
=
5
sequence_length
=
13
rank_1
,
hideen_size_1
,
model_parallel_size_1
,
loss_1
,
\
attention_layer_1
,
identity_layer_1
=
parallel_self_attention
(
attention_layer_1
,
identity_layer_1
=
parallel_self_attention
(
1
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
rank
,
hidden_size
,
model_parallel_size
,
loss
,
\
attention_layer
,
identity_layer
=
parallel_self_attention
(
attention_layer
,
identity_layer
=
parallel_self_attention
(
model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
assert
hideen_size_1
==
hidden_size
...
...
@@ -409,6 +409,7 @@ def test_parallel_self_attention(model_parallel_size):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' >> passed the test :-)'
)
def
parallel_transformer
(
model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
batch_size
,
sequence_length
):
...
...
@@ -419,7 +420,7 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
set_random_seed
(
seed
)
num_att_heads
=
num_att_heads_per_partition
*
\
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
get_world_size
()
hidden_size
=
hidden_size_per_att_head
*
num_att_heads
intermediate_size
=
4
*
hidden_size
...
...
megatron/mpu/tests/test_random.py
View file @
b7f1b050
...
...
@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
print_separator
from
commons
import
initialize_distributed
import
mpu
import
torch
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
mpu
from
commons
import
initialize_distributed
from
commons
import
print_separator
def
test_set_cuda_rng_state
(
model_parallel_size
):
...
...
@@ -204,4 +202,3 @@ if __name__ == '__main__':
print_separator
(
'test model parallel cuda manual seed'
)
test_model_parallel_cuda_manual_seed
(
model_parallel_size
)
model_parallel_size
*=
2
megatron/text_generation_utils.py
View file @
b7f1b050
...
...
@@ -120,8 +120,8 @@ def generate_samples_input_from_file(model):
context_length
=
len
(
context_tokens
)
if
context_length
>=
(
args
.
seq_length
//
2
):
print
(
"
\n
Context length"
,
context_length
,
\
"
\n
Please give smaller context (half of the "
print
(
"
\n
Context length"
,
context_length
,
"
\n
Please give smaller context (half of the "
"sequence length)!"
,
flush
=
True
)
continue
else
:
...
...
@@ -187,8 +187,8 @@ def generate_samples_interactive(model, print_frequency=24):
context_length
=
len
(
context_tokens
)
if
context_length
>=
(
args
.
seq_length
//
2
):
print
(
"
\n
Context length"
,
context_length
,
\
"
\n
Please give smaller context (half of the "
print
(
"
\n
Context length"
,
context_length
,
"
\n
Please give smaller context (half of the "
"sequence length)!"
,
flush
=
True
)
continue
else
:
...
...
@@ -246,7 +246,7 @@ def generate_samples_unconditional(model):
for
token_stream
in
get_token_stream
(
model
,
copy
.
deepcopy
(
context_tokens
)):
pass
if
ctr
%
args
.
log_interval
==
0
:
if
ctr
%
args
.
log_interval
==
0
:
print
(
'Avg s/batch:'
,
(
time
.
time
()
-
start_time
)
/
min
(
args
.
log_interval
,
ctr
+
1
))
start_time
=
time
.
time
()
...
...
@@ -254,10 +254,10 @@ def generate_samples_unconditional(model):
token_batch
=
token_stream
[
0
].
cpu
().
numpy
().
tolist
()
length_batch
=
token_stream
[
1
].
cpu
().
numpy
().
tolist
()
for
tokens
,
length
in
zip
(
token_batch
,
length_batch
):
tokens
=
tokens
[
1
:
length
-
1
]
tokens
=
tokens
[
1
:
length
-
1
]
text
=
tokenizer
.
detokenize
(
tokens
)
is_finished
=
length
<
args
.
seq_length
-
1
datum
=
{
'text'
:
text
,
'length'
:
length
-
1
,
'finished'
:
is_finished
}
datum
=
{
'text'
:
text
,
'length'
:
length
-
1
,
'finished'
:
is_finished
}
yield
datum
ctr
+=
1
if
ctr
>=
num_samples
:
...
...
@@ -272,7 +272,7 @@ def generate_and_write_samples_unconditional(model):
assert
args
.
genfile
is
not
None
with
open
(
args
.
genfile
,
'w'
)
as
f
:
for
datum
in
generate_samples_unconditional
(
model
):
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
def
pad_batch
(
batch
,
pad_id
,
args
):
...
...
@@ -281,7 +281,7 @@ def pad_batch(batch, pad_id, args):
for
tokens
in
batch
:
context_length
=
len
(
tokens
)
if
context_length
<
args
.
seq_length
:
tokens
.
extend
([
pad_id
]
*
(
args
.
seq_length
-
context_length
))
tokens
.
extend
([
pad_id
]
*
(
args
.
seq_length
-
context_length
))
context_lengths
.
append
(
context_length
)
return
batch
,
context_lengths
...
...
@@ -345,7 +345,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if
maxlen
>
(
org_context_length
+
args
.
out_seq_length
):
maxlen
=
org_context_length
+
args
.
out_seq_length
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
while
context_length
<=
(
maxlen
):
...
...
megatron/tokenizer/bert_tokenization.py
View file @
b7f1b050
...
...
@@ -25,377 +25,377 @@ import six
def
validate_case_matches_checkpoint
(
do_lower_case
,
init_checkpoint
):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if
not
init_checkpoint
:
return
m
=
re
.
match
(
"^.*?([A-Za-z0-9_-]+)/bert_model.ckpt"
,
init_checkpoint
)
if
m
is
None
:
return
model_name
=
m
.
group
(
1
)
lower_models
=
[
"uncased_L-24_H-1024_A-16"
,
"uncased_L-12_H-768_A-12"
,
"multilingual_L-12_H-768_A-12"
,
"chinese_L-12_H-768_A-12"
]
cased_models
=
[
"cased_L-12_H-768_A-12"
,
"cased_L-24_H-1024_A-16"
,
"multi_cased_L-12_H-768_A-12"
]
is_bad_config
=
False
if
model_name
in
lower_models
and
not
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"False"
case_name
=
"lowercased"
opposite_flag
=
"True"
if
model_name
in
cased_models
and
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"True"
case_name
=
"cased"
opposite_flag
=
"False"
if
is_bad_config
:
raise
ValueError
(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check."
%
(
actual_flag
,
init_checkpoint
,
model_name
,
case_name
,
opposite_flag
))
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if
not
init_checkpoint
:
return
m
=
re
.
match
(
"^.*?([A-Za-z0-9_-]+)/bert_model.ckpt"
,
init_checkpoint
)
if
m
is
None
:
return
model_name
=
m
.
group
(
1
)
lower_models
=
[
"uncased_L-24_H-1024_A-16"
,
"uncased_L-12_H-768_A-12"
,
"multilingual_L-12_H-768_A-12"
,
"chinese_L-12_H-768_A-12"
]
cased_models
=
[
"cased_L-12_H-768_A-12"
,
"cased_L-24_H-1024_A-16"
,
"multi_cased_L-12_H-768_A-12"
]
is_bad_config
=
False
if
model_name
in
lower_models
and
not
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"False"
case_name
=
"lowercased"
opposite_flag
=
"True"
if
model_name
in
cased_models
and
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"True"
case_name
=
"cased"
opposite_flag
=
"False"
if
is_bad_config
:
raise
ValueError
(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check."
%
(
actual_flag
,
init_checkpoint
,
model_name
,
case_name
,
opposite_flag
))
def
convert_to_unicode
(
text
):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
elif
isinstance
(
text
,
unicode
):
return
text
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
elif
isinstance
(
text
,
unicode
):
return
text
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
printable_text
(
text
):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
unicode
):
return
text
.
encode
(
"utf-8"
)
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
unicode
):
return
text
.
encode
(
"utf-8"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
load_vocab
(
vocab_file
):
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
index
=
0
with
open
(
vocab_file
,
"r"
)
as
reader
:
while
True
:
token
=
convert_to_unicode
(
reader
.
readline
())
if
not
token
:
break
token
=
token
.
strip
()
vocab
[
token
]
=
index
index
+=
1
return
vocab
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
index
=
0
with
open
(
vocab_file
,
"r"
)
as
reader
:
while
True
:
token
=
convert_to_unicode
(
reader
.
readline
())
if
not
token
:
break
token
=
token
.
strip
()
vocab
[
token
]
=
index
index
+=
1
return
vocab
def
convert_by_vocab
(
vocab
,
items
):
"""Converts a sequence of [tokens|ids] using the vocab."""
output
=
[]
for
item
in
items
:
output
.
append
(
vocab
[
item
])
return
output
"""Converts a sequence of [tokens|ids] using the vocab."""
output
=
[]
for
item
in
items
:
output
.
append
(
vocab
[
item
])
return
output
def
convert_tokens_to_ids
(
vocab
,
tokens
):
return
convert_by_vocab
(
vocab
,
tokens
)
return
convert_by_vocab
(
vocab
,
tokens
)
def
convert_ids_to_tokens
(
inv_vocab
,
ids
):
return
convert_by_vocab
(
inv_vocab
,
ids
)
return
convert_by_vocab
(
inv_vocab
,
ids
)
def
whitespace_tokenize
(
text
):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
tokens
=
text
.
split
()
return
tokens
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
tokens
=
text
.
split
()
return
tokens
class
FullTokenizer
(
object
):
"""Runs end-to-end tokenziation."""
"""Runs end-to-end tokenziation."""
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
):
self
.
vocab
=
load_vocab
(
vocab_file
)
self
.
inv_vocab
=
{
v
:
k
for
k
,
v
in
self
.
vocab
.
items
()}
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
)
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
):
self
.
vocab
=
load_vocab
(
vocab_file
)
self
.
inv_vocab
=
{
v
:
k
for
k
,
v
in
self
.
vocab
.
items
()}
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
)
def
tokenize
(
self
,
text
):
split_tokens
=
[]
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
split_tokens
.
append
(
sub_token
)
def
tokenize
(
self
,
text
):
split_tokens
=
[]
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
split_tokens
.
append
(
sub_token
)
return
split_tokens
return
split_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
return
convert_by_vocab
(
self
.
vocab
,
tokens
)
def
convert_tokens_to_ids
(
self
,
tokens
):
return
convert_by_vocab
(
self
.
vocab
,
tokens
)
def
convert_ids_to_tokens
(
self
,
ids
):
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
def
convert_ids_to_tokens
(
self
,
ids
):
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
def
vocab_size
(
self
):
return
len
(
self
.
vocab
)
def
vocab_size
(
self
):
return
len
(
self
.
vocab
)
class
BasicTokenizer
(
object
):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
self
,
do_lower_case
=
True
):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self
.
do_lower_case
=
do_lower_case
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text."""
text
=
convert_to_unicode
(
text
)
text
=
self
.
_clean_text
(
text
)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text
=
self
.
_tokenize_chinese_chars
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
split_tokens
=
[]
for
token
in
orig_tokens
:
if
self
.
do_lower_case
:
token
=
token
.
lower
()
token
=
self
.
_run_strip_accents
(
token
)
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
output_tokens
=
whitespace_tokenize
(
" "
.
join
(
split_tokens
))
return
output_tokens
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
self
,
do_lower_case
=
True
):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self
.
do_lower_case
=
do_lower_case
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text."""
text
=
convert_to_unicode
(
text
)
text
=
self
.
_clean_text
(
text
)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text
=
self
.
_tokenize_chinese_chars
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
split_tokens
=
[]
for
token
in
orig_tokens
:
if
self
.
do_lower_case
:
token
=
token
.
lower
()
token
=
self
.
_run_strip_accents
(
token
)
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
output_tokens
=
whitespace_tokenize
(
" "
.
join
(
split_tokens
))
return
output_tokens
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_tokenize_chinese_chars
(
self
,
text
):
"""Adds whitespace around any CJK character."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
self
.
_is_chinese_char
(
cp
):
output
.
append
(
" "
)
output
.
append
(
char
)
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_is_chinese_char
(
self
,
cp
):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if
((
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
or
#
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
or
#
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
or
#
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
or
#
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
or
#
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
#
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)):
#
return
True
return
False
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_tokenize_chinese_chars
(
self
,
text
):
"""Adds whitespace around any CJK character."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
self
.
_is_chinese_char
(
cp
):
output
.
append
(
" "
)
output
.
append
(
char
)
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_is_chinese_char
(
self
,
cp
):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if
((
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
or
#
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
or
#
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
or
#
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
or
#
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
or
#
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
#
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)):
#
return
True
return
False
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
class
WordpieceTokenizer
(
object
):
"""Runs WordPiece tokenziation."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
200
):
self
.
vocab
=
vocab
self
.
unk_token
=
unk_token
self
.
max_input_chars_per_word
=
max_input_chars_per_word
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text
=
convert_to_unicode
(
text
)
output_tokens
=
[]
for
token
in
whitespace_tokenize
(
text
):
chars
=
list
(
token
)
if
len
(
chars
)
>
self
.
max_input_chars_per_word
:
output_tokens
.
append
(
self
.
unk_token
)
continue
is_bad
=
False
start
=
0
sub_tokens
=
[]
while
start
<
len
(
chars
):
end
=
len
(
chars
)
cur_substr
=
None
while
start
<
end
:
substr
=
""
.
join
(
chars
[
start
:
end
])
if
start
>
0
:
substr
=
"##"
+
substr
if
substr
in
self
.
vocab
:
cur_substr
=
substr
break
end
-=
1
if
cur_substr
is
None
:
is_bad
=
True
break
sub_tokens
.
append
(
cur_substr
)
start
=
end
if
is_bad
:
output_tokens
.
append
(
self
.
unk_token
)
else
:
output_tokens
.
extend
(
sub_tokens
)
return
output_tokens
"""Runs WordPiece tokenziation."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
200
):
self
.
vocab
=
vocab
self
.
unk_token
=
unk_token
self
.
max_input_chars_per_word
=
max_input_chars_per_word
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text
=
convert_to_unicode
(
text
)
output_tokens
=
[]
for
token
in
whitespace_tokenize
(
text
):
chars
=
list
(
token
)
if
len
(
chars
)
>
self
.
max_input_chars_per_word
:
output_tokens
.
append
(
self
.
unk_token
)
continue
is_bad
=
False
start
=
0
sub_tokens
=
[]
while
start
<
len
(
chars
):
end
=
len
(
chars
)
cur_substr
=
None
while
start
<
end
:
substr
=
""
.
join
(
chars
[
start
:
end
])
if
start
>
0
:
substr
=
"##"
+
substr
if
substr
in
self
.
vocab
:
cur_substr
=
substr
break
end
-=
1
if
cur_substr
is
None
:
is_bad
=
True
break
sub_tokens
.
append
(
cur_substr
)
start
=
end
if
is_bad
:
output_tokens
.
append
(
self
.
unk_token
)
else
:
output_tokens
.
extend
(
sub_tokens
)
return
output_tokens
def
_is_whitespace
(
char
):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
return
True
return
False
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
return
True
return
False
def
_is_control
(
char
):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
in
(
"Cc"
,
"Cf"
):
return
True
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
in
(
"Cc"
,
"Cf"
):
return
True
return
False
def
_is_punctuation
(
char
):
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
return
True
return
False
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
return
True
return
False
megatron/tokenizer/gpt2_tokenization.py
View file @
b7f1b050
...
...
@@ -29,7 +29,8 @@ try:
from
functools
import
lru_cache
except
ImportError
:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def
lru_cache
():
return
lambda
func
:
func
...
...
@@ -49,6 +50,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
@
lru_cache
()
def
bytes_to_unicode
():
"""
...
...
@@ -61,17 +63,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr
=
unichr
if
sys
.
version_info
[
0
]
==
2
else
chr
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
\
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
cs
=
bs
[:]
n
=
0
for
b
in
range
(
2
**
8
):
if
b
not
in
bs
:
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
cs
=
[
_chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
def
get_pairs
(
word
):
"""Return set of symbol pairs in a word.
...
...
@@ -84,6 +88,7 @@ def get_pairs(word):
prev_char
=
char
return
pairs
class
GPT2Tokenizer
(
object
):
"""
GPT-2 BPE tokenizer. Peculiarities:
...
...
@@ -140,23 +145,31 @@ class GPT2Tokenizer(object):
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
special_tokens
=
None
,
max_len
=
None
):
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
special_tokens
=
None
,
max_len
=
None
):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
bpe_data
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
bpe_merges
=
[
tuple
(
merge
.
split
())
for
merge
in
bpe_data
]
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
cache
=
{}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
...
...
@@ -174,8 +187,9 @@ class GPT2Tokenizer(object):
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
return
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
logger
.
info
(
"Special tokens {}"
.
format
(
self
.
special_tokens
))
def
bpe
(
self
,
token
):
...
...
@@ -188,7 +202,7 @@ class GPT2Tokenizer(object):
return
token
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
...
...
@@ -199,12 +213,12 @@ class GPT2Tokenizer(object):
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
except
:
except
BaseException
:
new_word
.
extend
(
word
[
i
:])
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
...
...
@@ -247,7 +261,8 @@ class GPT2Tokenizer(object):
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
)
return
ids
...
...
megatron/tokenizer/tokenizer.py
View file @
b7f1b050
...
...
@@ -32,7 +32,7 @@ def build_tokenizer(args):
assert
args
.
vocab_file
is
not
None
if
args
.
tokenizer_type
==
'BertWordPieceLowerCase'
:
tokenizer
=
_BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
True
)
lower_case
=
True
)
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
assert
args
.
merge_file
is
not
None
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
...
...
@@ -53,7 +53,7 @@ def _vocab_size_with_padding(orig_vocab_size, args):
after
=
orig_vocab_size
multiple
=
args
.
make_vocab_size_divisible_by
*
\
args
.
model_parallel_size
args
.
model_parallel_size
while
(
after
%
multiple
)
!=
0
:
after
+=
1
if
args
.
rank
==
0
:
...
...
@@ -134,7 +134,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
self
.
cls_id
=
self
.
tokenizer
.
vocab
[
'[CLS]'
]
self
.
sep_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
]
self
.
pad_id
=
self
.
tokenizer
.
vocab
[
'[PAD]'
]
self
.
mask_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
self
.
mask_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
@
property
def
vocab_size
(
self
):
...
...
@@ -168,6 +168,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
def
mask
(
self
):
return
self
.
mask_id
class
_GPT2BPETokenizer
(
AbstractTokenizer
):
"""Original GPT2 BPE tokenizer."""
...
...
megatron/training.py
View file @
b7f1b050
...
...
@@ -97,7 +97,6 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
)
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
...
...
@@ -174,7 +173,7 @@ def get_optimizer(model):
dynamic_loss_scale
=
args
.
dynamic_loss_scale
,
dynamic_loss_args
=
{
'scale_window'
:
args
.
loss_scale_window
,
'min_scale'
:
args
.
min_scale
,
'min_scale'
:
args
.
min_scale
,
'delayed_shift'
:
args
.
hysteresis
})
return
optimizer
...
...
@@ -297,6 +296,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Logging.
timers_to_log
=
[]
def
add_to_logging
(
name
):
if
name
in
timers
.
timers
:
timers_to_log
.
append
(
name
)
...
...
@@ -431,7 +431,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
# Reduce across processes.
for
key
in
loss_dict
:
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
0.
)
+
\
loss_dict
[
key
]
loss_dict
[
key
]
# Move model back to the train mode.
model
.
train
()
...
...
@@ -521,14 +521,14 @@ def build_train_valid_test_data_iterators(
# Shift the start iterations.
if
train_dataloader
is
not
None
:
train_dataloader
.
batch_sampler
.
start_iter
=
args
.
iteration
%
\
len
(
train_dataloader
)
len
(
train_dataloader
)
print_rank_0
(
'setting training data start iteration to {}'
.
format
(
train_dataloader
.
batch_sampler
.
start_iter
))
if
valid_dataloader
is
not
None
:
start_iter_val
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
eval_iters
args
.
eval_iters
valid_dataloader
.
batch_sampler
.
start_iter
=
start_iter_val
%
\
len
(
valid_dataloader
)
len
(
valid_dataloader
)
print_rank_0
(
'setting validation data start iteration to {}'
.
format
(
valid_dataloader
.
batch_sampler
.
start_iter
))
...
...
megatron/utils.py
View file @
b7f1b050
...
...
@@ -48,7 +48,7 @@ def report_memory(name):
torch
.
cuda
.
max_memory_allocated
()
/
mega_bytes
)
string
+=
' | cached: {}'
.
format
(
torch
.
cuda
.
memory_cached
()
/
mega_bytes
)
string
+=
' | max cached: {}'
.
format
(
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
print_rank_0
(
string
)
...
...
@@ -164,10 +164,10 @@ def get_ltor_masks_and_position_ids(data,
i
=
eod_index
[
j
]
# Mask attention loss.
if
reset_attention_mask
:
attention_mask
[
b
,
0
,
(
i
+
1
):,
:(
i
+
1
)]
=
0
attention_mask
[
b
,
0
,
(
i
+
1
):,
:(
i
+
1
)]
=
0
# Reset positions.
if
reset_position_ids
:
position_ids
[
b
,
(
i
+
1
):]
-=
(
i
+
1
-
prev_index
)
position_ids
[
b
,
(
i
+
1
):]
-=
(
i
+
1
-
prev_index
)
prev_index
=
i
+
1
# Convert
...
...
tasks/data_utils.py
View file @
b7f1b050
...
...
@@ -75,8 +75,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
# A.
len_text_a
=
len
(
text_a_ids
)
ids
.
extend
(
text_a_ids
)
types
.
extend
([
0
]
*
len_text_a
)
paddings
.
extend
([
1
]
*
len_text_a
)
types
.
extend
([
0
]
*
len_text_a
)
paddings
.
extend
([
1
]
*
len_text_a
)
# [SEP].
ids
.
append
(
sep_id
)
...
...
@@ -87,8 +87,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
if
text_b_ids
is
not
None
:
len_text_b
=
len
(
text_b_ids
)
ids
.
extend
(
text_b_ids
)
types
.
extend
([
1
]
*
len_text_b
)
paddings
.
extend
([
1
]
*
len_text_b
)
types
.
extend
([
1
]
*
len_text_b
)
paddings
.
extend
([
1
]
*
len_text_b
)
# Cap the size.
trimmed
=
False
...
...
@@ -111,8 +111,8 @@ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
# Padding.
padding_length
=
max_seq_length
-
len
(
ids
)
if
padding_length
>
0
:
ids
.
extend
([
pad_id
]
*
padding_length
)
types
.
extend
([
pad_id
]
*
padding_length
)
paddings
.
extend
([
0
]
*
padding_length
)
ids
.
extend
([
pad_id
]
*
padding_length
)
types
.
extend
([
pad_id
]
*
padding_length
)
paddings
.
extend
([
0
]
*
padding_length
)
return
ids
,
types
,
paddings
tasks/ensemble_classifier.py
View file @
b7f1b050
...
...
@@ -5,6 +5,7 @@ import collections
import
numpy
as
np
import
torch
def
process_files
(
args
):
all_predictions
=
collections
.
OrderedDict
()
all_labels
=
collections
.
OrderedDict
()
...
...
@@ -40,12 +41,12 @@ def get_threshold(all_predictions, all_labels, one_threshold=False):
for
dataset
in
all_predictions
:
preds
=
all_predictions
[
dataset
]
labels
=
all_labels
[
dataset
]
out_thresh
.
append
(
calc_threshold
(
preds
,
labels
))
out_thresh
.
append
(
calc_threshold
(
preds
,
labels
))
return
out_thresh
def
calc_threshold
(
p
,
l
):
trials
=
[(
i
)
*
(
1.
/
100.
)
for
i
in
range
(
100
)]
trials
=
[(
i
)
*
(
1.
/
100.
)
for
i
in
range
(
100
)]
best_acc
=
float
(
'-inf'
)
best_thresh
=
0
for
t
in
trials
:
...
...
@@ -58,7 +59,7 @@ def calc_threshold(p, l):
def
apply_threshold
(
preds
,
t
):
assert
(
np
.
allclose
(
preds
.
sum
(
-
1
),
np
.
ones
(
preds
.
shape
[
0
])))
prob
=
preds
[:,
-
1
]
prob
=
preds
[:,
-
1
]
thresholded
=
(
prob
>=
t
).
astype
(
int
)
preds
=
np
.
zeros_like
(
preds
)
preds
[
np
.
arange
(
len
(
thresholded
)),
thresholded
.
reshape
(
-
1
)]
=
1
...
...
@@ -66,8 +67,8 @@ def apply_threshold(preds, t):
def
threshold_predictions
(
all_predictions
,
threshold
):
if
len
(
threshold
)
!=
len
(
all_predictions
):
threshold
=
[
threshold
[
-
1
]]
*
(
len
(
all_predictions
)
-
len
(
threshold
))
if
len
(
threshold
)
!=
len
(
all_predictions
):
threshold
=
[
threshold
[
-
1
]]
*
(
len
(
all_predictions
)
-
len
(
threshold
))
for
i
,
dataset
in
enumerate
(
all_predictions
):
thresh
=
threshold
[
i
]
preds
=
all_predictions
[
dataset
]
...
...
@@ -77,7 +78,7 @@ def threshold_predictions(all_predictions, threshold):
def
postprocess_predictions
(
all_predictions
,
all_labels
,
args
):
for
d
in
all_predictions
:
all_predictions
[
d
]
=
all_predictions
[
d
]
/
len
(
args
.
paths
)
all_predictions
[
d
]
=
all_predictions
[
d
]
/
len
(
args
.
paths
)
if
args
.
calc_threshold
:
args
.
threshold
=
get_threshold
(
all_predictions
,
all_labels
,
args
.
one_threshold
)
...
...
@@ -98,19 +99,22 @@ def write_predictions(all_predictions, all_labels, all_uid, args):
if
args
.
eval
:
correct
=
(
preds
==
all_labels
[
dataset
]).
sum
()
num
=
len
(
all_labels
[
dataset
])
accuracy
=
correct
/
num
accuracy
=
correct
/
num
count
+=
num
all_correct
+=
correct
accuracy
=
(
preds
==
all_labels
[
dataset
]).
mean
()
print
(
accuracy
)
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
outdir
,
dataset
)):
os
.
makedirs
(
os
.
path
.
join
(
args
.
outdir
,
dataset
))
outpath
=
os
.
path
.
join
(
args
.
outdir
,
dataset
,
os
.
path
.
splitext
(
args
.
prediction_name
)[
0
]
+
'.tsv'
)
outpath
=
os
.
path
.
join
(
args
.
outdir
,
dataset
,
os
.
path
.
splitext
(
args
.
prediction_name
)[
0
]
+
'.tsv'
)
with
open
(
outpath
,
'w'
)
as
f
:
f
.
write
(
'id
\t
label
\n
'
)
f
.
write
(
'
\n
'
.
join
(
str
(
uid
)
+
'
\t
'
+
str
(
args
.
labels
[
p
])
for
uid
,
p
in
zip
(
all_uid
[
dataset
],
preds
.
tolist
())))
f
.
write
(
'
\n
'
.
join
(
str
(
uid
)
+
'
\t
'
+
str
(
args
.
labels
[
p
])
for
uid
,
p
in
zip
(
all_uid
[
dataset
],
preds
.
tolist
())))
if
args
.
eval
:
print
(
all_correct
/
count
)
print
(
all_correct
/
count
)
def
ensemble_predictions
(
args
):
...
...
@@ -119,7 +123,7 @@ def ensemble_predictions(args):
write_predictions
(
all_predictions
,
all_labels
,
all_uid
,
args
)
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--paths'
,
required
=
True
,
nargs
=
'+'
,
help
=
'paths to checkpoint directories used in ensemble'
)
...
...
@@ -135,11 +139,11 @@ def main():
help
=
'use on threshold for all subdatasets'
)
parser
.
add_argument
(
'--threshold'
,
nargs
=
'+'
,
default
=
None
,
type
=
float
,
help
=
'user supplied threshold for classification'
)
parser
.
add_argument
(
'--labels'
,
nargs
=
'+'
,
default
=
None
,
parser
.
add_argument
(
'--labels'
,
nargs
=
'+'
,
default
=
None
,
help
=
'whitespace separated list of label names'
)
args
=
parser
.
parse_args
()
ensemble_predictions
(
args
)
if
__name__
==
'__main__'
:
main
()
\ No newline at end of file
main
()
tasks/finetune_utils.py
View file @
b7f1b050
...
...
@@ -21,7 +21,7 @@ from megatron import get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
setup_model_and_optimizer
...
...
@@ -53,7 +53,7 @@ def _cross_entropy_forward_step(batch, model):
timers
(
'batch generator'
).
start
()
try
:
batch_
=
next
(
batch
)
except
:
except
BaseException
:
batch_
=
batch
tokens
,
types
,
labels
,
attention_mask
=
process_batch
(
batch_
)
timers
(
'batch generator'
).
stop
()
...
...
@@ -146,7 +146,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# For each remaining epoch
timers
(
'interval time'
).
start
()
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
print_rank_0
(
'working on epoch {} ...'
.
format
(
epoch
+
1
))
print_rank_0
(
'working on epoch {} ...'
.
format
(
epoch
+
1
))
# Set the data loader epoch to shuffle the index iterator.
train_dataloader
.
sampler
.
set_epoch
(
args
.
seed
+
epoch
)
...
...
@@ -172,7 +172,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
report_memory_flag
)
# Autoresume
if
args
.
adlr_autoresume
and
\
if
args
.
adlr_autoresume
and
\
(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
...
tasks/glue/data.py
View file @
b7f1b050
...
...
@@ -48,11 +48,9 @@ class GLUEAbstractDataset(ABC, Dataset):
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
raw_sample
=
self
.
samples
[
idx
]
ids
,
types
,
paddings
=
build_tokens_types_paddings_from_text
(
...
...
@@ -62,7 +60,6 @@ class GLUEAbstractDataset(ABC, Dataset):
raw_sample
[
'label'
],
raw_sample
[
'uid'
])
return
sample
@
abstractmethod
def
process_samples_from_single_path
(
self
,
datapath
):
"""Abstract method that takes a single path / filename and
...
...
tasks/glue/finetune.py
View file @
b7f1b050
...
...
@@ -38,7 +38,6 @@ def glue_classification(num_classes, Dataset,
return
train_dataset
,
valid_dataset
def
model_provider
():
"""Build the model."""
args
=
get_args
()
...
...
@@ -48,7 +47,6 @@ def glue_classification(num_classes, Dataset,
return
Classification
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
def
metrics_func_provider
():
"""Privde metrics callback function."""
def
single_dataset_provider
(
datapath
):
...
...
@@ -59,7 +57,6 @@ def glue_classification(num_classes, Dataset,
return
Dataset
(
name
,
[
datapath
],
tokenizer
,
args
.
seq_length
)
return
accuracy_func_provider
(
single_dataset_provider
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
end_of_epoch_callback_provider
=
metrics_func_provider
)
...
...
@@ -72,6 +69,7 @@ def main():
num_classes
=
3
from
tasks.glue.mnli
import
MNLIDataset
as
Dataset
def
name_from_datapath
(
datapath
):
return
datapath
.
split
(
'MNLI'
)[
-
1
].
strip
(
'.tsv'
).
strip
(
'/'
).
replace
(
'_'
,
'-'
)
...
...
@@ -80,6 +78,7 @@ def main():
num_classes
=
2
from
tasks.glue.qqp
import
QQPDataset
as
Dataset
def
name_from_datapath
(
datapath
):
return
datapath
.
split
(
'QQP'
)[
-
1
].
strip
(
'.tsv'
).
strip
(
'/'
).
replace
(
'_'
,
'-'
)
...
...
tasks/glue/mnli.py
View file @
b7f1b050
...
...
@@ -31,7 +31,6 @@ class MNLIDataset(GLUEAbstractDataset):
super
().
__init__
(
'MNLI'
,
name
,
datapaths
,
tokenizer
,
max_seq_length
)
def
process_samples_from_single_path
(
self
,
filename
):
""""Implement abstract method."""
print_rank_0
(
' > Processing {} ...'
.
format
(
filename
))
...
...
tasks/glue/qqp.py
View file @
b7f1b050
...
...
@@ -31,7 +31,6 @@ class QQPDataset(GLUEAbstractDataset):
super
().
__init__
(
'QQP'
,
name
,
datapaths
,
tokenizer
,
max_seq_length
)
def
process_samples_from_single_path
(
self
,
filename
):
""""Implement abstract method."""
print_rank_0
(
' > Processing {} ...'
.
format
(
filename
))
...
...
tasks/main.py
View file @
b7f1b050
...
...
@@ -15,14 +15,13 @@
"""Main tasks functionality."""
from
megatron.initialize
import
initialize_megatron
from
megatron
import
get_args
import
os
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
)))
from
megatron
import
get_args
from
megatron.initialize
import
initialize_megatron
def
get_tasks_args
(
parser
):
"""Provide extra arguments required for tasks."""
...
...
@@ -46,7 +45,7 @@ def get_tasks_args(parser):
group
.
add_argument
(
'--overlapping-eval'
,
type
=
int
,
default
=
32
,
help
=
'Sliding window for overlapping evaluation.'
)
group
.
add_argument
(
'--strict-lambada'
,
action
=
'store_true'
,
help
=
'Use more difficult formulation of lambada.'
)
help
=
'Use more difficult formulation of lambada.'
)
return
parser
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment