parser.py 5.53 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
17
import json
import os
from dataclasses import dataclass
chenych's avatar
chenych committed
18
from typing import Any, Literal, Optional
chenych's avatar
chenych committed
19

chenych's avatar
chenych committed
20
from huggingface_hub import hf_hub_download
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
21
22

from ..extras.constants import DATA_CONFIG
luopl's avatar
luopl committed
23
from ..extras.misc import use_modelscope, use_openmind
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
24
25
26
27


@dataclass
class DatasetAttr:
chenych's avatar
chenych committed
28
    r"""Dataset attributes."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
29

chenych's avatar
chenych committed
30
    # basic configs
luopl's avatar
luopl committed
31
    load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
32
    dataset_name: str
chenych's avatar
chenych committed
33
34
35
    formatting: Literal["alpaca", "sharegpt"] = "alpaca"
    ranking: bool = False
    # extra configs
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
36
    subset: Optional[str] = None
chenych's avatar
chenych committed
37
    split: str = "train"
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
38
    folder: Optional[str] = None
chenych's avatar
chenych committed
39
40
    num_samples: Optional[int] = None
    # common columns
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
41
    system: Optional[str] = None
chenych's avatar
chenych committed
42
43
    tools: Optional[str] = None
    images: Optional[str] = None
luopl's avatar
luopl committed
44
    videos: Optional[str] = None
chenych's avatar
chenych committed
45
46
    audios: Optional[str] = None
    # dpo columns
chenych's avatar
chenych committed
47
48
49
50
    chosen: Optional[str] = None
    rejected: Optional[str] = None
    kto_tag: Optional[str] = None
    # alpaca columns
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
51
52
53
54
    prompt: Optional[str] = "instruction"
    query: Optional[str] = "input"
    response: Optional[str] = "output"
    history: Optional[str] = None
chenych's avatar
chenych committed
55
    # sharegpt columns
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
56
    messages: Optional[str] = "conversations"
chenych's avatar
chenych committed
57
    # sharegpt tags
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
58
59
60
61
62
63
64
65
66
67
68
    role_tag: Optional[str] = "from"
    content_tag: Optional[str] = "value"
    user_tag: Optional[str] = "human"
    assistant_tag: Optional[str] = "gpt"
    observation_tag: Optional[str] = "observation"
    function_tag: Optional[str] = "function_call"
    system_tag: Optional[str] = "system"

    def __repr__(self) -> str:
        return self.dataset_name

chenych's avatar
chenych committed
69
    def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
70
71
        setattr(self, key, obj.get(key, default))

chenych's avatar
chenych committed
72
    def join(self, attr: dict[str, Any]) -> None:
chenych's avatar
chenych committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        self.set_attr("formatting", attr, default="alpaca")
        self.set_attr("ranking", attr, default=False)
        self.set_attr("subset", attr)
        self.set_attr("split", attr, default="train")
        self.set_attr("folder", attr)
        self.set_attr("num_samples", attr)

        if "columns" in attr:
            column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
            column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"]
            for column_name in column_names:
                self.set_attr(column_name, attr["columns"])

        if "tags" in attr:
            tag_names = ["role_tag", "content_tag"]
            tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
            for tag in tag_names:
                self.set_attr(tag, attr["tags"])

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
92

chenych's avatar
chenych committed
93
94
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> list["DatasetAttr"]:
    r"""Get the attributes of the datasets."""
chenych's avatar
chenych committed
95
    if dataset_names is None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
96
97
        dataset_names = []

chenych's avatar
chenych committed
98
    if dataset_dir == "ONLINE":
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
99
100
        dataset_info = None
    else:
chenych's avatar
chenych committed
101
        if dataset_dir.startswith("REMOTE:"):
chenych's avatar
chenych committed
102
            config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
chenych's avatar
chenych committed
103
104
105
        else:
            config_path = os.path.join(dataset_dir, DATA_CONFIG)

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
106
        try:
luopl's avatar
luopl committed
107
            with open(config_path) as f:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
108
109
110
                dataset_info = json.load(f)
        except Exception as err:
            if len(dataset_names) != 0:
luopl's avatar
luopl committed
111
                raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
112

chenych's avatar
chenych committed
113
            dataset_info = None
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
114

chenych's avatar
chenych committed
115
    dataset_list: list[DatasetAttr] = []
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
116
    for name in dataset_names:
chenych's avatar
chenych committed
117
        if dataset_info is None:  # dataset_dir is ONLINE
luopl's avatar
luopl committed
118
119
120
121
122
123
            if use_modelscope():
                load_from = "ms_hub"
            elif use_openmind():
                load_from = "om_hub"
            else:
                load_from = "hf_hub"
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
124
125
126
127
128
            dataset_attr = DatasetAttr(load_from, dataset_name=name)
            dataset_list.append(dataset_attr)
            continue

        if name not in dataset_info:
luopl's avatar
luopl committed
129
            raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
130
131
132

        has_hf_url = "hf_hub_url" in dataset_info[name]
        has_ms_url = "ms_hub_url" in dataset_info[name]
luopl's avatar
luopl committed
133
        has_om_url = "om_hub_url" in dataset_info[name]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
134

luopl's avatar
luopl committed
135
136
        if has_hf_url or has_ms_url or has_om_url:
            if has_ms_url and (use_modelscope() or not has_hf_url):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
137
                dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
luopl's avatar
luopl committed
138
139
            elif has_om_url and (use_openmind() or not has_hf_url):
                dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
140
141
142
143
144
145
146
            else:
                dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
        elif "script_url" in dataset_info[name]:
            dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
        else:
            dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])

chenych's avatar
chenych committed
147
        dataset_attr.join(dataset_info[name])
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
148
149
150
        dataset_list.append(dataset_attr)

    return dataset_list