parser.py 5.6 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
17
import json
import os
from dataclasses import dataclass
shihm's avatar
uodata  
shihm committed
18
from typing import Any, Literal
chenych's avatar
chenych committed
19

chenych's avatar
chenych committed
20
from huggingface_hub import hf_hub_download
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
21
22

from ..extras.constants import DATA_CONFIG
luopl's avatar
luopl committed
23
from ..extras.misc import use_modelscope, use_openmind
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
24
25
26
27


@dataclass
class DatasetAttr:
chenych's avatar
chenych committed
28
    r"""Dataset attributes."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
29

chenych's avatar
chenych committed
30
    # basic configs
luopl's avatar
luopl committed
31
    load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
32
    dataset_name: str
shihm's avatar
uodata  
shihm committed
33
    formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
chenych's avatar
chenych committed
34
35
    ranking: bool = False
    # extra configs
shihm's avatar
uodata  
shihm committed
36
    subset: str | None = None
chenych's avatar
chenych committed
37
    split: str = "train"
shihm's avatar
uodata  
shihm committed
38
39
    folder: str | None = None
    num_samples: int | None = None
chenych's avatar
chenych committed
40
    # common columns
shihm's avatar
uodata  
shihm committed
41
42
43
44
45
    system: str | None = None
    tools: str | None = None
    images: str | None = None
    videos: str | None = None
    audios: str | None = None
chenych's avatar
chenych committed
46
    # dpo columns
shihm's avatar
uodata  
shihm committed
47
48
49
    chosen: str | None = None
    rejected: str | None = None
    kto_tag: str | None = None
chenych's avatar
chenych committed
50
    # alpaca columns
shihm's avatar
uodata  
shihm committed
51
52
53
54
    prompt: str | None = "instruction"
    query: str | None = "input"
    response: str | None = "output"
    history: str | None = None
chenych's avatar
chenych committed
55
    # sharegpt columns
shihm's avatar
uodata  
shihm committed
56
    messages: str | None = "conversations"
chenych's avatar
chenych committed
57
    # sharegpt tags
shihm's avatar
uodata  
shihm committed
58
59
60
61
62
63
64
    role_tag: str | None = "from"
    content_tag: str | None = "value"
    user_tag: str | None = "human"
    assistant_tag: str | None = "gpt"
    observation_tag: str | None = "observation"
    function_tag: str | None = "function_call"
    system_tag: str | None = "system"
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
65
66
67
68

    def __repr__(self) -> str:
        return self.dataset_name

shihm's avatar
uodata  
shihm committed
69
    def set_attr(self, key: str, obj: dict[str, Any], default: Any | None = None) -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
70
71
        setattr(self, key, obj.get(key, default))

chenych's avatar
chenych committed
72
    def join(self, attr: dict[str, Any]) -> None:
chenych's avatar
chenych committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        self.set_attr("formatting", attr, default="alpaca")
        self.set_attr("ranking", attr, default=False)
        self.set_attr("subset", attr)
        self.set_attr("split", attr, default="train")
        self.set_attr("folder", attr)
        self.set_attr("num_samples", attr)

        if "columns" in attr:
            column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
            column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"]
            for column_name in column_names:
                self.set_attr(column_name, attr["columns"])

        if "tags" in attr:
            tag_names = ["role_tag", "content_tag"]
            tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
            for tag in tag_names:
                self.set_attr(tag, attr["tags"])

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
92

shihm's avatar
uodata  
shihm committed
93
def get_dataset_list(dataset_names: list[str] | None, dataset_dir: str | dict) -> list["DatasetAttr"]:
chenych's avatar
chenych committed
94
    r"""Get the attributes of the datasets."""
chenych's avatar
chenych committed
95
    if dataset_names is None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
96
97
        dataset_names = []

shihm's avatar
uodata  
shihm committed
98
99
100
    if isinstance(dataset_dir, dict):
        dataset_info = dataset_dir
    elif dataset_dir == "ONLINE":
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
101
102
        dataset_info = None
    else:
chenych's avatar
chenych committed
103
        if dataset_dir.startswith("REMOTE:"):
chenych's avatar
chenych committed
104
            config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
chenych's avatar
chenych committed
105
106
107
        else:
            config_path = os.path.join(dataset_dir, DATA_CONFIG)

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
108
        try:
luopl's avatar
luopl committed
109
            with open(config_path) as f:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
110
111
112
                dataset_info = json.load(f)
        except Exception as err:
            if len(dataset_names) != 0:
luopl's avatar
luopl committed
113
                raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
114

chenych's avatar
chenych committed
115
            dataset_info = None
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
116

chenych's avatar
chenych committed
117
    dataset_list: list[DatasetAttr] = []
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
118
    for name in dataset_names:
chenych's avatar
chenych committed
119
        if dataset_info is None:  # dataset_dir is ONLINE
chenych's avatar
chenych committed
120
            load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub"
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
121
122
123
124
125
            dataset_attr = DatasetAttr(load_from, dataset_name=name)
            dataset_list.append(dataset_attr)
            continue

        if name not in dataset_info:
luopl's avatar
luopl committed
126
            raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
127
128
129

        has_hf_url = "hf_hub_url" in dataset_info[name]
        has_ms_url = "ms_hub_url" in dataset_info[name]
luopl's avatar
luopl committed
130
        has_om_url = "om_hub_url" in dataset_info[name]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
131

luopl's avatar
luopl committed
132
133
        if has_hf_url or has_ms_url or has_om_url:
            if has_ms_url and (use_modelscope() or not has_hf_url):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
134
                dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
luopl's avatar
luopl committed
135
136
            elif has_om_url and (use_openmind() or not has_hf_url):
                dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
137
138
139
140
            else:
                dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
        elif "script_url" in dataset_info[name]:
            dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
chenych's avatar
chenych committed
141
142
        elif "cloud_file_name" in dataset_info[name]:
            dataset_attr = DatasetAttr("cloud_file", dataset_name=dataset_info[name]["cloud_file_name"])
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
143
144
145
        else:
            dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])

chenych's avatar
chenych committed
146
        dataset_attr.join(dataset_info[name])
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
147
148
149
        dataset_list.append(dataset_attr)

    return dataset_list