preview.py 5.16 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
import shutil
import traceback
from typing import Any

import click
import torch

from megatron.energon import Sample, StandardWebdatasetFactory
from megatron.energon.dataset_config import load_config
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.empty_dataset_error import EmptyDatasetError
from megatron.energon.loader import get_loader
from megatron.energon.worker import WorkerConfig


def to_str(value: Any, indent: str) -> str:
    if isinstance(value, torch.Tensor):
        orig_value = value
        # Probably image?
        if value.ndim == 3 and value.shape[0] in [1, 3, 4]:
            # Convert to grayscale
            if value.shape[0] == 1:
                value = value[0]
            elif value.shape[0] == 3:
                value = value.to(dtype=torch.float32).mean(dim=0)
            elif value.shape[0] == 4:
                value = value[:3].to(dtype=torch.float32).mean(dim=0)
        if value.ndim == 2:
            # 2d image -> ascii print
            # Resize to fit terminal
            dst_w, dst_h = shutil.get_terminal_size((80, 24))
            orig_h, orig_w = value.shape
            dst_w -= len(indent)
            procrustes = 0.3
            # keep aspect ratio
            if orig_w / orig_h < dst_w / dst_h:
                dst_h = int(dst_w * procrustes * orig_h / orig_w)
            else:
                dst_w = int(dst_h / procrustes * orig_w / orig_h)
            value = torch.nn.functional.interpolate(
                value[None, None, :, :].to(dtype=torch.float32), size=(dst_h, dst_w), mode="area"
            )[0, 0]
            # normalize
            value = (value - value.min()) / (value.max() - value.min())
            # to ascii text
            return (
                f"Tensor(shape={orig_value.shape}, dtype={orig_value.dtype}):\n{indent}"
                + f"\n{indent}".join(
                    "".join(" .:-=+*#%@@"[int(v * 10)] for v in row) for row in value.tolist()
                )
                + "\n"
            )
        elif value.ndim == 1:
            # 1d array... print it?
            return f"Tensor(shape={value.shape}, dtype={value.dtype}): {value[:128].tolist()}"
        else:
            return f"Tensor(shape={value.shape}, dtype={value.dtype})"
    elif isinstance(value, (str, int, float, bool, type(None))):
        return repr(value)
    elif isinstance(value, (list, tuple)):
        if hasattr(value, "_fields"):
            return (
                f"{type(value).__name__}(\n{indent}"
                + f",\n{indent}  ".join(
                    f"{field.name}={to_str(value, indent + '    ')}"
                    for value, field in zip(value, value._fields)
                )
                + f"\n{indent})"
            )
        if len(value) > 0 and isinstance(value, (str, int, float, bool)):
            return repr(type(value)(to_str(v, indent) for v in value))
        else:
            return (
                f"[\n{indent}"
                + f"\n{indent}  ".join(to_str(v, indent + "    ") for v in value)
                + f"\n{indent}]"
            )
    elif isinstance(value, bytes):
        return f"bytes(length={len(value)}, value={value[:128]!r})"
    return repr(value)


def pprint(idx: int, sample: Sample):
    click.echo(f"Sample {idx}")
    for field in dataclasses.fields(sample):
        if field.name in ("__restore_key__", "__subflavors__", "__sources__"):
            continue
        click.echo(f" - {field.name} ({field.type}): {to_str(getattr(sample, field.name), '')}")


@click.command(name="preview")
@click.argument(
    "path",
    type=click.Path(file_okay=False, dir_okay=True, path_type=EPath),
)
@click.option(
    "--split-parts", default="train,val,test", help="The splits to verify", show_default=True
)
@click.option(
    "--dataset-config", default="dataset.yaml", help="Dataset config file name", show_default=True
)
def command(path: EPath, split_parts: str, dataset_config: str):
    """Preview samples of a dataset on the console."""

    worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0)

    for split_part in split_parts.split(","):
        try:
            dataset = load_config(
                EPath(path) / MAIN_FOLDER_NAME / dataset_config,
                default_kwargs=dict(
                    path=path,
                    split_part=split_part,
                    training=False,
                    worker_config=worker_config,
                ),
                default_type=StandardWebdatasetFactory,
            )
        except EmptyDatasetError:
            click.echo(f"Dataset {split_part} is empty. Skipping.")
            continue

        try:
            for idx, sample in enumerate(get_loader(dataset.build())):
                pprint(idx, sample)
                click.confirm("Continue?", abort=True)
        except click.Abort:
            click.echo("Exiting Preview")
        except BaseException:
            traceback.print_exc()
            raise click.ClickException("Validation failed with errors, see logs for details.")


if __name__ == "__main__":
    command()