Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5e56e563
Commit
5e56e563
authored
Apr 28, 2020
by
Neel Kant
Browse files
Merge master into realm-mlm
parents
6c0a5bd8
569b3dab
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
498 additions
and
434 deletions
+498
-434
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+1
-1
megatron/mpu/cross_entropy.py
megatron/mpu/cross_entropy.py
+1
-2
megatron/mpu/data.py
megatron/mpu/data.py
+1
-1
megatron/mpu/grads.py
megatron/mpu/grads.py
+54
-5
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+1
-1
megatron/mpu/layers.py
megatron/mpu/layers.py
+6
-3
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+4
-1
megatron/mpu/random.py
megatron/mpu/random.py
+2
-1
megatron/mpu/tests/commons.py
megatron/mpu/tests/commons.py
+2
-1
megatron/mpu/tests/test_cross_entropy.py
megatron/mpu/tests/test_cross_entropy.py
+9
-11
megatron/mpu/tests/test_data.py
megatron/mpu/tests/test_data.py
+6
-10
megatron/mpu/tests/test_initialize.py
megatron/mpu/tests/test_initialize.py
+5
-8
megatron/mpu/tests/test_layers.py
megatron/mpu/tests/test_layers.py
+22
-21
megatron/mpu/tests/test_random.py
megatron/mpu/tests/test_random.py
+5
-8
megatron/mpu/utils.py
megatron/mpu/utils.py
+1
-1
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+12
-13
megatron/tokenizer/__init__.py
megatron/tokenizer/__init__.py
+1
-1
megatron/tokenizer/bert_tokenization.py
megatron/tokenizer/bert_tokenization.py
+326
-325
megatron/tokenizer/gpt2_tokenization.py
megatron/tokenizer/gpt2_tokenization.py
+32
-17
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+7
-3
No files found.
megatron/mpu/__init__.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
megatron/mpu/cross_entropy.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -72,7 +72,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
...
...
megatron/mpu/data.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
megatron/mpu/grads.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -21,10 +21,52 @@
import
torch
from
torch._six
import
inf
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_model_parallel_rank
def
l2_grad_clipper
(
parameters
,
max_norm
):
"""Efficient L2 norm gradient clipping."""
overflow_buf
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
# Make sure we have an iterable.
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
# Filter parameters with gradients.
parameters_with_grads
=
list
(
filter
(
lambda
p
:
p
.
grad
is
not
None
,
parameters
))
# Filter parameters for norm calculations.
mp_rank_is_zero
=
(
get_model_parallel_rank
()
==
0
)
parameters_for_norm
=
list
(
filter
(
lambda
p
:
p
.
model_parallel
or
mp_rank_is_zero
,
parameters_with_grads
))
# Calculate L2 norm.
norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
overflow_buf
,
[
parameters_for_norm
],
False
# no per-parameter norm
)
# Sum across all model parallel GPUs.
norm_2
=
norm
*
norm
torch
.
distributed
.
all_reduce
(
norm_2
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
total_norm
=
norm_2
.
item
()
**
0.5
# Scale to get max_norm.
clip_coef
=
float
(
max_norm
)
/
(
total_norm
+
1.0e-6
)
grads
=
[
p
.
grad
for
p
in
parameters_with_grads
]
if
clip_coef
<
1.0
:
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
overflow_buf
,
[
grads
,
grads
],
clip_coef
)
return
total_norm
def
clip_grad_norm
(
parameters
,
max_norm
,
norm_type
=
2
):
"""Clips gradient norm of an iterable of parameters.
...
...
@@ -55,6 +97,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_model_parallel_group
())
total_norm
=
total_norm_cuda
[
0
].
item
()
clip_coef
=
max_norm
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
for
p
in
parameters
:
p
.
grad
.
data
.
mul_
(
clip_coef
)
#elif norm_type == 2:
# total_norm = l2_grad_clipper(parameters, max_norm)
else
:
total_norm
=
0
for
p
in
parameters
:
...
...
@@ -67,8 +116,8 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
total_norm
=
total_norm_cuda
[
0
].
item
()
**
(
1.
/
norm_type
)
clip_coef
=
max_norm
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
for
p
in
parameters
:
p
.
grad
.
data
.
mul_
(
clip_coef
)
clip_coef
=
max_norm
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
for
p
in
parameters
:
p
.
grad
.
data
.
mul_
(
clip_coef
)
return
total_norm
megatron/mpu/initialize.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
megatron/mpu/layers.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -89,6 +89,7 @@ class VocabParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
init_method
=
init
.
xavier_normal_
):
super
(
VocabParallelEmbedding
,
self
).
__init__
()
...
...
@@ -108,7 +109,7 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
num_embeddings
,
get_model_parallel_rank
(),
get_model_parallel_world_size
())
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
\
self
.
vocab_start_index
self
.
vocab_start_index
# Allocate weights.
self
.
weight
=
Parameter
(
torch
.
Tensor
(
self
.
num_embeddings_per_partition
,
...
...
@@ -147,6 +148,7 @@ class ParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
init_method
=
init
.
xavier_normal_
,
keep_master_weight_for_test
=
False
):
...
...
@@ -205,6 +207,7 @@ class ColumnParallelLinear(torch.nn.Module):
set to False. It returns the master weights
used for initialization.
"""
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
gather_output
=
True
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
keep_master_weight_for_test
=
False
):
...
...
@@ -279,6 +282,7 @@ class RowParallelLinear(torch.nn.Module):
set to False. It returns the master weights
used for initialization.
"""
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
input_is_parallel
=
False
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
...
...
@@ -327,4 +331,3 @@ class RowParallelLinear(torch.nn.Module):
else
:
output
=
output_
return
output
megatron/mpu/mappings.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -131,11 +131,14 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
def
copy_to_model_parallel_region
(
input_
):
return
_CopyToModelParallelRegion
.
apply
(
input_
)
def
reduce_from_model_parallel_region
(
input_
):
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_model_parallel_region
(
input_
):
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
def
gather_from_model_parallel_region
(
input_
):
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
megatron/mpu/random.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -73,6 +73,7 @@ class CudaRNGStatesTracker:
rng state, we can perform operations and return to our starting
cuda state.
"""
def
__init__
(
self
):
# Map from a string name to the cuda rng state.
self
.
states_
=
{}
...
...
megatron/mpu/tests/commons.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -26,6 +26,7 @@ class IdentityLayer(torch.nn.Module):
def
__init__
(
self
,
size
,
scale
=
1.0
):
super
(
IdentityLayer
,
self
).
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
scale
*
torch
.
randn
(
size
))
def
forward
(
self
):
return
self
.
weight
...
...
megatron/mpu/tests/test_cross_entropy.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
set_random_seed
from
commons
import
IdentityLayer
from
commons
import
print_separator
from
commons
import
initialize_distributed
from
mpu.cross_entropy
import
vocab_parallel_cross_entropy
import
mpu
import
torch.nn.functional
as
F
import
torch
import
random
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
torch.nn.functional
as
F
import
mpu
from
mpu.cross_entropy
import
vocab_parallel_cross_entropy
from
commons
import
initialize_distributed
from
commons
import
print_separator
from
commons
import
IdentityLayer
from
commons
import
set_random_seed
def
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
...
...
megatron/mpu/tests/test_data.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,18 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
print_separator
from
commons
import
initialize_distributed
from
mpu
import
data
as
data_utils
import
mpu
import
torch
import
functools
import
operator
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
mpu
from
mpu
import
data
as
data_utils
from
commons
import
initialize_distributed
from
commons
import
print_separator
def
test_boradcast_data
(
model_parallel_size
):
...
...
@@ -88,5 +86,3 @@ if __name__ == '__main__':
print_separator
(
'test test boradcast data'
)
test_boradcast_data
(
model_parallel_size
)
model_parallel_size
*=
2
megatron/mpu/tests/test_initialize.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
print_separator
from
commons
import
initialize_distributed
import
mpu
import
torch
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
mpu
from
commons
import
initialize_distributed
from
commons
import
print_separator
def
test_initialize_model_parallel
(
model_parallel_size
):
...
...
@@ -46,7 +44,6 @@ def test_initialize_model_parallel(model_parallel_size):
assert
rank
==
mpu
.
get_model_parallel_rank
()
check
(
mpu
.
get_model_parallel_group
(),
world_size
,
rank
)
# Data parallel.
world_size
=
torch
.
distributed
.
get_world_size
()
//
model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
//
model_parallel_size
...
...
megatron/mpu/tests/test_layers.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
mpu
import
layers
from
commons
import
set_random_seed
from
commons
import
print_separator
from
commons
import
initialize_distributed
import
mpu
from
torch.nn.parameter
import
Parameter
import
torch.nn.init
as
init
import
torch
import
random
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
torch.nn.init
as
init
from
torch.nn.parameter
import
Parameter
import
mpu
from
commons
import
initialize_distributed
from
commons
import
print_separator
from
commons
import
set_random_seed
from
mpu
import
layers
def
test_parallel_embedding
(
model_parallel_size
):
...
...
@@ -45,7 +43,7 @@ def test_parallel_embedding(model_parallel_size):
set_random_seed
(
123
)
input_data
=
torch
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
).
cuda
()
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
seq_length
,
hidden_size
]).
cuda
()
set_random_seed
(
seed
)
...
...
@@ -57,7 +55,7 @@ def test_parallel_embedding(model_parallel_size):
set_random_seed
(
seed
)
embedding_parallel
=
layers
.
ParallelEmbedding
(
vocab_size
,
hidden_size
,
init_method
=
init
.
normal_
).
cuda
()
vocab_size
,
hidden_size
,
init_method
=
init
.
normal_
).
cuda
()
output
=
embedding_parallel
(
input_data
)
loss_parallel
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
loss_parallel
.
backward
()
...
...
@@ -176,10 +174,11 @@ def test_initialize_affine_weight(model_parallel_size):
class
IdentityLayer2D
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
m
,
n
):
def
__init__
(
self
,
m
,
n
):
super
(
IdentityLayer2D
,
self
).
__init__
()
self
.
weight
=
Parameter
(
torch
.
Tensor
(
m
,
n
))
torch
.
nn
.
init
.
xavier_normal_
(
self
.
weight
)
def
forward
(
self
):
return
self
.
weight
...
...
@@ -317,10 +316,11 @@ def test_row_parallel_linear(model_parallel_size):
class
IdentityLayer3D
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
m
,
n
,
k
):
def
__init__
(
self
,
m
,
n
,
k
):
super
(
IdentityLayer3D
,
self
).
__init__
()
self
.
weight
=
Parameter
(
torch
.
Tensor
(
m
,
n
,
k
))
torch
.
nn
.
init
.
xavier_normal_
(
self
.
weight
)
def
forward
(
self
):
return
self
.
weight
...
...
@@ -335,14 +335,14 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
set_random_seed
(
seed
)
num_att_heads
=
num_att_heads_per_partition
*
\
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
get_world_size
()
hidden_size
=
hidden_size_per_att_head
*
num_att_heads
# Network
identity_layer
=
IdentityLayer3D
(
batch_size
,
sequence_length
,
hidden_size
).
cuda
()
attention_layer
=
mpu
.
BertParallelSelfAttention
(
hidden_size
,
num_att_heads
,
dropout_prob
).
cuda
()
dropout_prob
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
sequence_length
,
hidden_size
]).
cuda
()
attention_mask
=
torch
.
randn
([
batch_size
,
1
,
1
,
sequence_length
]).
cuda
()
# Forward
...
...
@@ -366,17 +366,17 @@ def test_parallel_self_attention(model_parallel_size):
num_att_heads_per_partition
=
3
hidden_size_per_att_head
=
7
dropout_prob
=
0.0
# has to be zero
dropout_prob
=
0.0
# has to be zero
batch_size
=
5
sequence_length
=
13
rank_1
,
hideen_size_1
,
model_parallel_size_1
,
loss_1
,
\
attention_layer_1
,
identity_layer_1
=
parallel_self_attention
(
attention_layer_1
,
identity_layer_1
=
parallel_self_attention
(
1
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
rank
,
hidden_size
,
model_parallel_size
,
loss
,
\
attention_layer
,
identity_layer
=
parallel_self_attention
(
attention_layer
,
identity_layer
=
parallel_self_attention
(
model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
assert
hideen_size_1
==
hidden_size
...
...
@@ -409,6 +409,7 @@ def test_parallel_self_attention(model_parallel_size):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' >> passed the test :-)'
)
def
parallel_transformer
(
model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
batch_size
,
sequence_length
):
...
...
@@ -419,7 +420,7 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
set_random_seed
(
seed
)
num_att_heads
=
num_att_heads_per_partition
*
\
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
get_world_size
()
hidden_size
=
hidden_size_per_att_head
*
num_att_heads
intermediate_size
=
4
*
hidden_size
...
...
megatron/mpu/tests/test_random.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
commons
import
print_separator
from
commons
import
initialize_distributed
import
mpu
import
torch
import
sys
sys
.
path
.
append
(
"../.."
)
import
torch
import
mpu
from
commons
import
initialize_distributed
from
commons
import
print_separator
def
test_set_cuda_rng_state
(
model_parallel_size
):
...
...
@@ -204,4 +202,3 @@ if __name__ == '__main__':
print_separator
(
'test model parallel cuda manual seed'
)
test_model_parallel_cuda_manual_seed
(
model_parallel_size
)
model_parallel_size
*=
2
megatron/mpu/utils.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
megatron/text_generation_utils.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -42,8 +42,7 @@ def get_batch(context_tokens):
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
args
.
eod_mask_loss
,
args
.
fp16
)
args
.
eod_mask_loss
)
return
tokens
,
attention_mask
,
position_ids
...
...
@@ -120,8 +119,8 @@ def generate_samples_input_from_file(model):
context_length
=
len
(
context_tokens
)
if
context_length
>=
(
args
.
seq_length
//
2
):
print
(
"
\n
Context length"
,
context_length
,
\
"
\n
Please give smaller context (half of the "
print
(
"
\n
Context length"
,
context_length
,
"
\n
Please give smaller context (half of the "
"sequence length)!"
,
flush
=
True
)
continue
else
:
...
...
@@ -187,8 +186,8 @@ def generate_samples_interactive(model, print_frequency=24):
context_length
=
len
(
context_tokens
)
if
context_length
>=
(
args
.
seq_length
//
2
):
print
(
"
\n
Context length"
,
context_length
,
\
"
\n
Please give smaller context (half of the "
print
(
"
\n
Context length"
,
context_length
,
"
\n
Please give smaller context (half of the "
"sequence length)!"
,
flush
=
True
)
continue
else
:
...
...
@@ -246,7 +245,7 @@ def generate_samples_unconditional(model):
for
token_stream
in
get_token_stream
(
model
,
copy
.
deepcopy
(
context_tokens
)):
pass
if
ctr
%
args
.
log_interval
==
0
:
if
ctr
%
args
.
log_interval
==
0
:
print
(
'Avg s/batch:'
,
(
time
.
time
()
-
start_time
)
/
min
(
args
.
log_interval
,
ctr
+
1
))
start_time
=
time
.
time
()
...
...
@@ -254,10 +253,10 @@ def generate_samples_unconditional(model):
token_batch
=
token_stream
[
0
].
cpu
().
numpy
().
tolist
()
length_batch
=
token_stream
[
1
].
cpu
().
numpy
().
tolist
()
for
tokens
,
length
in
zip
(
token_batch
,
length_batch
):
tokens
=
tokens
[
1
:
length
-
1
]
tokens
=
tokens
[
1
:
length
-
1
]
text
=
tokenizer
.
detokenize
(
tokens
)
is_finished
=
length
<
args
.
seq_length
-
1
datum
=
{
'text'
:
text
,
'length'
:
length
-
1
,
'finished'
:
is_finished
}
datum
=
{
'text'
:
text
,
'length'
:
length
-
1
,
'finished'
:
is_finished
}
yield
datum
ctr
+=
1
if
ctr
>=
num_samples
:
...
...
@@ -272,7 +271,7 @@ def generate_and_write_samples_unconditional(model):
assert
args
.
genfile
is
not
None
with
open
(
args
.
genfile
,
'w'
)
as
f
:
for
datum
in
generate_samples_unconditional
(
model
):
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
def
pad_batch
(
batch
,
pad_id
,
args
):
...
...
@@ -281,7 +280,7 @@ def pad_batch(batch, pad_id, args):
for
tokens
in
batch
:
context_length
=
len
(
tokens
)
if
context_length
<
args
.
seq_length
:
tokens
.
extend
([
pad_id
]
*
(
args
.
seq_length
-
context_length
))
tokens
.
extend
([
pad_id
]
*
(
args
.
seq_length
-
context_length
))
context_lengths
.
append
(
context_length
)
return
batch
,
context_lengths
...
...
@@ -345,7 +344,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if
maxlen
>
(
org_context_length
+
args
.
out_seq_length
):
maxlen
=
org_context_length
+
args
.
out_seq_length
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
while
context_length
<=
(
maxlen
):
...
...
megatron/tokenizer/__init__.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
megatron/tokenizer/bert_tokenization.py
View file @
5e56e563
...
...
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from
__future__
import
absolute_import
...
...
@@ -25,377 +26,377 @@ import six
def
validate_case_matches_checkpoint
(
do_lower_case
,
init_checkpoint
):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if
not
init_checkpoint
:
return
m
=
re
.
match
(
"^.*?([A-Za-z0-9_-]+)/bert_model.ckpt"
,
init_checkpoint
)
if
m
is
None
:
return
model_name
=
m
.
group
(
1
)
lower_models
=
[
"uncased_L-24_H-1024_A-16"
,
"uncased_L-12_H-768_A-12"
,
"multilingual_L-12_H-768_A-12"
,
"chinese_L-12_H-768_A-12"
]
cased_models
=
[
"cased_L-12_H-768_A-12"
,
"cased_L-24_H-1024_A-16"
,
"multi_cased_L-12_H-768_A-12"
]
is_bad_config
=
False
if
model_name
in
lower_models
and
not
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"False"
case_name
=
"lowercased"
opposite_flag
=
"True"
if
model_name
in
cased_models
and
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"True"
case_name
=
"cased"
opposite_flag
=
"False"
if
is_bad_config
:
raise
ValueError
(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check."
%
(
actual_flag
,
init_checkpoint
,
model_name
,
case_name
,
opposite_flag
))
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if
not
init_checkpoint
:
return
m
=
re
.
match
(
"^.*?([A-Za-z0-9_-]+)/bert_model.ckpt"
,
init_checkpoint
)
if
m
is
None
:
return
model_name
=
m
.
group
(
1
)
lower_models
=
[
"uncased_L-24_H-1024_A-16"
,
"uncased_L-12_H-768_A-12"
,
"multilingual_L-12_H-768_A-12"
,
"chinese_L-12_H-768_A-12"
]
cased_models
=
[
"cased_L-12_H-768_A-12"
,
"cased_L-24_H-1024_A-16"
,
"multi_cased_L-12_H-768_A-12"
]
is_bad_config
=
False
if
model_name
in
lower_models
and
not
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"False"
case_name
=
"lowercased"
opposite_flag
=
"True"
if
model_name
in
cased_models
and
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"True"
case_name
=
"cased"
opposite_flag
=
"False"
if
is_bad_config
:
raise
ValueError
(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check."
%
(
actual_flag
,
init_checkpoint
,
model_name
,
case_name
,
opposite_flag
))
def
convert_to_unicode
(
text
):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
elif
isinstance
(
text
,
unicode
):
return
text
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
elif
isinstance
(
text
,
unicode
):
return
text
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
printable_text
(
text
):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
unicode
):
return
text
.
encode
(
"utf-8"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
unicode
):
return
text
.
encode
(
"utf-8"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
load_vocab
(
vocab_file
):
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
index
=
0
with
open
(
vocab_file
,
"r"
)
as
reader
:
while
True
:
token
=
convert_to_unicode
(
reader
.
readline
())
if
not
token
:
break
token
=
token
.
strip
()
vocab
[
token
]
=
index
index
+=
1
return
vocab
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
index
=
0
with
open
(
vocab_file
,
"r"
)
as
reader
:
while
True
:
token
=
convert_to_unicode
(
reader
.
readline
())
if
not
token
:
break
token
=
token
.
strip
()
vocab
[
token
]
=
index
index
+=
1
return
vocab
def
convert_by_vocab
(
vocab
,
items
):
"""Converts a sequence of [tokens|ids] using the vocab."""
output
=
[]
for
item
in
items
:
output
.
append
(
vocab
[
item
])
return
output
"""Converts a sequence of [tokens|ids] using the vocab."""
output
=
[]
for
item
in
items
:
output
.
append
(
vocab
[
item
])
return
output
def
convert_tokens_to_ids
(
vocab
,
tokens
):
return
convert_by_vocab
(
vocab
,
tokens
)
return
convert_by_vocab
(
vocab
,
tokens
)
def
convert_ids_to_tokens
(
inv_vocab
,
ids
):
return
convert_by_vocab
(
inv_vocab
,
ids
)
return
convert_by_vocab
(
inv_vocab
,
ids
)
def
whitespace_tokenize
(
text
):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
tokens
=
text
.
split
()
return
tokens
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
tokens
=
text
.
split
()
return
tokens
class
FullTokenizer
(
object
):
"""Runs end-to-end tokenziation."""
"""Runs end-to-end tokenziation."""
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
):
self
.
vocab
=
load_vocab
(
vocab_file
)
self
.
inv_vocab
=
{
v
:
k
for
k
,
v
in
self
.
vocab
.
items
()}
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
)
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
):
self
.
vocab
=
load_vocab
(
vocab_file
)
self
.
inv_vocab
=
{
v
:
k
for
k
,
v
in
self
.
vocab
.
items
()}
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
)
def
tokenize
(
self
,
text
):
split_tokens
=
[]
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
split_tokens
.
append
(
sub_token
)
def
tokenize
(
self
,
text
):
split_tokens
=
[]
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
split_tokens
.
append
(
sub_token
)
return
split_tokens
return
split_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
return
convert_by_vocab
(
self
.
vocab
,
tokens
)
def
convert_tokens_to_ids
(
self
,
tokens
):
return
convert_by_vocab
(
self
.
vocab
,
tokens
)
def
convert_ids_to_tokens
(
self
,
ids
):
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
def
convert_ids_to_tokens
(
self
,
ids
):
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
def
vocab_size
(
self
):
return
len
(
self
.
vocab
)
def
vocab_size
(
self
):
return
len
(
self
.
vocab
)
class
BasicTokenizer
(
object
):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
self
,
do_lower_case
=
True
):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self
.
do_lower_case
=
do_lower_case
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text."""
text
=
convert_to_unicode
(
text
)
text
=
self
.
_clean_text
(
text
)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text
=
self
.
_tokenize_chinese_chars
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
split_tokens
=
[]
for
token
in
orig_tokens
:
if
self
.
do_lower_case
:
token
=
token
.
lower
()
token
=
self
.
_run_strip_accents
(
token
)
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
output_tokens
=
whitespace_tokenize
(
" "
.
join
(
split_tokens
))
return
output_tokens
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
self
,
do_lower_case
=
True
):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self
.
do_lower_case
=
do_lower_case
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text."""
text
=
convert_to_unicode
(
text
)
text
=
self
.
_clean_text
(
text
)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text
=
self
.
_tokenize_chinese_chars
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
split_tokens
=
[]
for
token
in
orig_tokens
:
if
self
.
do_lower_case
:
token
=
token
.
lower
()
token
=
self
.
_run_strip_accents
(
token
)
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
output_tokens
=
whitespace_tokenize
(
" "
.
join
(
split_tokens
))
return
output_tokens
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_tokenize_chinese_chars
(
self
,
text
):
"""Adds whitespace around any CJK character."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
self
.
_is_chinese_char
(
cp
):
output
.
append
(
" "
)
output
.
append
(
char
)
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_is_chinese_char
(
self
,
cp
):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if
((
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
or
#
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
or
#
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
or
#
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
or
#
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
or
#
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
#
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)):
#
return
True
return
False
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_tokenize_chinese_chars
(
self
,
text
):
"""Adds whitespace around any CJK character."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
self
.
_is_chinese_char
(
cp
):
output
.
append
(
" "
)
output
.
append
(
char
)
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_is_chinese_char
(
self
,
cp
):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if
((
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
or
#
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
or
#
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
or
#
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
or
#
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
or
#
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
#
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)):
#
return
True
return
False
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
class
WordpieceTokenizer
(
object
):
"""Runs WordPiece tokenziation."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
200
):
self
.
vocab
=
vocab
self
.
unk_token
=
unk_token
self
.
max_input_chars_per_word
=
max_input_chars_per_word
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text
=
convert_to_unicode
(
text
)
output_tokens
=
[]
for
token
in
whitespace_tokenize
(
text
):
chars
=
list
(
token
)
if
len
(
chars
)
>
self
.
max_input_chars_per_word
:
output_tokens
.
append
(
self
.
unk_token
)
continue
is_bad
=
False
start
=
0
sub_tokens
=
[]
while
start
<
len
(
chars
):
end
=
len
(
chars
)
cur_substr
=
None
while
start
<
end
:
substr
=
""
.
join
(
chars
[
start
:
end
])
if
start
>
0
:
substr
=
"##"
+
substr
if
substr
in
self
.
vocab
:
cur_substr
=
substr
break
end
-=
1
if
cur_substr
is
None
:
is_bad
=
True
break
sub_tokens
.
append
(
cur_substr
)
start
=
end
if
is_bad
:
output_tokens
.
append
(
self
.
unk_token
)
else
:
output_tokens
.
extend
(
sub_tokens
)
return
output_tokens
"""Runs WordPiece tokenziation."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
200
):
self
.
vocab
=
vocab
self
.
unk_token
=
unk_token
self
.
max_input_chars_per_word
=
max_input_chars_per_word
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text
=
convert_to_unicode
(
text
)
output_tokens
=
[]
for
token
in
whitespace_tokenize
(
text
):
chars
=
list
(
token
)
if
len
(
chars
)
>
self
.
max_input_chars_per_word
:
output_tokens
.
append
(
self
.
unk_token
)
continue
is_bad
=
False
start
=
0
sub_tokens
=
[]
while
start
<
len
(
chars
):
end
=
len
(
chars
)
cur_substr
=
None
while
start
<
end
:
substr
=
""
.
join
(
chars
[
start
:
end
])
if
start
>
0
:
substr
=
"##"
+
substr
if
substr
in
self
.
vocab
:
cur_substr
=
substr
break
end
-=
1
if
cur_substr
is
None
:
is_bad
=
True
break
sub_tokens
.
append
(
cur_substr
)
start
=
end
if
is_bad
:
output_tokens
.
append
(
self
.
unk_token
)
else
:
output_tokens
.
extend
(
sub_tokens
)
return
output_tokens
def
_is_whitespace
(
char
):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
return
True
return
False
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
return
True
return
False
def
_is_control
(
char
):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
in
(
"Cc"
,
"Cf"
):
return
True
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
in
(
"Cc"
,
"Cf"
):
return
True
return
False
def
_is_punctuation
(
char
):
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
return
True
return
False
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
return
True
return
False
megatron/tokenizer/gpt2_tokenization.py
View file @
5e56e563
...
...
@@ -29,7 +29,8 @@ try:
from
functools
import
lru_cache
except
ImportError
:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def
lru_cache
():
return
lambda
func
:
func
...
...
@@ -49,6 +50,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
@
lru_cache
()
def
bytes_to_unicode
():
"""
...
...
@@ -61,17 +63,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr
=
unichr
if
sys
.
version_info
[
0
]
==
2
else
chr
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
\
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
cs
=
bs
[:]
n
=
0
for
b
in
range
(
2
**
8
):
if
b
not
in
bs
:
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
cs
=
[
_chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
def
get_pairs
(
word
):
"""Return set of symbol pairs in a word.
...
...
@@ -84,6 +88,7 @@ def get_pairs(word):
prev_char
=
char
return
pairs
class
GPT2Tokenizer
(
object
):
"""
GPT-2 BPE tokenizer. Peculiarities:
...
...
@@ -140,23 +145,31 @@ class GPT2Tokenizer(object):
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
special_tokens
=
None
,
max_len
=
None
):
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
special_tokens
=
None
,
max_len
=
None
):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
bpe_data
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
bpe_merges
=
[
tuple
(
merge
.
split
())
for
merge
in
bpe_data
]
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
cache
=
{}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
...
...
@@ -174,8 +187,9 @@ class GPT2Tokenizer(object):
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
return
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
logger
.
info
(
"Special tokens {}"
.
format
(
self
.
special_tokens
))
def
bpe
(
self
,
token
):
...
...
@@ -188,7 +202,7 @@ class GPT2Tokenizer(object):
return
token
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
...
...
@@ -199,12 +213,12 @@ class GPT2Tokenizer(object):
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
except
:
except
BaseException
:
new_word
.
extend
(
word
[
i
:])
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
...
...
@@ -247,7 +261,8 @@ class GPT2Tokenizer(object):
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
)
return
ids
...
...
megatron/tokenizer/tokenizer.py
View file @
5e56e563
# coding=utf-8
# Copyright (c) 20
19
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 20
20
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -33,6 +33,9 @@ def build_tokenizer(args):
if
args
.
tokenizer_type
==
'BertWordPieceLowerCase'
:
tokenizer
=
_BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
True
)
elif
args
.
tokenizer_type
==
'BertWordPieceCase'
:
tokenizer
=
_BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
False
)
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
assert
args
.
merge_file
is
not
None
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
...
...
@@ -53,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args):
after
=
orig_vocab_size
multiple
=
args
.
make_vocab_size_divisible_by
*
\
args
.
model_parallel_size
args
.
model_parallel_size
while
(
after
%
multiple
)
!=
0
:
after
+=
1
if
args
.
rank
==
0
:
...
...
@@ -134,7 +137,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
self
.
cls_id
=
self
.
tokenizer
.
vocab
[
'[CLS]'
]
self
.
sep_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
]
self
.
pad_id
=
self
.
tokenizer
.
vocab
[
'[PAD]'
]
self
.
mask_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
self
.
mask_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
@
property
def
vocab_size
(
self
):
...
...
@@ -168,6 +171,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
def
mask
(
self
):
return
self
.
mask_id
class
_GPT2BPETokenizer
(
AbstractTokenizer
):
"""Original GPT2 BPE tokenizer."""
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment