Commit 3354081f authored by Neel Kant's avatar Neel Kant
Browse files

More ict_merge changes and interactive testing

parent 6495927e
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
from .package_info import ( from .package_info import (
__description__, __description__,
...@@ -30,7 +31,6 @@ from .global_vars import get_tensorboard_writer ...@@ -30,7 +31,6 @@ from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume from .global_vars import get_adlr_autoresume
from .global_vars import get_timers from .global_vars import get_timers
import torch
def print_rank_0(message): def print_rank_0(message):
"""If distributed is initialized print only on rank 0.""" """If distributed is initialized print only on rank 0."""
...@@ -38,4 +38,4 @@ def print_rank_0(message): ...@@ -38,4 +38,4 @@ def print_rank_0(message):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print(message, flush=True) print(message, flush=True)
else: else:
print(message, flush=True) print(message, flush=True)
\ No newline at end of file
...@@ -23,9 +23,8 @@ import numpy as np ...@@ -23,9 +23,8 @@ import numpy as np
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu from megatron import mpu, print_rank_0
from megatron import get_args from megatron import get_args
from megatron import print_rank_0
def check_checkpoint_args(checkpoint_args): def check_checkpoint_args(checkpoint_args):
......
...@@ -22,15 +22,13 @@ import numpy as np ...@@ -22,15 +22,13 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer, get_args from megatron import get_tokenizer, get_args, print_rank_0
from megatron import mpu from megatron import mpu
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.dataset_utils import get_a_and_b_segments from megatron.data.dataset_utils import get_a_and_b_segments
from megatron.data.dataset_utils import truncate_segments from megatron.data.dataset_utils import truncate_segments
from megatron.data.dataset_utils import create_tokens_and_tokentypes from megatron.data.dataset_utils import create_tokens_and_tokentypes
from megatron.data.dataset_utils import pad_and_convert_to_numpy from megatron.data.dataset_utils import pad_and_convert_to_numpy
from megatron.data.dataset_utils import create_masked_lm_predictions from megatron.data.dataset_utils import create_masked_lm_predictions
from megatron import print_rank_0
class BertDataset(Dataset): class BertDataset(Dataset):
...@@ -85,55 +83,6 @@ class BertDataset(Dataset): ...@@ -85,55 +83,6 @@ class BertDataset(Dataset):
self.masked_lm_prob, np_rng) self.masked_lm_prob, np_rng)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' > indexed dataset stats:')
print_rank_0(' number of documents: {}'.format(
indexed_dataset.doc_idx.shape[0] - 1))
print_rank_0(' number of sentences: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
def get_samples_mapping_(indexed_dataset, def get_samples_mapping_(indexed_dataset,
data_prefix, data_prefix,
num_epochs, num_epochs,
......
...@@ -18,17 +18,19 @@ ...@@ -18,17 +18,19 @@
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py # https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications. # with some modifications.
import time
import collections import collections
import itertools
import numpy as np import numpy as np
from megatron import print_rank_0, get_args from megatron import get_args, print_rank_0
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
DSET_TYPE_STD = 'standard_bert' DSET_TYPE_STD = 'standard_bert'
DSET_TYPE_ICT = 'ict' DSET_TYPE_ICT = 'ict'
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
def compile_helper(): def compile_helper():
"""Compile helper function ar runtime. Make sure this """Compile helper function ar runtime. Make sure this
is invoked on a single process.""" is invoked on a single process."""
...@@ -447,3 +449,54 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -447,3 +449,54 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
test_dataset = build_dataset(2, 'test') test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset) return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' > indexed dataset stats:')
print_rank_0(' number of documents: {}'.format(
indexed_dataset.doc_idx.shape[0] - 1))
print_rank_0(' number of sentences: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
...@@ -21,8 +21,7 @@ import time ...@@ -21,8 +21,7 @@ import time
import numpy as np import numpy as np
import torch import torch
from megatron import print_rank_0 from megatron import mpu, print_rank_0
from megatron import mpu
from megatron.data.bert_dataset import get_train_valid_test_split_ from megatron.data.bert_dataset import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
......
...@@ -401,7 +401,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -401,7 +401,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
const uint64_t max_num_samples, const uint64_t max_num_samples,
const int32_t max_seq_length, const int32_t max_seq_length,
const int32_t seed, const int32_t seed,
const bool verbose) { const bool verbose,
const bool use_one_sent_blocks) {
/* Build a mapping of (start-index, end-index, sequence-length) where /* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length. and sequence-length is the target sequence length.
...@@ -442,6 +443,12 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -442,6 +443,12 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
int64_t num_samples = -1; int64_t num_samples = -1;
DocIdx* maps = NULL; DocIdx* maps = NULL;
// Acceptable number of sentences per block.
int min_num_sent = 2;
if (use_one_sent_blocks) {
min_num_sent = 1;
}
// Perform two iterations, in the first iteration get the size // Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map. // and allocate memory and in the second iteration populate the map.
bool second = false; bool second = false;
...@@ -453,6 +460,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -453,6 +460,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Current map index. // Current map index.
uint64_t map_index = 0; uint64_t map_index = 0;
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// For each epoch: // For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) { for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
// assign every block a unique id // assign every block a unique id
...@@ -480,19 +490,31 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -480,19 +490,31 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Remaining documents. // Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first; auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
++empty_docs;
}
if (num_remain_sent == 1) {
++one_sent_docs;
}
}
// Detect documents with long sentences. // Detect documents with long sentences.
bool contains_long_sentence = false; bool contains_long_sentence = false;
if (num_remain_sent > 1) { if (num_remain_sent >= min_num_sent) {
for (auto sent_index=sent_index_first; for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) { sent_index < sent_index_last; ++sent_index) {
if (sizes[sent_index] > LONG_SENTENCE_LEN){ if (sizes[sent_index] > LONG_SENTENCE_LEN){
if ((epoch == 0) && (!second)) {
++long_sent_docs;
}
contains_long_sentence = true; contains_long_sentence = true;
break; break;
} }
} }
} }
// If we have more than two sentences. // If we have enough sentences and no long sentences.
if ((num_remain_sent > 1) && (!contains_long_sentence)) { if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values. // Set values.
auto seq_len = int32_t{0}; auto seq_len = int32_t{0};
...@@ -508,12 +530,12 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -508,12 +530,12 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
--num_remain_sent; --num_remain_sent;
// If we have reached the target length. // If we have reached the target length.
// and if not only one sentence is left in the document. // and there are an acceptable number of sentences left
// and if we have at least two sentneces. // and if we have at least the minimum number of sentences.
// or if we have reached end of the document. // or if we have reached end of the document.
if (((seq_len >= target_seq_len) && if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) && (num_remain_sent >= min_num_sent) &&
(num_sent > 1) ) || (num_remain_sent == 0)) { (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Populate the map. // Populate the map.
if (second) { if (second) {
...@@ -538,11 +560,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -538,11 +560,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
} // for (auto sent_index=sent_index_first; ... } // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) { } // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) { } // for (int doc=0; doc < num_docs; ++doc) {
block_id = 0;
} // for (int epoch=0; epoch < num_epochs; ++epoch) { } // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) { if (!second) {
if (verbose) { if (verbose) {
cout << " number of empty documents: " << empty_docs <<
endl << std::flush;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush;
cout << " number of documents with long sentences: " <<
long_sent_docs << endl << std::flush;
cout << " will create mapping for " << map_index << cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush; " samples" << endl << std::flush;
} }
...@@ -554,9 +581,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -554,9 +581,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
} // for (int iteration=0; iteration < 2; ++iteration) { } // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle. Shuffle.
// We need a 64 bit random number generator as we might have more We need a 64 bit random number generator as we might have more
// than 2 billion samples. than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1); std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) { for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1)); const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
...@@ -591,20 +618,21 @@ py::array build_blocks_mapping(const py::array_t<int64_t>& docs_, ...@@ -591,20 +618,21 @@ py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
const uint64_t max_num_samples, const uint64_t max_num_samples,
const int max_seq_length, const int max_seq_length,
const int seed, const int seed,
const bool verbose) { const bool verbose,
const bool use_one_sent_blocks) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) { if (verbose) {
cout << " using uint64 for data mapping..." << endl << std::flush; cout << " using uint64 for data mapping..." << endl << std::flush;
} }
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_, return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose); num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
} else { } else {
if (verbose) { if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush; cout << " using uint32 for data mapping..." << endl << std::flush;
} }
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_, return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose); num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
} }
} }
......
...@@ -65,6 +65,7 @@ class ICTDataset(Dataset): ...@@ -65,6 +65,7 @@ class ICTDataset(Dataset):
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
print(self.tokenizer.decode_token_ids(block_tokens), '\n')
block_data = np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64) block_data = np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
sample = { sample = {
......
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import numpy as np import numpy as np
import torch import torch
from megatron import print_rank_0, mpu from megatron import mpu, print_rank_0
def join_str_list(str_list): def join_str_list(str_list):
...@@ -19,7 +19,7 @@ def join_str_list(str_list): ...@@ -19,7 +19,7 @@ def join_str_list(str_list):
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs, def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name): max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires """Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.""" a dataset of the titles for the source documents since their lengths must be taken into account."""
if not num_epochs: if not num_epochs:
...@@ -39,6 +39,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -39,6 +39,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
indexmap_filename += '_{}mns'.format(max_num_samples) indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length) indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{}s'.format(seed) indexmap_filename += '_{}s'.format(seed)
if use_one_sent_docs:
indexmap_filename += '_1sentok'
indexmap_filename += '.npy' indexmap_filename += '.npy'
# Build the indexed mapping if not exist. # Build the indexed mapping if not exist.
...@@ -67,7 +69,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -67,7 +69,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
max_num_samples, max_num_samples,
max_seq_length-3, # account for added tokens max_seq_length-3, # account for added tokens
seed, seed,
verbose) verbose,
use_one_sent_docs)
print_rank_0(' > done building samples index mapping') print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True) np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format( print_rank_0(' > saved the index mapping in {}'.format(
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import torch import torch
from megatron import get_args from megatron import get_args, print_rank_0
from megatron.model.bert_model import bert_attention_mask_func from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids from megatron.model.bert_model import bert_position_ids
...@@ -26,7 +26,6 @@ from megatron.model.utils import get_linear_layer ...@@ -26,7 +26,6 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron import print_rank_0
class Classification(MegatronModule): class Classification(MegatronModule):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import torch import torch
from megatron import get_args from megatron import get_args, print_rank_0
from megatron.model.bert_model import bert_attention_mask_func from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids from megatron.model.bert_model import bert_position_ids
...@@ -26,7 +26,6 @@ from megatron.model.utils import get_linear_layer ...@@ -26,7 +26,6 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron import print_rank_0
class MultipleChoice(MegatronModule): class MultipleChoice(MegatronModule):
......
...@@ -125,7 +125,7 @@ class ICTBertModel(MegatronModule): ...@@ -125,7 +125,7 @@ class ICTBertModel(MegatronModule):
class IREncoderBertModel(MegatronModule): class IREncoderBertModel(MegatronModule):
"""Bert Language model.""" """BERT-based encoder for queries or blocks used for learned information retrieval."""
def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True): def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True):
super(IREncoderBertModel, self).__init__() super(IREncoderBertModel, self).__init__()
args = get_args() args = get_args()
...@@ -158,9 +158,8 @@ class IREncoderBertModel(MegatronModule): ...@@ -158,9 +158,8 @@ class IREncoderBertModel(MegatronModule):
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
# Output. # Output.
if self.add_ict_head: ict_logits = self.ict_head(pooled_output)
ict_logits = self.ict_head(pooled_output) return ict_logits, None
return ict_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
......
...@@ -20,7 +20,6 @@ from abc import abstractmethod ...@@ -20,7 +20,6 @@ from abc import abstractmethod
from .bert_tokenization import FullTokenizer as FullBertTokenizer from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer from .gpt2_tokenization import GPT2Tokenizer
from megatron.data.realm_dataset_utils import join_str_list
def build_tokenizer(args): def build_tokenizer(args):
...@@ -160,8 +159,15 @@ class _BertWordPieceTokenizer(AbstractTokenizer): ...@@ -160,8 +159,15 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids) tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]'] exclude_list = ['[PAD]', '[CLS]']
non_pads = [t for t in tokens if t not in exclude_list] non_pads = [t for t in tokens if t not in exclude_list]
joined_strs = join_str_list(non_pads)
return joined_strs result = ""
for s in non_pads:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
@property @property
def cls(self): def cls(self):
......
...@@ -22,11 +22,10 @@ import torch ...@@ -22,11 +22,10 @@ import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from megatron import get_args from megatron import get_args, print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Module
......
...@@ -19,10 +19,9 @@ import sys ...@@ -19,10 +19,9 @@ import sys
import torch import torch
from megatron import get_args from megatron import get_args, print_rank_0
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.data.samplers import DistributedBatchSampler from megatron.data.samplers import DistributedBatchSampler
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
...@@ -173,3 +172,5 @@ def get_ltor_masks_and_position_ids(data, ...@@ -173,3 +172,5 @@ def get_ltor_masks_and_position_ids(data,
attention_mask = (attention_mask < 0.5) attention_mask = (attention_mask < 0.5)
return attention_mask, loss_mask, position_ids return attention_mask, loss_mask, position_ids
...@@ -18,10 +18,9 @@ ...@@ -18,10 +18,9 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args, print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel from megatron.model import BertModel
from megatron.training import pretrain from megatron.training import pretrain
......
...@@ -19,10 +19,9 @@ import torch ...@@ -19,10 +19,9 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args, print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ICTBertModel from megatron.model import ICTBertModel
from megatron.training import pretrain from megatron.training import pretrain
......
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
import torch import torch
from megatron import get_args from megatron import get_args, print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.data.gpt2_dataset import build_train_valid_test_datasets from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.training import pretrain from megatron.training import pretrain
......
...@@ -20,9 +20,8 @@ import time ...@@ -20,9 +20,8 @@ import time
import torch import torch
from megatron import get_args from megatron import get_args, print_rank_0
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch from tasks.finetune_utils import process_batch
......
...@@ -17,10 +17,9 @@ ...@@ -17,10 +17,9 @@
import torch import torch
from megatron import get_args from megatron import get_args, print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results from megatron.training import evaluate_and_print_results
......
...@@ -15,9 +15,8 @@ ...@@ -15,9 +15,8 @@
"""GLUE finetuning/evaluation.""" """GLUE finetuning/evaluation."""
from megatron import get_args from megatron import get_args, print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.model.classification import Classification from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment