Commit 6f56b909 authored by Neel Kant's avatar Neel Kant
Browse files

Remove debug statements and correct dataloader

parent 932c0970
...@@ -924,7 +924,6 @@ class InverseClozeDataset(data.Dataset): ...@@ -924,7 +924,6 @@ class InverseClozeDataset(data.Dataset):
'context_types': np.array(context_token_types), 'context_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask) 'context_pad_mask': np.array(context_pad_mask)
} }
print("got item")
return sample return sample
...@@ -958,7 +957,7 @@ class InverseClozeDataset(data.Dataset): ...@@ -958,7 +957,7 @@ class InverseClozeDataset(data.Dataset):
doc = self.get_sentence_split_doc(doc_idx) doc = self.get_sentence_split_doc(doc_idx)
if not doc: if not doc:
doc = None doc = None
print("got doc sentences")
# set up and tokenize the entire selected document # set up and tokenize the entire selected document
num_sentences = len(doc) num_sentences = len(doc)
all_token_lists = [] all_token_lists = []
...@@ -968,39 +967,42 @@ class InverseClozeDataset(data.Dataset): ...@@ -968,39 +967,42 @@ class InverseClozeDataset(data.Dataset):
all_token_lists.append(tokens) all_token_lists.append(tokens)
all_token_type_lists.append(token_types) all_token_type_lists.append(token_types)
print("got tokenized sentences")
sentence_token_lens = [len(l) for l in all_token_lists] sentence_token_lens = [len(l) for l in all_token_lists]
inclusion_mask = [True] * num_sentences inclusion_mask = [False] * num_sentences
# select a random sentence from the document as input # select a random sentence from the document as input
input_sentence_idx = rng.randint(0, len(all_token_lists) - 1) input_sentence_idx = rng.randint(0, len(all_token_lists) - 1)
input_tokens = all_token_lists[input_sentence_idx].copy() input_tokens = all_token_lists[input_sentence_idx].copy()[:self.max_seq_len - 2]
input_token_types = all_token_type_lists[input_sentence_idx].copy() input_token_types = all_token_type_lists[input_sentence_idx].copy()[:self.max_seq_len - 2]
# 10% of the time, the input sentence is left in the context. # 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it. # The other 90% of the time, remove it.
if rng.random() > 0.1: if rng.random() < 0.1:
inclusion_mask[input_sentence_idx] = False inclusion_mask[input_sentence_idx] = True
# parameters for examining sentences to remove from the context # parameters for examining sentences to remove from the context
remove_preceding = True view_preceding = True
view_radius = 0 view_radius = 1
while sum(s for i, s in enumerate(sentence_token_lens) if inclusion_mask[i]) > target_seq_length: while sum(s for i, s in enumerate(sentence_token_lens) if inclusion_mask[i]) < self.max_seq_len - 2:
# keep removing sentences while the context is too large. # keep removing sentences while the context is too large.
if remove_preceding: if view_preceding:
if view_radius < input_sentence_idx: examine_idx = input_sentence_idx - view_radius
inclusion_mask[view_radius] = False if examine_idx >= 0:
inclusion_mask[examine_idx] = True
else:
examine_idx = input_sentence_idx + view_radius
if examine_idx < num_sentences:
inclusion_mask[examine_idx] = True
view_radius += 1 view_radius += 1
elif not remove_preceding and num_sentences - view_radius > input_sentence_idx: view_preceding = not view_preceding
inclusion_mask[num_sentences - view_radius] = False if view_radius > num_sentences:
remove_preceding = not remove_preceding break
print("got inclusion mask")
# assemble the tokens and token types of the context # assemble the tokens and token types of the context
context_tokens = list(itertools.chain( context_tokens = list(itertools.chain(
*[l for i, l in enumerate(all_token_lists) if inclusion_mask[i]])) *[l for i, l in enumerate(all_token_lists) if inclusion_mask[i]]))[:self.max_seq_len - 2]
context_token_types = list(itertools.chain( context_token_types = list(itertools.chain(
*[l for i, l in enumerate(all_token_type_lists) if inclusion_mask[i]])) *[l for i, l in enumerate(all_token_type_lists) if inclusion_mask[i]]))[:self.max_seq_len - 2]
# concatenate 'CLS' and 'SEP' tokens and add extra token types # concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens( input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(
...@@ -1008,7 +1010,6 @@ class InverseClozeDataset(data.Dataset): ...@@ -1008,7 +1010,6 @@ class InverseClozeDataset(data.Dataset):
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens( context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(
context_tokens, context_token_types) context_tokens, context_token_types)
print("got all tokens")
return (input_tokens, input_token_types, input_pad_mask), \ return (input_tokens, input_token_types, input_pad_mask), \
(context_tokens, context_token_types, context_pad_mask) (context_tokens, context_token_types, context_pad_mask)
...@@ -1018,6 +1019,7 @@ class InverseClozeDataset(data.Dataset): ...@@ -1018,6 +1019,7 @@ class InverseClozeDataset(data.Dataset):
tokens = [self.tokenizer.get_command('ENC').Id] + tokens + [self.tokenizer.get_command('sep').Id] tokens = [self.tokenizer.get_command('ENC').Id] + tokens + [self.tokenizer.get_command('sep').Id]
token_types = [token_types[0]] + token_types + [token_types[0]] token_types = [token_types[0]] + token_types + [token_types[0]]
assert len(tokens) <= self.max_seq_len
num_pad = max(0, self.max_seq_len - len(tokens)) num_pad = max(0, self.max_seq_len - len(tokens))
pad_mask = [0] * len(tokens) + [1] * num_pad pad_mask = [0] * len(tokens) + [1] * num_pad
tokens += [self.tokenizer.get_command('pad').Id] * num_pad tokens += [self.tokenizer.get_command('pad').Id] * num_pad
......
...@@ -292,13 +292,10 @@ class ICTBertModel(MegatronModule): ...@@ -292,13 +292,10 @@ class ICTBertModel(MegatronModule):
context_tokens, context_attention_mask, context_types): context_tokens, context_attention_mask, context_types):
question_ict_logits, _ = self.question_model.forward(input_tokens, input_attention_mask, input_types) question_ict_logits, _ = self.question_model.forward(input_tokens, input_attention_mask, input_types)
print("(bert ict forward) got question logits")
context_ict_logits, _ = self.context_model.forward(context_tokens, context_attention_mask, context_types) context_ict_logits, _ = self.context_model.forward(context_tokens, context_attention_mask, context_types)
print("(bert ict forward) got context logits")
# [batch x h] * [h x batch] # [batch x h] * [h x batch]
retrieval_scores = question_ict_logits.matmul(torch.transpose(context_ict_logits, 0, 1)) retrieval_scores = question_ict_logits.matmul(torch.transpose(context_ict_logits, 0, 1))
print("(bert ict forward) got retrieval scores")
return retrieval_scores return retrieval_scores
......
...@@ -253,7 +253,6 @@ def setup_model_and_optimizer(model_provider_func, args): ...@@ -253,7 +253,6 @@ def setup_model_and_optimizer(model_provider_func, args):
def backward_step(optimizer, model, loss, args, timers): def backward_step(optimizer, model, loss, args, timers):
"""Backward step.""" """Backward step."""
print("back1")
# Backward pass. # Backward pass.
optimizer.zero_grad() optimizer.zero_grad()
if args.fp16: if args.fp16:
...@@ -261,7 +260,6 @@ def backward_step(optimizer, model, loss, args, timers): ...@@ -261,7 +260,6 @@ def backward_step(optimizer, model, loss, args, timers):
else: else:
loss.backward() loss.backward()
print("back2")
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('allreduce').start() timers('allreduce').start()
...@@ -269,12 +267,10 @@ def backward_step(optimizer, model, loss, args, timers): ...@@ -269,12 +267,10 @@ def backward_step(optimizer, model, loss, args, timers):
fp32_allreduce=args.fp32_allreduce) fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop() timers('allreduce').stop()
print("back3")
# Update master gradients. # Update master gradients.
if args.fp16: if args.fp16:
optimizer.update_master_grads() optimizer.update_master_grads()
print("back4")
# Clipping gradients helps prevent the exploding gradient. # Clipping gradients helps prevent the exploding gradient.
if args.clip_grad > 0: if args.clip_grad > 0:
if not args.fp16: if not args.fp16:
...@@ -282,7 +278,6 @@ def backward_step(optimizer, model, loss, args, timers): ...@@ -282,7 +278,6 @@ def backward_step(optimizer, model, loss, args, timers):
else: else:
optimizer.clip_master_grads(args.clip_grad) optimizer.clip_master_grads(args.clip_grad)
print("back5")
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler, def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
args, timers): args, timers):
...@@ -293,21 +288,18 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler, ...@@ -293,21 +288,18 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
loss, loss_reduced = forward_step_func(data_iterator, model, args, timers) loss, loss_reduced = forward_step_func(data_iterator, model, args, timers)
timers('forward').stop() timers('forward').stop()
torch.cuda.synchronize() torch.cuda.synchronize()
print("confirm forward")
# Calculate gradients, reduce across processes, and clip. # Calculate gradients, reduce across processes, and clip.
timers('backward').start() timers('backward').start()
backward_step(optimizer, model, loss, args, timers) backward_step(optimizer, model, loss, args, timers)
timers('backward').stop() timers('backward').stop()
torch.cuda.synchronize() torch.cuda.synchronize()
print("did backward step")
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
optimizer.step() optimizer.step()
timers('optimizer').stop() timers('optimizer').stop()
torch.cuda.synchronize() torch.cuda.synchronize()
print("did optim step")
# Update learning rate. # Update learning rate.
skipped_iter = 0 skipped_iter = 0
......
...@@ -79,9 +79,6 @@ def get_batch(data_iterator, timers): ...@@ -79,9 +79,6 @@ def get_batch(data_iterator, timers):
context_types = data_b['context_types'].long() context_types = data_b['context_types'].long()
context_pad_mask = data_b['context_pad_mask'].long() context_pad_mask = data_b['context_pad_mask'].long()
global num_batches
print("got batch {}".format(num_batches))
return input_tokens, input_types, input_pad_mask,\ return input_tokens, input_types, input_pad_mask,\
context_tokens, context_types, context_pad_mask context_tokens, context_types, context_pad_mask
...@@ -98,19 +95,11 @@ def forward_step(data_iterator, model, args, timers): ...@@ -98,19 +95,11 @@ def forward_step(data_iterator, model, args, timers):
# Forward model. # Forward model.
retrieval_scores = model(input_tokens, 1 - input_pad_mask, input_types, retrieval_scores = model(input_tokens, 1 - input_pad_mask, input_types,
context_tokens, 1 - context_pad_mask, context_types) context_tokens, 1 - context_pad_mask, context_types)
print("ran model to get retrieval scores")
softmaxed = F.softmax(retrieval_scores, dim=0) softmaxed = F.softmax(retrieval_scores, dim=0)
retrieval_loss = F.cross_entropy(softmaxed, torch.arange(softmaxed.shape[0]).cuda()) retrieval_loss = F.cross_entropy(softmaxed, torch.arange(softmaxed.shape[0]).cuda())
print(type(retrieval_loss))
reduced_losses = reduce_losses([retrieval_loss]) reduced_losses = reduce_losses([retrieval_loss])
global num_batches
print("did forward step {}".format(num_batches))
num_batches += 1
print(retrieval_loss, {'retrieval loss': reduced_losses[0]})
return retrieval_loss, {'retrieval loss': reduced_losses[0]} return retrieval_loss, {'retrieval loss': reduced_losses[0]}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment