lora_finetune.ipynb 3.89 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MiniCPM-2B 参数高效微调(LoRA)A100 80G 单卡示例\n",
    "\n",
    "显存更小的显卡可用 batch size 和 grad_accum 间时间换空间\n",
    "\n",
    "本 notebook 是一个使用 `AdvertiseGen` 数据集对 MiniCPM-2B 进行 LoRA 微调,使其具备专业的广告生成能力的代码示例。\n",
    "\n",
    "## 最低硬件需求\n",
    "- 显存:12GB\n",
    "- 显卡架构:安培架构(推荐)\n",
    "- 内存:16GB"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 准备数据集\n",
    "\n",
    "将数据集转换为更通用的格式\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 转换为 ChatML 格式\n",
    "import os\n",
    "import shutil\n",
    "import json\n",
    "\n",
    "input_dir = \"data/AdvertiseGen\"\n",
    "output_dir = \"data/AdvertiseGenChatML\"\n",
    "if os.path.exists(output_dir):\n",
    "    shutil.rmtree(output_dir)\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "for fn in [\"train.json\", \"dev.json\"]:\n",
    "    data_out_list = []\n",
    "    with open(os.path.join(input_dir, fn), \"r\") as f, open(os.path.join(output_dir, fn), \"w\") as fo:\n",
    "        for line in f:\n",
    "            if len(line.strip()) > 0:\n",
    "                data = json.loads(line)\n",
    "                data_out = {\n",
    "                    \"messages\": [\n",
    "                        {\n",
    "                            \"role\": \"user\",\n",
    "                            \"content\": data[\"content\"],\n",
    "                        },\n",
    "                        {\n",
    "                            \"role\": \"assistant\",\n",
    "                            \"content\": data[\"summary\"],\n",
    "                        },\n",
    "                    ]\n",
    "                }\n",
    "                data_out_list.append(data_out)\n",
    "        json.dump(data_out_list, fo, ensure_ascii=False, indent=4)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 使用 LoRA 进行微调\n",
    "\n",
    "命令行一键运行"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!bash lora_finetune.sh"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 推理验证"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"output/AdvertiseGenLoRA/20240315224356/checkpoint-3000\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(path)\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    path, torch_dtype=torch.bfloat16, device_map=\"cuda\", trust_remote_code=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res, history = model.chat(tokenizer, query=\"<用户>类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞<AI>\", max_length=80, top_p=0.5)\n",
    "res, history"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}