utils.py 2.16 KB
Newer Older
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

import os
import torch
import types

from megatron import get_retro_args
from megatron.tokenizer.tokenizer import (
    _BertWordPieceTokenizer,
    _GPT2BPETokenizer,
liangjing's avatar
v1  
liangjing committed
11
    _GPTSentencePieceTokenizer,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
)


def get_args_path(workdir):
    '''Argument copy stored within retro workdir.'''
    return os.path.join(workdir, "args.json")


def get_num_chunks_per_sample():
    '''Compute seq_length // chunk_length.'''
    args = get_retro_args()
    sample_length = args.retro_gpt_seq_length
    chunk_length = args.retro_gpt_chunk_length
    assert sample_length % chunk_length == 0
    return sample_length // chunk_length


def get_gpt_tokenizer():
    '''GPT (BPE) tokenizer.'''
    args = get_retro_args()
liangjing's avatar
v1  
liangjing committed
32
33
34
35
36
37
38
39
40
41
42
43
    tokenizer_type = args.retro_gpt_tokenizer_type
    if tokenizer_type == "GPT2BPETokenizer":
        assert args.retro_gpt_vocab_file and args.retro_gpt_merge_file
        return _GPT2BPETokenizer(
            vocab_file=args.retro_gpt_vocab_file,
            merge_file=args.retro_gpt_merge_file,
        )
    elif tokenizer_type == 'GPTSentencePieceTokenizer':
        assert args.retro_gpt_tokenizer_model is not None
        return _GPTSentencePieceTokenizer(args.retro_gpt_tokenizer_model)
    else:
        raise Exception("unrecognized gpt tokenizer, '%s'." % tokenizer_type)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75


def get_bert_tokenizer():
    '''Bert (Wordpiece) tokenizer.'''
    args = get_retro_args()
    lower_case = {
        "BertWordPieceLowerCase" : True,
        "BertWordPieceCase" : False,
    }[args.retro_bert_tokenizer_type]
    return _BertWordPieceTokenizer(
        vocab_file=args.retro_bert_vocab_file,
        lower_case=lower_case,
    )


class GPTToTextDataset(torch.utils.data.Dataset):
    '''Dataset to convert GPT tokens to text.'''

    def __init__(self, gpt_dataset):

        super().__init__()

        self.gpt_dataset = gpt_dataset
        self.gpt_tokenizer = get_gpt_tokenizer()

    def __len__(self):
        return len(self.gpt_dataset)

    def __getitem__(self, idx):
        gpt_token_ids = self.gpt_dataset[idx]["text"].tolist()
        text = self.gpt_tokenizer.detokenize(gpt_token_ids)
        return {"text": text}