Commit 3dc39c4a authored by Neel Kant's avatar Neel Kant
Browse files

Correct title index in helpers and complete the ICDataset functionality

parent 064a6881
...@@ -83,7 +83,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -83,7 +83,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Build the dataset accordingly. # Build the dataset accordingly.
kwargs = dict( kwargs = dict(
name=name, name=name,
indexed_dataset=indexed_dataset, context_dataset=indexed_dataset,
data_prefix=data_prefix, data_prefix=data_prefix,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
...@@ -93,10 +93,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -93,10 +93,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
) )
if ict_dataset: if ict_dataset:
titles_idx_ptr = titles_dataset.get_doc_idx()
titles_dataset.set_doc_idx(titles_idx_ptr[start_index:end_index])
dataset = InverseClozeDataset(titles_dataset=titles_dataset, **kwargs) dataset = InverseClozeDataset(titles_dataset=titles_dataset, **kwargs)
titles_dataset.set_doc_idx(titles_idx_ptr)
else: else:
dataset = BertDataset(masked_lm_prob=masked_lm_prob, **kwargs) dataset = BertDataset(masked_lm_prob=masked_lm_prob, **kwargs)
# Set the original pointer so dataset remains the main dataset. # Set the original pointer so dataset remains the main dataset.
......
...@@ -428,7 +428,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -428,7 +428,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
const auto map_index_0 = 3 * map_index; const auto map_index_0 = 3 * map_index;
maps[map_index_0] = static_cast<DocIdx>(prev_start_index); maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1); maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len); maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
} }
// Update indices / counters. // Update indices / counters.
......
import itertools
import random import random
import os import os
import sys
import time import time
import numpy as np import numpy as np
...@@ -45,19 +47,26 @@ class InverseClozeDataset(Dataset): ...@@ -45,19 +47,26 @@ class InverseClozeDataset(Dataset):
return self.samples_mapping.shape[0] return self.samples_mapping.shape[0]
def __getitem__(self, idx): def __getitem__(self, idx):
start_index, end_index, _ = self.samples_mapping[idx] start_index, end_index, doc_index = self.samples_mapping[idx]
context = [self.indexed_dataset[i] for i in range(start_index, end_index)] context = [list(self.context_dataset[i]) for i in range(start_index, end_index)]
assert len(context) > 1 assert len(context) > 1
title = self.titles_dataset[idx] title = list(self.titles_dataset[int(doc_index)])
assert sum(len(c) for c in context) + len(title) <= self.max_seq_length - 3 full_sum = sum(len(c) for c in context) + len(title)
if len(context) == 2:
rand_sent_idx = int(self.rng.random() > 0.5)
else:
rand_sent_idx = self.rng.randint(1, len(context) - 2)
rand_sent_idx = self.rng.randint(0, len(context) - 1)
if self.rng.random() < 0.1: if self.rng.random() < 0.1:
input = list(context[rand_sent_idx]) input = list(context[rand_sent_idx])
else: else:
input = context.pop(rand_sent_idx) input = context.pop(rand_sent_idx)
input = input[:self.max_seq_length - 2]
context = list(itertools.chain(*context))[:self.max_seq_length - (3 + len(title))]
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(input) input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(input)
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(context, title) context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(context, title)
...@@ -77,7 +86,7 @@ class InverseClozeDataset(Dataset): ...@@ -77,7 +86,7 @@ class InverseClozeDataset(Dataset):
tokens = [self.cls_id] + tokens + [self.sep_id] tokens = [self.cls_id] + tokens + [self.sep_id]
if title is not None: if title is not None:
tokens += title + [self.sep_id] tokens += title + [self.sep_id]
assert len(tokens) <= self.max_seq_length assert len(tokens) <= self.max_seq_length, len(tokens)
num_pad = self.max_seq_length - len(tokens) num_pad = self.max_seq_length - len(tokens)
pad_mask = [0] * len(tokens) + [1] * num_pad pad_mask = [0] * len(tokens) + [1] * num_pad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment