Commit 09e05c6f authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

moved albert to bert

parent 3e4e1ab2
from . import indexed_dataset from . import indexed_dataset
from .bert_tokenization import FullTokenizer as FullBertTokenizer from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .albert_dataset import AlbertDataset
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""ALBERT Style dataset.""" """BERT Style dataset."""
import os import os
import time import time
...@@ -79,7 +79,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl, ...@@ -79,7 +79,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
# New doc_idx view. # New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly. # Build the dataset accordingly.
dataset = AlbertDataset( dataset = BertDataset(
name=name, name=name,
indexed_dataset=indexed_dataset, indexed_dataset=indexed_dataset,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -105,7 +105,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl, ...@@ -105,7 +105,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
return (train_dataset, valid_dataset, test_dataset) return (train_dataset, valid_dataset, test_dataset)
class AlbertDataset(Dataset): class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, tokenizer, data_prefix, def __init__(self, name, indexed_dataset, tokenizer, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Pretrain ALBERT""" """Pretrain BERT"""
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -24,7 +24,7 @@ from megatron.utils import print_rank_0 ...@@ -24,7 +24,7 @@ from megatron.utils import print_rank_0
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
from megatron.training import run from megatron.training import run
from megatron.data.albert_dataset import build_train_valid_test_datasets from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler from megatron.data_utils.samplers import DistributedBatchSampler
...@@ -116,16 +116,16 @@ def get_train_val_test_data(args): ...@@ -116,16 +116,16 @@ def get_train_val_test_data(args):
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
print_rank_0('> building train, validation, and test datasets ' print_rank_0('> building train, validation, and test datasets '
'for ALBERT ...') 'for BERT ...')
if args.data_loader is None: if args.data_loader is None:
args.data_loader = 'binary' args.data_loader = 'binary'
if args.data_loader != 'binary': if args.data_loader != 'binary':
print('Unsupported {} data loader for ALBERT.'.format( print('Unsupported {} data loader for BERT.'.format(
args.data_loader)) args.data_loader))
exit(1) exit(1)
if not args.data_path: if not args.data_path:
print('ALBERT only supports a unified dataset specified ' print('BERT only supports a unified dataset specified '
'with --data-path') 'with --data-path')
exit(1) exit(1)
...@@ -157,7 +157,7 @@ def get_train_val_test_data(args): ...@@ -157,7 +157,7 @@ def get_train_val_test_data(args):
short_seq_prob=args.short_seq_prob, short_seq_prob=args.short_seq_prob,
seed=args.seed, seed=args.seed,
skip_warmup=args.skip_mmap_warmup) skip_warmup=args.skip_mmap_warmup)
print_rank_0("> finished creating ALBERT datasets ...") print_rank_0("> finished creating BERT datasets ...")
def make_data_loader_(dataset): def make_data_loader_(dataset):
if not dataset: if not dataset:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment