loader.py 2.76 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from pathlib import Path
from typing import Any, Union

from megatron.energon.dataset_config import load_config
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.webdataset.metadata import EnergonDatasetType, get_dataset_type
from megatron.energon.metadataset.dataset_loader import DatasetLoader
from megatron.energon.metadataset.loader_interface import DatasetLoaderInterface
from megatron.energon.metadataset.metadataset import Metadataset
from megatron.energon.typed_converter import JsonParser


def load_dataset(
    path: Union[str, EPath, Path],
    **kwargs,
) -> DatasetLoaderInterface:
    """Loads a (meta)dataset."""

    if isinstance(path, dict):
        mds = load_config(
            path,
            default_type=Metadataset,
            default_kwargs=dict(path=EPath("/dict"), **kwargs),
        )
        return mds
    path = EPath(path)
    ds_type = get_dataset_type(path)
    if ds_type == EnergonDatasetType.METADATASET:
        mds = load_config(
            path,
            default_type=Metadataset,
            default_kwargs=dict(path=path, **kwargs),
        )
        mds.post_initialize()
        return mds
    elif ds_type in (EnergonDatasetType.WEBDATASET, EnergonDatasetType.JSONL):
        ds = DatasetLoader(path=path, **kwargs)
        ds.post_initialize()
        return ds
    else:
        raise ValueError(f"Invalid dataset at {path}")


class MockJsonParser(JsonParser):
    """Json Parser, which translates unknown objects to a mock class."""

    def _resolve_object(
        self,
        module_name: str,
        object_name: str,
        cls: type,
        is_type: bool,
        is_callable: bool,
        is_instantiating_class: bool,
        is_calling_function: bool,
    ) -> Any:
        try:
            return super()._resolve_object(
                module_name,
                object_name,
                cls,
                is_type,
                is_callable,
                is_instantiating_class,
                is_calling_function,
            )
        except ModuleNotFoundError:

            class MockObject(cls):
                def __init__(self, *_, **__):
                    pass

            if is_type or is_instantiating_class:
                return MockObject
            elif is_callable or is_calling_function:
                return MockObject


def prepare_metadataset(path: EPath):
    from megatron.energon.dataset_config import load_config
    from megatron.energon.metadataset.metadataset import Metadataset

    meta_ds = load_config(
        path,
        default_type=Metadataset,
        default_kwargs=dict(path=path),
        parser=MockJsonParser(strict=True),
    )
    meta_ds.post_initialize()

    meta_ds.prepare()