from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts, LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
from distilabel.models import OpenAILLM


'''
https://distilabel.argilla.io/latest/api/models/llm/llm_gallery/#distilabel.models.llms.OpenAILLM
https://distilabel.argilla.io/latest/components-gallery/steps/loaddatafromhub/?h=loaddatafromhub#input-output-columns
https://distilabel.argilla.io/latest/sections/how_to_guides/advanced/serving_an_llm_for_reuse/#serving-llms-using-vllm
'''

from transformers import AutoTokenizer

def extract_token_ids_and_logprobs(response, tokenizer_name="Qwen/Qwen3-4B"):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    result = []

    if response == [[]]:
        return [result]


    for item in response[0]:
        encoded = tokenizer.encode(item["token"], add_special_tokens=False)
        token_id = encoded[0] if len(encoded) > 0 else None

        if token_id is None:
            continue

        result.append({
            "logprob": item["logprob"],
            "token": f"token_id:{token_id}"
        })

    return [result]

# 配置你的 vLLM 服务地址和模型名
VLLM_BASE_URL = "http://x.x.x.x/v1"  # vLLM 提供的OpenAI兼容接口，推理是使用机器ip。
MODEL_NAME = "Qwen/Qwen3-4B"  # 模型名称（需与 vLLM 启动时一致）

with Pipeline() as pipeline:
    # Step 1: 从 HuggingFace Hub 加载数据集
    '''
    load_data = LoadDataFromDicts(
        data=[{"prompt": "Write a poem about the sun and moon."}]
    )
    '''
    load_data = LoadDataFromHub(
        output_mappings={"instruction": "instruction"}  # 将 'instruction' 映射为 'prompt'
    )
    '''
    load_data = LoadDataFromHub(
        repo_id="tatsu-lab/alpaca",
        split="train",
        batch_size=8,
        output_mappings={"instruction": "instruction"}
    )
    load_data.load()
    result = next(load_data.process())
    print(result)
    '''

    # Step 2: 设置 LLM（连接到 vLLM 服务）
    llm = OpenAILLM(
        model=MODEL_NAME,
        base_url=VLLM_BASE_URL,
        api_key="EMPTY",  # 本地部署无需真实 key
        max_retries=5,
        generation_kwargs={
            "temperature": 2.0,
            "max_new_tokens": 4096,
            "top_p": 0.9,
            "logprobs": True,       # 如果你需要 logprobs
            "top_logprobs": 20,     # 控制返回的 top logprob 数量
            # "skip_special_tokens": True,
        }
    )

    # Step 3: 文本生成任务（用于生成 label）
    text_generation = TextGeneration(
        name="text_generation",
        llm=llm,
        input_batch_size=1,  # 可根据显存调整
        # input_mappings={"instruction": "instruction"},
    )

    # Step 流程定义
    load_data >> text_generation

    '''
    # 查看vllm推理结果
    llm.load()
    output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Give three tips for staying healthy"}]])
    print(output)
    '''

if __name__ == "__main__":
    # 运行 pipeline
    distiset = pipeline.run(
        parameters={
            load_data.name: {
                "repo_id": "distilabel-internal-testing/instructions",  # 替换为你自己的数据集
                "split": "test",
                "batch_size": 1,
            },
        }
    )

    # 保存结果（可选）
    distiset.save_to_disk("distilabel-internal-testing/distilabel-example-instructions")
    # distiset.push_to_hub(repo_id="sankexin/distilabel-example-instructions-Qwen3-8B", token="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")
    # print(distiset)
    import pandas as pd
    for step_name in list(distiset.keys()):
        formatted_data = []
        step_output = distiset[step_name]["train"]
        formatted_data = []
        for item in step_output:
            item["llm_text_generation_logprobs"] = extract_token_ids_and_logprobs(item.get("llm_text_generation_logprobs", ""))
            messages_combined = [
                {"role": "user", "content": item.get("instruction", "")},
                {"role": "assistant", "content": item.get("generation", "")}
                ]
            item["messages"] = {"role": "user", "content": item.get("instruction", "")}
            item["messages_combined"] =  messages_combined
            formatted_data.append(item)

        df = pd.DataFrame(formatted_data)
        # print(df)
        df.to_parquet(f"distilabel-internal-testing/{step_name}.parquet", engine="pyarrow")
        print(f"✅ {step_name}.parquet 已保存")


    '''
    from datasets import load_dataset
    ds = load_dataset("distilabel-internal-testing/distilabel-example-instructions")
    print(ds['train'][0])
    '''
