dataset.py 1.96 KB
Newer Older
Lengyue's avatar
Lengyue committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import random
from transformers import AutoTokenizer
from datasets import load_dataset, interleave_datasets, IterableDataset
from functools import lru_cache
from torch.utils.data import DataLoader
from datasets.distributed import split_dataset_by_node
from torch.distributed import get_rank, get_world_size, is_initialized


@lru_cache(maxsize=1)
def get_tokenizer():
    return AutoTokenizer.from_pretrained("fishaudio/speech-lm-300m", revision="init")

def encode(examples):
    # Random choice a 512 token window for each example
    texts = []
    for text in examples["text"]:
        if len(text) <= 512:
            texts.append(text)
        else:
            start = random.randint(0, len(text) - 512)
            texts.append(text[start : start + 512])
    
    data = get_tokenizer()(
        texts, 
        truncation=True, 
        padding="max_length",
        max_length=512,
    )
    data["labels"] = data["input_ids"].copy()
    data["labels"][data["attention_mask"] == 0] = -100

    return data


def build_dataset():
    en_dataset = load_dataset("uonlp/CulturaX", "en", split="train", streaming=True)
    ja_dataset = load_dataset("uonlp/CulturaX", "ja", split="train", streaming=True)
    zh_dataset = load_dataset("uonlp/CulturaX", "zh", split="train", streaming=True)

    multilingual_dataset: IterableDataset = interleave_datasets(
        [en_dataset, ja_dataset, zh_dataset], probabilities=[0.4, 0.3, 0.3], seed=42
    )

    # DDP
    if is_initialized():
        multilingual_dataset = split_dataset_by_node(
            multilingual_dataset,
            rank=get_rank(),
            num_replicas=get_world_size(),
        )

    multilingual_dataset = multilingual_dataset.shuffle(seed=42, buffer_size=10000)

    multilingual_dataset = multilingual_dataset.map(
        encode, batched=True, remove_columns=multilingual_dataset.column_names
    )

    return multilingual_dataset


if __name__ == "__main__":
    dataset = build_dataset()
    print(list(dataset.take(16)))