datasets.py 5.18 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2
3
4
5
6
7
8
9
10

"""Zero-shot datasets."""

import json
import math

import numpy as np
import torch

xingjinliang's avatar
xingjinliang committed
11
12
13
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_tokenizer
Mohammad's avatar
Mohammad committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from .detokenizer import get_detokenizer


def build_dataset(task):
    """Helper function to select and build dataset."""

    if task == 'LAMBADA':
        return _build_lambada_dataset()
    if task == 'WIKITEXT103':
        return _build_wikitext103_dataset()

    raise NotImplementedError('dataset for {} task is not '
                              'implemented.'.format(task))


class _LMDataset(torch.utils.data.Dataset):

    def __init__(self, tokens, seq_len, pad_idx, num_original_tokens,
                 num_tokenized_tokens, overalapping_eval=None):
        self.tokens = tokens
        self.seq_len = seq_len
        self.pad_idx = pad_idx
        self.overalapping_eval = overalapping_eval
        if self.overalapping_eval is None:
            self.overalapping_eval = self.seq_len
        self.overalapping_eval = max(1, self.overalapping_eval)
        self.num_original_tokens = num_original_tokens
        self.num_tokenized_tokens = num_tokenized_tokens
        self.total_targets = len(self.tokens) - 1
        # remove first sequence tokens
        targets = max(self.total_targets - self.overalapping_eval, 0)
        self.total_sequences = max(
            math.ceil(targets / self.overalapping_eval) + 1, 1)

    def __len__(self):
        return self.total_sequences

    def __getitem__(self, idx):
        start_idx = idx * self.overalapping_eval
        end_idx = start_idx + self.seq_len
Neel Kant's avatar
Neel Kant committed
54
        tokens = self.tokens[start_idx:end_idx + 1]
Mohammad's avatar
Mohammad committed
55
        num_tokens = len(tokens)
Neel Kant's avatar
Neel Kant committed
56
57
58
59
        pad_mask = [1] * num_tokens
        if num_tokens < self.seq_len + 1:
            num_pad = (self.seq_len + 1 - num_tokens)
            pad_mask += [0] * (num_pad)
Mohammad's avatar
Mohammad committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
            tokens += [self.pad_idx] * num_pad
        pad_mask = np.array(pad_mask[1:])
        if self.overalapping_eval != self.seq_len and idx != 0:
            pad_mask[:-self.overalapping_eval] *= 0

        return {'text': np.array(tokens), 'pad_mask': pad_mask}


class _LambadaDataset(torch.utils.data.Dataset):

    def __init__(self, path, pad_idx, tokenizer, seq_len, strict=False):
        print_rank_0('> building lambada dataset from {} ...'.format(path))
        self.seq_len = seq_len
        self.pad_idx = pad_idx
        self.tokenizer = tokenizer
        self.strict = strict

        self.tokens = []
        self.labels = []
        with open(path, 'r') as f:
            for line in f.readlines():
                text = json.loads(line)['text']
                tokens, labels = self.get_tokens(text)
                self.tokens.append(tokens)
                self.labels.append(labels)

    def get_tokens(self, text):
        if not self.strict:
            tokens = self.tokenizer.tokenize(text)
            return tokens[:-1], [tokens[-1]]
        last_token = text.split()[-1]
        start_idx = text.rfind(last_token)
        beginning_tokens = self.tokenizer.tokenize(text[:start_idx].strip())
Neel Kant's avatar
Neel Kant committed
93
        last_token = self.tokenizer.tokenize(' ' + last_token)
Mohammad's avatar
Mohammad committed
94
95
96
97
98
99
100
101
        return beginning_tokens, last_token

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        tokens = self.tokens[idx]
        num_tokens = len(tokens)
Neel Kant's avatar
Neel Kant committed
102
        pad_mask = [0] * num_tokens
Mohammad's avatar
Mohammad committed
103
        labels = self.labels[idx]
Neel Kant's avatar
Neel Kant committed
104
105
        pad_mask += [1] * len(labels)
        tokens = tokens + labels
Mohammad's avatar
Mohammad committed
106
        num_tokens = len(tokens)
Neel Kant's avatar
Neel Kant committed
107
108
109
        if num_tokens < self.seq_len + 1:
            num_pad = (self.seq_len + 1 - num_tokens)
            pad_mask += [0] * (num_pad)
Mohammad's avatar
Mohammad committed
110
111
112
113
114
115
116
117
118
119
120
121
            tokens += [self.pad_idx] * num_pad
        pad_mask = np.array(pad_mask[1:])

        return {'text': np.array(tokens), 'pad_mask': pad_mask}


def _build_lambada_dataset():
    """Build lambada dataset."""
    args = get_args()
    tokenizer = get_tokenizer()

    assert len(args.valid_data) == 1
Raul Puri's avatar
Raul Puri committed
122
    val_dataset = _LambadaDataset(args.valid_data[0], tokenizer.eod, tokenizer,
Mohammad's avatar
Mohammad committed
123
124
125
126
127
128
129
130
131
132
133
134
                                  args.seq_length, args.strict_lambada)
    print_rank_0(' > found {} samples.'.format(len(val_dataset)))

    return val_dataset


def _build_wikitext103_dataset():
    """"""
    args = get_args()
    tokenizer = get_tokenizer()

    assert len(args.valid_data) == 1
Raul Puri's avatar
Raul Puri committed
135
    with open(args.valid_data[0], "rb") as reader:
Mohammad's avatar
Mohammad committed
136
137
        entire_data = reader.read().decode('utf-8')
    num_original_tokens = len(entire_data.strip().split(" "))
Raul Puri's avatar
Raul Puri committed
138
    entire_data = get_detokenizer(args.valid_data[0])(entire_data)
Mohammad's avatar
Mohammad committed
139
140
141
142
143
144
145
146
147
148
    tokenized_data = tokenizer.tokenize(entire_data)
    num_tokenized_tokens = len(tokenized_data)

    val_dataset = _LMDataset(tokenized_data, args.seq_length, tokenizer.eod,
                             num_original_tokens, num_tokenized_tokens,
                             args.overlapping_eval)
    print_rank_0(' > number of original tokens: {}, number of detokenized '
                 'tokens: {}'.format(num_original_tokens, num_tokenized_tokens))

    return val_dataset