Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
dedb2ef7
Commit
dedb2ef7
authored
Mar 30, 2020
by
Mohammad
Browse files
removed building tokenizer from bert dataset
parent
1788c910
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
23 deletions
+49
-23
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+12
-22
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+37
-0
pretrain_bert.py
pretrain_bert.py
+0
-1
No files found.
megatron/data/bert_dataset.py
View file @
dedb2ef7
...
@@ -22,24 +22,19 @@ import numpy as np
...
@@ -22,24 +22,19 @@ import numpy as np
import
torch
import
torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.data
import
helpers
from
megatron.data
import
helpers
from
megatron.tokenizer.bert_tokenization
import
FullTokenizer
as
FullBertTokenizer
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
def
build_train_valid_test_datasets
(
vocab_file
,
data_prefix
,
data_impl
,
def
build_train_valid_test_datasets
(
data_prefix
,
data_impl
,
splits_string
,
splits_string
,
train_valid_test_num_samples
,
train_valid_test_num_samples
,
max_seq_length
,
masked_lm_prob
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
):
short_seq_prob
,
seed
,
skip_warmup
):
# Tokenizer is the same
tokenizer
=
FullBertTokenizer
(
vocab_file
,
do_lower_case
=
True
)
print_rank_0
(
' > using full BERT tokenizer with vocabulary size: {}'
.
format
(
tokenizer
.
vocab_size
()))
# Indexed dataset.
# Indexed dataset.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
data_impl
,
data_impl
,
...
@@ -82,7 +77,6 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
...
@@ -82,7 +77,6 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
dataset
=
BertDataset
(
dataset
=
BertDataset
(
name
=
name
,
name
=
name
,
indexed_dataset
=
indexed_dataset
,
indexed_dataset
=
indexed_dataset
,
tokenizer
=
tokenizer
,
data_prefix
=
data_prefix
,
data_prefix
=
data_prefix
,
num_epochs
=
None
,
num_epochs
=
None
,
max_num_samples
=
train_valid_test_num_samples
[
index
],
max_num_samples
=
train_valid_test_num_samples
[
index
],
...
@@ -107,7 +101,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
...
@@ -107,7 +101,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
class
BertDataset
(
Dataset
):
class
BertDataset
(
Dataset
):
def
__init__
(
self
,
name
,
indexed_dataset
,
tokenizer
,
data_prefix
,
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
max_seq_length
,
short_seq_prob
,
seed
):
...
@@ -117,8 +111,7 @@ class BertDataset(Dataset):
...
@@ -117,8 +111,7 @@ class BertDataset(Dataset):
self
.
masked_lm_prob
=
masked_lm_prob
self
.
masked_lm_prob
=
masked_lm_prob
self
.
max_seq_length
=
max_seq_length
self
.
max_seq_length
=
max_seq_length
# Tokenizer and dataset.
# Dataset.
self
.
tokenizer
=
tokenizer
self
.
indexed_dataset
=
indexed_dataset
self
.
indexed_dataset
=
indexed_dataset
...
@@ -133,16 +126,13 @@ class BertDataset(Dataset):
...
@@ -133,16 +126,13 @@ class BertDataset(Dataset):
self
.
name
)
self
.
name
)
# Vocab stuff.
# Vocab stuff.
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
tokenizer
=
get_tokenizer
()
self
.
vocab_id_to_token_dict
=
self
.
tokenizer
.
inv_vocab
self
.
vocab_id_list
=
list
(
tokenizer
.
inv_vocab
.
keys
())
self
.
cls_id
=
self
.
tokenizer
.
vocab
[
'[CLS]'
]
self
.
vocab_id_to_token_dict
=
tokenizer
.
inv_vocab
self
.
sep_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
]
self
.
cls_id
=
tokenizer
.
cls
self
.
mask_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
self
.
sep_id
=
tokenizer
.
sep
self
.
pad_id
=
self
.
tokenizer
.
vocab
[
'[PAD]'
]
self
.
mask_id
=
tokenizer
.
mask
self
.
pad_id
=
tokenizer
.
pad
def
num_tokens
(
self
):
return
self
.
tokenizer
.
vocab_size
()
def
__len__
(
self
):
def
__len__
(
self
):
...
...
megatron/tokenizer/tokenizer.py
View file @
dedb2ef7
...
@@ -75,6 +75,18 @@ class AbstractTokenizer(ABC):
...
@@ -75,6 +75,18 @@ class AbstractTokenizer(ABC):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
pass
pass
@
property
@
abstractmethod
def
vocab
(
self
):
"""Dictionary from vocab text token to id token."""
pass
@
property
@
abstractmethod
def
inv_vocab
(
self
):
"""Dictionary from vocab id token to text token."""
pass
@
abstractmethod
@
abstractmethod
def
tokenize
(
self
,
text
):
def
tokenize
(
self
,
text
):
pass
pass
...
@@ -99,6 +111,11 @@ class AbstractTokenizer(ABC):
...
@@ -99,6 +111,11 @@ class AbstractTokenizer(ABC):
raise
NotImplementedError
(
'EOD is not provided for {} '
raise
NotImplementedError
(
'EOD is not provided for {} '
'tokenizer'
.
format
(
self
.
name
))
'tokenizer'
.
format
(
self
.
name
))
@
property
def
mask
(
self
):
raise
NotImplementedError
(
'MASK is not provided for {} '
'tokenizer'
.
format
(
self
.
name
))
class
_BertWordPieceTokenizer
(
AbstractTokenizer
):
class
_BertWordPieceTokenizer
(
AbstractTokenizer
):
"""Original BERT wordpiece tokenizer."""
"""Original BERT wordpiece tokenizer."""
...
@@ -113,11 +130,20 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
...
@@ -113,11 +130,20 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
self
.
cls_id
=
self
.
tokenizer
.
vocab
[
'[CLS]'
]
self
.
cls_id
=
self
.
tokenizer
.
vocab
[
'[CLS]'
]
self
.
sep_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
]
self
.
sep_id
=
self
.
tokenizer
.
vocab
[
'[SEP]'
]
self
.
pad_id
=
self
.
tokenizer
.
vocab
[
'[PAD]'
]
self
.
pad_id
=
self
.
tokenizer
.
vocab
[
'[PAD]'
]
self
.
mask_id
=
self
.
tokenizer
.
vocab
[
'[MASK]'
]
@
property
@
property
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
self
.
tokenizer
.
vocab_size
()
return
self
.
tokenizer
.
vocab_size
()
@
property
def
vocab
(
self
):
return
self
.
tokenizer
.
vocab
@
property
def
inv_vocab
(
self
):
return
self
.
tokenizer
.
inv_vocab
def
tokenize
(
self
,
text
):
def
tokenize
(
self
,
text
):
text_tokens
=
self
.
tokenizer
.
tokenize
(
text
)
text_tokens
=
self
.
tokenizer
.
tokenize
(
text
)
return
self
.
tokenizer
.
convert_tokens_to_ids
(
text_tokens
)
return
self
.
tokenizer
.
convert_tokens_to_ids
(
text_tokens
)
...
@@ -134,6 +160,9 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
...
@@ -134,6 +160,9 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
def
pad
(
self
):
def
pad
(
self
):
return
self
.
pad_id
return
self
.
pad_id
@
property
def
mask
(
self
):
return
self
.
mask_id
class
_GPT2BPETokenizer
(
AbstractTokenizer
):
class
_GPT2BPETokenizer
(
AbstractTokenizer
):
"""Original GPT2 BPE tokenizer."""
"""Original GPT2 BPE tokenizer."""
...
@@ -150,6 +179,14 @@ class _GPT2BPETokenizer(AbstractTokenizer):
...
@@ -150,6 +179,14 @@ class _GPT2BPETokenizer(AbstractTokenizer):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
tokenizer
.
encoder
)
return
len
(
self
.
tokenizer
.
encoder
)
@
property
def
vocab
(
self
):
return
self
.
tokenizer
.
encoder
@
property
def
inv_vocab
(
self
):
return
self
.
tokenizer
.
decoder
def
tokenize
(
self
,
text
):
def
tokenize
(
self
,
text
):
return
self
.
tokenizer
.
encode
(
text
)
return
self
.
tokenizer
.
encode
(
text
)
...
...
pretrain_bert.py
View file @
dedb2ef7
...
@@ -137,7 +137,6 @@ def get_train_val_test_data():
...
@@ -137,7 +137,6 @@ def get_train_val_test_data():
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
vocab_file
=
args
.
vocab_file
,
data_prefix
=
args
.
data_path
,
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
splits_string
=
args
.
split
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment