distilabel_example_axolotl.py 4.62 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts, LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
from distilabel.models import OpenAILLM


'''
https://distilabel.argilla.io/latest/api/models/llm/llm_gallery/#distilabel.models.llms.OpenAILLM
https://distilabel.argilla.io/latest/components-gallery/steps/loaddatafromhub/?h=loaddatafromhub#input-output-columns
https://distilabel.argilla.io/latest/sections/how_to_guides/advanced/serving_an_llm_for_reuse/#serving-llms-using-vllm
'''

from transformers import AutoTokenizer

def extract_token_ids_and_logprobs(response, tokenizer_name="Qwen/Qwen3-4B"):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    result = []

    if response == [[]]:
        return [result]


    for item in response[0]:
        encoded = tokenizer.encode(item["token"], add_special_tokens=False)
        token_id = encoded[0] if len(encoded) > 0 else None

        if token_id is None:
            continue

        result.append({
            "logprob": item["logprob"],
            "token": f"token_id:{token_id}"
        })

    return [result]

# 配置你的 vLLM 服务地址和模型名
VLLM_BASE_URL = "http://x.x.x.x/v1"  # vLLM 提供的OpenAI兼容接口,推理是使用机器ip。
MODEL_NAME = "Qwen/Qwen3-4B"  # 模型名称(需与 vLLM 启动时一致)

with Pipeline() as pipeline:
    # Step 1: 从 HuggingFace Hub 加载数据集
    '''
    load_data = LoadDataFromDicts(
        data=[{"prompt": "Write a poem about the sun and moon."}]
    )
    '''
    load_data = LoadDataFromHub(
        output_mappings={"instruction": "instruction"}  # 将 'instruction' 映射为 'prompt'
    )
    '''
    load_data = LoadDataFromHub(
        repo_id="tatsu-lab/alpaca",
        split="train",
        batch_size=8,
        output_mappings={"instruction": "instruction"}
    )
    load_data.load()
    result = next(load_data.process())
    print(result)
    '''

    # Step 2: 设置 LLM(连接到 vLLM 服务)
    llm = OpenAILLM(
        model=MODEL_NAME,
        base_url=VLLM_BASE_URL,
        api_key="EMPTY",  # 本地部署无需真实 key
        max_retries=5,
        generation_kwargs={
            "temperature": 2.0,
            "max_new_tokens": 4096,
            "top_p": 0.9,
            "logprobs": True,       # 如果你需要 logprobs
            "top_logprobs": 20,     # 控制返回的 top logprob 数量
            # "skip_special_tokens": True,
        }
    )

    # Step 3: 文本生成任务(用于生成 label)
    text_generation = TextGeneration(
        name="text_generation",
        llm=llm,
        input_batch_size=1,  # 可根据显存调整
        # input_mappings={"instruction": "instruction"},
    )

    # Step 流程定义
    load_data >> text_generation

    '''
    # 查看vllm推理结果
    llm.load()
    output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Give three tips for staying healthy"}]])
    print(output)
    '''

if __name__ == "__main__":
    # 运行 pipeline
    distiset = pipeline.run(
        parameters={
            load_data.name: {
                "repo_id": "distilabel-internal-testing/instructions",  # 替换为你自己的数据集
                "split": "test",
                "batch_size": 1,
            },
        }
    )

    # 保存结果(可选)
    distiset.save_to_disk("distilabel-internal-testing/distilabel-example-instructions")
    # distiset.push_to_hub(repo_id="sankexin/distilabel-example-instructions-Qwen3-8B", token="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")
    # print(distiset)
    import pandas as pd
    for step_name in list(distiset.keys()):
        formatted_data = []
        step_output = distiset[step_name]["train"]
        formatted_data = []
        for item in step_output:
            item["llm_text_generation_logprobs"] = extract_token_ids_and_logprobs(item.get("llm_text_generation_logprobs", ""))
            messages_combined = [
                {"role": "user", "content": item.get("instruction", "")},
                {"role": "assistant", "content": item.get("generation", "")}
                ]
            item["messages"] = {"role": "user", "content": item.get("instruction", "")}
            item["messages_combined"] =  messages_combined
            formatted_data.append(item)

        df = pd.DataFrame(formatted_data)
        # print(df)
        df.to_parquet(f"distilabel-internal-testing/{step_name}.parquet", engine="pyarrow")
        print(f"✅ {step_name}.parquet 已保存")


    '''
    from datasets import load_dataset
    ds = load_dataset("distilabel-internal-testing/distilabel-example-instructions")
    print(ds['train'][0])
    '''