Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
22c0e300
Commit
22c0e300
authored
Apr 16, 2020
by
Jared Casper
Browse files
Merge branch 'staging' into 'master'
April 2020 Release See merge request ADLR/megatron-lm!69
parents
70174ae3
acf8780b
Changes
122
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1468 additions
and
75 deletions
+1468
-75
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+1
-1
megatron/mpu/cross_entropy.py
megatron/mpu/cross_entropy.py
+1
-2
megatron/mpu/data.py
megatron/mpu/data.py
+1
-1
megatron/mpu/grads.py
megatron/mpu/grads.py
+49
-5
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+1
-1
megatron/mpu/layers.py
megatron/mpu/layers.py
+6
-3
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+4
-1
megatron/mpu/random.py
megatron/mpu/random.py
+2
-1
megatron/mpu/tests/commons.py
megatron/mpu/tests/commons.py
+2
-1
megatron/mpu/tests/test_cross_entropy.py
megatron/mpu/tests/test_cross_entropy.py
+9
-11
megatron/mpu/tests/test_data.py
megatron/mpu/tests/test_data.py
+6
-10
megatron/mpu/tests/test_initialize.py
megatron/mpu/tests/test_initialize.py
+5
-8
megatron/mpu/tests/test_layers.py
megatron/mpu/tests/test_layers.py
+22
-21
megatron/mpu/tests/test_random.py
megatron/mpu/tests/test_random.py
+5
-8
megatron/mpu/utils.py
megatron/mpu/utils.py
+1
-1
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+410
-0
megatron/tokenizer/__init__.py
megatron/tokenizer/__init__.py
+17
-0
megatron/tokenizer/bert_tokenization.py
megatron/tokenizer/bert_tokenization.py
+402
-0
megatron/tokenizer/gpt2_tokenization.py
megatron/tokenizer/gpt2_tokenization.py
+321
-0
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+203
-0
No files found.
megatron/mpu/__init__.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
megatron/mpu/cross_entropy.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -72,7 +72,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
...
...
megatron/mpu/data.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
megatron/mpu/grads.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -21,10 +21,47 @@
import
torch
from
torch._six
import
inf
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_model_parallel_rank
def
l2_grad_clipper
(
parameters
,
max_norm
):
"""Efficient L2 norm gradient clipping."""
overflow_buf
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
mp_rank_is_zero
=
(
get_model_parallel_rank
()
==
0
)
parameters
=
list
(
filter
(
lambda
p
:
(
p
.
grad
is
not
None
)
and
(
p
.
model_parallel
or
mp_rank_is_zero
),
parameters
))
norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
overflow_buf
,
[
parameters
],
False
# no per-parameter norm
)
# Sum across all model parallel GPUs.
norm_2
=
norm
*
norm
torch
.
distributed
.
all_reduce
(
norm_2
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
total_norm
=
norm_2
.
item
()
**
0.5
clip_coef
=
max_norm
/
(
total_norm
+
1e-6
)
grads
=
[
p
.
grad
for
p
in
parameters
]
if
clip_coef
<
1
:
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
overflow_buf
,
[
grads
,
grads
],
clip_coef
)
return
total_norm
def
clip_grad_norm
(
parameters
,
max_norm
,
norm_type
=
2
):
"""Clips gradient norm of an iterable of parameters.
...
...
@@ -55,6 +92,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_model_parallel_group
())
total_norm
=
total_norm_cuda
[
0
].
item
()
clip_coef
=
max_norm
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
for
p
in
parameters
:
p
.
grad
.
data
.
mul_
(
clip_coef
)
elif
norm_type
==
2
:
total_norm
=
l2_grad_clipper
(
parameters
,
max_norm
)
else
:
total_norm
=
0
for
p
in
parameters
:
...
...
@@ -67,8 +111,8 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
total_norm
=
total_norm_cuda
[
0
].
item
()
**
(
1.
/
norm_type
)
clip_coef
=
max_norm
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
for
p
in
parameters
:
p
.
grad
.
data
.
mul_
(
clip_coef
)
clip_coef
=
max_norm
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
for
p
in
parameters
:
p
.
grad
.
data
.
mul_
(
clip_coef
)
return
total_norm
megatron/mpu/initialize.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
megatron/mpu/layers.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -89,6 +89,7 @@ class VocabParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
init_method
=
init
.
xavier_normal_
):
super
(
VocabParallelEmbedding
,
self
).
__init__
()
...
...
@@ -108,7 +109,7 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
num_embeddings
,
get_model_parallel_rank
(),
get_model_parallel_world_size
())
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
\
self
.
vocab_start_index
self
.
vocab_start_index
# Allocate weights.
self
.
weight
=
Parameter
(
torch
.
Tensor
(
self
.
num_embeddings_per_partition
,
...
...
@@ -147,6 +148,7 @@ class ParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
init_method
=
init
.
xavier_normal_
,
keep_master_weight_for_test
=
False
):
...
...
@@ -205,6 +207,7 @@ class ColumnParallelLinear(torch.nn.Module):
set to False. It returns the master weights
used for initialization.
"""
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
gather_output
=
True
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
keep_master_weight_for_test
=
False
):
...
...
@@ -279,6 +282,7 @@ class RowParallelLinear(torch.nn.Module):
set to False. It returns the master weights
used for initialization.
"""
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
input_is_parallel
=
False
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
...
...
@@ -327,4 +331,3 @@ class RowParallelLinear(torch.nn.Module):
else
:
output
=
output_
return
output
megatron/mpu/mappings.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -131,11 +131,14 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
def
copy_to_model_parallel_region
(
input_
):
return
_CopyToModelParallelRegion
.
apply
(
input_
)
def
reduce_from_model_parallel_region
(
input_
):
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_model_parallel_region
(
input_
):
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
def
gather_from_model_parallel_region
(
input_
):
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
megatron/mpu/random.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -73,6 +73,7 @@ class CudaRNGStatesTracker:
rng state, we can perform operations and return to our starting
cuda state.
"""
def
__init__
(
self
):
# Map from a string name to the cuda rng state.
self
.
states_
=
{}
...
...
megatron/mpu/tests/commons.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -26,6 +26,7 @@ class IdentityLayer(torch.nn.Module):
def
__init__
(
self
,
size
,
scale
=
1.0
):
super
(
IdentityLayer
,
self
).
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
scale
*
torch
.
randn
(
size
))
def
forward
(
self
):
return
self
.
weight
...
...
megatron/mpu/tests/test_cross_entropy.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
set_random_seed
from
commons
import
IdentityLayer
from
commons
import
print_separator
from
commons
import
initialize_distributed
from
mpu.cross_entropy
import
vocab_parallel_cross_entropy
import
mpu
import
torch.nn.functional
as
F
import
torch
import
random
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
torch.nn.functional
as
F
import
mpu
from
mpu.cross_entropy
import
vocab_parallel_cross_entropy
from
commons
import
initialize_distributed
from
commons
import
print_separator
from
commons
import
IdentityLayer
from
commons
import
set_random_seed
def
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
...
...
megatron/mpu/tests/test_data.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,18 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
print_separator
from
commons
import
initialize_distributed
from
mpu
import
data
as
data_utils
import
mpu
import
torch
import
functools
import
operator
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
mpu
from
mpu
import
data
as
data_utils
from
commons
import
initialize_distributed
from
commons
import
print_separator
def
test_boradcast_data
(
model_parallel_size
):
...
...
@@ -88,5 +86,3 @@ if __name__ == '__main__':
print_separator
(
'test test boradcast data'
)
test_boradcast_data
(
model_parallel_size
)
model_parallel_size
*=
2
megatron/mpu/tests/test_initialize.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
print_separator
from
commons
import
initialize_distributed
import
mpu
import
torch
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
mpu
from
commons
import
initialize_distributed
from
commons
import
print_separator
def
test_initialize_model_parallel
(
model_parallel_size
):
...
...
@@ -46,7 +44,6 @@ def test_initialize_model_parallel(model_parallel_size):
assert
rank
==
mpu
.
get_model_parallel_rank
()
check
(
mpu
.
get_model_parallel_group
(),
world_size
,
rank
)
# Data parallel.
world_size
=
torch
.
distributed
.
get_world_size
()
//
model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
//
model_parallel_size
...
...
megatron/mpu/tests/test_layers.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
mpu
import
layers
from
commons
import
set_random_seed
from
commons
import
print_separator
from
commons
import
initialize_distributed
import
mpu
from
torch.nn.parameter
import
Parameter
import
torch.nn.init
as
init
import
torch
import
random
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
torch.nn.init
as
init
from
torch.nn.parameter
import
Parameter
import
mpu
from
commons
import
initialize_distributed
from
commons
import
print_separator
from
commons
import
set_random_seed
from
mpu
import
layers
def
test_parallel_embedding
(
model_parallel_size
):
...
...
@@ -45,7 +43,7 @@ def test_parallel_embedding(model_parallel_size):
set_random_seed
(
123
)
input_data
=
torch
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
).
cuda
()
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
seq_length
,
hidden_size
]).
cuda
()
set_random_seed
(
seed
)
...
...
@@ -57,7 +55,7 @@ def test_parallel_embedding(model_parallel_size):
set_random_seed
(
seed
)
embedding_parallel
=
layers
.
ParallelEmbedding
(
vocab_size
,
hidden_size
,
init_method
=
init
.
normal_
).
cuda
()
vocab_size
,
hidden_size
,
init_method
=
init
.
normal_
).
cuda
()
output
=
embedding_parallel
(
input_data
)
loss_parallel
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
loss_parallel
.
backward
()
...
...
@@ -176,10 +174,11 @@ def test_initialize_affine_weight(model_parallel_size):
class
IdentityLayer2D
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
m
,
n
):
def
__init__
(
self
,
m
,
n
):
super
(
IdentityLayer2D
,
self
).
__init__
()
self
.
weight
=
Parameter
(
torch
.
Tensor
(
m
,
n
))
torch
.
nn
.
init
.
xavier_normal_
(
self
.
weight
)
def
forward
(
self
):
return
self
.
weight
...
...
@@ -317,10 +316,11 @@ def test_row_parallel_linear(model_parallel_size):
class
IdentityLayer3D
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
m
,
n
,
k
):
def
__init__
(
self
,
m
,
n
,
k
):
super
(
IdentityLayer3D
,
self
).
__init__
()
self
.
weight
=
Parameter
(
torch
.
Tensor
(
m
,
n
,
k
))
torch
.
nn
.
init
.
xavier_normal_
(
self
.
weight
)
def
forward
(
self
):
return
self
.
weight
...
...
@@ -335,14 +335,14 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
set_random_seed
(
seed
)
num_att_heads
=
num_att_heads_per_partition
*
\
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
get_world_size
()
hidden_size
=
hidden_size_per_att_head
*
num_att_heads
# Network
identity_layer
=
IdentityLayer3D
(
batch_size
,
sequence_length
,
hidden_size
).
cuda
()
attention_layer
=
mpu
.
BertParallelSelfAttention
(
hidden_size
,
num_att_heads
,
dropout_prob
).
cuda
()
dropout_prob
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
sequence_length
,
hidden_size
]).
cuda
()
attention_mask
=
torch
.
randn
([
batch_size
,
1
,
1
,
sequence_length
]).
cuda
()
# Forward
...
...
@@ -366,17 +366,17 @@ def test_parallel_self_attention(model_parallel_size):
num_att_heads_per_partition
=
3
hidden_size_per_att_head
=
7
dropout_prob
=
0.0
# has to be zero
dropout_prob
=
0.0
# has to be zero
batch_size
=
5
sequence_length
=
13
rank_1
,
hideen_size_1
,
model_parallel_size_1
,
loss_1
,
\
attention_layer_1
,
identity_layer_1
=
parallel_self_attention
(
attention_layer_1
,
identity_layer_1
=
parallel_self_attention
(
1
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
rank
,
hidden_size
,
model_parallel_size
,
loss
,
\
attention_layer
,
identity_layer
=
parallel_self_attention
(
attention_layer
,
identity_layer
=
parallel_self_attention
(
model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
assert
hideen_size_1
==
hidden_size
...
...
@@ -409,6 +409,7 @@ def test_parallel_self_attention(model_parallel_size):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' >> passed the test :-)'
)
def
parallel_transformer
(
model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
batch_size
,
sequence_length
):
...
...
@@ -419,7 +420,7 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
set_random_seed
(
seed
)
num_att_heads
=
num_att_heads_per_partition
*
\
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
get_world_size
()
hidden_size
=
hidden_size_per_att_head
*
num_att_heads
intermediate_size
=
4
*
hidden_size
...
...
megatron/mpu/tests/test_random.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
print_separator
from
commons
import
initialize_distributed
import
mpu
import
torch
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
mpu
from
commons
import
initialize_distributed
from
commons
import
print_separator
def
test_set_cuda_rng_state
(
model_parallel_size
):
...
...
@@ -204,4 +202,3 @@ if __name__ == '__main__':
print_separator
(
'test model parallel cuda manual seed'
)
test_model_parallel_cuda_manual_seed
(
model_parallel_size
)
model_parallel_size
*=
2
megatron/mpu/utils.py
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
generate_sample
s.py
→
megatron/text_generation_util
s.py
100755 → 100644
View file @
22c0e300
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,150 +13,98 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sample Generate GPT2
"""
"""
Utilities for generating text.
"""
import
os
import
random
import
json
import
copy
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
argparse
import
json
import
os
import
time
from
arguments
import
get_args
from
megatron.utils
import
Timers
from
megatron.utils
import
initialize_distributed
from
megatron.utils
import
set_random_seed
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
load_checkpoint
from
megatron.data_utils
import
make_tokenizer
from
configure_data
import
configure_data
from
megatron
import
mpu
from
megatron.fp16
import
FP16_Module
from
megatron.model
import
GPT2Model
from
megatron.model
import
DistributedDataParallel
as
DDP
from
megatron.utils
import
print_rank_0
def
get_model
(
args
):
"""Build the model."""
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
vocab_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
attention_dropout_prob
=
args
.
attention_dropout
,
output_dropout_prob
=
args
.
hidden_dropout
,
max_sequence_length
=
args
.
max_position_embeddings
,
checkpoint_activations
=
args
.
checkpoint_activations
,
checkpoint_num_layers
=
args
.
checkpoint_num_layers
,
parallel_output
=
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' > number of parameters on model parallel rank {}: {}'
.
format
(
mpu
.
get_model_parallel_rank
(),
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])),
flush
=
True
)
# GPU allocation.
model
.
cuda
(
torch
.
cuda
.
current_device
())
# Fp16 conversion.
if
args
.
fp16
:
model
=
FP16_Module
(
model
)
# Wrap model for distributed training.
model
=
DDP
(
model
)
return
model
def
setup_model
(
args
):
"""Setup model and optimizer."""
model
=
get_model
(
args
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
,
args
)
import
torch
import
torch.nn.functional
as
F
return
model
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.utils
import
get_ltor_masks_and_position_ids
def
get_batch
(
context_tokens
,
args
):
tokens
=
context_tokens
tokens
=
tokens
.
view
(
args
.
batch_size
,
-
1
).
contiguous
()
device
=
args
.
device
tokens
=
tokens
.
to
(
device
)
def
get_batch
(
context_tokens
):
"""Generate batch from context tokens."""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Get the masks and postition ids.
attention_mask
,
loss_mask
,
position_ids
=
get_ltor_masks_and_position_ids
(
# Move to GPU.
tokens
=
context_tokens
.
view
(
args
.
batch_size
,
-
1
).
contiguous
().
cuda
()
# Get the attention mask and postition ids.
attention_mask
,
_
,
position_ids
=
get_ltor_masks_and_position_ids
(
tokens
,
args
.
eod_token
,
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
False
)
# Fp16 conversion.
if
args
.
fp16
:
attention_mask
=
attention_mask
.
half
()
args
.
eod_mask_loss
)
return
tokens
,
attention_mask
,
position_ids
def
top_k_logits
(
logits
,
top_k
=
0
,
top_p
=
0.0
,
filter_value
=-
float
(
'Inf'
)):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
""" This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if
top_k
>
0
:
# Remove all tokens with a probability less than the last token of the top-k
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
filter_value
if
top_p
>
0.0
:
#convert to 1D
# logits=logits.view(logits.size()[1]).contiguous()
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
,
dim
=-
1
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
# Cconvert to 1D
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
,
dim
=-
1
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove
=
cumulative_probs
>
top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove
[...,
1
:]
\
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
for
i
in
range
(
sorted_indices
.
size
(
0
)):
indices_to_remove
=
sorted_indices
[
i
][
sorted_indices_to_remove
[
i
]]
logits
[
i
][
indices_to_remove
]
=
filter_value
#going back to 2D
# logits=logits.view(1, -1).contiguous()
return
logits
def
generate_samples_input_from_file
(
model
,
tokenizer
,
args
):
if
args
.
sample_input_file
==
""
:
if
mpu
.
get_model_parallel_rank
()
==
0
:
print
(
"args.sample_input_file CAN NOT BE empty!
\n
"
)
return
def
generate_samples_input_from_file
(
model
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
get_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
input_pos
=
0
if
args
.
sample_output_file
==
""
:
print
(
"Argument:
sample
-
output
-
file
can't be empty, setting it to
\n
"
)
print
(
"
\t
args.
sample
_in
put
_
file
.out"
)
args
.
sample_output_file
=
args
.
sample_input_file
+
".out"
fname_out
=
open
(
args
.
sample_output_file
,
"w+"
)
if
args
.
sample_output_file
is
None
:
sample
_
output
_
file
=
args
.
sample_input_file
+
".out"
print
(
'could not find `
sample
-out
put
-
file
`, setting '
'it to {}'
.
format
(
sample_output_file
))
fname_out
=
open
(
sample_output_file
,
"w+"
)
context_count
=
0
context_count
=
0
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
0
terminate_runs
=
0
if
mpu
.
get_model_parallel_rank
()
==
0
:
raw_text
=
all_raw_text
[
input_pos
]
...
...
@@ -167,63 +115,62 @@ def generate_samples_input_from_file(model, tokenizer, args):
if
"stop"
in
raw_text
:
terminate_runs
=
1
else
:
context_tokens
=
tokenizer
.
EncodeAsIds
(
raw_text
).
tokenization
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
context_length
=
len
(
context_tokens
)
if
context_length
>=
args
.
seq_length
//
2
:
print
(
"
\n
Context length"
,
context_length
,
\
"
\n
Please give smaller context (half of the sequence length)!"
)
if
context_length
>=
(
args
.
seq_length
//
2
):
print
(
"
\n
Context length"
,
context_length
,
"
\n
Please give smaller context (half of the "
"sequence length)!"
,
flush
=
True
)
continue
else
:
context_tokens
=
tokenizer
.
EncodeAsIds
(
"EMPTY TEXT"
)
.
tokenization
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
len
(
context_tokens
)
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
if
terminate_runs
==
1
:
return
start_time
=
time
.
time
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
],
tokenizer
,
args
)
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
if
mpu
.
get_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
DecodeIds
(
decode_tokens
)[
len
(
raw_text
):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
len
(
raw_text
):]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
fname_out
.
write
(
"
\n
Context:"
)
fname_out
.
write
(
raw_text
)
fname_out
.
write
(
"
\n\n
Megatron-LM:"
)
fname_out
.
write
(
trim_decode_tokens
)
#fname_out.write(trim_decode_tokens.replace("\n", "\n\n"))
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
context_count
+=
1
def
generate_samples_interactive
(
model
,
tokenizer
,
args
):
print_frequency
=
24
context_count
=
0
def
generate_samples_interactive
(
model
,
print_frequency
=
24
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
context_count
=
0
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
0
terminate_runs
=
0
if
mpu
.
get_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
...
...
@@ -231,85 +178,85 @@ def generate_samples_interactive(model, tokenizer, args):
while
not
raw_text
:
print
(
'Prompt should not be empty!'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
if
"stop"
in
raw_text
:
terminate_runs
=
1
else
:
context_tokens
=
tokenizer
.
EncodeAsIds
(
raw_text
).
tokenization
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
context_length
=
len
(
context_tokens
)
if
context_length
>=
args
.
seq_length
//
2
:
print
(
"
\n
Context length"
,
context_length
,
\
"
\n
Please give smaller context (half of the sequence length)!"
)
if
context_length
>=
(
args
.
seq_length
//
2
):
print
(
"
\n
Context length"
,
context_length
,
"
\n
Please give smaller context (half of the "
"sequence length)!"
,
flush
=
True
)
continue
else
:
context_tokens
=
tokenizer
.
EncodeAsIds
(
"EMPTY TEXT"
)
.
tokenization
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
len
(
context_tokens
)
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
if
terminate_runs
==
1
:
return
start_time
=
time
.
time
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
],
tokenizer
,
args
)
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
if
mpu
.
get_model_parallel_rank
()
==
0
and
counter
%
print_frequency
==
0
:
if
mpu
.
get_model_parallel_rank
()
==
0
and
\
counter
%
print_frequency
==
0
:
os
.
system
(
'clear'
)
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
DecodeIds
(
decode_tokens
)[
len
(
raw_text
):]
#print("\nGPT2:", trim_decode_tokens, flush=True)
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
len
(
raw_text
):]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
if
mpu
.
get_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
DecodeIds
(
decode_tokens
)[
len
(
raw_text
):]
#print("\nGPT2:", trim_decode_tokens, flush=True)
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
len
(
raw_text
):]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
context_count
+=
1
if
mpu
.
get_model_parallel_rank
()
==
0
:
input
(
"
\n
Press any key to continue >>>"
)
def
generate_samples_unconditional
(
model
,
tokenizer
,
args
):
def
generate_samples_unconditional
(
model
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
num_samples
=
args
.
num_samples
context_tokens
=
[[
tokenizer
.
get_command
(
'pad'
).
Id
]
for
_
in
range
(
args
.
batch_size
)]
samples
=
[]
# with open(args.genfile, 'w') as f:
context_tokens
=
[[
tokenizer
.
eod
]
for
_
in
range
(
args
.
batch_size
)]
ctr
=
0
while
True
:
start_time
=
time
.
time
()
for
token_stream
in
get_token_stream
(
model
,
copy
.
deepcopy
(
context_tokens
),
tokenizer
,
args
):
for
token_stream
in
get_token_stream
(
model
,
copy
.
deepcopy
(
context_tokens
)):
pass
# token_stream = list(get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args))
if
ctr
%
args
.
log_interval
==
0
:
print
(
'Avg s/batch:'
,
(
time
.
time
()
-
start_time
)
/
min
(
args
.
log_interval
,
ctr
+
1
))
if
ctr
%
args
.
log_interval
==
0
:
print
(
'Avg s/batch:'
,
(
time
.
time
()
-
start_time
)
/
min
(
args
.
log_interval
,
ctr
+
1
))
start_time
=
time
.
time
()
length
=
len
(
token_stream
)
token_batch
=
token_stream
[
0
].
cpu
().
numpy
().
tolist
()
length_batch
=
token_stream
[
1
].
cpu
().
numpy
().
tolist
()
for
tokens
,
length
in
zip
(
token_batch
,
length_batch
):
tokens
=
tokens
[
1
:
length
-
1
]
text
=
tokenizer
.
DecodeIds
(
tokens
)
tokens
=
tokens
[
1
:
length
-
1
]
text
=
tokenizer
.
detokenize
(
tokens
)
is_finished
=
length
<
args
.
seq_length
-
1
datum
=
{
'text'
:
text
,
'length'
:
length
-
1
,
'finished'
:
is_finished
}
datum
=
{
'text'
:
text
,
'length'
:
length
-
1
,
'finished'
:
is_finished
}
yield
datum
ctr
+=
1
if
ctr
>=
num_samples
:
...
...
@@ -317,65 +264,73 @@ def generate_samples_unconditional(model, tokenizer, args):
if
ctr
>=
num_samples
:
break
def
write_and_generate_samples_unconditional
(
model
,
tokenizer
,
args
):
def
generate_and_write_samples_unconditional
(
model
):
args
=
get_args
()
assert
args
.
genfile
is
not
None
with
open
(
args
.
genfile
,
'w'
)
as
f
:
for
datum
in
generate_samples_unconditional
(
model
,
tokenizer
,
args
):
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
for
datum
in
generate_samples_unconditional
(
model
):
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
def
pad_batch
(
batch
,
pad_id
,
args
):
def
pad_batch
(
batch
,
tokenizer
,
args
):
pad_id
=
tokenizer
.
get_command
(
'pad'
).
Id
context_lengths
=
[]
for
tokens
in
batch
:
context_length
=
len
(
tokens
)
if
context_length
<
args
.
seq_length
:
tokens
.
extend
([
pad_id
]
*
(
args
.
seq_length
-
context_length
))
tokens
.
extend
([
pad_id
]
*
(
args
.
seq_length
-
context_length
))
context_lengths
.
append
(
context_length
)
return
batch
,
context_lengths
def
get_token_stream
(
model
,
context_tokens
,
tokenizer
,
args
):
pad_id
=
tokenizer
.
get_command
(
'pad'
).
Id
# context_length = len(context_tokens)
# if context_length < args.seq_length:
# context_tokens = context_tokens + [pad_id] * (args.seq_length - context_length)
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
,
args
)
def
get_token_stream
(
model
,
context_tokens
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
.
eod
,
args
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
# context_length_tensor = torch.cuda.LongTensor([context_length])
torch
.
distributed
.
broadcast
(
context_length_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_length_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
,
args
)
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
counter
=
0
org_context_length
=
context_length
layer_past
=
None
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
,
tokenizer
,
args
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
)
for
tokens
,
lengths
in
batch_token_iterator
:
context_length
+=
1
yield
tokens
[:,
:
context_length
],
lengths
def
switch
(
val1
,
val2
,
boolean
):
boolean
=
boolean
.
type_as
(
val1
)
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
tokenizer
,
args
,
maxlen
=
None
,
type_ids
=
None
):
if
isinstance
(
model
,
DDP
):
model
=
model
.
module
if
isinstance
(
model
,
FP16_Module
):
model
=
model
.
module
original_output_parallel
=
model
.
parallel_output
model
.
parallel_output
=
False
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
maxlen
=
None
,
type_ids
=
None
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
model
.
eval
()
with
torch
.
no_grad
():
context_length
=
context_lengths
.
min
().
item
()
eos_id
=
tokenizer
.
get_command
(
'eos'
).
I
d
eos_id
=
tokenizer
.
eo
d
counter
=
0
org_context_length
=
context_length
...
...
@@ -389,12 +344,16 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
if
maxlen
>
(
org_context_length
+
args
.
out_seq_length
):
maxlen
=
org_context_length
+
args
.
out_seq_length
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
while
context_length
<=
(
maxlen
):
if
args
.
recompute
:
logits
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
type_ids
)
logits
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
type_ids
,
forward_method_parallel_output
=
False
)
logits
=
logits
[:,
context_length
-
1
,
:]
else
:
types2use
=
None
...
...
@@ -404,113 +363,48 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
:
context_length
]
else
:
tokens2use
=
tokens
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
positions2use
=
position_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
tokens2use
=
tokens
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
positions2use
=
position_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
logits
,
layer_past
=
model
(
tokens2use
,
positions2use
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
True
,
tokentype_ids
=
types2use
)
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
logits
,
layer_past
=
model
(
tokens2use
,
positions2use
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
True
,
tokentype_ids
=
types2use
,
forward_method_parallel_output
=
False
)
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
args
.
greedy
:
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
else
:
logits
=
logits
.
float
()
logits
/=
args
.
temperature
logits
=
top_k_logits
(
logits
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
)
logits
=
top_k_logits
(
logits
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
print_logits
=
[]
for
p
in
prev
:
print_logits
.
append
([
logits
[
i
,
p
].
item
()
for
i
in
range
(
batch_size
)])
print_logits
.
append
([
logits
[
i
,
p
].
item
()
for
i
in
range
(
batch_size
)])
started
=
context_lengths
<=
context_length
tokens
[:,
context_length
]
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
]
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
context_length
+=
1
counter
+=
1
done_token
=
(
prev
==
eos_id
).
byte
()
&
started
.
byte
()
just_finished
=
(
done_token
&
~
is_done
).
bool
()
lengths
[
just_finished
.
view
(
-
1
)]
=
context_length
was_done
=
is_done
is_done
=
is_done
|
done_token
done
=
torch
.
all
(
is_done
)
yield
tokens
,
lengths
if
done
:
break
model
.
parallel_output
=
original_output_parallel
def
prepare_tokenizer
(
args
):
tokenizer_args
=
{
'tokenizer_type'
:
args
.
tokenizer_type
,
'corpus'
:
None
,
'model_path'
:
args
.
tokenizer_path
,
'vocab_size'
:
args
.
vocab_size
,
'model_type'
:
args
.
tokenizer_model_type
,
'cache_dir'
:
args
.
cache_dir
}
tokenizer
=
make_tokenizer
(
**
tokenizer_args
)
args
.
tokenizer_num_tokens
=
tokenizer
.
num_tokens
args
.
tokenizer_num_type_tokens
=
tokenizer
.
num_type_tokens
args
.
eod_token
=
tokenizer
.
get_command
(
'eos'
).
Id
after
=
tokenizer
.
num_tokens
multiple
=
args
.
make_vocab_size_divisible_by
*
\
mpu
.
get_model_parallel_world_size
()
if
multiple
!=
0
:
while
(
after
%
multiple
)
!=
0
:
after
+=
1
args
.
vocab_size
=
after
print
(
"prepare tokenizer done"
,
flush
=
True
)
return
tokenizer
def
main
():
"""Main training program."""
print
(
'Generate Samples'
)
# Disable CuDNN.
torch
.
backends
.
cudnn
.
enabled
=
False
# Timer.
timers
=
Timers
()
# Arguments.
args
=
get_args
()
# Pytorch distributed.
initialize_distributed
(
args
)
# Random seeds for reproducability.
set_random_seed
(
args
.
seed
)
#get the tokenizer
tokenizer
=
prepare_tokenizer
(
args
)
# Model, optimizer, and learning rate.
model
=
setup_model
(
args
)
#setting default batch size to 1
# args.batch_size = 1
args
.
device
=
torch
.
cuda
.
current_device
()
#generate samples
if
args
.
num_samples
==
0
:
args
.
batch_size
=
1
if
args
.
sample_input_file
!=
""
:
generate_samples_input_from_file
(
model
,
tokenizer
,
args
)
else
:
generate_samples_interactive
(
model
,
tokenizer
,
args
)
else
:
write_and_generate_samples_unconditional
(
model
,
tokenizer
,
args
)
if
__name__
==
"__main__"
:
main
()
megatron/tokenizer/__init__.py
0 → 100644
View file @
22c0e300
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.tokenizer
import
build_tokenizer
megatron/tokenizer/bert_tokenization.py
0 → 100644
View file @
22c0e300
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
re
import
unicodedata
import
six
def
validate_case_matches_checkpoint
(
do_lower_case
,
init_checkpoint
):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if
not
init_checkpoint
:
return
m
=
re
.
match
(
"^.*?([A-Za-z0-9_-]+)/bert_model.ckpt"
,
init_checkpoint
)
if
m
is
None
:
return
model_name
=
m
.
group
(
1
)
lower_models
=
[
"uncased_L-24_H-1024_A-16"
,
"uncased_L-12_H-768_A-12"
,
"multilingual_L-12_H-768_A-12"
,
"chinese_L-12_H-768_A-12"
]
cased_models
=
[
"cased_L-12_H-768_A-12"
,
"cased_L-24_H-1024_A-16"
,
"multi_cased_L-12_H-768_A-12"
]
is_bad_config
=
False
if
model_name
in
lower_models
and
not
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"False"
case_name
=
"lowercased"
opposite_flag
=
"True"
if
model_name
in
cased_models
and
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"True"
case_name
=
"cased"
opposite_flag
=
"False"
if
is_bad_config
:
raise
ValueError
(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check."
%
(
actual_flag
,
init_checkpoint
,
model_name
,
case_name
,
opposite_flag
))
def
convert_to_unicode
(
text
):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
elif
isinstance
(
text
,
unicode
):
return
text
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
printable_text
(
text
):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
unicode
):
return
text
.
encode
(
"utf-8"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
load_vocab
(
vocab_file
):
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
index
=
0
with
open
(
vocab_file
,
"r"
)
as
reader
:
while
True
:
token
=
convert_to_unicode
(
reader
.
readline
())
if
not
token
:
break
token
=
token
.
strip
()
vocab
[
token
]
=
index
index
+=
1
return
vocab
def
convert_by_vocab
(
vocab
,
items
):
"""Converts a sequence of [tokens|ids] using the vocab."""
output
=
[]
for
item
in
items
:
output
.
append
(
vocab
[
item
])
return
output
def
convert_tokens_to_ids
(
vocab
,
tokens
):
return
convert_by_vocab
(
vocab
,
tokens
)
def
convert_ids_to_tokens
(
inv_vocab
,
ids
):
return
convert_by_vocab
(
inv_vocab
,
ids
)
def
whitespace_tokenize
(
text
):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
tokens
=
text
.
split
()
return
tokens
class
FullTokenizer
(
object
):
"""Runs end-to-end tokenziation."""
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
):
self
.
vocab
=
load_vocab
(
vocab_file
)
self
.
inv_vocab
=
{
v
:
k
for
k
,
v
in
self
.
vocab
.
items
()}
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
)
def
tokenize
(
self
,
text
):
split_tokens
=
[]
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
split_tokens
.
append
(
sub_token
)
return
split_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
return
convert_by_vocab
(
self
.
vocab
,
tokens
)
def
convert_ids_to_tokens
(
self
,
ids
):
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
def
vocab_size
(
self
):
return
len
(
self
.
vocab
)
class
BasicTokenizer
(
object
):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
self
,
do_lower_case
=
True
):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self
.
do_lower_case
=
do_lower_case
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text."""
text
=
convert_to_unicode
(
text
)
text
=
self
.
_clean_text
(
text
)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text
=
self
.
_tokenize_chinese_chars
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
split_tokens
=
[]
for
token
in
orig_tokens
:
if
self
.
do_lower_case
:
token
=
token
.
lower
()
token
=
self
.
_run_strip_accents
(
token
)
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
output_tokens
=
whitespace_tokenize
(
" "
.
join
(
split_tokens
))
return
output_tokens
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_tokenize_chinese_chars
(
self
,
text
):
"""Adds whitespace around any CJK character."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
self
.
_is_chinese_char
(
cp
):
output
.
append
(
" "
)
output
.
append
(
char
)
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_is_chinese_char
(
self
,
cp
):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if
((
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
or
#
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
or
#
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
or
#
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
or
#
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
or
#
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
#
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)):
#
return
True
return
False
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
class
WordpieceTokenizer
(
object
):
"""Runs WordPiece tokenziation."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
200
):
self
.
vocab
=
vocab
self
.
unk_token
=
unk_token
self
.
max_input_chars_per_word
=
max_input_chars_per_word
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text
=
convert_to_unicode
(
text
)
output_tokens
=
[]
for
token
in
whitespace_tokenize
(
text
):
chars
=
list
(
token
)
if
len
(
chars
)
>
self
.
max_input_chars_per_word
:
output_tokens
.
append
(
self
.
unk_token
)
continue
is_bad
=
False
start
=
0
sub_tokens
=
[]
while
start
<
len
(
chars
):
end
=
len
(
chars
)
cur_substr
=
None
while
start
<
end
:
substr
=
""
.
join
(
chars
[
start
:
end
])
if
start
>
0
:
substr
=
"##"
+
substr
if
substr
in
self
.
vocab
:
cur_substr
=
substr
break
end
-=
1
if
cur_substr
is
None
:
is_bad
=
True
break
sub_tokens
.
append
(
cur_substr
)
start
=
end
if
is_bad
:
output_tokens
.
append
(
self
.
unk_token
)
else
:
output_tokens
.
extend
(
sub_tokens
)
return
output_tokens
def
_is_whitespace
(
char
):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
return
True
return
False
def
_is_control
(
char
):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
in
(
"Cc"
,
"Cf"
):
return
True
return
False
def
_is_punctuation
(
char
):
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
return
True
return
False
megatron/tokenizer/gpt2_tokenization.py
0 → 100644
View file @
22c0e300
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
sys
import
json
import
logging
import
os
import
regex
as
re
from
io
import
open
try
:
from
functools
import
lru_cache
except
ImportError
:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def
lru_cache
():
return
lambda
func
:
func
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
,
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'gpt2'
:
1024
,
}
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
@
lru_cache
()
def
bytes_to_unicode
():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr
=
unichr
if
sys
.
version_info
[
0
]
==
2
else
chr
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
\
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
cs
=
bs
[:]
n
=
0
for
b
in
range
(
2
**
8
):
if
b
not
in
bs
:
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
cs
=
[
_chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
def
get_pairs
(
word
):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs
=
set
()
prev_char
=
word
[
0
]
for
char
in
word
[
1
:]:
pairs
.
add
((
prev_char
,
char
))
prev_char
=
char
return
pairs
class
GPT2Tokenizer
(
object
):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
special_tokens_file
=
None
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
if
not
os
.
path
.
exists
(
special_tokens_file
):
special_tokens_file
=
None
else
:
logger
.
info
(
"loading special tokens file {}"
.
format
(
special_tokens_file
))
# redirect to the cache, if necessary
try
:
from
.file_utils
import
cached_path
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading merges file {}"
.
format
(
merges_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
logger
.
info
(
"loading merges file {} from cache at {}"
.
format
(
merges_file
,
resolved_merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
if
special_tokens_file
and
'special_tokens'
not
in
kwargs
:
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
special_tokens
=
None
,
max_len
=
None
):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
bpe_data
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
bpe_merges
=
[
tuple
(
merge
.
split
())
for
merge
in
bpe_data
]
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
cache
=
{}
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
set_special_tokens
(
special_tokens
)
def
__len__
(
self
):
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
def
set_special_tokens
(
self
,
special_tokens
):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if
not
special_tokens
:
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
return
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
logger
.
info
(
"Special tokens {}"
.
format
(
self
.
special_tokens
))
def
bpe
(
self
,
token
):
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
word
=
tuple
(
token
)
pairs
=
get_pairs
(
word
)
if
not
pairs
:
return
token
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
new_word
=
[]
i
=
0
while
i
<
len
(
word
):
try
:
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
except
BaseException
:
new_word
.
extend
(
word
[
i
:])
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
i
+=
1
new_word
=
tuple
(
new_word
)
word
=
new_word
if
len
(
word
)
==
1
:
break
else
:
pairs
=
get_pairs
(
word
)
word
=
' '
.
join
(
word
)
self
.
cache
[
token
]
=
word
return
word
def
tokenize
(
self
,
text
):
""" Tokenize a string. """
bpe_tokens
=
[]
for
token
in
re
.
findall
(
self
.
pat
,
text
):
if
sys
.
version_info
[
0
]
==
2
:
token
=
''
.
join
(
self
.
byte_encoder
[
ord
(
b
)]
for
b
in
token
)
else
:
token
=
''
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
'utf-8'
))
bpe_tokens
.
extend
(
bpe_token
for
bpe_token
in
self
.
bpe
(
token
).
split
(
' '
))
return
bpe_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
""" Converts a sequence of tokens into ids using the vocab. """
ids
=
[]
if
isinstance
(
tokens
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
tokens
,
unicode
)):
if
tokens
in
self
.
special_tokens
:
return
self
.
special_tokens
[
tokens
]
else
:
return
self
.
encoder
.
get
(
tokens
,
0
)
for
token
in
tokens
:
if
token
in
self
.
special_tokens
:
ids
.
append
(
self
.
special_tokens
[
token
])
else
:
ids
.
append
(
self
.
encoder
.
get
(
token
,
0
))
if
len
(
ids
)
>
self
.
max_len
:
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
)
return
ids
def
convert_ids_to_tokens
(
self
,
ids
,
skip_special_tokens
=
False
):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens
=
[]
for
i
in
ids
:
if
i
in
self
.
special_tokens_decoder
:
if
not
skip_special_tokens
:
tokens
.
append
(
self
.
special_tokens_decoder
[
i
])
else
:
tokens
.
append
(
self
.
decoder
[
i
])
return
tokens
def
encode
(
self
,
text
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
def
decode
(
self
,
tokens
):
text
=
''
.
join
([
self
.
decoder
[
token
]
for
token
in
tokens
])
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
return
text
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
vocab_path
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
SPECIAL_TOKENS_NAME
)
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
u
'#version: 0.2
\n
'
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
index
=
token_index
writer
.
write
(
' '
.
join
(
bpe_tokens
)
+
u
'
\n
'
)
index
+=
1
index
=
len
(
self
.
encoder
)
with
open
(
special_tokens_file
,
'w'
,
encoding
=
'utf-8'
)
as
writer
:
for
token
,
token_index
in
sorted
(
self
.
special_tokens
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
special_tokens_file
))
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
return
vocab_file
,
merge_file
,
special_tokens_file
megatron/tokenizer/tokenizer.py
0 → 100644
View file @
22c0e300
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron tokenizers."""
from
abc
import
ABC
from
abc
import
abstractmethod
from
.bert_tokenization
import
FullTokenizer
as
FullBertTokenizer
from
.gpt2_tokenization
import
GPT2Tokenizer
def
build_tokenizer
(
args
):
"""Initialize tokenizer."""
if
args
.
rank
==
0
:
print
(
'> building {} tokenizer ...'
.
format
(
args
.
tokenizer_type
),
flush
=
True
)
# Select and instantiate the tokenizer.
assert
args
.
vocab_file
is
not
None
if
args
.
tokenizer_type
==
'BertWordPieceLowerCase'
:
tokenizer
=
_BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
True
)
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
assert
args
.
merge_file
is
not
None
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
else
:
raise
NotImplementedError
(
'{} tokenizer is not '
'implemented.'
.
format
(
args
.
tokenizer_type
))
# Add vocab size.
args
.
padded_vocab_size
=
_vocab_size_with_padding
(
tokenizer
.
vocab_size
,
args
)
return
tokenizer
def
_vocab_size_with_padding
(
orig_vocab_size
,
args
):
"""Pad vocab size so it is divisible by model parallel size and
still having GPU friendly size."""
after
=
orig_vocab_size
multiple
=
args
.
make_vocab_size_divisible_by
*
\
args
.
model_parallel_size
while
(
after
%
multiple
)
!=
0
:
after
+=
1
if
args
.
rank
==
0
:
print
(
' > padded vocab (size: {}) with {} dummy tokens '
'(new size: {})'
.
format
(
orig_vocab_size
,
after
-
orig_vocab_size
,
after
),
flush
=
True
)
return
after
class
AbstractTokenizer
(
ABC
):
"""Abstract class for tokenizer."""
def
__init__
(
self
,
name
):
self
.
name
=
name
super
().
__init__
()
@
property
@
abstractmethod
def
vocab_size
(
self
):
pass
@
property
@
abstractmethod
def
vocab
(
self
):
"""Dictionary from vocab text token to id token."""
pass
@
property
@
abstractmethod
def
inv_vocab
(
self
):
"""Dictionary from vocab id token to text token."""
pass
@
abstractmethod
def
tokenize
(
self
,
text
):
pass
def
detokenize
(
self
,
token_ids
):
raise
NotImplementedError
(
'detokenizer is not implemented for {} '
'tokenizer'
.
format
(
self
.
name
))
@
property
def
cls
(
self
):
raise
NotImplementedError
(
'CLS is not provided for {} '
'tokenizer'
.
format
(
self
.
name
))
@
property
def
sep
(
self
):
raise
NotImplementedError
(
'SEP is not provided for {} '
'tokenizer'
.
format
(
self
.
name
))
@
property
def
pad
(
self
):
raise
NotImplementedError
(
'PAD is not provided for {} '
'tokenizer'
.
format
(
self
.
name
))
@
property
def
eod
(
self
):
raise
NotImplementedError
(
'EOD is not provided for {} '
'tokenizer'
.
format
(
self
.
name
))
@
property
def
mask
(
self
):
raise
NotImplementedError
(
'MASK is not provided for {} '
'tokenizer'
.
format
(
self
.
name
))
class
_BertWordPieceTokenizer
(
AbstractTokenizer
):
"""Original BERT wordpiece tokenizer."""
def
__init__
(
self
,
vocab_file
,
lower_case
=
True
):
if
lower_case
:
name
=
'BERT Lower Case'
else
:
name
=
'BERT Upper Case'
super
().
__init__
(
name
)
self
.
tokenizer
=
FullBertTokenizer
(
vocab_file
,
do_lower_case
=
lower_case
)
self
.
cls_id
=
self
.
tokenizer
.
vocab
[
'[CLS]'
]
self
.
sep_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
]
self
.
pad_id
=
self
.
tokenizer
.
vocab
[
'[PAD]'
]
self
.
mask_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
@
property
def
vocab_size
(
self
):
return
self
.
tokenizer
.
vocab_size
()
@
property
def
vocab
(
self
):
return
self
.
tokenizer
.
vocab
@
property
def
inv_vocab
(
self
):
return
self
.
tokenizer
.
inv_vocab
def
tokenize
(
self
,
text
):
text_tokens
=
self
.
tokenizer
.
tokenize
(
text
)
return
self
.
tokenizer
.
convert_tokens_to_ids
(
text_tokens
)
@
property
def
cls
(
self
):
return
self
.
cls_id
@
property
def
sep
(
self
):
return
self
.
sep_id
@
property
def
pad
(
self
):
return
self
.
pad_id
@
property
def
mask
(
self
):
return
self
.
mask_id
class
_GPT2BPETokenizer
(
AbstractTokenizer
):
"""Original GPT2 BPE tokenizer."""
def
__init__
(
self
,
vocab_file
,
merge_file
):
name
=
'GPT2 BPE'
super
().
__init__
(
name
)
self
.
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merge_file
,
errors
=
'replace'
,
special_tokens
=
[],
max_len
=
None
)
self
.
eod_id
=
self
.
tokenizer
.
encoder
[
'<|endoftext|>'
]
@
property
def
vocab_size
(
self
):
return
len
(
self
.
tokenizer
.
encoder
)
@
property
def
vocab
(
self
):
return
self
.
tokenizer
.
encoder
@
property
def
inv_vocab
(
self
):
return
self
.
tokenizer
.
decoder
def
tokenize
(
self
,
text
):
return
self
.
tokenizer
.
encode
(
text
)
def
detokenize
(
self
,
token_ids
):
return
self
.
tokenizer
.
decode
(
token_ids
)
@
property
def
eod
(
self
):
return
self
.
eod_id
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment