Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
09e05c6f
Commit
09e05c6f
authored
Mar 16, 2020
by
Mohammad Shoeybi
Browse files
moved albert to bert
parent
3e4e1ab2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
10 deletions
+10
-10
megatron/data/__init__.py
megatron/data/__init__.py
+1
-1
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+3
-3
pretrain_bert.py
pretrain_bert.py
+6
-6
No files found.
megatron/data/__init__.py
View file @
09e05c6f
from
.
import
indexed_dataset
from
.
import
indexed_dataset
from
.bert_tokenization
import
FullTokenizer
as
FullBertTokenizer
from
.bert_tokenization
import
FullTokenizer
as
FullBertTokenizer
from
.albert_dataset
import
AlbertDataset
megatron/data/
al
bert_dataset.py
→
megatron/data/bert_dataset.py
View file @
09e05c6f
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
AL
BERT Style dataset."""
"""BERT Style dataset."""
import
os
import
os
import
time
import
time
...
@@ -79,7 +79,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
...
@@ -79,7 +79,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
# New doc_idx view.
# New doc_idx view.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
[
start_index
:
end_index
])
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
[
start_index
:
end_index
])
# Build the dataset accordingly.
# Build the dataset accordingly.
dataset
=
Alb
ertDataset
(
dataset
=
B
ertDataset
(
name
=
name
,
name
=
name
,
indexed_dataset
=
indexed_dataset
,
indexed_dataset
=
indexed_dataset
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -105,7 +105,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
...
@@ -105,7 +105,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
class
Alb
ertDataset
(
Dataset
):
class
B
ertDataset
(
Dataset
):
def
__init__
(
self
,
name
,
indexed_dataset
,
tokenizer
,
data_prefix
,
def
__init__
(
self
,
name
,
indexed_dataset
,
tokenizer
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
...
...
pretrain_bert.py
View file @
09e05c6f
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Pretrain
AL
BERT"""
"""Pretrain BERT"""
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -24,7 +24,7 @@ from megatron.utils import print_rank_0
...
@@ -24,7 +24,7 @@ from megatron.utils import print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
from
megatron.utils
import
vocab_size_with_padding
from
megatron.training
import
run
from
megatron.training
import
run
from
megatron.data.
al
bert_dataset
import
build_train_valid_test_datasets
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.data_utils.samplers
import
DistributedBatchSampler
from
megatron.data_utils.samplers
import
DistributedBatchSampler
...
@@ -116,16 +116,16 @@ def get_train_val_test_data(args):
...
@@ -116,16 +116,16 @@ def get_train_val_test_data(args):
# Data loader only on rank 0 of each model parallel group.
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_model_parallel_rank
()
==
0
:
if
mpu
.
get_model_parallel_rank
()
==
0
:
print_rank_0
(
'> building train, validation, and test datasets '
print_rank_0
(
'> building train, validation, and test datasets '
'for
AL
BERT ...'
)
'for BERT ...'
)
if
args
.
data_loader
is
None
:
if
args
.
data_loader
is
None
:
args
.
data_loader
=
'binary'
args
.
data_loader
=
'binary'
if
args
.
data_loader
!=
'binary'
:
if
args
.
data_loader
!=
'binary'
:
print
(
'Unsupported {} data loader for
AL
BERT.'
.
format
(
print
(
'Unsupported {} data loader for BERT.'
.
format
(
args
.
data_loader
))
args
.
data_loader
))
exit
(
1
)
exit
(
1
)
if
not
args
.
data_path
:
if
not
args
.
data_path
:
print
(
'
AL
BERT only supports a unified dataset specified '
print
(
'BERT only supports a unified dataset specified '
'with --data-path'
)
'with --data-path'
)
exit
(
1
)
exit
(
1
)
...
@@ -157,7 +157,7 @@ def get_train_val_test_data(args):
...
@@ -157,7 +157,7 @@ def get_train_val_test_data(args):
short_seq_prob
=
args
.
short_seq_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
seed
=
args
.
seed
,
skip_warmup
=
args
.
skip_mmap_warmup
)
skip_warmup
=
args
.
skip_mmap_warmup
)
print_rank_0
(
"> finished creating
AL
BERT datasets ..."
)
print_rank_0
(
"> finished creating BERT datasets ..."
)
def
make_data_loader_
(
dataset
):
def
make_data_loader_
(
dataset
):
if
not
dataset
:
if
not
dataset
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment