Commit 467ec853 authored by lvzhen's avatar lvzhen
Browse files

Merge branch 'master' into 'master'

ChatGLM3-6B 微调训练

See merge request !2
parents 971c0aee 0006ad16
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# 单卡GPU 进行 ChatGLM3-6B模型 LORA 高效微调\n",
"本 Cookbook 将带领开发者使用 `AdvertiseGen` 对 ChatGLM3-6B 数据集进行 lora微调,使其具备专业的广告生成能力。\n",
"\n",
"## 硬件需求\n",
"显存:24GB及以上(推荐使用30系或A10等sm80架构以上的NVIDIA显卡进行尝试)\n",
"内存:16GB\n",
"RAM: 2.9 /16 GB\n",
"GPU RAM: 15.5/16.0 GB"
],
"metadata": {
"collapsed": false,
"id": "89b89f64d8f8053d"
},
"id": "89b89f64d8f8053d"
},
{
"cell_type": "markdown",
"source": [
"## 0. 环境检查\n",
"首先,先检查代码的运行地址,确保运行地址处于 `finetune_demo` 中。\n",
"并且,确保已经安装了 `requirements.txt`中的依赖。\n",
"\n",
"> 本 demo 中,不需要使用 deepspeed, mpi4py 两个依赖,如果您安装这两个依赖遇到问题,可以不安装这两个依赖。"
],
"metadata": {
"collapsed": false,
"id": "a7bd9a514ed09ea6"
},
"id": "a7bd9a514ed09ea6"
},
{
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/media/zr/Data/Code/ChatGLM3/finetune_demo\r\n"
]
}
],
"source": [
"!pwd"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-04-14T05:29:22.200365Z",
"start_time": "2024-04-14T05:29:22.080929Z"
}
},
"id": "f7703109d1443346",
"execution_count": 1
},
{
"cell_type": "markdown",
"source": [
"## 1. 准备数据集\n",
"我们使用 AdvertiseGen 数据集来进行微调。从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 AdvertiseGen 数据集,将解压后的 AdvertiseGen 目录放到本目录的 `/data/` 下, 例如。\n",
"> /media/zr/Data/Code/ChatGLM3/finetune_demo/data/AdvertiseGen"
],
"metadata": {
"collapsed": false
},
"id": "2f50e92810011977"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"import json\n",
"from typing import Union\n",
"from pathlib import Path\n",
"\n",
"\n",
"def _resolve_path(path: Union[str, Path]) -> Path:\n",
" return Path(path).expanduser().resolve()\n",
"\n",
"\n",
"def _mkdir(dir_name: Union[str, Path]):\n",
" dir_name = _resolve_path(dir_name)\n",
" if not dir_name.is_dir():\n",
" dir_name.mkdir(parents=True, exist_ok=False)\n",
"\n",
"\n",
"def convert_adgen(data_dir: Union[str, Path], save_dir: Union[str, Path]):\n",
" def _convert(in_file: Path, out_file: Path):\n",
" _mkdir(out_file.parent)\n",
" with open(in_file, encoding='utf-8') as fin:\n",
" with open(out_file, 'wt', encoding='utf-8') as fout:\n",
" for line in fin:\n",
" dct = json.loads(line)\n",
" sample = {'conversations': [{'role': 'user', 'content': dct['content']},\n",
" {'role': 'assistant', 'content': dct['summary']}]}\n",
" fout.write(json.dumps(sample, ensure_ascii=False) + '\\n')\n",
"\n",
" data_dir = _resolve_path(data_dir)\n",
" save_dir = _resolve_path(save_dir)\n",
"\n",
" train_file = data_dir / 'train.json'\n",
" if train_file.is_file():\n",
" out_file = save_dir / train_file.relative_to(data_dir)\n",
" _convert(train_file, out_file)\n",
"\n",
" dev_file = data_dir / 'dev.json'\n",
" if dev_file.is_file():\n",
" out_file = save_dir / dev_file.relative_to(data_dir)\n",
" _convert(dev_file, out_file)\n",
"\n",
"\n",
"convert_adgen('data/AdvertiseGen', 'data/AdvertiseGen_fix')"
],
"metadata": {
"collapsed": true,
"cellView": "form",
"id": "initial_id",
"ExecuteTime": {
"end_time": "2024-04-14T05:29:23.809255Z",
"start_time": "2024-04-14T05:29:22.202731Z"
}
},
"id": "initial_id",
"execution_count": 2
},
{
"cell_type": "markdown",
"source": [
"## 2. 使用命令行开始微调,我们使用 lora 进行微调\n",
"接着,我们仅需要将配置好的参数以命令行的形式传参给程序,就可以使用命令行进行高效微调。"
],
"metadata": {
"collapsed": false,
"id": "a1b7a99923349056"
},
"id": "a1b7a99923349056"
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Setting eos_token is not supported, use the default one.\r\n",
"Setting pad_token is not supported, use the default one.\r\n",
"Setting unk_token is not supported, use the default one.\r\n",
"Loading checkpoint shards: 100%|██████████████████| 7/7 [00:02<00:00, 2.77it/s]\r\n",
"trainable params: 1,949,696 || all params: 6,245,533,696 || trainable%: 0.031217444255383614\r\n",
"--> Model\r\n",
"\r\n",
"--> model has 1.949696M params\r\n",
"\r\n",
"Setting num_proc from 16 back to 1 for the train split to disable multiprocessing as it only contains one shard.\r\n",
"Generating train split: 114599 examples [00:00, 836881.77 examples/s]\r\n",
"Setting num_proc from 16 back to 1 for the validation split to disable multiprocessing as it only contains one shard.\r\n",
"Generating validation split: 1070 examples [00:00, 252512.53 examples/s]\r\n",
"Setting num_proc from 16 back to 1 for the test split to disable multiprocessing as it only contains one shard.\r\n",
"Generating test split: 1070 examples [00:00, 313510.67 examples/s]\r\n",
"Map (num_proc=16): 100%|██████| 114599/114599 [00:02<00:00, 39254.76 examples/s]\r\n",
"train_dataset: Dataset({\r\n",
" features: ['input_ids', 'labels'],\r\n",
" num_rows: 114599\r\n",
"})\r\n",
"Map (num_proc=16): 100%|███████████| 1070/1070 [00:00<00:00, 1399.56 examples/s]\r\n",
"val_dataset: Dataset({\r\n",
" features: ['input_ids', 'output_ids'],\r\n",
" num_rows: 1070\r\n",
"})\r\n",
"Map (num_proc=16): 100%|███████████| 1070/1070 [00:00<00:00, 1339.19 examples/s]\r\n",
"test_dataset: Dataset({\r\n",
" features: ['input_ids', 'output_ids'],\r\n",
" num_rows: 1070\r\n",
"})\r\n",
"--> Sanity check\r\n",
" '[gMASK]': 64790 -> -100\r\n",
" 'sop': 64792 -> -100\r\n",
" '<|user|>': 64795 -> -100\r\n",
" '': 30910 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" '': 30910 -> -100\r\n",
" '类型': 33467 -> -100\r\n",
" '#': 31010 -> -100\r\n",
" '裤': 56532 -> -100\r\n",
" '*': 30998 -> -100\r\n",
" '版': 55090 -> -100\r\n",
" '型': 54888 -> -100\r\n",
" '#': 31010 -> -100\r\n",
" '宽松': 40833 -> -100\r\n",
" '*': 30998 -> -100\r\n",
" '风格': 32799 -> -100\r\n",
" '#': 31010 -> -100\r\n",
" '性感': 40589 -> -100\r\n",
" '*': 30998 -> -100\r\n",
" '图案': 37505 -> -100\r\n",
" '#': 31010 -> -100\r\n",
" '线条': 37216 -> -100\r\n",
" '*': 30998 -> -100\r\n",
" '裤': 56532 -> -100\r\n",
" '型': 54888 -> -100\r\n",
" '#': 31010 -> -100\r\n",
" '阔': 56529 -> -100\r\n",
" '腿': 56158 -> -100\r\n",
" '裤': 56532 -> -100\r\n",
" '<|assistant|>': 64796 -> -100\r\n",
" '': 30910 -> 30910\r\n",
" '\\n': 13 -> 13\r\n",
" '': 30910 -> 30910\r\n",
" '宽松': 40833 -> 40833\r\n",
" '的': 54530 -> 54530\r\n",
" '阔': 56529 -> 56529\r\n",
" '腿': 56158 -> 56158\r\n",
" '裤': 56532 -> 56532\r\n",
" '这': 54551 -> 54551\r\n",
" '两年': 33808 -> 33808\r\n",
" '真的': 32041 -> 32041\r\n",
" '吸': 55360 -> 55360\r\n",
" '粉': 55486 -> 55486\r\n",
" '不少': 32138 -> 32138\r\n",
" ',': 31123 -> 31123\r\n",
" '明星': 32943 -> 32943\r\n",
" '时尚': 33481 -> 33481\r\n",
" '达': 54880 -> 54880\r\n",
" '人的': 31664 -> 31664\r\n",
" '心头': 46565 -> 46565\r\n",
" '爱': 54799 -> 54799\r\n",
" '。': 31155 -> 31155\r\n",
" '毕竟': 33051 -> 33051\r\n",
" '好': 54591 -> 54591\r\n",
" '穿': 55432 -> 55432\r\n",
" '时尚': 33481 -> 33481\r\n",
" ',': 31123 -> 31123\r\n",
" '谁': 55622 -> 55622\r\n",
" '都能': 32904 -> 32904\r\n",
" '穿': 55432 -> 55432\r\n",
" '出': 54557 -> 54557\r\n",
" '腿': 56158 -> 56158\r\n",
" '长': 54625 -> 54625\r\n",
" '2': 30943 -> 30943\r\n",
" '米': 55055 -> 55055\r\n",
" '的效果': 35590 -> 35590\r\n",
" '宽松': 40833 -> 40833\r\n",
" '的': 54530 -> 54530\r\n",
" '裤': 56532 -> 56532\r\n",
" '腿': 56158 -> 56158\r\n",
" ',': 31123 -> 31123\r\n",
" '当然是': 48466 -> 48466\r\n",
" '遮': 57148 -> 57148\r\n",
" '肉': 55343 -> 55343\r\n",
" '小': 54603 -> 54603\r\n",
" '能手': 49355 -> 49355\r\n",
" '啊': 55674 -> 55674\r\n",
" '。': 31155 -> 31155\r\n",
" '上身': 51605 -> 51605\r\n",
" '随': 55119 -> 55119\r\n",
" '性': 54642 -> 54642\r\n",
" '自然': 31799 -> 31799\r\n",
" '不': 54535 -> 54535\r\n",
" '拘': 57036 -> 57036\r\n",
" '束': 55625 -> 55625\r\n",
" ',': 31123 -> 31123\r\n",
" '面料': 46839 -> 46839\r\n",
" '亲': 55113 -> 55113\r\n",
" '肤': 56089 -> 56089\r\n",
" '舒适': 33894 -> 33894\r\n",
" '贴': 55778 -> 55778\r\n",
" '身体': 31902 -> 31902\r\n",
" '验': 55017 -> 55017\r\n",
" '感': 54706 -> 54706\r\n",
" '棒': 56382 -> 56382\r\n",
" '棒': 56382 -> 56382\r\n",
" '哒': 59230 -> 59230\r\n",
" '。': 31155 -> 31155\r\n",
" '系': 54712 -> 54712\r\n",
" '带': 54882 -> 54882\r\n",
" '部分': 31726 -> 31726\r\n",
" '增加': 31917 -> 31917\r\n",
" '设计': 31735 -> 31735\r\n",
" '看点': 45032 -> 45032\r\n",
" ',': 31123 -> 31123\r\n",
" '还': 54656 -> 54656\r\n",
" '让': 54772 -> 54772\r\n",
" '单品': 46539 -> 46539\r\n",
" '的设计': 34481 -> 34481\r\n",
" '感': 54706 -> 54706\r\n",
" '更强': 43084 -> 43084\r\n",
" '。': 31155 -> 31155\r\n",
" '腿部': 46799 -> 46799\r\n",
" '线条': 37216 -> 37216\r\n",
" '若': 55351 -> 55351\r\n",
" '隐': 55733 -> 55733\r\n",
" '若': 55351 -> 55351\r\n",
" '现': 54600 -> 54600\r\n",
" '的': 54530 -> 54530\r\n",
" ',': 31123 -> 31123\r\n",
" '性感': 40589 -> 40589\r\n",
" '撩': 58521 -> 58521\r\n",
" '人': 54533 -> 54533\r\n",
" '。': 31155 -> 31155\r\n",
" '颜色': 33692 -> 33692\r\n",
" '敲': 57004 -> 57004\r\n",
" '温柔': 34678 -> 34678\r\n",
" '的': 54530 -> 54530\r\n",
" ',': 31123 -> 31123\r\n",
" '与': 54619 -> 54619\r\n",
" '裤子': 44722 -> 44722\r\n",
" '本身': 32754 -> 32754\r\n",
" '所': 54626 -> 54626\r\n",
" '呈现': 33169 -> 33169\r\n",
" '的风格': 48084 -> 48084\r\n",
" '有点': 33149 -> 33149\r\n",
" '反': 54955 -> 54955\r\n",
" '差': 55342 -> 55342\r\n",
" '萌': 56842 -> 56842\r\n",
" '。': 31155 -> 31155\r\n",
" '': 2 -> 2\r\n",
"/media/zr/Data/Code/ChatGLM3/venv/lib/python3.10/site-packages/accelerate/accelerator.py:436: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \r\n",
"dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\r\n",
" warnings.warn(\r\n",
"max_steps is given, it will override any value given in num_train_epochs\r\n",
"***** Running training *****\r\n",
" Num examples = 114,599\r\n",
" Num Epochs = 1\r\n",
" Instantaneous batch size per device = 4\r\n",
" Total train batch size (w. parallel, distributed & accumulation) = 4\r\n",
" Gradient Accumulation steps = 1\r\n",
" Total optimization steps = 4,000\r\n",
" Number of trainable parameters = 1,949,696\r\n",
"{'loss': 4.832, 'grad_norm': 2.1177706718444824, 'learning_rate': 4.9875000000000006e-05, 'epoch': 0.0}\r\n",
"{'loss': 4.6094, 'grad_norm': 3.104412078857422, 'learning_rate': 4.975e-05, 'epoch': 0.0}\r\n",
"{'loss': 4.5043, 'grad_norm': 2.9755077362060547, 'learning_rate': 4.962500000000001e-05, 'epoch': 0.0}\r\n",
"{'loss': 4.14, 'grad_norm': 3.3869752883911133, 'learning_rate': 4.9500000000000004e-05, 'epoch': 0.0}\r\n",
"{'loss': 4.1275, 'grad_norm': 2.698483467102051, 'learning_rate': 4.937500000000001e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.8748, 'grad_norm': 2.9052674770355225, 'learning_rate': 4.9250000000000004e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.8506, 'grad_norm': 2.8566994667053223, 'learning_rate': 4.9125e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.7518, 'grad_norm': 2.9119534492492676, 'learning_rate': 4.9e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.6375, 'grad_norm': 3.1845204830169678, 'learning_rate': 4.8875e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.7219, 'grad_norm': 3.359720230102539, 'learning_rate': 4.875e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.676, 'grad_norm': 3.559992790222168, 'learning_rate': 4.8625e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.849, 'grad_norm': 3.822449207305908, 'learning_rate': 4.85e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.6154, 'grad_norm': 3.4438886642456055, 'learning_rate': 4.8375000000000004e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.7326, 'grad_norm': 4.374788284301758, 'learning_rate': 4.825e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.6854, 'grad_norm': 3.5999808311462402, 'learning_rate': 4.8125000000000004e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.7447, 'grad_norm': 3.8460822105407715, 'learning_rate': 4.8e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5766, 'grad_norm': 4.053386211395264, 'learning_rate': 4.7875000000000005e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5758, 'grad_norm': 4.296564102172852, 'learning_rate': 4.775e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5486, 'grad_norm': 4.701301574707031, 'learning_rate': 4.7625000000000006e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5775, 'grad_norm': 4.4896979331970215, 'learning_rate': 4.75e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.55, 'grad_norm': 4.9407429695129395, 'learning_rate': 4.7375e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6437, 'grad_norm': 4.0624542236328125, 'learning_rate': 4.7249999999999997e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6098, 'grad_norm': 4.786097049713135, 'learning_rate': 4.7125e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5107, 'grad_norm': 4.457597255706787, 'learning_rate': 4.7e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.4723, 'grad_norm': 5.279415130615234, 'learning_rate': 4.6875e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6016, 'grad_norm': 5.297557353973389, 'learning_rate': 4.6750000000000005e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5475, 'grad_norm': 5.397997856140137, 'learning_rate': 4.6625e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6115, 'grad_norm': 4.472784519195557, 'learning_rate': 4.6500000000000005e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6273, 'grad_norm': 4.7433905601501465, 'learning_rate': 4.6375e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5379, 'grad_norm': 5.81007194519043, 'learning_rate': 4.6250000000000006e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.4654, 'grad_norm': 5.297420501708984, 'learning_rate': 4.6125e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6057, 'grad_norm': 5.738197326660156, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.4168, 'grad_norm': 5.207597732543945, 'learning_rate': 4.5875000000000004e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.4932, 'grad_norm': 5.2784833908081055, 'learning_rate': 4.575e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.518, 'grad_norm': 5.428376197814941, 'learning_rate': 4.5625e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5727, 'grad_norm': 5.190096855163574, 'learning_rate': 4.55e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.3615, 'grad_norm': 4.818575859069824, 'learning_rate': 4.5375e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5275, 'grad_norm': 5.174643039703369, 'learning_rate': 4.525e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5232, 'grad_norm': 5.241923809051514, 'learning_rate': 4.5125e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.4699, 'grad_norm': 5.603521823883057, 'learning_rate': 4.5e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6916, 'grad_norm': 5.468681335449219, 'learning_rate': 4.4875e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.4975, 'grad_norm': 4.969369888305664, 'learning_rate': 4.4750000000000004e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6207, 'grad_norm': 5.575362682342529, 'learning_rate': 4.4625e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4152, 'grad_norm': 6.52517032623291, 'learning_rate': 4.4500000000000004e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4098, 'grad_norm': 5.987551212310791, 'learning_rate': 4.4375e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4244, 'grad_norm': 5.613704681396484, 'learning_rate': 4.4250000000000005e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5303, 'grad_norm': 5.790269374847412, 'learning_rate': 4.4125e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4475, 'grad_norm': 7.037369728088379, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4562, 'grad_norm': 5.771510601043701, 'learning_rate': 4.3875e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5623, 'grad_norm': 5.876147747039795, 'learning_rate': 4.375e-05, 'epoch': 0.02}\r\n",
" 12%|█████ | 500/4000 [04:39<37:01, 1.58it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:16<00:16, 8.09s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:32<00:11, 11.45s/it]\u001B[A\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:49<00:00, 13.52s/it]\u001B[ABuilding prefix dict from the default dictionary ...\r\n",
"Dumping model to file cache /tmp/jieba.cache\r\n",
"Loading model cost 0.580 seconds.\r\n",
"Prefix dict has been built successfully.\r\n",
" \r\n",
"\u001B[A{'eval_rouge-1': 31.645344, 'eval_rouge-2': 6.79404, 'eval_rouge-l': 23.83732, 'eval_bleu-4': 0.03250689604242964, 'eval_runtime': 54.3911, 'eval_samples_per_second': 0.919, 'eval_steps_per_second': 0.074, 'epoch': 0.02}\r\n",
" 12%|█████ | 500/4000 [05:34<37:01, 1.58it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:50<00:00, 13.52s/it]\u001B[A\r\n",
"{'loss': 3.3207, 'grad_norm': 5.6840596199035645, 'learning_rate': 4.3625e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5459, 'grad_norm': 6.672524929046631, 'learning_rate': 4.35e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5822, 'grad_norm': 5.989180564880371, 'learning_rate': 4.3375000000000004e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4859, 'grad_norm': 5.341927528381348, 'learning_rate': 4.325e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5219, 'grad_norm': 5.3769707679748535, 'learning_rate': 4.3125000000000005e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.6453, 'grad_norm': 5.812618732452393, 'learning_rate': 4.3e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4934, 'grad_norm': 5.726740837097168, 'learning_rate': 4.2875000000000005e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.3719, 'grad_norm': 5.551002025604248, 'learning_rate': 4.275e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4236, 'grad_norm': 6.213701248168945, 'learning_rate': 4.2625000000000006e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4887, 'grad_norm': 6.39825963973999, 'learning_rate': 4.25e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4365, 'grad_norm': 6.213500499725342, 'learning_rate': 4.237500000000001e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4559, 'grad_norm': 6.593310356140137, 'learning_rate': 4.2250000000000004e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4463, 'grad_norm': 5.9485673904418945, 'learning_rate': 4.2125e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4531, 'grad_norm': 6.2323737144470215, 'learning_rate': 4.2e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5338, 'grad_norm': 5.925570964813232, 'learning_rate': 4.1875e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4822, 'grad_norm': 6.287123203277588, 'learning_rate': 4.175e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5402, 'grad_norm': 6.1548848152160645, 'learning_rate': 4.1625e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.3025, 'grad_norm': 6.961801052093506, 'learning_rate': 4.15e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4016, 'grad_norm': 6.60474967956543, 'learning_rate': 4.1375e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.3547, 'grad_norm': 6.296048641204834, 'learning_rate': 4.125e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4992, 'grad_norm': 7.013551712036133, 'learning_rate': 4.1125000000000004e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5275, 'grad_norm': 6.747519493103027, 'learning_rate': 4.1e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.2475, 'grad_norm': 6.900665283203125, 'learning_rate': 4.0875000000000004e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.5727, 'grad_norm': 5.7873334884643555, 'learning_rate': 4.075e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.3961, 'grad_norm': 6.46198844909668, 'learning_rate': 4.0625000000000005e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4777, 'grad_norm': 6.117852687835693, 'learning_rate': 4.05e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.6215, 'grad_norm': 6.421164035797119, 'learning_rate': 4.0375e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4736, 'grad_norm': 6.280588626861572, 'learning_rate': 4.025e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.3248, 'grad_norm': 6.418524265289307, 'learning_rate': 4.0125e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.5496, 'grad_norm': 6.983282089233398, 'learning_rate': 4e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.2926, 'grad_norm': 6.696746349334717, 'learning_rate': 3.9875e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.3609, 'grad_norm': 6.474392414093018, 'learning_rate': 3.9750000000000004e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.458, 'grad_norm': 7.111743450164795, 'learning_rate': 3.9625e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4062, 'grad_norm': 6.317008018493652, 'learning_rate': 3.9500000000000005e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.5057, 'grad_norm': 6.232912540435791, 'learning_rate': 3.9375e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.5305, 'grad_norm': 6.192782402038574, 'learning_rate': 3.9250000000000005e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.2908, 'grad_norm': 7.155930042266846, 'learning_rate': 3.9125e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4904, 'grad_norm': 6.664801597595215, 'learning_rate': 3.9000000000000006e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4529, 'grad_norm': 7.4175615310668945, 'learning_rate': 3.8875e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.2643, 'grad_norm': 7.862004280090332, 'learning_rate': 3.875e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4562, 'grad_norm': 7.8772687911987305, 'learning_rate': 3.8625e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4186, 'grad_norm': 6.901059150695801, 'learning_rate': 3.85e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4582, 'grad_norm': 7.472389221191406, 'learning_rate': 3.8375e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.5643, 'grad_norm': 7.333090305328369, 'learning_rate': 3.825e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.3639, 'grad_norm': 6.445948600769043, 'learning_rate': 3.8125e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4389, 'grad_norm': 7.957160949707031, 'learning_rate': 3.8e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.5336, 'grad_norm': 5.9428324699401855, 'learning_rate': 3.7875e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.3242, 'grad_norm': 6.897878646850586, 'learning_rate': 3.775e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4594, 'grad_norm': 7.274386882781982, 'learning_rate': 3.7625e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.3949, 'grad_norm': 7.8012471199035645, 'learning_rate': 3.7500000000000003e-05, 'epoch': 0.03}\r\n",
" 25%|█████████▊ | 1000/4000 [10:11<28:52, 1.73it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:03<00:03, 1.53s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:05<00:01, 1.97s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 32.134831999999996, 'eval_rouge-2': 6.325576000000001, 'eval_rouge-l': 25.315346000000005, 'eval_bleu-4': 0.03137707571044217, 'eval_runtime': 9.9272, 'eval_samples_per_second': 5.037, 'eval_steps_per_second': 0.403, 'epoch': 0.03}\r\n",
" 25%|█████████▊ | 1000/4000 [10:21<28:52, 1.73it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:07<00:00, 1.77s/it]\u001B[A\r\n",
"{'loss': 3.4504, 'grad_norm': 6.908702373504639, 'learning_rate': 3.737500000000001e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4596, 'grad_norm': 7.377086639404297, 'learning_rate': 3.7250000000000004e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.6484, 'grad_norm': 8.061379432678223, 'learning_rate': 3.7125e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4, 'grad_norm': 6.452291011810303, 'learning_rate': 3.7e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3891, 'grad_norm': 8.560649871826172, 'learning_rate': 3.6875e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3551, 'grad_norm': 7.644310474395752, 'learning_rate': 3.675e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3895, 'grad_norm': 7.036133766174316, 'learning_rate': 3.6625e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4611, 'grad_norm': 7.2408528327941895, 'learning_rate': 3.65e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.5271, 'grad_norm': 7.058151721954346, 'learning_rate': 3.6375e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4662, 'grad_norm': 6.564244747161865, 'learning_rate': 3.625e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3428, 'grad_norm': 6.844818115234375, 'learning_rate': 3.6125000000000004e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.5244, 'grad_norm': 7.949232578277588, 'learning_rate': 3.6e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4357, 'grad_norm': 7.32559871673584, 'learning_rate': 3.5875000000000005e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3572, 'grad_norm': 8.051689147949219, 'learning_rate': 3.575e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3174, 'grad_norm': 7.550294399261475, 'learning_rate': 3.5625000000000005e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3588, 'grad_norm': 7.240135669708252, 'learning_rate': 3.55e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4516, 'grad_norm': 6.720525741577148, 'learning_rate': 3.5375e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4717, 'grad_norm': 6.3586320877075195, 'learning_rate': 3.525e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3574, 'grad_norm': 6.693387985229492, 'learning_rate': 3.5125e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.407, 'grad_norm': 6.322566509246826, 'learning_rate': 3.5e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.2439, 'grad_norm': 6.481217384338379, 'learning_rate': 3.4875e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3391, 'grad_norm': 7.359728813171387, 'learning_rate': 3.475e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3771, 'grad_norm': 7.4071478843688965, 'learning_rate': 3.4625e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3758, 'grad_norm': 7.325416564941406, 'learning_rate': 3.45e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4434, 'grad_norm': 6.780652046203613, 'learning_rate': 3.4375e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.2818, 'grad_norm': 7.619284152984619, 'learning_rate': 3.4250000000000006e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4562, 'grad_norm': 7.123080253601074, 'learning_rate': 3.4125e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3322, 'grad_norm': 7.0780863761901855, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3887, 'grad_norm': 6.898688316345215, 'learning_rate': 3.3875000000000003e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4793, 'grad_norm': 7.293100357055664, 'learning_rate': 3.375000000000001e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4607, 'grad_norm': 6.927903175354004, 'learning_rate': 3.3625000000000004e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4535, 'grad_norm': 6.639427661895752, 'learning_rate': 3.35e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4008, 'grad_norm': 10.613078117370605, 'learning_rate': 3.3375e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3059, 'grad_norm': 7.491557598114014, 'learning_rate': 3.325e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3484, 'grad_norm': 7.497087001800537, 'learning_rate': 3.3125e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.2969, 'grad_norm': 8.017332077026367, 'learning_rate': 3.3e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.5152, 'grad_norm': 7.311262130737305, 'learning_rate': 3.2875e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3871, 'grad_norm': 7.2260003089904785, 'learning_rate': 3.275e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3563, 'grad_norm': 7.222864151000977, 'learning_rate': 3.2625e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4166, 'grad_norm': 6.612077713012695, 'learning_rate': 3.2500000000000004e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3465, 'grad_norm': 7.431714057922363, 'learning_rate': 3.2375e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.2621, 'grad_norm': 7.619777202606201, 'learning_rate': 3.2250000000000005e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3795, 'grad_norm': 7.628826141357422, 'learning_rate': 3.2125e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3551, 'grad_norm': 7.093392848968506, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.2658, 'grad_norm': 6.70922327041626, 'learning_rate': 3.1875e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3914, 'grad_norm': 7.325173377990723, 'learning_rate': 3.175e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4367, 'grad_norm': 9.542543411254883, 'learning_rate': 3.1624999999999996e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.2979, 'grad_norm': 6.646926403045654, 'learning_rate': 3.15e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4375, 'grad_norm': 7.366168975830078, 'learning_rate': 3.1375e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4574, 'grad_norm': 6.800962924957275, 'learning_rate': 3.125e-05, 'epoch': 0.05}\r\n",
" 38%|██████████████▋ | 1500/4000 [14:57<20:28, 2.03it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:02<00:02, 1.43s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:18<00:07, 7.54s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 31.905676000000007, 'eval_rouge-2': 6.630377999999999, 'eval_rouge-l': 25.126853999999998, 'eval_bleu-4': 0.03152151596531457, 'eval_runtime': 23.6793, 'eval_samples_per_second': 2.112, 'eval_steps_per_second': 0.169, 'epoch': 0.05}\r\n",
" 38%|██████████████▋ | 1500/4000 [15:21<20:28, 2.03it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:20<00:00, 5.41s/it]\u001B[A\r\n",
"{'loss': 3.3451, 'grad_norm': 6.90294075012207, 'learning_rate': 3.1125000000000004e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3844, 'grad_norm': 8.37482738494873, 'learning_rate': 3.1e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4359, 'grad_norm': 8.105109214782715, 'learning_rate': 3.0875000000000005e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3988, 'grad_norm': 7.031566143035889, 'learning_rate': 3.075e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4945, 'grad_norm': 7.260471343994141, 'learning_rate': 3.0625000000000006e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4061, 'grad_norm': 8.252367973327637, 'learning_rate': 3.05e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4643, 'grad_norm': 7.982962131500244, 'learning_rate': 3.0375000000000003e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4326, 'grad_norm': 7.5859808921813965, 'learning_rate': 3.025e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.5098, 'grad_norm': 9.218013763427734, 'learning_rate': 3.0125000000000004e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3924, 'grad_norm': 7.129590034484863, 'learning_rate': 3e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3645, 'grad_norm': 7.882465362548828, 'learning_rate': 2.9875000000000004e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3656, 'grad_norm': 8.374431610107422, 'learning_rate': 2.975e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.4676, 'grad_norm': 7.145497798919678, 'learning_rate': 2.9625000000000002e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3199, 'grad_norm': 7.946256160736084, 'learning_rate': 2.95e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3682, 'grad_norm': 7.46930456161499, 'learning_rate': 2.9375000000000003e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.2996, 'grad_norm': 6.9753265380859375, 'learning_rate': 2.925e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.475, 'grad_norm': 8.484821319580078, 'learning_rate': 2.9125000000000003e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3715, 'grad_norm': 7.118030548095703, 'learning_rate': 2.9e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3742, 'grad_norm': 7.3347368240356445, 'learning_rate': 2.8875e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.5146, 'grad_norm': 6.8588714599609375, 'learning_rate': 2.8749999999999997e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.4602, 'grad_norm': 7.292227745056152, 'learning_rate': 2.8625e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.499, 'grad_norm': 7.423632621765137, 'learning_rate': 2.8499999999999998e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.4059, 'grad_norm': 7.430981636047363, 'learning_rate': 2.8375000000000002e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.398, 'grad_norm': 7.364171981811523, 'learning_rate': 2.825e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.4631, 'grad_norm': 7.548583984375, 'learning_rate': 2.8125000000000003e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.442, 'grad_norm': 7.765754699707031, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3605, 'grad_norm': 8.27833366394043, 'learning_rate': 2.7875e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3459, 'grad_norm': 8.09084415435791, 'learning_rate': 2.7750000000000004e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3928, 'grad_norm': 8.150015830993652, 'learning_rate': 2.7625e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3408, 'grad_norm': 7.760500907897949, 'learning_rate': 2.7500000000000004e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3803, 'grad_norm': 8.982950210571289, 'learning_rate': 2.7375e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3381, 'grad_norm': 7.609743118286133, 'learning_rate': 2.725e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.5785, 'grad_norm': 7.900216102600098, 'learning_rate': 2.7125000000000002e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3395, 'grad_norm': 8.472111701965332, 'learning_rate': 2.7000000000000002e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.4895, 'grad_norm': 8.781264305114746, 'learning_rate': 2.6875e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3846, 'grad_norm': 7.472824573516846, 'learning_rate': 2.6750000000000003e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3115, 'grad_norm': 8.073516845703125, 'learning_rate': 2.6625e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3037, 'grad_norm': 7.2763519287109375, 'learning_rate': 2.6500000000000004e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3965, 'grad_norm': 7.201462268829346, 'learning_rate': 2.6375e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3717, 'grad_norm': 7.831448554992676, 'learning_rate': 2.625e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.391, 'grad_norm': 7.940402507781982, 'learning_rate': 2.6124999999999998e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.477, 'grad_norm': 7.303577899932861, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.2766, 'grad_norm': 7.596188545227051, 'learning_rate': 2.5875e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4998, 'grad_norm': 7.545307159423828, 'learning_rate': 2.5750000000000002e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3592, 'grad_norm': 6.786509990692139, 'learning_rate': 2.5625e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.2854, 'grad_norm': 8.573935508728027, 'learning_rate': 2.5500000000000003e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3727, 'grad_norm': 7.578614234924316, 'learning_rate': 2.5375e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.2307, 'grad_norm': 7.565990447998047, 'learning_rate': 2.525e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.41, 'grad_norm': 7.094372749328613, 'learning_rate': 2.5124999999999997e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4619, 'grad_norm': 7.98245096206665, 'learning_rate': 2.5e-05, 'epoch': 0.07}\r\n",
" 50%|███████████████████▌ | 2000/4000 [19:57<17:54, 1.86it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:16<00:16, 8.01s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:32<00:11, 11.33s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 31.442076, 'eval_rouge-2': 7.156823999999999, 'eval_rouge-l': 23.246924000000003, 'eval_bleu-4': 0.03405216374744, 'eval_runtime': 64.2793, 'eval_samples_per_second': 0.778, 'eval_steps_per_second': 0.062, 'epoch': 0.07}\r\n",
" 50%|███████████████████▌ | 2000/4000 [21:01<17:54, 1.86it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:48<00:00, 12.97s/it]\u001B[A\r\n",
" \u001B[ASaving model checkpoint to ./output/checkpoint-2000\r\n",
"/media/zr/Data/Code/ChatGLM3/venv/lib/python3.10/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /media/zr/Data/Models/LLM/chatglm3-6b - will assume that the vocabulary was not modified.\r\n",
" warnings.warn(\r\n",
"{'loss': 3.3818, 'grad_norm': 8.677833557128906, 'learning_rate': 2.4875e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4928, 'grad_norm': 7.391153812408447, 'learning_rate': 2.4750000000000002e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.5547, 'grad_norm': 8.77245044708252, 'learning_rate': 2.4625000000000002e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4939, 'grad_norm': 8.10531997680664, 'learning_rate': 2.45e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3687, 'grad_norm': 8.14376449584961, 'learning_rate': 2.4375e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3307, 'grad_norm': 7.644017219543457, 'learning_rate': 2.425e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4414, 'grad_norm': 7.982100486755371, 'learning_rate': 2.4125e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4115, 'grad_norm': 8.171486854553223, 'learning_rate': 2.4e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4326, 'grad_norm': 7.437331199645996, 'learning_rate': 2.3875e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3533, 'grad_norm': 7.70622444152832, 'learning_rate': 2.375e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.2926, 'grad_norm': 7.60914945602417, 'learning_rate': 2.3624999999999998e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.5812, 'grad_norm': 8.040843963623047, 'learning_rate': 2.35e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.2502, 'grad_norm': 7.3959574699401855, 'learning_rate': 2.3375000000000002e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3521, 'grad_norm': 8.238727569580078, 'learning_rate': 2.3250000000000003e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3969, 'grad_norm': 7.359251022338867, 'learning_rate': 2.3125000000000003e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.5178, 'grad_norm': 8.128018379211426, 'learning_rate': 2.3000000000000003e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.393, 'grad_norm': 7.082696914672852, 'learning_rate': 2.2875e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4172, 'grad_norm': 7.790773868560791, 'learning_rate': 2.275e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3604, 'grad_norm': 7.583011150360107, 'learning_rate': 2.2625e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4316, 'grad_norm': 7.347414970397949, 'learning_rate': 2.25e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4496, 'grad_norm': 6.759352207183838, 'learning_rate': 2.2375000000000002e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4145, 'grad_norm': 7.640699863433838, 'learning_rate': 2.2250000000000002e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4189, 'grad_norm': 8.391305923461914, 'learning_rate': 2.2125000000000002e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3705, 'grad_norm': 8.04839038848877, 'learning_rate': 2.2000000000000003e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.2355, 'grad_norm': 8.35435962677002, 'learning_rate': 2.1875e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3584, 'grad_norm': 7.815989017486572, 'learning_rate': 2.175e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4268, 'grad_norm': 8.53368854522705, 'learning_rate': 2.1625e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.467, 'grad_norm': 7.677575588226318, 'learning_rate': 2.15e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.2885, 'grad_norm': 8.361733436584473, 'learning_rate': 2.1375e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3535, 'grad_norm': 8.110257148742676, 'learning_rate': 2.125e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3191, 'grad_norm': 8.498170852661133, 'learning_rate': 2.1125000000000002e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3271, 'grad_norm': 8.709260940551758, 'learning_rate': 2.1e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3629, 'grad_norm': 9.01534366607666, 'learning_rate': 2.0875e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3635, 'grad_norm': 7.54719352722168, 'learning_rate': 2.075e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.2623, 'grad_norm': 8.59843635559082, 'learning_rate': 2.0625e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3803, 'grad_norm': 8.170056343078613, 'learning_rate': 2.05e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3506, 'grad_norm': 7.873594284057617, 'learning_rate': 2.0375e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4871, 'grad_norm': 8.418689727783203, 'learning_rate': 2.025e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.2262, 'grad_norm': 8.624137878417969, 'learning_rate': 2.0125e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4514, 'grad_norm': 7.584123611450195, 'learning_rate': 2e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4514, 'grad_norm': 7.975276470184326, 'learning_rate': 1.9875000000000002e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.2789, 'grad_norm': 7.9726481437683105, 'learning_rate': 1.9750000000000002e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3652, 'grad_norm': 7.4362945556640625, 'learning_rate': 1.9625000000000003e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3795, 'grad_norm': 8.107170104980469, 'learning_rate': 1.9500000000000003e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.2727, 'grad_norm': 7.757025241851807, 'learning_rate': 1.9375e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3055, 'grad_norm': 7.5721869468688965, 'learning_rate': 1.925e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.2545, 'grad_norm': 8.496746063232422, 'learning_rate': 1.9125e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4332, 'grad_norm': 7.52405309677124, 'learning_rate': 1.9e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4711, 'grad_norm': 7.90508508682251, 'learning_rate': 1.8875e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.39, 'grad_norm': 9.309752464294434, 'learning_rate': 1.8750000000000002e-05, 'epoch': 0.09}\r\n",
" 62%|████████████████████████▍ | 2500/4000 [25:37<13:33, 1.84it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:03<00:03, 1.72s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:06<00:02, 2.25s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 31.633207999999996, 'eval_rouge-2': 6.800014, 'eval_rouge-l': 25.123896000000006, 'eval_bleu-4': 0.03327400496195634, 'eval_runtime': 25.5968, 'eval_samples_per_second': 1.953, 'eval_steps_per_second': 0.156, 'epoch': 0.09}\r\n",
" 62%|████████████████████████▍ | 2500/4000 [26:03<13:33, 1.84it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:22<00:00, 7.31s/it]\u001B[A\r\n",
"{'loss': 3.2988, 'grad_norm': 8.42829704284668, 'learning_rate': 1.8625000000000002e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3408, 'grad_norm': 9.460935592651367, 'learning_rate': 1.85e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.2467, 'grad_norm': 7.881652355194092, 'learning_rate': 1.8375e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3906, 'grad_norm': 8.49362564086914, 'learning_rate': 1.825e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3859, 'grad_norm': 7.6069016456604, 'learning_rate': 1.8125e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3982, 'grad_norm': 8.237305641174316, 'learning_rate': 1.8e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.465, 'grad_norm': 7.80671501159668, 'learning_rate': 1.7875e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4805, 'grad_norm': 8.655023574829102, 'learning_rate': 1.775e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3734, 'grad_norm': 8.358222961425781, 'learning_rate': 1.7625e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4732, 'grad_norm': 8.640260696411133, 'learning_rate': 1.75e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3471, 'grad_norm': 8.130788803100586, 'learning_rate': 1.7375e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4129, 'grad_norm': 7.604771614074707, 'learning_rate': 1.725e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.5184, 'grad_norm': 7.612947463989258, 'learning_rate': 1.7125000000000003e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4441, 'grad_norm': 8.518109321594238, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3992, 'grad_norm': 7.822119235992432, 'learning_rate': 1.6875000000000004e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3439, 'grad_norm': 7.961773872375488, 'learning_rate': 1.675e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4062, 'grad_norm': 8.931722640991211, 'learning_rate': 1.6625e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.2609, 'grad_norm': 7.5368194580078125, 'learning_rate': 1.65e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4715, 'grad_norm': 8.477120399475098, 'learning_rate': 1.6375e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4461, 'grad_norm': 9.24991512298584, 'learning_rate': 1.6250000000000002e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4182, 'grad_norm': 8.294699668884277, 'learning_rate': 1.6125000000000002e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.2432, 'grad_norm': 7.574826717376709, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3834, 'grad_norm': 8.255449295043945, 'learning_rate': 1.5875e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.385, 'grad_norm': 8.229700088500977, 'learning_rate': 1.575e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.449, 'grad_norm': 8.934239387512207, 'learning_rate': 1.5625e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3947, 'grad_norm': 8.390064239501953, 'learning_rate': 1.55e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3486, 'grad_norm': 8.181641578674316, 'learning_rate': 1.5375e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2568, 'grad_norm': 8.498324394226074, 'learning_rate': 1.525e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2709, 'grad_norm': 7.9656147956848145, 'learning_rate': 1.5125e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2258, 'grad_norm': 7.652721405029297, 'learning_rate': 1.5e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.4379, 'grad_norm': 8.255173683166504, 'learning_rate': 1.4875e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3639, 'grad_norm': 7.929840564727783, 'learning_rate': 1.475e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3836, 'grad_norm': 8.210647583007812, 'learning_rate': 1.4625e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.4367, 'grad_norm': 8.759031295776367, 'learning_rate': 1.45e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.4047, 'grad_norm': 8.681133270263672, 'learning_rate': 1.4374999999999999e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.327, 'grad_norm': 8.468674659729004, 'learning_rate': 1.4249999999999999e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3654, 'grad_norm': 8.48736572265625, 'learning_rate': 1.4125e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.5008, 'grad_norm': 9.581798553466797, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2943, 'grad_norm': 8.112646102905273, 'learning_rate': 1.3875000000000002e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3182, 'grad_norm': 8.913463592529297, 'learning_rate': 1.3750000000000002e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2932, 'grad_norm': 7.881869792938232, 'learning_rate': 1.3625e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2365, 'grad_norm': 7.5258941650390625, 'learning_rate': 1.3500000000000001e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3527, 'grad_norm': 9.253165245056152, 'learning_rate': 1.3375000000000002e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.248, 'grad_norm': 8.01251220703125, 'learning_rate': 1.3250000000000002e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.36, 'grad_norm': 8.332780838012695, 'learning_rate': 1.3125e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2068, 'grad_norm': 9.181897163391113, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.4514, 'grad_norm': 8.965094566345215, 'learning_rate': 1.2875000000000001e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.424, 'grad_norm': 8.944855690002441, 'learning_rate': 1.2750000000000002e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.4562, 'grad_norm': 8.20882511138916, 'learning_rate': 1.2625e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.358, 'grad_norm': 7.769922733306885, 'learning_rate': 1.25e-05, 'epoch': 0.1}\r\n",
" 75%|█████████████████████████████▎ | 3000/4000 [30:40<08:42, 1.91it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:02<00:02, 1.43s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:05<00:01, 1.94s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 33.007998, 'eval_rouge-2': 7.157356, 'eval_rouge-l': 25.306306000000003, 'eval_bleu-4': 0.0348571644891679, 'eval_runtime': 38.0831, 'eval_samples_per_second': 1.313, 'eval_steps_per_second': 0.105, 'epoch': 0.1}\r\n",
" 75%|█████████████████████████████▎ | 3000/4000 [31:18<08:42, 1.91it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:21<00:00, 7.25s/it]\u001B[A\r\n",
"{'loss': 3.4711, 'grad_norm': 8.417685508728027, 'learning_rate': 1.2375000000000001e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3418, 'grad_norm': 8.048948287963867, 'learning_rate': 1.225e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3564, 'grad_norm': 8.270435333251953, 'learning_rate': 1.2125e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.2293, 'grad_norm': 7.761234760284424, 'learning_rate': 1.2e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3873, 'grad_norm': 8.1546049118042, 'learning_rate': 1.1875e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.5338, 'grad_norm': 7.905092239379883, 'learning_rate': 1.175e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.2963, 'grad_norm': 8.120687484741211, 'learning_rate': 1.1625000000000001e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.292, 'grad_norm': 9.561246871948242, 'learning_rate': 1.1500000000000002e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.2029, 'grad_norm': 9.09880542755127, 'learning_rate': 1.1375e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3873, 'grad_norm': 7.879208087921143, 'learning_rate': 1.125e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3383, 'grad_norm': 8.732316970825195, 'learning_rate': 1.1125000000000001e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3205, 'grad_norm': 8.577627182006836, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3717, 'grad_norm': 9.737064361572266, 'learning_rate': 1.0875e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.2996, 'grad_norm': 8.619685173034668, 'learning_rate': 1.075e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.4496, 'grad_norm': 8.600975036621094, 'learning_rate': 1.0625e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.4277, 'grad_norm': 8.75851821899414, 'learning_rate': 1.05e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.4809, 'grad_norm': 7.5685930252075195, 'learning_rate': 1.0375e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.226, 'grad_norm': 8.321500778198242, 'learning_rate': 1.025e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3586, 'grad_norm': 7.587204933166504, 'learning_rate': 1.0125e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.4166, 'grad_norm': 8.86058235168457, 'learning_rate': 1e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.382, 'grad_norm': 9.254091262817383, 'learning_rate': 9.875000000000001e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.3961, 'grad_norm': 7.718448162078857, 'learning_rate': 9.750000000000002e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.4699, 'grad_norm': 8.792988777160645, 'learning_rate': 9.625e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.2145, 'grad_norm': 8.899701118469238, 'learning_rate': 9.5e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.4141, 'grad_norm': 8.802495956420898, 'learning_rate': 9.375000000000001e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.3627, 'grad_norm': 9.895890235900879, 'learning_rate': 9.25e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.4182, 'grad_norm': 8.153362274169922, 'learning_rate': 9.125e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.2916, 'grad_norm': 8.173482894897461, 'learning_rate': 9e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.2963, 'grad_norm': 9.929978370666504, 'learning_rate': 8.875e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.4039, 'grad_norm': 7.541258335113525, 'learning_rate': 8.75e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3602, 'grad_norm': 7.881056785583496, 'learning_rate': 8.625e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2324, 'grad_norm': 8.763860702514648, 'learning_rate': 8.500000000000002e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.4018, 'grad_norm': 9.141348838806152, 'learning_rate': 8.375e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3771, 'grad_norm': 8.166316032409668, 'learning_rate': 8.25e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2783, 'grad_norm': 9.261619567871094, 'learning_rate': 8.125000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.4312, 'grad_norm': 8.153901100158691, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.327, 'grad_norm': 7.708031177520752, 'learning_rate': 7.875e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3779, 'grad_norm': 7.920627117156982, 'learning_rate': 7.75e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2857, 'grad_norm': 9.732666015625, 'learning_rate': 7.625e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3588, 'grad_norm': 8.037003517150879, 'learning_rate': 7.5e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2002, 'grad_norm': 8.716700553894043, 'learning_rate': 7.375e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2863, 'grad_norm': 9.12403678894043, 'learning_rate': 7.25e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3447, 'grad_norm': 8.44495677947998, 'learning_rate': 7.1249999999999995e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3088, 'grad_norm': 8.425846099853516, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3281, 'grad_norm': 8.53967571258545, 'learning_rate': 6.875000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3451, 'grad_norm': 9.039155960083008, 'learning_rate': 6.750000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2674, 'grad_norm': 9.248905181884766, 'learning_rate': 6.625000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2703, 'grad_norm': 10.257024765014648, 'learning_rate': 6.5000000000000004e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.4084, 'grad_norm': 8.447395324707031, 'learning_rate': 6.375000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.4488, 'grad_norm': 8.430671691894531, 'learning_rate': 6.25e-06, 'epoch': 0.12}\r\n",
" 88%|██████████████████████████████████▏ | 3500/4000 [35:52<04:30, 1.85it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:04<00:04, 2.18s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:06<00:02, 2.23s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 32.222722, 'eval_rouge-2': 6.6331180000000005, 'eval_rouge-l': 25.087382, 'eval_bleu-4': 0.03253227960558209, 'eval_runtime': 25.0679, 'eval_samples_per_second': 1.995, 'eval_steps_per_second': 0.16, 'epoch': 0.12}\r\n",
" 88%|██████████████████████████████████▏ | 3500/4000 [36:17<04:30, 1.85it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:08<00:00, 2.14s/it]\u001B[A\r\n",
"{'loss': 3.3912, 'grad_norm': 9.152791976928711, 'learning_rate': 6.125e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3229, 'grad_norm': 9.17188549041748, 'learning_rate': 6e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2846, 'grad_norm': 8.172340393066406, 'learning_rate': 5.875e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.308, 'grad_norm': 8.928167343139648, 'learning_rate': 5.750000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3578, 'grad_norm': 8.738048553466797, 'learning_rate': 5.625e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2117, 'grad_norm': 8.161530494689941, 'learning_rate': 5.500000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3182, 'grad_norm': 7.672643184661865, 'learning_rate': 5.375e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.4324, 'grad_norm': 9.408201217651367, 'learning_rate': 5.25e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2418, 'grad_norm': 9.635400772094727, 'learning_rate': 5.125e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.1869, 'grad_norm': 8.71308708190918, 'learning_rate': 5e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2719, 'grad_norm': 10.24747085571289, 'learning_rate': 4.875000000000001e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.5238, 'grad_norm': 8.207618713378906, 'learning_rate': 4.75e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3982, 'grad_norm': 9.101743698120117, 'learning_rate': 4.625e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2068, 'grad_norm': 9.008282661437988, 'learning_rate': 4.5e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3084, 'grad_norm': 9.63040828704834, 'learning_rate': 4.375e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.1973, 'grad_norm': 8.8562593460083, 'learning_rate': 4.250000000000001e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.298, 'grad_norm': 8.217488288879395, 'learning_rate': 4.125e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3773, 'grad_norm': 8.624151229858398, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3955, 'grad_norm': 8.07646369934082, 'learning_rate': 3.875e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.4082, 'grad_norm': 9.692364692687988, 'learning_rate': 3.75e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3699, 'grad_norm': 9.671299934387207, 'learning_rate': 3.625e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.39, 'grad_norm': 9.423399925231934, 'learning_rate': 3.5000000000000004e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.168, 'grad_norm': 10.555978775024414, 'learning_rate': 3.3750000000000003e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.4062, 'grad_norm': 9.081645011901855, 'learning_rate': 3.2500000000000002e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2328, 'grad_norm': 8.238192558288574, 'learning_rate': 3.125e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2117, 'grad_norm': 8.344420433044434, 'learning_rate': 3e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2488, 'grad_norm': 9.779040336608887, 'learning_rate': 2.8750000000000004e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2828, 'grad_norm': 8.346026420593262, 'learning_rate': 2.7500000000000004e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.4674, 'grad_norm': 8.168132781982422, 'learning_rate': 2.625e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2598, 'grad_norm': 7.97592830657959, 'learning_rate': 2.5e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3447, 'grad_norm': 10.082160949707031, 'learning_rate': 2.375e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2311, 'grad_norm': 8.935636520385742, 'learning_rate': 2.25e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3922, 'grad_norm': 8.796125411987305, 'learning_rate': 2.1250000000000004e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.315, 'grad_norm': 8.807939529418945, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2951, 'grad_norm': 8.721334457397461, 'learning_rate': 1.875e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3289, 'grad_norm': 9.166098594665527, 'learning_rate': 1.7500000000000002e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.46, 'grad_norm': 8.010759353637695, 'learning_rate': 1.6250000000000001e-06, 'epoch': 0.14}\r\n",
"{'loss': 3.4809, 'grad_norm': 8.220529556274414, 'learning_rate': 1.5e-06, 'epoch': 0.14}\r\n",
"{'loss': 3.4166, 'grad_norm': 8.10384750366211, 'learning_rate': 1.3750000000000002e-06, 'epoch': 0.14}\r\n",
"{'loss': 3.458, 'grad_norm': 8.7192964553833, 'learning_rate': 1.25e-06, 'epoch': 0.14}\r\n",
"{'loss': 3.2795, 'grad_norm': 8.834420204162598, 'learning_rate': 1.125e-06, 'epoch': 0.14}\r\n",
"{'loss': 3.3441, 'grad_norm': 9.3894681930542, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.14}\r\n",
"{'loss': 3.3844, 'grad_norm': 7.872992038726807, 'learning_rate': 8.750000000000001e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.5111, 'grad_norm': 8.390124320983887, 'learning_rate': 7.5e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.3422, 'grad_norm': 9.196588516235352, 'learning_rate': 6.25e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.2922, 'grad_norm': 8.946027755737305, 'learning_rate': 5.000000000000001e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.4168, 'grad_norm': 7.884989261627197, 'learning_rate': 3.75e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.4125, 'grad_norm': 9.072811126708984, 'learning_rate': 2.5000000000000004e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.4373, 'grad_norm': 8.543241500854492, 'learning_rate': 1.2500000000000002e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.3844, 'grad_norm': 9.427127838134766, 'learning_rate': 0.0, 'epoch': 0.14}\r\n",
"100%|███████████████████████████████████████| 4000/4000 [40:55<00:00, 1.92it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:03<00:03, 1.96s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:06<00:02, 2.33s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 31.607680000000002, 'eval_rouge-2': 6.832874, 'eval_rouge-l': 25.068815999999998, 'eval_bleu-4': 0.03411200822704291, 'eval_runtime': 12.6342, 'eval_samples_per_second': 3.958, 'eval_steps_per_second': 0.317, 'epoch': 0.14}\r\n",
"100%|███████████████████████████████████████| 4000/4000 [41:08<00:00, 1.92it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:09<00:00, 2.33s/it]\u001B[A\r\n",
" \u001B[ASaving model checkpoint to ./output/checkpoint-4000\r\n",
"/media/zr/Data/Code/ChatGLM3/venv/lib/python3.10/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /media/zr/Data/Models/LLM/chatglm3-6b - will assume that the vocabulary was not modified.\r\n",
" warnings.warn(\r\n",
"\r\n",
"\r\n",
"Training completed. Do not forget to share your model on huggingface.co/models =)\r\n",
"\r\n",
"\r\n",
"{'train_runtime': 2468.7229, 'train_samples_per_second': 6.481, 'train_steps_per_second': 1.62, 'train_loss': 3.419384765625, 'epoch': 0.14}\r\n",
"100%|███████████████████████████████████████| 4000/4000 [41:08<00:00, 1.62it/s]\r\n",
"***** Running Prediction *****\r\n",
" Num examples = 1070\r\n",
" Batch size = 16\r\n",
"100%|███████████████████████████████████████████| 67/67 [12:42<00:00, 11.38s/it]\r\n"
]
}
],
"source": [
"!CUDA_VISIBLE_DEVICES=0 NCCL_P2P_DISABLE=\"1\" NCCL_IB_DISABLE=\"1\" python finetune_hf.py data/AdvertiseGen_fix /media/zr/Data/Models/LLM/chatglm3-6b configs/lora.yaml"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "17c87410a24d844f",
"outputId": "e347fc7d-875e-40c9-c682-3e064100476b",
"ExecuteTime": {
"end_time": "2024-04-14T06:23:41.282431Z",
"start_time": "2024-04-14T05:29:23.810692Z"
}
},
"id": "17c87410a24d844f"
},
{
"cell_type": "markdown",
"source": [
"## 3. 使用微调的数据集进行推理\n",
"在完成微调任务之后,我们可以查看到 `output` 文件夹下多了很多个`checkpoint-*`的文件夹,这些文件夹代表了训练的轮数。\n",
"我们选择最后一轮的微调权重,并使用inference进行导入。"
],
"metadata": {
"collapsed": false,
"id": "d9418f6c5c264601"
},
"id": "d9418f6c5c264601"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████████████| 7/7 [00:02<00:00, 2.45it/s]\r\n",
"Setting eos_token is not supported, use the default one.\r\n",
"Setting pad_token is not supported, use the default one.\r\n",
"Setting unk_token is not supported, use the default one.\r\n",
"这款连衣裙采用压褶的版型设计,不规则的木耳边拼接,修饰了腰线,使得身材更加修长,不规则的压褶设计,增加了层次感,不规则的压褶,修饰了腰线,拉长腿部比例,显瘦又性感,套头的设计,方便穿脱,不规则的压褶,增加层次感,视觉上拉长腿部比例,百褶的网纱拼接,增加了层次感,整体气质优雅。\r\n"
]
}
],
"source": [
"!CUDA_VISIBLE_DEVICES=0 NCCL_P2P_DISABLE=\"1\" NCCL_IB_DISABLE=\"1\" python inference_hf.py output/checkpoint-4000/ --prompt \"类型#裙*版型#显瘦*材质#网纱*风格#性感*裙型#百褶*裙下摆#压褶*裙长#连衣裙*裙衣门襟#拉链*裙衣门襟#套头*裙款式#拼接*裙款式#拉链*裙款式#木耳边*裙款式#抽褶*裙款式#不规则\""
],
"metadata": {
"id": "5060015c24e97ae",
"outputId": "d3f03d0d-46bf-4c74-9b00-dc0160da0e15",
"colab": {
"base_uri": "https://localhost:8080/"
},
"ExecuteTime": {
"end_time": "2024-04-14T06:23:52.725227Z",
"start_time": "2024-04-14T06:23:41.284552Z"
}
},
"id": "5060015c24e97ae"
},
{
"cell_type": "markdown",
"source": [
"## 4. 总结\n",
"到此位置,我们就完成了使用单张 GPU Lora 来微调 ChatGLM3-6B 模型,使其能生产出更好的广告。\n",
"在本章节中,你将会学会:\n",
"+ 如何使用模型进行 Lora 微调\n",
"+ 微调数据集的准备和对齐\n",
"+ 使用微调的模型进行推理"
],
"metadata": {
"collapsed": false,
"id": "18cd83087f096094"
},
"id": "18cd83087f096094"
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3 (ipykernel)"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
},
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "V100"
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 5
}
export HIP_VISIBLE_DEIVCES=7
python inference_hf.py /path/to/checkpoint/ --prompt "类型#裙*版型#显瘦*材质#网纱*风格#性感*裙型#百褶*裙下摆#压褶*裙长#连衣裙*裙衣门襟#拉链*裙衣门襟#套头*裙款式#拼接*裙款式#拉链*裙款式#木耳边*裙款式#抽褶*裙款式#不规则"
import json
from typing import Union
from pathlib import Path
def _resolve_path(path: Union[str, Path]) -> Path:
return Path(path).expanduser().resolve()
def _mkdir(dir_name: Union[str, Path]):
dir_name = _resolve_path(dir_name)
if not dir_name.is_dir():
dir_name.mkdir(parents=True, exist_ok=False)
def convert_adgen(data_dir: Union[str, Path], save_dir: Union[str, Path]):
def _convert(in_file: Path, out_file: Path):
_mkdir(out_file.parent)
with open(in_file, encoding='utf-8') as fin:
with open(out_file, 'wt', encoding='utf-8') as fout:
for line in fin:
dct = json.loads(line)
sample = {'conversations': [{'role': 'user', 'content': dct['content']},
{'role': 'assistant', 'content': dct['summary']}]}
fout.write(json.dumps(sample, ensure_ascii=False) + '\n')
data_dir = _resolve_path(data_dir)
save_dir = _resolve_path(save_dir)
train_file = data_dir / 'train.json'
if train_file.is_file():
out_file = save_dir / train_file.relative_to(data_dir)
_convert(train_file, out_file)
dev_file = data_dir / 'dev.json'
if dev_file.is_file():
out_file = save_dir / dev_file.relative_to(data_dir)
_convert(dev_file, out_file)
convert_adgen('data/AdvertiseGen', 'data/AdvertiseGen_fix')
jieba>=0.42.1
ruamel_yaml>=0.18.6
rouge_chinese>=1.0.3
jupyter>=1.0.0
datasets>=2.18.0
peft>=0.10.0
#deepspeed==0.13.1
mpi4py>=3.1.5
export HIP_VISIBLE_DEVICES=1,2,3,4
export HSA_FORCE_FINE_GRAIN_PCIE=1
torchrun --standalone --nnodes=1 --nproc_per_node=4 finetune_hf_sft.py data/AdvertiseGen_fix /path/to/chatglm3-6b configs/sft.yaml
export HIP_VISIBLE_DEIVCES=7
python inference_hf.py /path/to/checkpoint/ --prompt "类型#裙*版型#显瘦*材质#网纱*风格#性感*裙型#百褶*裙下摆#压褶*裙长#连衣裙*裙衣门襟#拉链*裙衣门襟#套头*裙款式#拼接*裙款式#拉链*裙款式#木耳边*裙款式#抽褶*裙款式#不规则"
import ast
import json
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel, AutoConfig
from typing import List, Optional
from utils import tool_config_from_file
class ChatGLM3(LLM):
max_token: int = 8192
do_sample: bool = False
do_sample: bool = True
temperature: float = 0.8
top_p = 0.8
tokenizer: object = None
model: object = None
history: List = []
tool_names: List = []
has_search: bool = False
def __init__(self):
......@@ -33,32 +32,50 @@ class ChatGLM3(LLM):
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(
model_name_or_path, config=model_config, trust_remote_code=True
).half().cuda()
model_name_or_path, config=model_config, trust_remote_code=True, device_map="auto").eval()
def _tool_history(self, prompt: str):
ans = []
tool_prompts = prompt.split(
"You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n")
tool_names = [tool.split(":")[0] for tool in tool_prompts]
self.tool_names = tool_names
tools_json = []
for i, tool in enumerate(tool_names):
tool_config = tool_config_from_file(tool)
if tool_config:
tools_json.append(tool_config)
else:
ValueError(
f"Tool {tool} config not found! It's description is {tool_prompts[i]}"
)
for tool_desc in tool_prompts:
name = tool_desc.split(":")[0]
description = tool_desc.split(", args:")[0].split(":")[1].strip()
parameters_str = tool_desc.split("args:")[1].strip()
parameters_dict = ast.literal_eval(parameters_str)
params_cleaned = {}
for param, details in parameters_dict.items():
params_cleaned[param] = {'description': details['description'], 'type': details['type']}
tools_json.append({
"name": name,
"description": description,
"parameters": params_cleaned
})
ans.append({
"role": "system",
"content": "Answer the following questions as best as you can. You have access to the following tools:",
"tools": tools_json
})
query = f"""{prompt.split("Human: ")[-1].strip()}"""
dialog_parts = prompt.split("Human: ")
for part in dialog_parts[1:]:
if "\nAI: " in part:
user_input, ai_response = part.split("\nAI: ")
ai_response = ai_response.split("\n")[0]
else:
user_input = part
ai_response = None
ans.append({"role": "user", "content": user_input.strip()})
if ai_response:
ans.append({"role": "assistant", "content": ai_response.strip()})
query = dialog_parts[-1].split("\n")[0]
return ans, query
def _extract_observation(self, prompt: str):
......@@ -73,16 +90,25 @@ class ChatGLM3(LLM):
if len(self.history[-1]["metadata"]) > 0:
metadata = self.history[-1]["metadata"]
content = self.history[-1]["content"]
if "tool_call" in content:
for tool in self.tool_names:
if tool in metadata:
input_para = content.split("='")[-1].split("'")[0]
action_json = {
"action": tool,
"action_input": input_para
}
self.has_search = True
return f"""
lines = content.split('\n')
for line in lines:
if 'tool_call(' in line and ')' in line and self.has_search is False:
# 获取括号内的字符串
params_str = line.split('tool_call(')[-1].split(')')[0]
# 解析参数对
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
action_json = {
"action": metadata,
"action_input": params
}
self.has_search = True
print("*****Action*****")
print(action_json)
print("*****Answer*****")
return f"""
Action:
```
{json.dumps(action_json, ensure_ascii=False)}
......@@ -99,17 +125,11 @@ Action:
```"""
def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]):
print("======")
print(prompt)
print("======")
if not self.has_search:
self.history, query = self._tool_history(prompt)
else:
self._extract_observation(prompt)
query = ""
# print("======")
# print(history)
# print("======")
_, self.history = self.model.chat(
self.tokenizer,
query,
......
# README
## 模型配置
`main.py` 文件中,修改 `model_path = /path/to/chatglm3-6b` 路径,也可以填写 `THUDM/chatglm3-6b` 自动下载模型。
## 工具添加
### LangChain 已实现工具
参考 [langchain](https://python.langchain.com/docs/modules/agents/tools/) 工具相关函数,在 `main.py` 中导入工具模块,例如导入 `arxiv` 工具
```python
run_tool(["arxiv"], llm, [
"帮我查询AgentTuning相关工作"
])
```
#### Calculator、Weather Tool配置
如果你的 Python 环境中 `LangChain` 的版本低于 **`0.0.278`** 则需要在这两个预定义工具类中实现 `_arun` 方法
否则将会出现
`TypeError: Can't instantiate abstract class Weather with abstract method _arun`
示例如下:
```python
class Weather(BaseTool):
name = "weather"
description = "Use for searching weather at a specific location"
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
# 用例中没有用到 arun 不予具体实现
pass
```
运行 `main.py` 文件
```
python main.py
```
模型会因找不到 `arxiv` 工具的 yaml 文件描述而中断,需要用户手动构建 `./Tool/arxiv.yaml` 文件。工具可以用户自行描述,也可以参考 LangChain 对该工具的描述。
`arxiv` 这个例子而言,参考内容位于 `./Tool/arxiv_example.yaml` 文件,可参考该文件构建 `Tool/arxiv.yaml` 文件(最简单的方式修改名称即可),重新运行模型就能合理调用工具。
> 有些工具需要导入 API_KEY,按照 langchain 报错添加到环境变量即可。
### 自定义工具
如果用户想自定义工具,可以参考 `Tool/Weather.py` 以及 `Tool/Weather.yaml` 文件,重载新的 `Tool` 类,实现其对应的 `_run()` 方法,然后在 `main.py` 中导入该工具模块,例如导入 `Weather` 工具,即可以调用
```python
# 对同一个工具调用多次
# 需要 export SENIVERSE_KEY=<YOUR_API_KEY_HERE>
run_tool([Weather()], llm, [
"今天北京天气怎么样?",
"What's the weather like in Shanghai today",
])
```
## 多工具使用
可以将多个工具组装在一起让模型自动选择调用,例如
```python
run_tool([Calculator(), "arxiv", Weather()], llm, [
"帮我检索GLM-130B相关论文",
"今天北京天气怎么样?",
"根号3减去根号二再加上4等于多少?",
])
```
import abc
import math
from typing import Any
from langchain.tools import BaseTool
class Calculator(BaseTool, abc.ABC):
name = "Calculator"
description = "Useful for when you need to answer questions about math"
def __init__(self):
super().__init__()
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
# 用例中没有用到 arun 不予具体实现
pass
def _run(self, para: str) -> str:
para = para.replace("^", "**")
if "sqrt" in para:
para = para.replace("sqrt", "math.sqrt")
elif "log" in para:
para = para.replace("log", "math.log")
return eval(para)
if __name__ == "__main__":
calculator_tool = Calculator()
result = calculator_tool.run("sqrt(2) + 3")
print(result)
name: Calculator
description: Useful for when you need to answer questions about math
parameters:
type: object
properties:
formula:
type: string
description: The formula to be calculated
required:
- formula
\ No newline at end of file
name: arxiv
description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.
parameters:
type: object
properties:
query:
type: string
description: The search query title
required:
- query
\ No newline at end of file
name: weather
description: Search the current weather of a city
parameters:
type: object
properties:
city:
type: string
description: City name
required:
- city
\ No newline at end of file
"""
This script demonstrates the use of the LangChain's StructuredChatAgent and AgentExecutor alongside various tools
The script utilizes the ChatGLM3 model, a large language model for understanding and generating human-like text.
The model is loaded from a specified path and integrated into the chat agent.
Tools:
- Calculator: Performs arithmetic calculations.
- Weather: Provides weather-related information based on input queries.
- DistanceConverter: Converts distances between meters, kilometers, and feet.
The agent operates in three modes:
1. Single Parameter without History: Uses Calculator to perform simple arithmetic.
2. Single Parameter with History: Uses Weather tool to answer queries about temperature, considering the
conversation history.
3. Multiple Parameters without History: Uses DistanceConverter to convert distances between specified units.
4. Single use Langchain Tool: Uses Arxiv tool to search for scientific articles.
Note:
The model calling tool fails, which may cause some errors or inability to execute. Try to reduce the temperature
parameters of the model, or reduce the number of tools, especially the third function.
The success rate of multi-parameter calling is low. The following errors may occur:
Required fields [type=missing, input_value={'distance': '30', 'unit': 'm', 'to': 'km'}, input_type=dict]
The model illusion in this case generates parameters that do not meet the requirements.
The top_p and temperature parameters of the model should be adjusted to better solve such problems.
Success example:
*****Action*****
{
'action': 'weather',
'action_input': {
'location': '厦门'
}
}
*****Answer*****
{
'input': '厦门比北京热吗?',
'chat_history': [HumanMessage(content='北京温度多少度'), AIMessage(content='北京现在12度')],
'output': '根据最新的天气数据,厦门今天的气温为18度,天气晴朗。而北京今天的气温为12度。所以,厦门比北京热。'
}
****************
"""
import os
from typing import List
from ChatGLM3 import ChatGLM3
from langchain.agents import load_tools
from Tool.Weather import Weather
from Tool.Calculator import Calculator
from langchain.agents import initialize_agent
from langchain.agents import AgentType
from langchain import hub
from langchain.agents import AgentExecutor, create_structured_chat_agent, load_tools
from langchain_core.messages import AIMessage, HumanMessage
from ChatGLM3 import ChatGLM3
from tools.Calculator import Calculator
from tools.Weather import Weather
from tools.DistanceConversion import DistanceConverter
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
def run_tool(tools, llm, prompt_chain: List[str]):
loaded_tolls = []
for tool in tools:
if isinstance(tool, str):
loaded_tolls.append(load_tools([tool], llm=llm)[0])
else:
loaded_tolls.append(tool)
agent = initialize_agent(
loaded_tolls, llm,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
handle_parsing_errors=True
if __name__ == "__main__":
llm = ChatGLM3()
llm.load_model(MODEL_PATH)
prompt = hub.pull("hwchase17/structured-chat-agent")
# for single parameter without history
tools = [Calculator()]
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)
ans = agent_executor.invoke({"input": "34 * 34"})
print(ans)
# for singe parameter with history
tools = [Weather()]
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)
ans = agent_executor.invoke(
{
"input": "厦门比北京热吗?",
"chat_history": [
HumanMessage(content="北京温度多少度"),
AIMessage(content="北京现在12度"),
],
}
)
for prompt in prompt_chain:
agent.run(prompt)
print(ans)
# for multiple parameters without history
if __name__ == "__main__":
llm = ChatGLM3()
llm.load_model(model_name_or_path=MODEL_PATH)
# arxiv: 单个工具调用示例 1
run_tool(["arxiv"], llm, [
"帮我查询GLM-130B相关工作"
])
# weather: 单个工具调用示例 2
run_tool([Weather()], llm, [
"今天北京天气怎么样?",
"What's the weather like in Shanghai today",
])
# calculator: 单个工具调用示例 3
run_tool([Calculator()], llm, [
"12345679乘以54等于多少?",
"3.14的3.14次方等于多少?",
"根号2加上根号三等于多少?",
]),
# arxiv + weather + calculator: 多个工具结合调用
# run_tool([Calculator(), "arxiv", Weather()], llm, [
# "帮我检索GLM-130B相关论文",
# "今天北京天气怎么样?",
# "根号3减去根号二再加上4等于多少?",
# ])
\ No newline at end of file
tools = [DistanceConverter()]
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)
ans = agent_executor.invoke({"input": "how many meters in 30 km?"})
print(ans)
# for using langchain tools
tools = load_tools(["arxiv"], llm=llm)
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)
ans = agent_executor.invoke({"input": "Describe the paper about GLM 130B"})
print(ans)
langchain
arxiv
\ No newline at end of file
import abc
import re
from typing import Type
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
class CalculatorInput(BaseModel):
calculation: str = Field(description="calculation to perform")
class Calculator(BaseTool, abc.ABC):
name = "Calculator"
description = "Useful for when you need to calculate math problems"
args_schema: Type[BaseModel] = CalculatorInput
def __init__(self):
super().__init__()
def parameter_validation(self, para: str):
"""
You can write your own parameter validation rules here,
you can refer to the code given here.
:param para:
:return:
"""
symbols = ["math", "sqrt", "log", "sin", "cos", "tan", "pi"]
for sym in symbols:
para = para.replace(sym, "")
patten = re.compile("[+*/\-%\d()=\s.]{3,}")
if re.findall(patten, para):
return True
def _run(self, calculation: str) -> str:
calculation = calculation.replace("^", "**")
if "sqrt" in calculation and "math" not in calculation:
calculation = calculation.replace("sqrt", "math.sqrt")
if "log" in calculation and "math" not in calculation:
calculation = calculation.replace("log", "math.log")
if "sin" in calculation and "math" not in calculation:
calculation = calculation.replace("sin", "math.sin")
if "cos" in calculation and "math" not in calculation:
calculation = calculation.replace("cos", "math.cos")
if "tan" in calculation and "math" not in calculation:
calculation = calculation.replace("tan", "math.tan")
if "pi" in calculation and "math" not in calculation:
calculation = calculation.replace("pi", "math.pi")
if "pI" in calculation and "math" not in calculation:
calculation = calculation.replace("pI", "math.pi")
if "PI" in calculation and "math" not in calculation:
calculation = calculation.replace("PI", "math.pi")
if "Pi" in calculation and "math" not in calculation:
calculation = calculation.replace("Pi", "math.pi")
return eval(calculation)
import abc
from typing import Type
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
class DistanceConversionInput(BaseModel):
distance: float = Field(description="The numerical value of the distance to convert")
unit: str = Field(description="The current unit of the distance (m, km, or feet)")
to_unit: str = Field(description="The target unit to convert the distance into (m, km, or feet)")
class DistanceConverter(BaseTool, abc.ABC):
name = "DistanceConverter"
description = "Converts distance between meters, kilometers, and feet"
args_schema: Type[BaseModel] = DistanceConversionInput
def __init__(self):
super().__init__()
def _run(self, distance: float, unit: str, to_unit: str) -> str:
unit_conversions = {
"m_to_km": 0.001,
"km_to_m": 1000,
"feet_to_m": 0.3048,
"m_to_feet": 3.28084,
"km_to_feet": 3280.84,
"feet_to_km": 0.0003048
}
if unit == to_unit:
return f"{distance} {unit} is equal to {distance} {to_unit}"
if unit == "km":
distance *= unit_conversions["km_to_m"]
elif unit == "feet":
distance *= unit_conversions["feet_to_m"]
if to_unit == "km":
converted_distance = distance * unit_conversions["m_to_km"]
elif to_unit == "feet":
converted_distance = distance * unit_conversions["m_to_feet"]
else:
converted_distance = distance # already in meters if this block is reached
return f"{distance} {unit} is equal to {converted_distance} {to_unit}"
import os
from typing import Any
import requests
from typing import Type, Any
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
class WeatherInput(BaseModel):
location: str = Field(description="the location need to check the weather")
class Weather(BaseTool):
name = "weather"
description = "Use for searching weather at a specific location"
args_schema: Type[BaseModel] = WeatherInput
def __init__(self):
super().__init__()
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
# 用例中没有用到 arun 不予具体实现
pass
def get_weather(self, location):
def _run(self, location: str) -> dict[str, Any]:
api_key = os.environ["SENIVERSE_KEY"]
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
response = requests.get(url)
......@@ -30,12 +31,3 @@ class Weather(BaseTool):
else:
raise Exception(
f"Failed to retrieve weather: {response.status_code}")
def _run(self, para: str) -> str:
return self.get_weather(para)
if __name__ == "__main__":
weather_tool = Weather()
weather_info = weather_tool.run("成都")
print(weather_info)
import os
import yaml
def tool_config_from_file(tool_name, directory="Tool/"):
"""search tool yaml and return json format"""
for filename in os.listdir(directory):
if filename.endswith('.yaml') and tool_name in filename:
file_path = os.path.join(directory, filename)
with open(file_path, encoding='utf-8') as f:
return yaml.safe_load(f)
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment