Commit 242770dd authored by mshoeybi's avatar mshoeybi
Browse files

moved compile helper to initialize

parent a495871b
......@@ -128,8 +128,6 @@ def get_samples_mapping_(indexed_dataset,
print_rank_0(' > building sapmles index mapping for {} ...'.format(
name))
# First compile and then import.
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers
samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
......
......@@ -49,13 +49,6 @@ class BlendableDataset(torch.utils.data.Dataset):
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
if torch.distributed.get_rank() == 0:
from megatron.data.dataset_utils import compile_helper
compile_helper()
# Simple barrier
tmp = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(tmp, group=mpu.get_data_parallel_group())
from megatron.data import helpers
helpers.build_blending_indices(self.dataset_index,
self.dataset_sample_index,
......
......@@ -269,8 +269,6 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
start_time = time.time()
# Use C++ implementation for speed.
# First compile and then import.
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers
assert doc_idx.dtype == np.int32
assert sizes.dtype == np.int32
......
......@@ -152,10 +152,6 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
# compile/bind the C++ helper code
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers
mapping_array = helpers.build_blocks_mapping(
block_dataset.doc_idx,
......
......@@ -78,6 +78,16 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Autoresume.
_init_autoresume()
# Compile dataset C++ code.
try:
from megatron.data import helpers
except:
if torch.distributed.get_rank() == 0:
from megatron.data.dataset_utils import compile_helper
compile_helper()
# Simple barrier
torch.distributed.barrier()
# No continuation function
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment