ict_qualitative_test.py 2.85 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import torch
import torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.ict_dataset import InverseClozeDataset
from megatron.initialize import initialize_megatron
from megatron.training import get_model
from pretrain_bert_ict import model_provider


def main():
    initialize_megatron(extra_args_provider=None,
                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
    args = get_args()
    model = load_checkpoint()
    dataset = get_dataset()

    num_docs = 100
    all_doc_logits = np.zeros(num_docs, 128)
    for i in range(num_docs):
        doc_tokens = []
        doc_token_lists = dataset.get_sentence_split_doc(i)
        ptr = 0
        while len(doc_tokens) < args.seq_length and ptr < len(doc_token_lists):
            doc_tokens.extend(doc_token_lists[ptr])

        doc_tokens, doc_token_types, doc_pad_mask = dataset.concat_and_pad_tokens(doc_tokens)
        doc_logits = model.embed_doc(np.array(doc_tokens), np.array(doc_pad_mask), np.array(doc_token_types))
        all_doc_logits[i] = doc_logits

    print(all_doc_logits, flush=True)


def load_checkpoint():
    args = get_args()
    model = get_model(model_provider)

    if isinstance(model, torchDDP):
        model = model.module
    tracker_filename = get_checkpoint_tracker_filename(args.load)
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

    assert iteration > 0
    checkpoint_name = get_checkpoint_name(args.load, iteration, False)
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    state_dict = torch.load(checkpoint_name, map_location='cpu')
    model.load_state_dict(state_dict['model'])
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

    return model


def load_doc_embeds(path):
    pass


def get_dataset():
    args = get_args()
    indexed_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)

    doc_idx_ptr = indexed_dataset.get_doc_idx()
    total_num_documents = indexed_dataset.doc_idx.shape[0] - 1
    indexed_dataset.set_doc_idx(doc_idx_ptr[0:total_num_documents])
    kwargs = dict(
        name='full',
        indexed_dataset=indexed_dataset,
        data_prefix=args.data_path,
        num_epochs=None,
        max_num_samples=total_num_documents,
        max_seq_length=288,  # doesn't matter
        short_seq_prob=0.01,  # doesn't matter
        seed=1
    )
    dataset = InverseClozeDataset(**kwargs)
    return dataset


if __name__ == "__main__":
    main()