{ "cells": [ { "cell_type": "markdown", "source": [ "# 单卡GPU 进行 ChatGLM3-6B模型 LORA 高效微调\n", "本 Cookbook 将带领开发者使用 `AdvertiseGen` 对 ChatGLM3-6B 数据集进行 lora微调,使其具备专业的广告生成能力。\n", "\n", "## 硬件需求\n", "显存:24GB及以上(推荐使用30系或A10等sm80架构以上的NVIDIA显卡进行尝试)\n", "内存:16GB\n", "RAM: 2.9 /16 GB\n", "GPU RAM: 15.5/16.0 GB" ], "metadata": { "collapsed": false, "id": "89b89f64d8f8053d" }, "id": "89b89f64d8f8053d" }, { "cell_type": "markdown", "source": [ "## 0. 环境检查\n", "首先,先检查代码的运行地址,确保运行地址处于 `finetune_demo` 中。\n", "并且,确保已经安装了 `requirements.txt`中的依赖。\n", "\n", "> 本 demo 中,不需要使用 deepspeed, mpi4py 两个依赖,如果您安装这两个依赖遇到问题,可以不安装这两个依赖。" ], "metadata": { "collapsed": false, "id": "a7bd9a514ed09ea6" }, "id": "a7bd9a514ed09ea6" }, { "cell_type": "code", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/media/zr/Data/Code/ChatGLM3/finetune_demo\r\n" ] } ], "source": [ "!pwd" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-14T05:29:22.200365Z", "start_time": "2024-04-14T05:29:22.080929Z" } }, "id": "f7703109d1443346", "execution_count": 1 }, { "cell_type": "markdown", "source": [ "## 1. 准备数据集\n", "我们使用 AdvertiseGen 数据集来进行微调。从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 AdvertiseGen 数据集,将解压后的 AdvertiseGen 目录放到本目录的 `/data/` 下, 例如。\n", "> /media/zr/Data/Code/ChatGLM3/finetune_demo/data/AdvertiseGen" ], "metadata": { "collapsed": false }, "id": "2f50e92810011977" }, { "cell_type": "code", "outputs": [], "source": [ "import json\n", "from typing import Union\n", "from pathlib import Path\n", "\n", "\n", "def _resolve_path(path: Union[str, Path]) -> Path:\n", " return Path(path).expanduser().resolve()\n", "\n", "\n", "def _mkdir(dir_name: Union[str, Path]):\n", " dir_name = _resolve_path(dir_name)\n", " if not dir_name.is_dir():\n", " dir_name.mkdir(parents=True, exist_ok=False)\n", "\n", "\n", "def convert_adgen(data_dir: Union[str, Path], save_dir: Union[str, Path]):\n", " def _convert(in_file: Path, out_file: Path):\n", " _mkdir(out_file.parent)\n", " with open(in_file, encoding='utf-8') as fin:\n", " with open(out_file, 'wt', encoding='utf-8') as fout:\n", " for line in fin:\n", " dct = json.loads(line)\n", " sample = {'conversations': [{'role': 'user', 'content': dct['content']},\n", " {'role': 'assistant', 'content': dct['summary']}]}\n", " fout.write(json.dumps(sample, ensure_ascii=False) + '\\n')\n", "\n", " data_dir = _resolve_path(data_dir)\n", " save_dir = _resolve_path(save_dir)\n", "\n", " train_file = data_dir / 'train.json'\n", " if train_file.is_file():\n", " out_file = save_dir / train_file.relative_to(data_dir)\n", " _convert(train_file, out_file)\n", "\n", " dev_file = data_dir / 'dev.json'\n", " if dev_file.is_file():\n", " out_file = save_dir / dev_file.relative_to(data_dir)\n", " _convert(dev_file, out_file)\n", "\n", "\n", "convert_adgen('data/AdvertiseGen', 'data/AdvertiseGen_fix')" ], "metadata": { "collapsed": true, "cellView": "form", "id": "initial_id", "ExecuteTime": { "end_time": "2024-04-14T05:29:23.809255Z", "start_time": "2024-04-14T05:29:22.202731Z" } }, "id": "initial_id", "execution_count": 2 }, { "cell_type": "markdown", "source": [ "## 2. 使用命令行开始微调,我们使用 lora 进行微调\n", "接着,我们仅需要将配置好的参数以命令行的形式传参给程序,就可以使用命令行进行高效微调。" ], "metadata": { "collapsed": false, "id": "a1b7a99923349056" }, "id": "a1b7a99923349056" }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Setting eos_token is not supported, use the default one.\r\n", "Setting pad_token is not supported, use the default one.\r\n", "Setting unk_token is not supported, use the default one.\r\n", "Loading checkpoint shards: 100%|██████████████████| 7/7 [00:02<00:00, 2.77it/s]\r\n", "trainable params: 1,949,696 || all params: 6,245,533,696 || trainable%: 0.031217444255383614\r\n", "--> Model\r\n", "\r\n", "--> model has 1.949696M params\r\n", "\r\n", "Setting num_proc from 16 back to 1 for the train split to disable multiprocessing as it only contains one shard.\r\n", "Generating train split: 114599 examples [00:00, 836881.77 examples/s]\r\n", "Setting num_proc from 16 back to 1 for the validation split to disable multiprocessing as it only contains one shard.\r\n", "Generating validation split: 1070 examples [00:00, 252512.53 examples/s]\r\n", "Setting num_proc from 16 back to 1 for the test split to disable multiprocessing as it only contains one shard.\r\n", "Generating test split: 1070 examples [00:00, 313510.67 examples/s]\r\n", "Map (num_proc=16): 100%|██████| 114599/114599 [00:02<00:00, 39254.76 examples/s]\r\n", "train_dataset: Dataset({\r\n", " features: ['input_ids', 'labels'],\r\n", " num_rows: 114599\r\n", "})\r\n", "Map (num_proc=16): 100%|███████████| 1070/1070 [00:00<00:00, 1399.56 examples/s]\r\n", "val_dataset: Dataset({\r\n", " features: ['input_ids', 'output_ids'],\r\n", " num_rows: 1070\r\n", "})\r\n", "Map (num_proc=16): 100%|███████████| 1070/1070 [00:00<00:00, 1339.19 examples/s]\r\n", "test_dataset: Dataset({\r\n", " features: ['input_ids', 'output_ids'],\r\n", " num_rows: 1070\r\n", "})\r\n", "--> Sanity check\r\n", " '[gMASK]': 64790 -> -100\r\n", " 'sop': 64792 -> -100\r\n", " '<|user|>': 64795 -> -100\r\n", " '': 30910 -> -100\r\n", " '\\n': 13 -> -100\r\n", " '': 30910 -> -100\r\n", " '类型': 33467 -> -100\r\n", " '#': 31010 -> -100\r\n", " '裤': 56532 -> -100\r\n", " '*': 30998 -> -100\r\n", " '版': 55090 -> -100\r\n", " '型': 54888 -> -100\r\n", " '#': 31010 -> -100\r\n", " '宽松': 40833 -> -100\r\n", " '*': 30998 -> -100\r\n", " '风格': 32799 -> -100\r\n", " '#': 31010 -> -100\r\n", " '性感': 40589 -> -100\r\n", " '*': 30998 -> -100\r\n", " '图案': 37505 -> -100\r\n", " '#': 31010 -> -100\r\n", " '线条': 37216 -> -100\r\n", " '*': 30998 -> -100\r\n", " '裤': 56532 -> -100\r\n", " '型': 54888 -> -100\r\n", " '#': 31010 -> -100\r\n", " '阔': 56529 -> -100\r\n", " '腿': 56158 -> -100\r\n", " '裤': 56532 -> -100\r\n", " '<|assistant|>': 64796 -> -100\r\n", " '': 30910 -> 30910\r\n", " '\\n': 13 -> 13\r\n", " '': 30910 -> 30910\r\n", " '宽松': 40833 -> 40833\r\n", " '的': 54530 -> 54530\r\n", " '阔': 56529 -> 56529\r\n", " '腿': 56158 -> 56158\r\n", " '裤': 56532 -> 56532\r\n", " '这': 54551 -> 54551\r\n", " '两年': 33808 -> 33808\r\n", " '真的': 32041 -> 32041\r\n", " '吸': 55360 -> 55360\r\n", " '粉': 55486 -> 55486\r\n", " '不少': 32138 -> 32138\r\n", " ',': 31123 -> 31123\r\n", " '明星': 32943 -> 32943\r\n", " '时尚': 33481 -> 33481\r\n", " '达': 54880 -> 54880\r\n", " '人的': 31664 -> 31664\r\n", " '心头': 46565 -> 46565\r\n", " '爱': 54799 -> 54799\r\n", " '。': 31155 -> 31155\r\n", " '毕竟': 33051 -> 33051\r\n", " '好': 54591 -> 54591\r\n", " '穿': 55432 -> 55432\r\n", " '时尚': 33481 -> 33481\r\n", " ',': 31123 -> 31123\r\n", " '谁': 55622 -> 55622\r\n", " '都能': 32904 -> 32904\r\n", " '穿': 55432 -> 55432\r\n", " '出': 54557 -> 54557\r\n", " '腿': 56158 -> 56158\r\n", " '长': 54625 -> 54625\r\n", " '2': 30943 -> 30943\r\n", " '米': 55055 -> 55055\r\n", " '的效果': 35590 -> 35590\r\n", " '宽松': 40833 -> 40833\r\n", " '的': 54530 -> 54530\r\n", " '裤': 56532 -> 56532\r\n", " '腿': 56158 -> 56158\r\n", " ',': 31123 -> 31123\r\n", " '当然是': 48466 -> 48466\r\n", " '遮': 57148 -> 57148\r\n", " '肉': 55343 -> 55343\r\n", " '小': 54603 -> 54603\r\n", " '能手': 49355 -> 49355\r\n", " '啊': 55674 -> 55674\r\n", " '。': 31155 -> 31155\r\n", " '上身': 51605 -> 51605\r\n", " '随': 55119 -> 55119\r\n", " '性': 54642 -> 54642\r\n", " '自然': 31799 -> 31799\r\n", " '不': 54535 -> 54535\r\n", " '拘': 57036 -> 57036\r\n", " '束': 55625 -> 55625\r\n", " ',': 31123 -> 31123\r\n", " '面料': 46839 -> 46839\r\n", " '亲': 55113 -> 55113\r\n", " '肤': 56089 -> 56089\r\n", " '舒适': 33894 -> 33894\r\n", " '贴': 55778 -> 55778\r\n", " '身体': 31902 -> 31902\r\n", " '验': 55017 -> 55017\r\n", " '感': 54706 -> 54706\r\n", " '棒': 56382 -> 56382\r\n", " '棒': 56382 -> 56382\r\n", " '哒': 59230 -> 59230\r\n", " '。': 31155 -> 31155\r\n", " '系': 54712 -> 54712\r\n", " '带': 54882 -> 54882\r\n", " '部分': 31726 -> 31726\r\n", " '增加': 31917 -> 31917\r\n", " '设计': 31735 -> 31735\r\n", " '看点': 45032 -> 45032\r\n", " ',': 31123 -> 31123\r\n", " '还': 54656 -> 54656\r\n", " '让': 54772 -> 54772\r\n", " '单品': 46539 -> 46539\r\n", " '的设计': 34481 -> 34481\r\n", " '感': 54706 -> 54706\r\n", " '更强': 43084 -> 43084\r\n", " '。': 31155 -> 31155\r\n", " '腿部': 46799 -> 46799\r\n", " '线条': 37216 -> 37216\r\n", " '若': 55351 -> 55351\r\n", " '隐': 55733 -> 55733\r\n", " '若': 55351 -> 55351\r\n", " '现': 54600 -> 54600\r\n", " '的': 54530 -> 54530\r\n", " ',': 31123 -> 31123\r\n", " '性感': 40589 -> 40589\r\n", " '撩': 58521 -> 58521\r\n", " '人': 54533 -> 54533\r\n", " '。': 31155 -> 31155\r\n", " '颜色': 33692 -> 33692\r\n", " '敲': 57004 -> 57004\r\n", " '温柔': 34678 -> 34678\r\n", " '的': 54530 -> 54530\r\n", " ',': 31123 -> 31123\r\n", " '与': 54619 -> 54619\r\n", " '裤子': 44722 -> 44722\r\n", " '本身': 32754 -> 32754\r\n", " '所': 54626 -> 54626\r\n", " '呈现': 33169 -> 33169\r\n", " '的风格': 48084 -> 48084\r\n", " '有点': 33149 -> 33149\r\n", " '反': 54955 -> 54955\r\n", " '差': 55342 -> 55342\r\n", " '萌': 56842 -> 56842\r\n", " '。': 31155 -> 31155\r\n", " '': 2 -> 2\r\n", "/media/zr/Data/Code/ChatGLM3/venv/lib/python3.10/site-packages/accelerate/accelerator.py:436: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \r\n", "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\r\n", " warnings.warn(\r\n", "max_steps is given, it will override any value given in num_train_epochs\r\n", "***** Running training *****\r\n", " Num examples = 114,599\r\n", " Num Epochs = 1\r\n", " Instantaneous batch size per device = 4\r\n", " Total train batch size (w. parallel, distributed & accumulation) = 4\r\n", " Gradient Accumulation steps = 1\r\n", " Total optimization steps = 4,000\r\n", " Number of trainable parameters = 1,949,696\r\n", "{'loss': 4.832, 'grad_norm': 2.1177706718444824, 'learning_rate': 4.9875000000000006e-05, 'epoch': 0.0}\r\n", "{'loss': 4.6094, 'grad_norm': 3.104412078857422, 'learning_rate': 4.975e-05, 'epoch': 0.0}\r\n", "{'loss': 4.5043, 'grad_norm': 2.9755077362060547, 'learning_rate': 4.962500000000001e-05, 'epoch': 0.0}\r\n", "{'loss': 4.14, 'grad_norm': 3.3869752883911133, 'learning_rate': 4.9500000000000004e-05, 'epoch': 0.0}\r\n", "{'loss': 4.1275, 'grad_norm': 2.698483467102051, 'learning_rate': 4.937500000000001e-05, 'epoch': 0.0}\r\n", "{'loss': 3.8748, 'grad_norm': 2.9052674770355225, 'learning_rate': 4.9250000000000004e-05, 'epoch': 0.0}\r\n", "{'loss': 3.8506, 'grad_norm': 2.8566994667053223, 'learning_rate': 4.9125e-05, 'epoch': 0.0}\r\n", "{'loss': 3.7518, 'grad_norm': 2.9119534492492676, 'learning_rate': 4.9e-05, 'epoch': 0.0}\r\n", "{'loss': 3.6375, 'grad_norm': 3.1845204830169678, 'learning_rate': 4.8875e-05, 'epoch': 0.0}\r\n", "{'loss': 3.7219, 'grad_norm': 3.359720230102539, 'learning_rate': 4.875e-05, 'epoch': 0.0}\r\n", "{'loss': 3.676, 'grad_norm': 3.559992790222168, 'learning_rate': 4.8625e-05, 'epoch': 0.0}\r\n", "{'loss': 3.849, 'grad_norm': 3.822449207305908, 'learning_rate': 4.85e-05, 'epoch': 0.0}\r\n", "{'loss': 3.6154, 'grad_norm': 3.4438886642456055, 'learning_rate': 4.8375000000000004e-05, 'epoch': 0.0}\r\n", "{'loss': 3.7326, 'grad_norm': 4.374788284301758, 'learning_rate': 4.825e-05, 'epoch': 0.0}\r\n", "{'loss': 3.6854, 'grad_norm': 3.5999808311462402, 'learning_rate': 4.8125000000000004e-05, 'epoch': 0.01}\r\n", "{'loss': 3.7447, 'grad_norm': 3.8460822105407715, 'learning_rate': 4.8e-05, 'epoch': 0.01}\r\n", "{'loss': 3.5766, 'grad_norm': 4.053386211395264, 'learning_rate': 4.7875000000000005e-05, 'epoch': 0.01}\r\n", "{'loss': 3.5758, 'grad_norm': 4.296564102172852, 'learning_rate': 4.775e-05, 'epoch': 0.01}\r\n", "{'loss': 3.5486, 'grad_norm': 4.701301574707031, 'learning_rate': 4.7625000000000006e-05, 'epoch': 0.01}\r\n", "{'loss': 3.5775, 'grad_norm': 4.4896979331970215, 'learning_rate': 4.75e-05, 'epoch': 0.01}\r\n", "{'loss': 3.55, 'grad_norm': 4.9407429695129395, 'learning_rate': 4.7375e-05, 'epoch': 0.01}\r\n", "{'loss': 3.6437, 'grad_norm': 4.0624542236328125, 'learning_rate': 4.7249999999999997e-05, 'epoch': 0.01}\r\n", "{'loss': 3.6098, 'grad_norm': 4.786097049713135, 'learning_rate': 4.7125e-05, 'epoch': 0.01}\r\n", "{'loss': 3.5107, 'grad_norm': 4.457597255706787, 'learning_rate': 4.7e-05, 'epoch': 0.01}\r\n", "{'loss': 3.4723, 'grad_norm': 5.279415130615234, 'learning_rate': 4.6875e-05, 'epoch': 0.01}\r\n", "{'loss': 3.6016, 'grad_norm': 5.297557353973389, 'learning_rate': 4.6750000000000005e-05, 'epoch': 0.01}\r\n", "{'loss': 3.5475, 'grad_norm': 5.397997856140137, 'learning_rate': 4.6625e-05, 'epoch': 0.01}\r\n", "{'loss': 3.6115, 'grad_norm': 4.472784519195557, 'learning_rate': 4.6500000000000005e-05, 'epoch': 0.01}\r\n", "{'loss': 3.6273, 'grad_norm': 4.7433905601501465, 'learning_rate': 4.6375e-05, 'epoch': 0.01}\r\n", "{'loss': 3.5379, 'grad_norm': 5.81007194519043, 'learning_rate': 4.6250000000000006e-05, 'epoch': 0.01}\r\n", "{'loss': 3.4654, 'grad_norm': 5.297420501708984, 'learning_rate': 4.6125e-05, 'epoch': 0.01}\r\n", "{'loss': 3.6057, 'grad_norm': 5.738197326660156, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.01}\r\n", "{'loss': 3.4168, 'grad_norm': 5.207597732543945, 'learning_rate': 4.5875000000000004e-05, 'epoch': 0.01}\r\n", "{'loss': 3.4932, 'grad_norm': 5.2784833908081055, 'learning_rate': 4.575e-05, 'epoch': 0.01}\r\n", "{'loss': 3.518, 'grad_norm': 5.428376197814941, 'learning_rate': 4.5625e-05, 'epoch': 0.01}\r\n", "{'loss': 3.5727, 'grad_norm': 5.190096855163574, 'learning_rate': 4.55e-05, 'epoch': 0.01}\r\n", "{'loss': 3.3615, 'grad_norm': 4.818575859069824, 'learning_rate': 4.5375e-05, 'epoch': 0.01}\r\n", "{'loss': 3.5275, 'grad_norm': 5.174643039703369, 'learning_rate': 4.525e-05, 'epoch': 0.01}\r\n", "{'loss': 3.5232, 'grad_norm': 5.241923809051514, 'learning_rate': 4.5125e-05, 'epoch': 0.01}\r\n", "{'loss': 3.4699, 'grad_norm': 5.603521823883057, 'learning_rate': 4.5e-05, 'epoch': 0.01}\r\n", "{'loss': 3.6916, 'grad_norm': 5.468681335449219, 'learning_rate': 4.4875e-05, 'epoch': 0.01}\r\n", "{'loss': 3.4975, 'grad_norm': 4.969369888305664, 'learning_rate': 4.4750000000000004e-05, 'epoch': 0.01}\r\n", "{'loss': 3.6207, 'grad_norm': 5.575362682342529, 'learning_rate': 4.4625e-05, 'epoch': 0.02}\r\n", "{'loss': 3.4152, 'grad_norm': 6.52517032623291, 'learning_rate': 4.4500000000000004e-05, 'epoch': 0.02}\r\n", "{'loss': 3.4098, 'grad_norm': 5.987551212310791, 'learning_rate': 4.4375e-05, 'epoch': 0.02}\r\n", "{'loss': 3.4244, 'grad_norm': 5.613704681396484, 'learning_rate': 4.4250000000000005e-05, 'epoch': 0.02}\r\n", "{'loss': 3.5303, 'grad_norm': 5.790269374847412, 'learning_rate': 4.4125e-05, 'epoch': 0.02}\r\n", "{'loss': 3.4475, 'grad_norm': 7.037369728088379, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.02}\r\n", "{'loss': 3.4562, 'grad_norm': 5.771510601043701, 'learning_rate': 4.3875e-05, 'epoch': 0.02}\r\n", "{'loss': 3.5623, 'grad_norm': 5.876147747039795, 'learning_rate': 4.375e-05, 'epoch': 0.02}\r\n", " 12%|█████ | 500/4000 [04:39<37:01, 1.58it/s]***** Running Evaluation *****\r\n", " Num examples = 50\r\n", " Batch size = 16\r\n", "\r\n", " 0%| | 0/4 [00:00