Commit 09e05c6f authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

moved albert to bert

parent 3e4e1ab2
from . import indexed_dataset
from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .albert_dataset import AlbertDataset
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""ALBERT Style dataset."""
"""BERT Style dataset."""
import os
import time
......@@ -79,7 +79,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
dataset = AlbertDataset(
dataset = BertDataset(
name=name,
indexed_dataset=indexed_dataset,
tokenizer=tokenizer,
......@@ -105,7 +105,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
return (train_dataset, valid_dataset, test_dataset)
class AlbertDataset(Dataset):
class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, tokenizer, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
......
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain ALBERT"""
"""Pretrain BERT"""
import torch
import torch.nn.functional as F
......@@ -24,7 +24,7 @@ from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run
from megatron.data.albert_dataset import build_train_valid_test_datasets
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler
......@@ -116,16 +116,16 @@ def get_train_val_test_data(args):
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
print_rank_0('> building train, validation, and test datasets '
'for ALBERT ...')
'for BERT ...')
if args.data_loader is None:
args.data_loader = 'binary'
if args.data_loader != 'binary':
print('Unsupported {} data loader for ALBERT.'.format(
print('Unsupported {} data loader for BERT.'.format(
args.data_loader))
exit(1)
if not args.data_path:
print('ALBERT only supports a unified dataset specified '
print('BERT only supports a unified dataset specified '
'with --data-path')
exit(1)
......@@ -157,7 +157,7 @@ def get_train_val_test_data(args):
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=args.skip_mmap_warmup)
print_rank_0("> finished creating ALBERT datasets ...")
print_rank_0("> finished creating BERT datasets ...")
def make_data_loader_(dataset):
if not dataset:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment