Commit b8bb0b49 authored by Mohammad's avatar Mohammad
Browse files

Debugging done on Circe

parent e3c57c82
...@@ -391,6 +391,7 @@ def _add_data_args(parser): ...@@ -391,6 +391,7 @@ def _add_data_args(parser):
group.add_argument('--faiss-use-gpu', action='store_true') group.add_argument('--faiss-use-gpu', action='store_true')
group.add_argument('--index-reload-interval', type=int, default=500) group.add_argument('--index-reload-interval', type=int, default=500)
group.add_argument('--use-regular-masking', action='store_true') group.add_argument('--use-regular-masking', action='store_true')
group.add_argument('--use-random-spans', action='store_true')
group.add_argument('--allow-trivial-doc', action='store_true') group.add_argument('--allow-trivial-doc', action='store_true')
group.add_argument('--ner-data-path', type=str, default=None) group.add_argument('--ner-data-path', type=str, default=None)
......
...@@ -28,6 +28,9 @@ def build_realm_training_sample(sample, max_seq_length, ...@@ -28,6 +28,9 @@ def build_realm_training_sample(sample, max_seq_length,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng) cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
elif block_ner_mask is not None: elif block_ner_mask is not None:
block_ner_mask = list(itertools.chain(*block_ner_mask))[:max_seq_length - 2] block_ner_mask = list(itertools.chain(*block_ner_mask))[:max_seq_length - 2]
if args.use_random_spans:
rand_idx = np.random.randint(len(block_ner_mask))
block_ner_mask = block_ner_mask[rand_idx:] + block_ner_mask[:rand_idx]
block_ner_mask = [0] + block_ner_mask + [0] block_ner_mask = [0] + block_ner_mask + [0]
masked_tokens, masked_positions, masked_labels = get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id) masked_tokens, masked_positions, masked_labels = get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id)
else: else:
...@@ -182,7 +185,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -182,7 +185,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
indexmap_filename += '.npy' indexmap_filename += '.npy'
# Build the indexed mapping if not exist. # Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \ if mpu.get_data_parallel_rank() == 0 and \
not os.path.isfile(indexmap_filename): not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building ' print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename)) 'the indices on rank 0 ...'.format(indexmap_filename))
......
...@@ -15,12 +15,16 @@ def detach(tensor): ...@@ -15,12 +15,16 @@ def detach(tensor):
class BlockData(object): class BlockData(object):
def __init__(self): def __init__(self, block_data_path=None):
args = get_args()
self.embed_data = dict() self.embed_data = dict()
self.meta_data = dict() self.meta_data = dict()
block_data_path = os.path.splitext(args.block_data_path)[0] if block_data_path is None:
self.temp_dir_name = block_data_path + '_tmp' args = get_args()
block_data_path = args.block_data_path
self.block_data_path = block_data_path
block_data_name = os.path.splitext(self.block_data_path)[0]
self.temp_dir_name = block_data_name + '_tmp'
def state(self): def state(self):
return { return {
...@@ -54,7 +58,7 @@ class BlockData(object): ...@@ -54,7 +58,7 @@ class BlockData(object):
def save_shard(self, rank): def save_shard(self, rank):
if not os.path.isdir(self.temp_dir_name): if not os.path.isdir(self.temp_dir_name):
os.mkdir(self.temp_dir_name) os.makedirs(self.temp_dir_name, exist_ok=True)
# save the data for each shard # save the data for each shard
with open('{}/{}.pkl'.format(self.temp_dir_name, rank), 'wb') as data_file: with open('{}/{}.pkl'.format(self.temp_dir_name, rank), 'wb') as data_file:
...@@ -73,8 +77,7 @@ class BlockData(object): ...@@ -73,8 +77,7 @@ class BlockData(object):
self.meta_data.update(data['meta_data']) self.meta_data.update(data['meta_data'])
# assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname) # assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname)
args = get_args() with open(self.block_data_path, 'wb') as final_file:
with open(args.block_data_path, 'wb') as final_file:
pickle.dump(self.state(), final_file) pickle.dump(self.state(), final_file)
shutil.rmtree(self.temp_dir_name, ignore_errors=True) shutil.rmtree(self.temp_dir_name, ignore_errors=True)
......
...@@ -422,8 +422,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -422,8 +422,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
elif iteration < 20: elif iteration < 20:
print("moving right along", flush=True) #print("moving right along", flush=True)
report_memory("iteration {}".format(iteration)) #report_memory("iteration {}".format(iteration))
pass
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment