Commit 4bd96acc authored by lvzhen's avatar lvzhen
Browse files

Update tools_using_demo/cli_demo_tool.py, tools_using_demo/openai_api_demo.py,...

Update tools_using_demo/cli_demo_tool.py, tools_using_demo/openai_api_demo.py, tools_using_demo/README.md, tools_using_demo/README_en.md, tools_using_demo/tool_register.py, tensorrt_llm_demo/README.md, tensorrt_llm_demo/tensorrt_llm_cli_demo.py, resources/cli-demo.png, resources/web-demo2.png, resources/tool_en.png, resources/tool.png, resources/heart.png, resources/wechat.jpg, resources/web-demo.gif, resources/web-demo2.gif, resources/WECHAT.md, resources/code_en.gif, openai_api_demo/api_server.py, openai_api_demo/.env, openai_api_demo/openai_api_request.py, openai_api_demo/docker-compose.yml, openai_api_demo/utils.py, openai_api_demo/zhipu_api_request.py, openai_api_demo/langchain_openai_api.py, langchain_demo/ChatGLM3.py, langchain_demo/main.py, langchain_demo/tools/Calculator.py, langchain_demo/tools/DistanceConversion.py, langchain_demo/tools/Weather.py, Intel_device_demo/README.md, Intel_device_demo/ipex_llm_cpu_demo/api_server.py, Intel_device_demo/ipex_llm_cpu_demo/chatglm3_infer.py, Intel_device_demo/ipex_llm_cpu_demo/chatglm3_web_demo.py, Intel_device_demo/ipex_llm_cpu_demo/openai_api_request.py, Intel_device_demo/ipex_llm_cpu_demo/generate.py, Intel_device_demo/ipex_llm_cpu_demo/utils.py, Intel_device_demo/openvino_demo/openvino_cli_demo.py, Intel_device_demo/openvino_demo/README.md, finetune_demo/lora_finetune.ipynb, finetune_demo/finetune_hf.py, finetune_demo/inference_hf.py, finetune_demo/README.md, finetune_demo/README_en.md, finetune_demo/requirements.txt, finetune_demo/configs/ds_zero_3.json, finetune_demo/configs/ds_zero_2.json, finetune_demo/configs/ptuning_v2.yaml, finetune_demo/configs/lora.yaml, finetune_demo/configs/sft.yaml, composite_demo/assets/emojis.png, composite_demo/assets/demo.png, composite_demo/assets/heart.png, composite_demo/assets/tool.png, composite_demo/.streamlit/config.toml, composite_demo/client.py, composite_demo/conversation.py, composite_demo/README_en.md, composite_demo/main.py, composite_demo/demo_chat.py, composite_demo/README.md, composite_demo/requirements.txt, composite_demo/demo_tool.py, composite_demo/tool_registry.py, composite_demo/demo_ci.py, basic_demo/cli_demo_bad_word_ids.py, basic_demo/cli_demo.py, basic_demo/cli_batch_request_demo.py, basic_demo/web_demo_gradio.py, basic_demo/web_demo_streamlit.py, .github/ISSUE_TEMPLATE/bug_report.yaml, .github/ISSUE_TEMPLATE/feature-request.yaml, .github/PULL_REQUEST_TEMPLATE/pr_template.md, MODEL_LICENSE, .gitignore, DEPLOYMENT.md, DEPLOYMENT_en.md, LICENSE, PROMPT.md, README_en.md, requirements.txt, README.md, PROMPT_en.md, update_requirements.sh files
parent d0572507
"""
This code is the tool registration part. By registering the tool, the model can call the tool.
This code provides extended functionality to the model, enabling it to call and interact with a variety of utilities
through defined interfaces.
"""
import copy
import inspect
from pprint import pformat
import traceback
from types import GenericAlias
from typing import get_origin, Annotated
import subprocess
_TOOL_HOOKS = {}
_TOOL_DESCRIPTIONS = {}
def register_tool(func: callable):
tool_name = func.__name__
tool_description = inspect.getdoc(func).strip()
python_params = inspect.signature(func).parameters
tool_params = []
for name, param in python_params.items():
annotation = param.annotation
if annotation is inspect.Parameter.empty:
raise TypeError(f"Parameter `{name}` missing type annotation")
if get_origin(annotation) != Annotated:
raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
typ, (description, required) = annotation.__origin__, annotation.__metadata__
typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
if not isinstance(description, str):
raise TypeError(f"Description for `{name}` must be a string")
if not isinstance(required, bool):
raise TypeError(f"Required for `{name}` must be a bool")
tool_params.append({
"name": name,
"description": description,
"type": typ,
"required": required
})
tool_def = {
"name": tool_name,
"description": tool_description,
"params": tool_params
}
print("[registered tool] " + pformat(tool_def))
_TOOL_HOOKS[tool_name] = func
_TOOL_DESCRIPTIONS[tool_name] = tool_def
return func
def dispatch_tool(tool_name: str, tool_params: dict) -> str:
if tool_name not in _TOOL_HOOKS:
return f"Tool `{tool_name}` not found. Please use a provided tool."
tool_call = _TOOL_HOOKS[tool_name]
try:
ret = tool_call(**tool_params)
except:
ret = traceback.format_exc()
return str(ret)
def get_tools() -> dict:
return copy.deepcopy(_TOOL_DESCRIPTIONS)
# Tool Definitions
@register_tool
def random_number_generator(
seed: Annotated[int, 'The random seed used by the generator', True],
range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
) -> int:
"""
Generates a random number x, s.t. range[0] <= x < range[1]
"""
if not isinstance(seed, int):
raise TypeError("Seed must be an integer")
if not isinstance(range, tuple):
raise TypeError("Range must be a tuple")
if not isinstance(range[0], int) or not isinstance(range[1], int):
raise TypeError("Range must be a tuple of integers")
import random
return random.Random(seed).randint(*range)
@register_tool
def get_weather(
city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
"""
Get the current weather for `city_name`
"""
if not isinstance(city_name, str):
raise TypeError("City name must be a string")
key_selection = {
"current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
}
import requests
try:
resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
resp.raise_for_status()
resp = resp.json()
ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
except:
import traceback
ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
return str(ret)
@register_tool
def get_shell(
query: Annotated[str, 'The command should run in Linux shell', True],
) -> str:
"""
Use shell to run command
"""
if not isinstance(query, str):
raise TypeError("Command must be a string")
try:
result = subprocess.run(query, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
text=True)
return result.stdout
except subprocess.CalledProcessError as e:
return e.stderr
if __name__ == "__main__":
# print(dispatch_tool("get_shell", {"query": "pwd"}))
print(get_tools())
\ No newline at end of file
# ChatGLM3-6B 微调
本目录提供 ChatGLM3-6B 模型的微调示例,包括全量微调和 P-Tuning v2。格式上,提供多轮对话微调样例和输入输出格式微调样例。
如果将模型下载到了本地,本文和代码中的 `THUDM/chatglm3-6b` 字段均应替换为相应地址以从本地加载模型。
运行示例需要 `python>=3.10`,除基础的 `torch` 依赖外,示例代码运行还需要依赖。
**我们提供了 [示例notebook](lora_finetune.ipynb) 用于演示如何使用我们的微调代码。**
```bash
pip install -r requirements.txt
```
## 测试硬件标准
我们仅提供了单机多卡/多机多卡的运行示例,因此您需要至少一台具有多个 GPU 的机器。本仓库中的**默认配置文件**中,我们记录了显存的占用情况:
+ SFT 全量微调: 4张显卡平均分配,每张显卡占用 `48346MiB` 显存。
+ P-TuningV2 微调: 1张显卡,占用 `18426MiB` 显存。
+ LORA 微调: 1张显卡,占用 `14082MiB` 显存。
> 请注意,该结果仅供参考,对于不同的参数,显存占用可能会有所不同。请结合你的硬件情况进行调整。
> 请注意,我们仅仅使用英伟达 Hopper(代表显卡:H100) 和 Ampère(代表显卡:A100) 架构和系列显卡做过测试。如果您使用其他架构的显卡,可能会出现
> 1. 未知的训练问题 / 显存占用与上述有误差。
> 2. 架构过低而不支持某些特性。
> 3. 推理效果问题。
> 以上三种情况为社区曾经遇到过的问题,虽然概率较低,如果您遇到了以上问题,可以尝试在社区中解决。
## 多轮对话格式
多轮对话微调示例采用 ChatGLM3 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`
对于数据文件,样例采用如下格式
如果您仅希望微调模型的对话能力,而非工具能力,您应该按照以下格式整理数据。
```json
[
{
"conversations": [
{
"role": "system",
"content": "<system prompt text>"
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
// ... Muti Turn
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
// ...
]
```
**请注意,这种方法在微调的step较多的情况下会影响到模型的工具调用功能**
如果您希望微调模型的对话和工具能力,您应该按照以下格式整理数据。
```json
[
{
"tools": [
// available tools, format is not restricted
],
"conversations": [
{
"role": "system",
"content": "<system prompt text>"
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant thought to text>"
},
{
"role": "tool",
"name": "<name of the tool to be called",
"parameters": {
"<parameter_name>": "<parameter_value>"
},
"observation": "<observation>"
// don't have to be string
},
{
"role": "assistant",
"content": "<assistant response to observation>"
},
// ... Muti Turn
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
// ...
]
```
- 关于工具描述的 system prompt 无需手动插入,预处理时会将 `tools` 字段使用 `json.dumps(..., ensure_ascii=False)`
格式化后插入为首条 system prompt。
- 每种角色可以附带一个 `bool` 类型的 `loss` 字段,表示该字段所预测的内容是否参与 `loss`
计算。若没有该字段,样例实现中默认对 `system`, `user` 不计算 `loss`,其余角色则计算 `loss`
- `tool` 并不是 ChatGLM3 中的原生角色,这里的 `tool` 在预处理阶段将被自动转化为一个具有工具调用 `metadata``assistant`
角色(默认计算 `loss`)和一个表示工具返回值的 `observation` 角色(不计算 `loss`)。
- 目前暂未实现 `Code interpreter` 的微调任务。
- `system` 角色为可选角色,但若存在 `system` 角色,其必须出现在 `user`
角色之前,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `system` 角色。
## 数据集格式示例
这里以 AdvertiseGen 数据集为例,
您可以从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing)
或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载 AdvertiseGen 数据集。
将解压后的 AdvertiseGen 目录放到 `data` 目录下并自行转换为如下格式数据集。
> 请注意,现在的微调代码中加入了验证集,因此,对于一组完整的微调数据集,必须包含训练数据集和验证数据集,测试数据集可以不填写。或者直接用验证数据集代替。
```
{"conversations": [{"role": "user", "content": "类型#裙*裙长#半身裙"}, {"role": "assistant", "content": "这款百搭时尚的仙女半身裙,整体设计非常的飘逸随性,穿上之后每个女孩子都能瞬间变成小仙女啦。料子非常的轻盈,透气性也很好,穿到夏天也很舒适。"}]}
```
## 配置文件
微调配置文件位于 `config` 目录下,包括以下文件:
1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed 配置文件。
2. `lora.yaml / ptuning.yaml / sft.yaml`: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。 部分重要参数解释如下:
+ data_config 部分
+ train_file: 训练数据集的文件路径。
+ val_file: 验证数据集的文件路径。
+ test_file: 测试数据集的文件路径。
+ num_proc: 在加载数据时使用的进程数量。
+ max_input_length: 输入序列的最大长度。
+ max_output_length: 输出序列的最大长度。
+ training_args 部分
+ output_dir: 用于保存模型和其他输出的目录。
+ max_steps: 训练的最大步数。
+ per_device_train_batch_size: 每个设备(如 GPU)的训练批次大小。
+ dataloader_num_workers: 加载数据时使用的工作线程数量。
+ remove_unused_columns: 是否移除数据中未使用的列。
+ save_strategy: 模型保存策略(例如,每隔多少步保存一次)。
+ save_steps: 每隔多少步保存一次模型。
+ log_level: 日志级别(如 info)。
+ logging_strategy: 日志记录策略。
+ logging_steps: 每隔多少步记录一次日志。
+ per_device_eval_batch_size: 每个设备的评估批次大小。
+ evaluation_strategy: 评估策略(例如,每隔多少步进行一次评估)。
+ eval_steps: 每隔多少步进行一次评估。
+ predict_with_generate: 是否使用生成模式进行预测。
+ generation_config 部分
+ max_new_tokens: 生成的最大新 token 数量。
+ peft_config 部分
+ peft_type: 使用的参数有效调整类型(如 LORA)。
+ task_type: 任务类型,这里是因果语言模型(CAUSAL_LM)。
+ Lora 参数:
+ r: LoRA 的秩。
+ lora_alpha: LoRA 的缩放因子。
+ lora_dropout: 在 LoRA 层使用的 dropout 概率
+ P-TuningV2 参数:
+ num_virtual_tokens: 虚拟 token 的数量。
## 开始微调
通过以下代码执行 **单机多卡/多机多卡** 运行,这是使用 `deepspeed` 作为加速方案的,您需要安装 `deepspeed`
```angular2html
cd finetune_demo
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml configs/ds_zero_2.json
```
通过以下代码执行 **单机单卡** 运行。
```angular2html
cd finetune_demo
python finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml
```
## 从保存点进行微调
如果按照上述方式进行训练,每次微调都会从头开始,如果你想从训练一半的模型开始微调,你可以加入第四个参数,这个参数有两种传入方式:
1. `yes`, 自动从最后一个保存的 Checkpoint开始训练
2. `XX`, 断点号数字 例 `600` 则从序号600 Checkpoint开始训练
例如,这就是一个从最后一个保存点继续微调的示例代码
```angular2html
cd finetune_demo
python finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml yes
```
## 使用微调后的模型
### 在 inference_hf.py 中验证微调后的模型
您可以在 `finetune_demo/inference_hf.py` 中使用我们的微调后的模型,仅需要一行代码就能简单的进行测试。
```angular2html
python inference_hf.py your_finetune_path --prompt your prompt
```
这样,得到的回答就微调后的回答了。
### 在本仓库的其他 demo 或者外部仓库使用微调后的模型
您可以在任何一个 demo 内使用我们的 `lora` 和 全参微调的模型。这需要你自己按照以下教程进行修改代码。
1. 使用`finetune_demo/inference_hf.py`中读入模型的方式替换 demo 中读入模型的方式。
> 请注意,对于 LORA 和 P-TuningV2 我们没有合并训练后的模型,而是在`adapter_config.json`
> 中记录了微调型的路径,如果你的原始模型位置发生更改,则你应该修改`adapter_config.json`中`base_model_name_or_path`的路径。
```python
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code
)
return model, tokenizer
```
2. 读取微调的模型,请注意,你应该使用微调模型的位置,例如,若你的模型位置为`/path/to/finetune_adapter_model`
,原始模型地址为`path/to/base_model`,则你应该使用`/path/to/finetune_adapter_model`作为`model_dir`
3. 完成上述操作后,就能正常使用微调的模型了,其他的调用方式没有变化。
### 提示
1. 微调代码在开始训练前,会先打印首条训练数据的预处理信息(默认已经注释,可以解除注释),显示为
```log
Sanity
Check >> >> >> >> >> >> >
'[gMASK]': 64790 -> -100
'sop': 64792 -> -100
'<|system|>': 64794 -> -100
'': 30910 -> -100
'\n': 13 -> -100
'Answer': 20115 -> -100
'the': 267 -> -100
'following': 1762 -> -100
...
'know': 683 -> -100
'the': 267 -> -100
'response': 3010 -> -100
'details': 3296 -> -100
'.': 30930 -> -100
'<|assistant|>': 64796 -> -100
'': 30910 -> 30910
'\n': 13 -> 13
'I': 307 -> 307
'need': 720 -> 720
'to': 289 -> 289
'use': 792 -> 792
...
<< << << << << << < Sanity
Check
```
字样,每行依次表示一个 detokenized string, token_id 和 target_id。其中,`target_id``token_id`在模型词表中的索引,`-100`表示该
token 不参与 `loss` 计算。
2. `_prepare_model_for_training` 的作用是遍历模型的所有可训练参数,并确保它们的数据类型为`torch.float32`
这在某些情况下是必要的,因为混合精度训练或其他操作可能会更改模型参数的数据类型。该代码默打开,可以注释,但是如果使用
`half` 格式训练出现问题,可以切换回这个代码,显存可能增加。
3. 在我们的[Huggingface模型代码](https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py)中,有以下内容:
```python
if self.gradient_checkpointing and self.training:
layer_ret = torch.utils.checkpoint.checkpoint(
layer,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_caches[index],
use_cache,
use_reentrant=False
)
```
这可能导致训练的时候显存增加,因此,如果您的显存不足,可以尝试将``` use_reentrant``` 修改为`True`
4. 微调后的模型可以使用任何支持 `peft` 载入的模型加速框架,在这里,我们没有提供demo。
5. 本仓库的微调数据集格式与 API 微调数据集格式有一定区别
+ ZhipuAI API 微调数据集中的 `messages` 字段在本仓库为 `conversation` 字段。
+ ZhipuAI API 中的微调文件为 `jsonl`, 在本仓库,需要简单的将文件名改为 `json`
## 参考文献
```
@inproceedings{liu2022p,
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short
Papers)},
pages={61--68},
year={2022}
}
@misc{tang2023toolalpaca,
title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases},
author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun},
year={2023},
eprint={2306.05301},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
\ No newline at end of file
# ChatGLM3-6B finetune
This directory provides fine-tuning examples of the ChatGLM3-6B model, including full fine-tuning and P-Tuning v2. In
terms of format, it provides multiple rounds of dialogue fine-tuning samples and input and output format fine-tuning
samples.
If the model is downloaded locally, the `THUDM/chatglm3-6b` field in this article and the code should be replaced with
the corresponding address to load the model locally.
Running the example requires `python>=3.10`. In addition to the basic `torch` dependency, the example code also requires
dependencies to run.
**We provide [sample notebook](lora_finetune.ipynb) to demonstrate how to use our fine-tuning code. **
```bash
pip install -r requirements.txt
```
## Test hardware standards
We only provide single-machine multi-card/multi-machine multi-card running examples, so you will need at least one
machine with multiple GPUs. In the **default configuration file** in this warehouse, we record the usage of video
memory:
+ SFT full fine-tuning: evenly distributed among 4 graphics cards, each graphics card occupies `48346MiB` of video
memory.
+ P-TuningV2 fine-tuning: 1 graphics card, occupying `18426MiB` memory.
+ LORA fine-tuning: 1 graphics card, occupying `14082MiB` memory.
> Please note that this result is for reference only, and the memory usage may be different for different parameters.
> Please make adjustments based on your hardware conditions.
## Multi-turn dialogue format
The multi-round dialogue fine-tuning example adopts the ChatGLM3 dialogue format convention and adds
different `loss_mask` to different characters to calculate `loss` for multiple rounds of responses in one pass.
For data files, the sample adopts the following format
If you only want to fine-tune your model's conversational capabilities, rather than its tool capabilities, you should
organize your data in the following format.
```json
[
{
"conversations": [
{
"role": "system",
"content": "<system prompt text>"
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
// ... Muti Turn
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
// ...
]
```
**Please note that this method will affect the tool calling function of the model when there are many fine-tuning steps
**
If you wish to fine-tune your model's dialog and tool capabilities, you should organize your data in the following
format.
```json
[
{
"tools": [
// available tools, format is not restricted
],
"conversations": [
{
"role": "system",
"content": "<system prompt text>"
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant thought to text>"
},
{
"role": "tool",
"name": "<name of the tool to be called",
"parameters": {
"<parameter_name>": "<parameter_value>"
},
"observation": "<observation>"
// don't have to be string
},
{
"role": "assistant",
"content": "<assistant response to observation>"
},
// ... Muti Turn
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
// ...
]
```
- There is no need to manually insert the system prompt about the tool description. The `tools` field will be used
during preprocessing using `json.dumps(..., ensure_ascii=False)`
After formatting, insert it as the first system prompt.
- Each role can be accompanied by a `loss` field of type `bool`, indicating whether the content predicted by this field
participates in `loss`
calculate. If there is no such field, the sample implementation does not calculate `loss` for `system` and `user` by
default, but calculates `loss` for other roles.
- `tool` is not a native role in ChatGLM3. The `tool` here will be automatically converted into an `assistant` with tool
call `metadata` during the preprocessing stage.
role (default `loss` is calculated) and an `observation` role representing the tool return value (`loss` is not
calculated).
- The fine-tuning task of `Code interpreter` has not been implemented yet.
- The `system` role is optional, but if the `system` role exists, it must appear in `user`
Before the character, the `system` character can only appear once in a complete dialogue data (regardless of single
round or multiple rounds of dialogue).
## Dataset format example
Here we take the AdvertiseGen data set as an example,
You can download it
from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing)
Or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) download the AdvertiseGen data set.
Place the decompressed AdvertiseGen directory in the `data` directory and convert it into the following format data set
yourself.
> Please note that the verification set is added to the current fine-tuning code. Therefore, for a complete set of
> fine-tuning data sets, the training data set and the verification data set must be included, and the test data set
> does
> not need to be filled in. Or directly use the validation data set instead.
```
{"conversations": [{"role": "user", "content": "Type#skirt*skirt length#skirt"}, {"role": "assistant", "content": "This is versatile Fashionable fairy skirt, the overall design is very elegant and casual. Every girl can instantly turn into a fairy after wearing it. The material is very light and breathable, making it very comfortable to wear in summer."} ]}
```
## Configuration file
Fine-tuning configuration files are located in the `config` directory and include the following files:
1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed configuration file.
2. `lora.yaml / ptuning.yaml / sft.yaml`: Configuration files for different models, including model parameters,
optimizer parameters, training parameters, etc. Some important parameters are explained as follows:
+ data_config section
+ train_file: The file path of the training data set.
+ val_file: The file path of the verification data set.
+ test_file: The file path of the test data set.
+ num_proc: Number of processes used when loading data.
+ max_input_length: The maximum length of the input sequence.
+ max_output_length: The maximum length of the output sequence.
+ training_args section
+ output_dir: Directory for saving models and other outputs.
+ max_steps: The maximum number of steps for training.
+ per_device_train_batch_size: training batch size per device (e.g. GPU).
+ dataloader_num_workers: The number of worker threads used when loading data.
+ remove_unused_columns: Whether to remove unused columns in the data.
+ save_strategy: model saving strategy (for example, how many steps should be saved).
+ save_steps: How many steps should be taken to save the model.
+ log_level: log level (such as info).
+ logging_strategy: logging strategy.
+ logging_steps: How many steps to log.
+ per_device_eval_batch_size: Evaluation batch size per device.
+ evaluation_strategy: Evaluation strategy (e.g. how many steps should be evaluated).
+ eval_steps: How many steps to evaluate.
+ predict_with_generate: Whether to use generate mode for prediction.
+ generation_config section
+ max_new_tokens: The maximum number of new tokens generated.
+ peft_config section
+ peft_type: The parameter valid adjustment type used (e.g. LORA).
+ task_type: task type, here is the causal language model (CAUSAL_LM).
+ Lora parameters:
+ r: LoRA rank.
+ lora_alpha: Scaling factor for LoRA.
+ lora_dropout: dropout probability used in LoRA layer
+ P-TuningV2 parameters:
+ num_virtual_tokens: The number of virtual tokens.
## Start fine-tuning
Use the following code to execute **single machine multiple cards/multiple machines multiple cards** operation.
```angular2html
cd finetune_demo
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml configs/ds_zero_2.json
```
Execute **Single machine single card** operation through the following code.
```angular2html
cd finetune_demo
python finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml
```
## Fine-tuning from a checkpoint
If you train according to the above method, each fine-tuning will start from scratch. If you want to fine-tune from a
half-trained model, you can add a fourth parameter, which has two ways to pass in:
1. `yes`, automatically start training from the last saved Checkpoint
2. `XX`, breakpoint number, for example, `600` means training from Checkpoint number 600
For example, this is an example of continuing fine-tuning from the last saved point
```angular2html
cd finetune_demo
python finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml yes
```
## Use the fine-tuned model
### Verify the fine-tuned model in inference_hf.py
You can use our fine-tuned model in `finetune_demo/inference_hf.py`, which can be easily tested with just one line of
code.
```angular2html
python inference_hf.py your_finetune_path --prompt your prompt
```
In this way, the answer you get is a fine-tuned answer.
### Use the fine-tuned model in other demos in this repos or external repos
You can use our `lora` and fully parameterized fine-tuned models in any demo, as follows:
1. Use the method of reading the model in `finetune_demo/inference_hf.py` to replace the method of reading the model in
the demo.
> Please note that for LORA and P-TuningV2 we do not merge the trained models, but in `adapter_config.json`
> The fine-tuning path is recorded in . If your original model location changes, you should modify the path
> of `base_model_name_or_path` in `adapter_config.json`.
> Please note that we have only tested using NVIDIA Hopper (representative GPU: H100) and Ampère (representative GPU:
> A100) architecture and series of graphics cards. If you use a graphics card with another architecture, you may
> experience
> 1. Unknown training problem/Video memory usage is different from the above.
> 2. The architecture is too low and does not support certain features.
> 3. The problem of reasoning effect.
> The above three situations are problems that the community has encountered before. Although the probability is
extremely low, if you encounter the above problems, you can try to solve them in the community.
```python
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code
)
return model, tokenizer
```
2. Read the fine-tuned model, please note that you should use the location of the fine-tuned model, for example, if your
model location is `/path/to/finetune_adapter_model`
, the original model address is `path/to/base_model`, then you should use `/path/to/finetune_adapter_model`
as `model_dir`.
3. After completing the above operations, you can use the fine-tuned model normally, and other calling methods remain
unchanged.
### hint
1. Before starting training, the fine-tuning code will print the preprocessing information of the first training data (
it is commented by default and can be uncommented), which is displayed as
```log
Sanity
Check >> >> >> >> >> >> >
'[gMASK]': 64790 -> -100
'sop': 64792 -> -100
'<|system|>': 64794 -> -100
'': 30910 -> -100
'\n': 13 -> -100
'Answer': 20115 -> -100
'the': 267 -> -100
'following': 1762 -> -100
...
'know': 683 -> -100
'the': 267 -> -100
'response': 3010 -> -100
'details': 3296 -> -100
'.': 30930 -> -100
'<|assistant|>': 64796 -> -100
'': 30910 -> 30910
'\n': 13 -> 13
'I': 307 -> 307
'need': 720 -> 720
'to': 289 -> 289
'use': 792 -> 792
...
<< << << << << << < Sanity
Check
```
words, each line represents a detokenized string, token_id and target_id in turn. Among them, `target_id` is the index
of `token_id` in the model vocabulary, and `-100` means that
Token does not participate in `loss` calculation.
2. The function of `_prepare_model_for_training` is to iterate through all the trainable parameters of the model and
ensure that their data type is `torch.float32`.
This is necessary in some cases because mixed precision training or other operations may change the data type of the
model parameters. This code is opened by default and can be commented, but if you use
If there is a problem with `half` format training, you can switch back to this code, and the video memory may
increase.
3. In our [Huggingface model code](https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py), there is the
following content:
```python
if self.gradient_checkpointing and self.training:
layer_ret = torch.utils.checkpoint.checkpoint(
layer,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_caches[index],
use_cache,
use_reentrant=False
)
```
This may cause the video memory to increase during training, so if you have insufficient video memory, you can try
changing ``` use_reentrant``` to `True`.
4. The fine-tuned model can use any model acceleration framework that supports `peft` loading. Here, we do not provide a
demo.
5. There are certain differences between the fine-tuning data set format of this warehouse and the API fine-tuning data
set format.
+ The `messages` field in the ZhipuAI API fine-tuning data set is the `conversation` field in this warehouse.
+ The fine-tuning file in ZhipuAI API is `jsonl`. In this warehouse, you need to simply change the file name
to `json`.
## Citation
```
@inproceedings{liu2022p,
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short
Papers)},
pages={61--68},
year={2022}
}
@misc{tang2023toolalpaca,
title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases},
author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun},
year={2023},
eprint={2306.05301},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
\ No newline at end of file
{
"train_micro_batch_size_per_gpu": "auto",
"zero_allow_untested_optimizer": true,
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"zero_optimization": {
"stage": 3,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"contiguous_gradients": true,
"overlap_comm": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
}
}
\ No newline at end of file
data_config:
train_file: train.json
val_file: dev.json
test_file: dev.json
num_proc: 16
max_input_length: 256
max_output_length: 512
training_args:
# see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output
max_steps: 3000
# needed to be fit for the dataset
learning_rate: 5e-5
# settings for data loading
per_device_train_batch_size: 4
dataloader_num_workers: 16
remove_unused_columns: false
# settings for saving checkpoints
save_strategy: steps
save_steps: 500
# settings for logging
log_level: info
logging_strategy: steps
logging_steps: 10
# settings for evaluation
per_device_eval_batch_size: 16
evaluation_strategy: steps
eval_steps: 500
# settings for optimizer
# adam_epsilon: 1e-6
# uncomment the following line to detect nan or inf values
# debug: underflow_overflow
predict_with_generate: true
# see `transformers.GenerationConfig`
generation_config:
max_new_tokens: 512
# set your absolute deepspeed path here
#deepspeed: ds_zero_2.json
# set to true if train with cpu.
use_cpu: false
peft_config:
peft_type: LORA
task_type: CAUSAL_LM
r: 8
lora_alpha: 32
lora_dropout: 0.1
data_config:
train_file: train.json
val_file: dev.json
test_file: dev.json
num_proc: 16
max_input_length: 256
max_output_length: 512
training_args:
# see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output
max_steps: 3000
# needed to be fit for the dataset
learning_rate: 5e-5
# settings for data loading
per_device_train_batch_size: 4
dataloader_num_workers: 16
remove_unused_columns: false
# settings for saving checkpoints
save_strategy: steps
save_steps: 500
# settings for logging
log_level: info
logging_strategy: steps
logging_steps: 10
# settings for evaluation
per_device_eval_batch_size: 16
evaluation_strategy: steps
eval_steps: 500
# settings for optimizer
# adam_epsilon: 1e-6
# uncomment the following line to detect nan or inf values
# debug: underflow_overflow
predict_with_generate: true
# see `transformers.GenerationConfig`
generation_config:
max_new_tokens: 512
# set your absolute deepspeed path here
#deepspeed: ds_zero_3.json
use_cpu: false
peft_config:
peft_type: PREFIX_TUNING
task_type: CAUSAL_LM
num_virtual_tokens: 128
data_config:
train_file: train.json
val_file: dev.json
test_file: dev.json
num_proc: 16
max_input_length: 256
max_output_length: 512
training_args:
# see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output
max_steps: 3000
# needed to be fit for the dataset
learning_rate: 5e-5
# settings for data loading
per_device_train_batch_size: 4
dataloader_num_workers: 16
remove_unused_columns: false
# settings for saving checkpoints
save_strategy: steps
save_steps: 500
# settings for logging
log_level: info
logging_strategy: steps
logging_steps: 10
# settings for evaluation
per_device_eval_batch_size: 16
evaluation_strategy: steps
eval_steps: 500
# settings for optimizer
# adam_epsilon: 1e-6
# uncomment the following line to detect nan or inf values
# debug: underflow_overflow
predict_with_generate: true
generation_config:
max_new_tokens: 512
# set your absolute deepspeed path here
deepspeed: ds_zero_3.json
# -*- coding: utf-8 -*-
import os
import jieba
import dataclasses as dc
import functools
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import Annotated, Any, Optional, Union
import numpy as np
import ruamel.yaml as yaml
import torch
import typer
from datasets import Dataset, DatasetDict, NamedSplit, Split, load_dataset
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from peft import (
PeftConfig,
PeftModelForCausalLM,
get_peft_config,
get_peft_model
)
from rouge_chinese import Rouge
from torch import nn
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
EvalPrediction,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
Seq2SeqTrainingArguments, AutoConfig,
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
app = typer.Typer(pretty_exceptions_show_locals=False)
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
output_ids = (
[feature['output_ids'] for feature in features]
if 'output_ids' in features[0].keys()
else None
)
if output_ids is not None:
max_output_length = max(len(out) for out in output_ids)
if self.pad_to_multiple_of is not None:
max_output_length = (
(
max_output_length + self.pad_to_multiple_of - 1) //
self.pad_to_multiple_of * self.pad_to_multiple_of
)
for feature in features:
remainder = [self.tokenizer.pad_token_id] * (
max_output_length - len(feature['output_ids'])
)
if isinstance(feature['output_ids'], list):
feature['output_ids'] = feature['output_ids'] + remainder
else:
feature['output_ids'] = np.concatenate(
[feature['output_ids'], remainder]
).astype(np.int64)
return super().__call__(features, return_tensors)
class Seq2SeqTrainer(_Seq2SeqTrainer):
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, Any],
prediction_loss_only: bool,
ignore_keys=None,
**gen_kwargs,
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.args.predict_with_generate:
output_ids = inputs.pop('output_ids')
input_ids = inputs['input_ids']
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
)
generated_tokens = generated_tokens[:, input_ids.size()[1]:]
if self.args.predict_with_generate:
labels = output_ids
return loss, generated_tokens, labels
# For P-Tuning a new save_model function is fine for the prefix_encoder model
# but may cost problems for the whole model loading
# def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
# if output_dir is None:
# output_dir = self.args.output_dir
# os.makedirs(output_dir, exist_ok=True)
# ptuning_params = {k: v for k, v in self.model.transformer.prefix_encoder.state_dict().items()}
#
# torch.save(ptuning_params, os.path.join(output_dir, 'pytorch_model.bin'))
#
# print(f"P-Tuning model weights saved in {output_dir}")
#
# if self.tokenizer is not None:
# self.tokenizer.save_pretrained(output_dir)
def _resolve_path(path: Union[str, Path]) -> Path:
return Path(path).expanduser().resolve()
def _sanity_check(
input_ids: Sequence[int],
output_ids: Sequence[int],
tokenizer: PreTrainedTokenizer,
):
print('--> Sanity check')
for in_id, out_id in zip(input_ids, output_ids):
if in_id == 0:
continue
if in_id in tokenizer.tokenizer.index_special_tokens:
in_text = tokenizer.tokenizer.index_special_tokens[in_id]
else:
in_text = tokenizer.decode([in_id])
print(f'{repr(in_text):>20}: {in_id} -> {out_id}')
@functools.cache
def _get_yaml_parser() -> yaml.YAML:
parser = yaml.YAML(typ='safe', pure=True)
parser.indent(mapping=2, offset=2, sequence=4)
parser.default_flow_style = False
return parser
@dc.dataclass
class DataConfig(object):
train_file: str
val_file: Optional[str] = None
test_file: Optional[str] = None
num_proc: Optional[int] = None
@property
def data_format(self) -> str:
return Path(self.train_file).suffix
@property
def data_files(self) -> dict[NamedSplit, str]:
return {
split: data_file
for split, data_file in zip(
[Split.TRAIN, Split.VALIDATION, Split.TEST],
[self.train_file, self.val_file, self.test_file],
)
if data_file is not None
}
@dc.dataclass
class FinetuningConfig(object):
data_config: DataConfig
max_input_length: int
max_output_length: int
training_args: Seq2SeqTrainingArguments = dc.field(
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
)
peft_config: Optional[PeftConfig] = None
def __post_init__(self):
if not self.training_args.do_eval or self.data_config.val_file is None:
# skips the evaluation stage when `do_eval` or `eval_file` is not provided
self.training_args.do_eval = False
self.training_args.evaluation_strategy = 'no'
self.data_config.val_file = None
else:
self.training_args.per_device_eval_batch_size = (
self.training_args.per_device_eval_batch_size
or self.training_args.per_device_train_batch_size
)
@classmethod
def from_dict(cls, **kwargs) -> 'FinetuningConfig':
training_args = kwargs.get('training_args', None)
if training_args is not None and not isinstance(
training_args, Seq2SeqTrainingArguments
):
gen_config = training_args.get('generation_config')
# TODO: a bit hacky
if not isinstance(gen_config, GenerationConfig):
training_args['generation_config'] = GenerationConfig(
**gen_config
)
kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)
data_config = kwargs.get('data_config')
if not isinstance(data_config, DataConfig):
kwargs['data_config'] = DataConfig(**data_config)
peft_config = kwargs.get('peft_config', None)
if peft_config is not None and not isinstance(peft_config, PeftConfig):
kwargs['peft_config'] = get_peft_config(peft_config)
return cls(**kwargs)
@classmethod
def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
path = _resolve_path(path)
kwargs = _get_yaml_parser().load(path)
return cls.from_dict(**kwargs)
def _load_datasets(
data_dir: Path,
data_format: str,
data_files: dict[NamedSplit, str],
num_proc: Optional[int],
) -> DatasetDict:
if data_format in ('.csv', '.json', '.jsonl'):
dataset_dct = load_dataset(
data_format[1:],
data_dir=data_dir,
data_files=data_files,
num_proc=num_proc,
)
else:
err_msg = f"Cannot load dataset in the '{data_format}' format."
raise NotImplementedError(err_msg)
return dataset_dct
class DataManager(object):
def __init__(self, data_dir: str, data_config: DataConfig):
self._num_proc = data_config.num_proc
self._dataset_dct = _load_datasets(
_resolve_path(data_dir),
data_config.data_format,
data_config.data_files,
self._num_proc,
)
def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
return self._dataset_dct.get(split, None)
def get_dataset(
self,
split: NamedSplit,
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
batched: bool = True,
remove_orig_columns: bool = True,
) -> Optional[Dataset]:
orig_dataset = self._get_dataset(split)
if orig_dataset is None:
return
if remove_orig_columns:
remove_columns = orig_dataset.column_names
else:
remove_columns = None
return orig_dataset.map(
process_fn,
batched=batched,
remove_columns=remove_columns,
num_proc=self._num_proc,
)
def print_model_size(model: PreTrainedModel):
print("--> Model")
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n--> model has {total_params / 1e6}M params\n")
def process_batch(
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
) -> dict[str, list]:
batched_tools = batch.get('tools', None)
batched_conv = batch['conversations']
batched_input_ids = []
batched_labels = []
if batched_tools is None:
batched_tools = [None] * len(batched_conv)
for tools, conv in zip(batched_tools, batched_conv):
input_ids, loss_masks = [
tokenizer.get_command('[gMASK]'),
tokenizer.get_command('sop'),
], [False, False]
if tools is not None:
raise NotImplementedError()
for message in conv:
if message['role'] in ('system', 'user'):
loss_mask_val = False
else:
loss_mask_val = True
if message['role'] == 'tool':
raise NotImplementedError()
else:
new_input_ids = tokenizer.build_single_message(
message['role'], '', message['content']
)
new_loss_masks = [loss_mask_val] * len(new_input_ids)
input_ids += new_input_ids
loss_masks += new_loss_masks
input_ids.append(tokenizer.eos_token_id)
loss_masks = [False, *loss_masks]
labels = []
for input_id, mask in zip(input_ids, loss_masks):
if mask:
labels.append(input_id)
else:
labels.append(-100)
max_length = max_input_length + max_output_length + 1
batched_input_ids.append(input_ids[:max_length])
batched_labels.append(labels[:max_length])
return {'input_ids': batched_input_ids, 'labels': batched_labels}
def process_batch_eval(
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
) -> dict[str, list]:
batched_tools = batch.get('tools', None)
batched_conv = batch['conversations']
batched_input_ids = []
# To avoid computing loss, we do not provide the `labels` field in the input dictionary.
batched_output_ids = []
if batched_tools is None:
batched_tools = [None] * len(batched_conv)
for tools, conv in zip(batched_tools, batched_conv):
input_ids = [
tokenizer.get_command('[gMASK]'),
tokenizer.get_command('sop'),
]
if tools is not None:
raise NotImplementedError()
for message in conv:
if len(input_ids) >= max_input_length:
break
if message['role'] == 'tool':
raise NotImplementedError()
else:
new_input_ids = tokenizer.build_single_message(
message['role'], '', message['content']
)
if message['role'] == 'assistant':
output_prompt, output_ids = (
new_input_ids[:1],
new_input_ids[1:],
)
output_ids.append(tokenizer.eos_token_id)
batched_input_ids.append(
input_ids[:max_input_length] + output_prompt[:1]
)
batched_output_ids.append(output_ids[:max_output_length])
input_ids += new_input_ids
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
# Not sure if this is necessary, can set it to half.
# If train with cpu, cast all params to fp32 instead of trainable ones.
def _prepare_model_for_training(model: nn.Module, use_cpu: bool):
for param in model.parameters():
if param.requires_grad or use_cpu:
param.data = param.data.to(torch.float32)
def load_tokenizer_and_model(
model_dir: str,
peft_config: Optional[PeftConfig] = None,
) -> tuple[PreTrainedTokenizer, nn.Module]:
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
if peft_config is not None:
if peft_config.peft_type.name == "PREFIX_TUNING":
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
config.pre_seq_len = peft_config.num_virtual_tokens
config.use_cache = False
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
config=config,
)
if peft_config.peft_type.name == "LORA":
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
empty_init=False,
use_cache=False
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
empty_init=False,
use_cache=False
)
print_model_size(model)
return tokenizer, model
def compute_metrics(eval_preds: EvalPrediction, tokenizer: PreTrainedTokenizer):
batched_pred_ids, batched_label_ids = eval_preds
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
pred_txt = tokenizer.decode(pred_ids).strip()
label_txt = tokenizer.decode(label_ids).strip()
pred_tokens = list(jieba.cut(pred_txt))
label_tokens = list(jieba.cut(label_txt))
rouge = Rouge()
scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
for k, v in scores[0].items():
metrics_dct[k].append(round(v['f'] * 100, 4))
metrics_dct['bleu-4'].append(
sentence_bleu(
[label_tokens],
pred_tokens,
smoothing_function=SmoothingFunction().method3,
)
)
return {k: np.mean(v) for k, v in metrics_dct.items()}
@app.command()
def main(
data_dir: Annotated[str, typer.Argument(help='')],
model_dir: Annotated[
str,
typer.Argument(
help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
),
],
config_file: Annotated[str, typer.Argument(help='')],
auto_resume_from_checkpoint: str = typer.Argument(
default='',
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
),
):
ft_config = FinetuningConfig.from_file(config_file)
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)
train_dataset = data_manager.get_dataset(
Split.TRAIN,
functools.partial(
process_batch,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
print('train_dataset:', train_dataset)
val_dataset = data_manager.get_dataset(
Split.VALIDATION,
functools.partial(
process_batch_eval,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
if val_dataset is not None:
print('val_dataset:', val_dataset)
test_dataset = data_manager.get_dataset(
Split.TEST,
functools.partial(
process_batch_eval,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
if test_dataset is not None:
print('test_dataset:', test_dataset)
# checks encoded dataset
_sanity_check(
train_dataset[0]["input_ids"], train_dataset[0]["labels"], tokenizer
)
# turn model to fp32
_prepare_model_for_training(model, ft_config.training_args.use_cpu)
ft_config.training_args.generation_config.pad_token_id = (
tokenizer.pad_token_id
)
ft_config.training_args.generation_config.eos_token_id = [
tokenizer.eos_token_id,
tokenizer.get_command('<|user|>'),
tokenizer.get_command('<|observation|>'),
]
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
trainer = Seq2SeqTrainer(
model=model,
args=ft_config.training_args,
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding='longest',
return_tensors='pt',
),
train_dataset=train_dataset,
eval_dataset=val_dataset.select(list(range(50))),
tokenizer=tokenizer if ft_config.peft_config.peft_type != "LORA" else None, # LORA does not need tokenizer
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
)
if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
trainer.train()
else:
output_dir = ft_config.training_args.output_dir
dirlist = os.listdir(output_dir)
checkpoint_sn = 0
for checkpoint_str in dirlist:
if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
if checkpoint > checkpoint_sn:
checkpoint_sn = checkpoint
if auto_resume_from_checkpoint.upper() == "YES":
if checkpoint_sn > 0:
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
trainer.train()
else:
if auto_resume_from_checkpoint.isdigit():
if int(auto_resume_from_checkpoint) > 0:
checkpoint_sn = int(auto_resume_from_checkpoint)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
print(auto_resume_from_checkpoint,
"The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct chkeckpoint in the model output directory")
# test stage
if test_dataset is not None:
trainer.predict(test_dataset)
if __name__ == '__main__':
app()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from pathlib import Path
from typing import Annotated, Union
import typer
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
app = typer.Typer(pretty_exceptions_show_locals=False)
def _resolve_path(path: Union[str, Path]) -> Path:
return Path(path).expanduser().resolve()
def load_model_and_tokenizer(model_dir: Union[str, Path]) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=True, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=True, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=True
)
return model, tokenizer
@app.command()
def main(
model_dir: Annotated[str, typer.Argument(help='')],
prompt: Annotated[str, typer.Option(help='')],
):
model, tokenizer = load_model_and_tokenizer(model_dir)
response, _ = model.chat(tokenizer, prompt)
print(response)
if __name__ == '__main__':
app()
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# 单卡GPU 进行 ChatGLM3-6B模型 LORA 高效微调\n",
"本 Cookbook 将带领开发者使用 `AdvertiseGen` 对 ChatGLM3-6B 数据集进行 lora微调,使其具备专业的广告生成能力。\n",
"\n",
"## 硬件需求\n",
"显存:24GB及以上(推荐使用30系或A10等sm80架构以上的NVIDIA显卡进行尝试)\n",
"内存:16GB\n",
"RAM: 2.9 /16 GB\n",
"GPU RAM: 15.5/16.0 GB"
],
"metadata": {
"collapsed": false,
"id": "89b89f64d8f8053d"
},
"id": "89b89f64d8f8053d"
},
{
"cell_type": "markdown",
"source": [
"## 0. 环境检查\n",
"首先,先检查代码的运行地址,确保运行地址处于 `finetune_demo` 中。\n",
"并且,确保已经安装了 `requirements.txt`中的依赖。\n",
"\n",
"> 本 demo 中,不需要使用 deepspeed, mpi4py 两个依赖,如果您安装这两个依赖遇到问题,可以不安装这两个依赖。"
],
"metadata": {
"collapsed": false,
"id": "a7bd9a514ed09ea6"
},
"id": "a7bd9a514ed09ea6"
},
{
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/media/zr/Data/Code/ChatGLM3/finetune_demo\r\n"
]
}
],
"source": [
"!pwd"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-04-14T05:29:22.200365Z",
"start_time": "2024-04-14T05:29:22.080929Z"
}
},
"id": "f7703109d1443346",
"execution_count": 1
},
{
"cell_type": "markdown",
"source": [
"## 1. 准备数据集\n",
"我们使用 AdvertiseGen 数据集来进行微调。从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 AdvertiseGen 数据集,将解压后的 AdvertiseGen 目录放到本目录的 `/data/` 下, 例如。\n",
"> /media/zr/Data/Code/ChatGLM3/finetune_demo/data/AdvertiseGen"
],
"metadata": {
"collapsed": false
},
"id": "2f50e92810011977"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"import json\n",
"from typing import Union\n",
"from pathlib import Path\n",
"\n",
"\n",
"def _resolve_path(path: Union[str, Path]) -> Path:\n",
" return Path(path).expanduser().resolve()\n",
"\n",
"\n",
"def _mkdir(dir_name: Union[str, Path]):\n",
" dir_name = _resolve_path(dir_name)\n",
" if not dir_name.is_dir():\n",
" dir_name.mkdir(parents=True, exist_ok=False)\n",
"\n",
"\n",
"def convert_adgen(data_dir: Union[str, Path], save_dir: Union[str, Path]):\n",
" def _convert(in_file: Path, out_file: Path):\n",
" _mkdir(out_file.parent)\n",
" with open(in_file, encoding='utf-8') as fin:\n",
" with open(out_file, 'wt', encoding='utf-8') as fout:\n",
" for line in fin:\n",
" dct = json.loads(line)\n",
" sample = {'conversations': [{'role': 'user', 'content': dct['content']},\n",
" {'role': 'assistant', 'content': dct['summary']}]}\n",
" fout.write(json.dumps(sample, ensure_ascii=False) + '\\n')\n",
"\n",
" data_dir = _resolve_path(data_dir)\n",
" save_dir = _resolve_path(save_dir)\n",
"\n",
" train_file = data_dir / 'train.json'\n",
" if train_file.is_file():\n",
" out_file = save_dir / train_file.relative_to(data_dir)\n",
" _convert(train_file, out_file)\n",
"\n",
" dev_file = data_dir / 'dev.json'\n",
" if dev_file.is_file():\n",
" out_file = save_dir / dev_file.relative_to(data_dir)\n",
" _convert(dev_file, out_file)\n",
"\n",
"\n",
"convert_adgen('data/AdvertiseGen', 'data/AdvertiseGen_fix')"
],
"metadata": {
"collapsed": true,
"cellView": "form",
"id": "initial_id",
"ExecuteTime": {
"end_time": "2024-04-14T05:29:23.809255Z",
"start_time": "2024-04-14T05:29:22.202731Z"
}
},
"id": "initial_id",
"execution_count": 2
},
{
"cell_type": "markdown",
"source": [
"## 2. 使用命令行开始微调,我们使用 lora 进行微调\n",
"接着,我们仅需要将配置好的参数以命令行的形式传参给程序,就可以使用命令行进行高效微调。"
],
"metadata": {
"collapsed": false,
"id": "a1b7a99923349056"
},
"id": "a1b7a99923349056"
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Setting eos_token is not supported, use the default one.\r\n",
"Setting pad_token is not supported, use the default one.\r\n",
"Setting unk_token is not supported, use the default one.\r\n",
"Loading checkpoint shards: 100%|██████████████████| 7/7 [00:02<00:00, 2.77it/s]\r\n",
"trainable params: 1,949,696 || all params: 6,245,533,696 || trainable%: 0.031217444255383614\r\n",
"--> Model\r\n",
"\r\n",
"--> model has 1.949696M params\r\n",
"\r\n",
"Setting num_proc from 16 back to 1 for the train split to disable multiprocessing as it only contains one shard.\r\n",
"Generating train split: 114599 examples [00:00, 836881.77 examples/s]\r\n",
"Setting num_proc from 16 back to 1 for the validation split to disable multiprocessing as it only contains one shard.\r\n",
"Generating validation split: 1070 examples [00:00, 252512.53 examples/s]\r\n",
"Setting num_proc from 16 back to 1 for the test split to disable multiprocessing as it only contains one shard.\r\n",
"Generating test split: 1070 examples [00:00, 313510.67 examples/s]\r\n",
"Map (num_proc=16): 100%|██████| 114599/114599 [00:02<00:00, 39254.76 examples/s]\r\n",
"train_dataset: Dataset({\r\n",
" features: ['input_ids', 'labels'],\r\n",
" num_rows: 114599\r\n",
"})\r\n",
"Map (num_proc=16): 100%|███████████| 1070/1070 [00:00<00:00, 1399.56 examples/s]\r\n",
"val_dataset: Dataset({\r\n",
" features: ['input_ids', 'output_ids'],\r\n",
" num_rows: 1070\r\n",
"})\r\n",
"Map (num_proc=16): 100%|███████████| 1070/1070 [00:00<00:00, 1339.19 examples/s]\r\n",
"test_dataset: Dataset({\r\n",
" features: ['input_ids', 'output_ids'],\r\n",
" num_rows: 1070\r\n",
"})\r\n",
"--> Sanity check\r\n",
" '[gMASK]': 64790 -> -100\r\n",
" 'sop': 64792 -> -100\r\n",
" '<|user|>': 64795 -> -100\r\n",
" '': 30910 -> -100\r\n",
" '\\n': 13 -> -100\r\n",
" '': 30910 -> -100\r\n",
" '类型': 33467 -> -100\r\n",
" '#': 31010 -> -100\r\n",
" '裤': 56532 -> -100\r\n",
" '*': 30998 -> -100\r\n",
" '版': 55090 -> -100\r\n",
" '型': 54888 -> -100\r\n",
" '#': 31010 -> -100\r\n",
" '宽松': 40833 -> -100\r\n",
" '*': 30998 -> -100\r\n",
" '风格': 32799 -> -100\r\n",
" '#': 31010 -> -100\r\n",
" '性感': 40589 -> -100\r\n",
" '*': 30998 -> -100\r\n",
" '图案': 37505 -> -100\r\n",
" '#': 31010 -> -100\r\n",
" '线条': 37216 -> -100\r\n",
" '*': 30998 -> -100\r\n",
" '裤': 56532 -> -100\r\n",
" '型': 54888 -> -100\r\n",
" '#': 31010 -> -100\r\n",
" '阔': 56529 -> -100\r\n",
" '腿': 56158 -> -100\r\n",
" '裤': 56532 -> -100\r\n",
" '<|assistant|>': 64796 -> -100\r\n",
" '': 30910 -> 30910\r\n",
" '\\n': 13 -> 13\r\n",
" '': 30910 -> 30910\r\n",
" '宽松': 40833 -> 40833\r\n",
" '的': 54530 -> 54530\r\n",
" '阔': 56529 -> 56529\r\n",
" '腿': 56158 -> 56158\r\n",
" '裤': 56532 -> 56532\r\n",
" '这': 54551 -> 54551\r\n",
" '两年': 33808 -> 33808\r\n",
" '真的': 32041 -> 32041\r\n",
" '吸': 55360 -> 55360\r\n",
" '粉': 55486 -> 55486\r\n",
" '不少': 32138 -> 32138\r\n",
" ',': 31123 -> 31123\r\n",
" '明星': 32943 -> 32943\r\n",
" '时尚': 33481 -> 33481\r\n",
" '达': 54880 -> 54880\r\n",
" '人的': 31664 -> 31664\r\n",
" '心头': 46565 -> 46565\r\n",
" '爱': 54799 -> 54799\r\n",
" '。': 31155 -> 31155\r\n",
" '毕竟': 33051 -> 33051\r\n",
" '好': 54591 -> 54591\r\n",
" '穿': 55432 -> 55432\r\n",
" '时尚': 33481 -> 33481\r\n",
" ',': 31123 -> 31123\r\n",
" '谁': 55622 -> 55622\r\n",
" '都能': 32904 -> 32904\r\n",
" '穿': 55432 -> 55432\r\n",
" '出': 54557 -> 54557\r\n",
" '腿': 56158 -> 56158\r\n",
" '长': 54625 -> 54625\r\n",
" '2': 30943 -> 30943\r\n",
" '米': 55055 -> 55055\r\n",
" '的效果': 35590 -> 35590\r\n",
" '宽松': 40833 -> 40833\r\n",
" '的': 54530 -> 54530\r\n",
" '裤': 56532 -> 56532\r\n",
" '腿': 56158 -> 56158\r\n",
" ',': 31123 -> 31123\r\n",
" '当然是': 48466 -> 48466\r\n",
" '遮': 57148 -> 57148\r\n",
" '肉': 55343 -> 55343\r\n",
" '小': 54603 -> 54603\r\n",
" '能手': 49355 -> 49355\r\n",
" '啊': 55674 -> 55674\r\n",
" '。': 31155 -> 31155\r\n",
" '上身': 51605 -> 51605\r\n",
" '随': 55119 -> 55119\r\n",
" '性': 54642 -> 54642\r\n",
" '自然': 31799 -> 31799\r\n",
" '不': 54535 -> 54535\r\n",
" '拘': 57036 -> 57036\r\n",
" '束': 55625 -> 55625\r\n",
" ',': 31123 -> 31123\r\n",
" '面料': 46839 -> 46839\r\n",
" '亲': 55113 -> 55113\r\n",
" '肤': 56089 -> 56089\r\n",
" '舒适': 33894 -> 33894\r\n",
" '贴': 55778 -> 55778\r\n",
" '身体': 31902 -> 31902\r\n",
" '验': 55017 -> 55017\r\n",
" '感': 54706 -> 54706\r\n",
" '棒': 56382 -> 56382\r\n",
" '棒': 56382 -> 56382\r\n",
" '哒': 59230 -> 59230\r\n",
" '。': 31155 -> 31155\r\n",
" '系': 54712 -> 54712\r\n",
" '带': 54882 -> 54882\r\n",
" '部分': 31726 -> 31726\r\n",
" '增加': 31917 -> 31917\r\n",
" '设计': 31735 -> 31735\r\n",
" '看点': 45032 -> 45032\r\n",
" ',': 31123 -> 31123\r\n",
" '还': 54656 -> 54656\r\n",
" '让': 54772 -> 54772\r\n",
" '单品': 46539 -> 46539\r\n",
" '的设计': 34481 -> 34481\r\n",
" '感': 54706 -> 54706\r\n",
" '更强': 43084 -> 43084\r\n",
" '。': 31155 -> 31155\r\n",
" '腿部': 46799 -> 46799\r\n",
" '线条': 37216 -> 37216\r\n",
" '若': 55351 -> 55351\r\n",
" '隐': 55733 -> 55733\r\n",
" '若': 55351 -> 55351\r\n",
" '现': 54600 -> 54600\r\n",
" '的': 54530 -> 54530\r\n",
" ',': 31123 -> 31123\r\n",
" '性感': 40589 -> 40589\r\n",
" '撩': 58521 -> 58521\r\n",
" '人': 54533 -> 54533\r\n",
" '。': 31155 -> 31155\r\n",
" '颜色': 33692 -> 33692\r\n",
" '敲': 57004 -> 57004\r\n",
" '温柔': 34678 -> 34678\r\n",
" '的': 54530 -> 54530\r\n",
" ',': 31123 -> 31123\r\n",
" '与': 54619 -> 54619\r\n",
" '裤子': 44722 -> 44722\r\n",
" '本身': 32754 -> 32754\r\n",
" '所': 54626 -> 54626\r\n",
" '呈现': 33169 -> 33169\r\n",
" '的风格': 48084 -> 48084\r\n",
" '有点': 33149 -> 33149\r\n",
" '反': 54955 -> 54955\r\n",
" '差': 55342 -> 55342\r\n",
" '萌': 56842 -> 56842\r\n",
" '。': 31155 -> 31155\r\n",
" '': 2 -> 2\r\n",
"/media/zr/Data/Code/ChatGLM3/venv/lib/python3.10/site-packages/accelerate/accelerator.py:436: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \r\n",
"dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\r\n",
" warnings.warn(\r\n",
"max_steps is given, it will override any value given in num_train_epochs\r\n",
"***** Running training *****\r\n",
" Num examples = 114,599\r\n",
" Num Epochs = 1\r\n",
" Instantaneous batch size per device = 4\r\n",
" Total train batch size (w. parallel, distributed & accumulation) = 4\r\n",
" Gradient Accumulation steps = 1\r\n",
" Total optimization steps = 4,000\r\n",
" Number of trainable parameters = 1,949,696\r\n",
"{'loss': 4.832, 'grad_norm': 2.1177706718444824, 'learning_rate': 4.9875000000000006e-05, 'epoch': 0.0}\r\n",
"{'loss': 4.6094, 'grad_norm': 3.104412078857422, 'learning_rate': 4.975e-05, 'epoch': 0.0}\r\n",
"{'loss': 4.5043, 'grad_norm': 2.9755077362060547, 'learning_rate': 4.962500000000001e-05, 'epoch': 0.0}\r\n",
"{'loss': 4.14, 'grad_norm': 3.3869752883911133, 'learning_rate': 4.9500000000000004e-05, 'epoch': 0.0}\r\n",
"{'loss': 4.1275, 'grad_norm': 2.698483467102051, 'learning_rate': 4.937500000000001e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.8748, 'grad_norm': 2.9052674770355225, 'learning_rate': 4.9250000000000004e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.8506, 'grad_norm': 2.8566994667053223, 'learning_rate': 4.9125e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.7518, 'grad_norm': 2.9119534492492676, 'learning_rate': 4.9e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.6375, 'grad_norm': 3.1845204830169678, 'learning_rate': 4.8875e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.7219, 'grad_norm': 3.359720230102539, 'learning_rate': 4.875e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.676, 'grad_norm': 3.559992790222168, 'learning_rate': 4.8625e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.849, 'grad_norm': 3.822449207305908, 'learning_rate': 4.85e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.6154, 'grad_norm': 3.4438886642456055, 'learning_rate': 4.8375000000000004e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.7326, 'grad_norm': 4.374788284301758, 'learning_rate': 4.825e-05, 'epoch': 0.0}\r\n",
"{'loss': 3.6854, 'grad_norm': 3.5999808311462402, 'learning_rate': 4.8125000000000004e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.7447, 'grad_norm': 3.8460822105407715, 'learning_rate': 4.8e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5766, 'grad_norm': 4.053386211395264, 'learning_rate': 4.7875000000000005e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5758, 'grad_norm': 4.296564102172852, 'learning_rate': 4.775e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5486, 'grad_norm': 4.701301574707031, 'learning_rate': 4.7625000000000006e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5775, 'grad_norm': 4.4896979331970215, 'learning_rate': 4.75e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.55, 'grad_norm': 4.9407429695129395, 'learning_rate': 4.7375e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6437, 'grad_norm': 4.0624542236328125, 'learning_rate': 4.7249999999999997e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6098, 'grad_norm': 4.786097049713135, 'learning_rate': 4.7125e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5107, 'grad_norm': 4.457597255706787, 'learning_rate': 4.7e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.4723, 'grad_norm': 5.279415130615234, 'learning_rate': 4.6875e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6016, 'grad_norm': 5.297557353973389, 'learning_rate': 4.6750000000000005e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5475, 'grad_norm': 5.397997856140137, 'learning_rate': 4.6625e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6115, 'grad_norm': 4.472784519195557, 'learning_rate': 4.6500000000000005e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6273, 'grad_norm': 4.7433905601501465, 'learning_rate': 4.6375e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5379, 'grad_norm': 5.81007194519043, 'learning_rate': 4.6250000000000006e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.4654, 'grad_norm': 5.297420501708984, 'learning_rate': 4.6125e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6057, 'grad_norm': 5.738197326660156, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.4168, 'grad_norm': 5.207597732543945, 'learning_rate': 4.5875000000000004e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.4932, 'grad_norm': 5.2784833908081055, 'learning_rate': 4.575e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.518, 'grad_norm': 5.428376197814941, 'learning_rate': 4.5625e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5727, 'grad_norm': 5.190096855163574, 'learning_rate': 4.55e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.3615, 'grad_norm': 4.818575859069824, 'learning_rate': 4.5375e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5275, 'grad_norm': 5.174643039703369, 'learning_rate': 4.525e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.5232, 'grad_norm': 5.241923809051514, 'learning_rate': 4.5125e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.4699, 'grad_norm': 5.603521823883057, 'learning_rate': 4.5e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6916, 'grad_norm': 5.468681335449219, 'learning_rate': 4.4875e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.4975, 'grad_norm': 4.969369888305664, 'learning_rate': 4.4750000000000004e-05, 'epoch': 0.01}\r\n",
"{'loss': 3.6207, 'grad_norm': 5.575362682342529, 'learning_rate': 4.4625e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4152, 'grad_norm': 6.52517032623291, 'learning_rate': 4.4500000000000004e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4098, 'grad_norm': 5.987551212310791, 'learning_rate': 4.4375e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4244, 'grad_norm': 5.613704681396484, 'learning_rate': 4.4250000000000005e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5303, 'grad_norm': 5.790269374847412, 'learning_rate': 4.4125e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4475, 'grad_norm': 7.037369728088379, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4562, 'grad_norm': 5.771510601043701, 'learning_rate': 4.3875e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5623, 'grad_norm': 5.876147747039795, 'learning_rate': 4.375e-05, 'epoch': 0.02}\r\n",
" 12%|█████ | 500/4000 [04:39<37:01, 1.58it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:16<00:16, 8.09s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:32<00:11, 11.45s/it]\u001B[A\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:49<00:00, 13.52s/it]\u001B[ABuilding prefix dict from the default dictionary ...\r\n",
"Dumping model to file cache /tmp/jieba.cache\r\n",
"Loading model cost 0.580 seconds.\r\n",
"Prefix dict has been built successfully.\r\n",
" \r\n",
"\u001B[A{'eval_rouge-1': 31.645344, 'eval_rouge-2': 6.79404, 'eval_rouge-l': 23.83732, 'eval_bleu-4': 0.03250689604242964, 'eval_runtime': 54.3911, 'eval_samples_per_second': 0.919, 'eval_steps_per_second': 0.074, 'epoch': 0.02}\r\n",
" 12%|█████ | 500/4000 [05:34<37:01, 1.58it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:50<00:00, 13.52s/it]\u001B[A\r\n",
"{'loss': 3.3207, 'grad_norm': 5.6840596199035645, 'learning_rate': 4.3625e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5459, 'grad_norm': 6.672524929046631, 'learning_rate': 4.35e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5822, 'grad_norm': 5.989180564880371, 'learning_rate': 4.3375000000000004e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4859, 'grad_norm': 5.341927528381348, 'learning_rate': 4.325e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5219, 'grad_norm': 5.3769707679748535, 'learning_rate': 4.3125000000000005e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.6453, 'grad_norm': 5.812618732452393, 'learning_rate': 4.3e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4934, 'grad_norm': 5.726740837097168, 'learning_rate': 4.2875000000000005e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.3719, 'grad_norm': 5.551002025604248, 'learning_rate': 4.275e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4236, 'grad_norm': 6.213701248168945, 'learning_rate': 4.2625000000000006e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4887, 'grad_norm': 6.39825963973999, 'learning_rate': 4.25e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4365, 'grad_norm': 6.213500499725342, 'learning_rate': 4.237500000000001e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4559, 'grad_norm': 6.593310356140137, 'learning_rate': 4.2250000000000004e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4463, 'grad_norm': 5.9485673904418945, 'learning_rate': 4.2125e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4531, 'grad_norm': 6.2323737144470215, 'learning_rate': 4.2e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5338, 'grad_norm': 5.925570964813232, 'learning_rate': 4.1875e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4822, 'grad_norm': 6.287123203277588, 'learning_rate': 4.175e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5402, 'grad_norm': 6.1548848152160645, 'learning_rate': 4.1625e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.3025, 'grad_norm': 6.961801052093506, 'learning_rate': 4.15e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4016, 'grad_norm': 6.60474967956543, 'learning_rate': 4.1375e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.3547, 'grad_norm': 6.296048641204834, 'learning_rate': 4.125e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.4992, 'grad_norm': 7.013551712036133, 'learning_rate': 4.1125000000000004e-05, 'epoch': 0.02}\r\n",
"{'loss': 3.5275, 'grad_norm': 6.747519493103027, 'learning_rate': 4.1e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.2475, 'grad_norm': 6.900665283203125, 'learning_rate': 4.0875000000000004e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.5727, 'grad_norm': 5.7873334884643555, 'learning_rate': 4.075e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.3961, 'grad_norm': 6.46198844909668, 'learning_rate': 4.0625000000000005e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4777, 'grad_norm': 6.117852687835693, 'learning_rate': 4.05e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.6215, 'grad_norm': 6.421164035797119, 'learning_rate': 4.0375e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4736, 'grad_norm': 6.280588626861572, 'learning_rate': 4.025e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.3248, 'grad_norm': 6.418524265289307, 'learning_rate': 4.0125e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.5496, 'grad_norm': 6.983282089233398, 'learning_rate': 4e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.2926, 'grad_norm': 6.696746349334717, 'learning_rate': 3.9875e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.3609, 'grad_norm': 6.474392414093018, 'learning_rate': 3.9750000000000004e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.458, 'grad_norm': 7.111743450164795, 'learning_rate': 3.9625e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4062, 'grad_norm': 6.317008018493652, 'learning_rate': 3.9500000000000005e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.5057, 'grad_norm': 6.232912540435791, 'learning_rate': 3.9375e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.5305, 'grad_norm': 6.192782402038574, 'learning_rate': 3.9250000000000005e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.2908, 'grad_norm': 7.155930042266846, 'learning_rate': 3.9125e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4904, 'grad_norm': 6.664801597595215, 'learning_rate': 3.9000000000000006e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4529, 'grad_norm': 7.4175615310668945, 'learning_rate': 3.8875e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.2643, 'grad_norm': 7.862004280090332, 'learning_rate': 3.875e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4562, 'grad_norm': 7.8772687911987305, 'learning_rate': 3.8625e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4186, 'grad_norm': 6.901059150695801, 'learning_rate': 3.85e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4582, 'grad_norm': 7.472389221191406, 'learning_rate': 3.8375e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.5643, 'grad_norm': 7.333090305328369, 'learning_rate': 3.825e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.3639, 'grad_norm': 6.445948600769043, 'learning_rate': 3.8125e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4389, 'grad_norm': 7.957160949707031, 'learning_rate': 3.8e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.5336, 'grad_norm': 5.9428324699401855, 'learning_rate': 3.7875e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.3242, 'grad_norm': 6.897878646850586, 'learning_rate': 3.775e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.4594, 'grad_norm': 7.274386882781982, 'learning_rate': 3.7625e-05, 'epoch': 0.03}\r\n",
"{'loss': 3.3949, 'grad_norm': 7.8012471199035645, 'learning_rate': 3.7500000000000003e-05, 'epoch': 0.03}\r\n",
" 25%|█████████▊ | 1000/4000 [10:11<28:52, 1.73it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:03<00:03, 1.53s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:05<00:01, 1.97s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 32.134831999999996, 'eval_rouge-2': 6.325576000000001, 'eval_rouge-l': 25.315346000000005, 'eval_bleu-4': 0.03137707571044217, 'eval_runtime': 9.9272, 'eval_samples_per_second': 5.037, 'eval_steps_per_second': 0.403, 'epoch': 0.03}\r\n",
" 25%|█████████▊ | 1000/4000 [10:21<28:52, 1.73it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:07<00:00, 1.77s/it]\u001B[A\r\n",
"{'loss': 3.4504, 'grad_norm': 6.908702373504639, 'learning_rate': 3.737500000000001e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4596, 'grad_norm': 7.377086639404297, 'learning_rate': 3.7250000000000004e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.6484, 'grad_norm': 8.061379432678223, 'learning_rate': 3.7125e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4, 'grad_norm': 6.452291011810303, 'learning_rate': 3.7e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3891, 'grad_norm': 8.560649871826172, 'learning_rate': 3.6875e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3551, 'grad_norm': 7.644310474395752, 'learning_rate': 3.675e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3895, 'grad_norm': 7.036133766174316, 'learning_rate': 3.6625e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4611, 'grad_norm': 7.2408528327941895, 'learning_rate': 3.65e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.5271, 'grad_norm': 7.058151721954346, 'learning_rate': 3.6375e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4662, 'grad_norm': 6.564244747161865, 'learning_rate': 3.625e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3428, 'grad_norm': 6.844818115234375, 'learning_rate': 3.6125000000000004e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.5244, 'grad_norm': 7.949232578277588, 'learning_rate': 3.6e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4357, 'grad_norm': 7.32559871673584, 'learning_rate': 3.5875000000000005e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3572, 'grad_norm': 8.051689147949219, 'learning_rate': 3.575e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3174, 'grad_norm': 7.550294399261475, 'learning_rate': 3.5625000000000005e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3588, 'grad_norm': 7.240135669708252, 'learning_rate': 3.55e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4516, 'grad_norm': 6.720525741577148, 'learning_rate': 3.5375e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4717, 'grad_norm': 6.3586320877075195, 'learning_rate': 3.525e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3574, 'grad_norm': 6.693387985229492, 'learning_rate': 3.5125e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.407, 'grad_norm': 6.322566509246826, 'learning_rate': 3.5e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.2439, 'grad_norm': 6.481217384338379, 'learning_rate': 3.4875e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3391, 'grad_norm': 7.359728813171387, 'learning_rate': 3.475e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3771, 'grad_norm': 7.4071478843688965, 'learning_rate': 3.4625e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3758, 'grad_norm': 7.325416564941406, 'learning_rate': 3.45e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4434, 'grad_norm': 6.780652046203613, 'learning_rate': 3.4375e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.2818, 'grad_norm': 7.619284152984619, 'learning_rate': 3.4250000000000006e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.4562, 'grad_norm': 7.123080253601074, 'learning_rate': 3.4125e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3322, 'grad_norm': 7.0780863761901855, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.04}\r\n",
"{'loss': 3.3887, 'grad_norm': 6.898688316345215, 'learning_rate': 3.3875000000000003e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4793, 'grad_norm': 7.293100357055664, 'learning_rate': 3.375000000000001e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4607, 'grad_norm': 6.927903175354004, 'learning_rate': 3.3625000000000004e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4535, 'grad_norm': 6.639427661895752, 'learning_rate': 3.35e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4008, 'grad_norm': 10.613078117370605, 'learning_rate': 3.3375e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3059, 'grad_norm': 7.491557598114014, 'learning_rate': 3.325e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3484, 'grad_norm': 7.497087001800537, 'learning_rate': 3.3125e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.2969, 'grad_norm': 8.017332077026367, 'learning_rate': 3.3e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.5152, 'grad_norm': 7.311262130737305, 'learning_rate': 3.2875e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3871, 'grad_norm': 7.2260003089904785, 'learning_rate': 3.275e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3563, 'grad_norm': 7.222864151000977, 'learning_rate': 3.2625e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4166, 'grad_norm': 6.612077713012695, 'learning_rate': 3.2500000000000004e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3465, 'grad_norm': 7.431714057922363, 'learning_rate': 3.2375e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.2621, 'grad_norm': 7.619777202606201, 'learning_rate': 3.2250000000000005e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3795, 'grad_norm': 7.628826141357422, 'learning_rate': 3.2125e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3551, 'grad_norm': 7.093392848968506, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.2658, 'grad_norm': 6.70922327041626, 'learning_rate': 3.1875e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3914, 'grad_norm': 7.325173377990723, 'learning_rate': 3.175e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4367, 'grad_norm': 9.542543411254883, 'learning_rate': 3.1624999999999996e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.2979, 'grad_norm': 6.646926403045654, 'learning_rate': 3.15e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4375, 'grad_norm': 7.366168975830078, 'learning_rate': 3.1375e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4574, 'grad_norm': 6.800962924957275, 'learning_rate': 3.125e-05, 'epoch': 0.05}\r\n",
" 38%|██████████████▋ | 1500/4000 [14:57<20:28, 2.03it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:02<00:02, 1.43s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:18<00:07, 7.54s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 31.905676000000007, 'eval_rouge-2': 6.630377999999999, 'eval_rouge-l': 25.126853999999998, 'eval_bleu-4': 0.03152151596531457, 'eval_runtime': 23.6793, 'eval_samples_per_second': 2.112, 'eval_steps_per_second': 0.169, 'epoch': 0.05}\r\n",
" 38%|██████████████▋ | 1500/4000 [15:21<20:28, 2.03it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:20<00:00, 5.41s/it]\u001B[A\r\n",
"{'loss': 3.3451, 'grad_norm': 6.90294075012207, 'learning_rate': 3.1125000000000004e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3844, 'grad_norm': 8.37482738494873, 'learning_rate': 3.1e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4359, 'grad_norm': 8.105109214782715, 'learning_rate': 3.0875000000000005e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.3988, 'grad_norm': 7.031566143035889, 'learning_rate': 3.075e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4945, 'grad_norm': 7.260471343994141, 'learning_rate': 3.0625000000000006e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4061, 'grad_norm': 8.252367973327637, 'learning_rate': 3.05e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4643, 'grad_norm': 7.982962131500244, 'learning_rate': 3.0375000000000003e-05, 'epoch': 0.05}\r\n",
"{'loss': 3.4326, 'grad_norm': 7.5859808921813965, 'learning_rate': 3.025e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.5098, 'grad_norm': 9.218013763427734, 'learning_rate': 3.0125000000000004e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3924, 'grad_norm': 7.129590034484863, 'learning_rate': 3e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3645, 'grad_norm': 7.882465362548828, 'learning_rate': 2.9875000000000004e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3656, 'grad_norm': 8.374431610107422, 'learning_rate': 2.975e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.4676, 'grad_norm': 7.145497798919678, 'learning_rate': 2.9625000000000002e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3199, 'grad_norm': 7.946256160736084, 'learning_rate': 2.95e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3682, 'grad_norm': 7.46930456161499, 'learning_rate': 2.9375000000000003e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.2996, 'grad_norm': 6.9753265380859375, 'learning_rate': 2.925e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.475, 'grad_norm': 8.484821319580078, 'learning_rate': 2.9125000000000003e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3715, 'grad_norm': 7.118030548095703, 'learning_rate': 2.9e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3742, 'grad_norm': 7.3347368240356445, 'learning_rate': 2.8875e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.5146, 'grad_norm': 6.8588714599609375, 'learning_rate': 2.8749999999999997e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.4602, 'grad_norm': 7.292227745056152, 'learning_rate': 2.8625e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.499, 'grad_norm': 7.423632621765137, 'learning_rate': 2.8499999999999998e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.4059, 'grad_norm': 7.430981636047363, 'learning_rate': 2.8375000000000002e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.398, 'grad_norm': 7.364171981811523, 'learning_rate': 2.825e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.4631, 'grad_norm': 7.548583984375, 'learning_rate': 2.8125000000000003e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.442, 'grad_norm': 7.765754699707031, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3605, 'grad_norm': 8.27833366394043, 'learning_rate': 2.7875e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3459, 'grad_norm': 8.09084415435791, 'learning_rate': 2.7750000000000004e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3928, 'grad_norm': 8.150015830993652, 'learning_rate': 2.7625e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3408, 'grad_norm': 7.760500907897949, 'learning_rate': 2.7500000000000004e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3803, 'grad_norm': 8.982950210571289, 'learning_rate': 2.7375e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3381, 'grad_norm': 7.609743118286133, 'learning_rate': 2.725e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.5785, 'grad_norm': 7.900216102600098, 'learning_rate': 2.7125000000000002e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3395, 'grad_norm': 8.472111701965332, 'learning_rate': 2.7000000000000002e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.4895, 'grad_norm': 8.781264305114746, 'learning_rate': 2.6875e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3846, 'grad_norm': 7.472824573516846, 'learning_rate': 2.6750000000000003e-05, 'epoch': 0.06}\r\n",
"{'loss': 3.3115, 'grad_norm': 8.073516845703125, 'learning_rate': 2.6625e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3037, 'grad_norm': 7.2763519287109375, 'learning_rate': 2.6500000000000004e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3965, 'grad_norm': 7.201462268829346, 'learning_rate': 2.6375e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3717, 'grad_norm': 7.831448554992676, 'learning_rate': 2.625e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.391, 'grad_norm': 7.940402507781982, 'learning_rate': 2.6124999999999998e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.477, 'grad_norm': 7.303577899932861, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.2766, 'grad_norm': 7.596188545227051, 'learning_rate': 2.5875e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4998, 'grad_norm': 7.545307159423828, 'learning_rate': 2.5750000000000002e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3592, 'grad_norm': 6.786509990692139, 'learning_rate': 2.5625e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.2854, 'grad_norm': 8.573935508728027, 'learning_rate': 2.5500000000000003e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3727, 'grad_norm': 7.578614234924316, 'learning_rate': 2.5375e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.2307, 'grad_norm': 7.565990447998047, 'learning_rate': 2.525e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.41, 'grad_norm': 7.094372749328613, 'learning_rate': 2.5124999999999997e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4619, 'grad_norm': 7.98245096206665, 'learning_rate': 2.5e-05, 'epoch': 0.07}\r\n",
" 50%|███████████████████▌ | 2000/4000 [19:57<17:54, 1.86it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:16<00:16, 8.01s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:32<00:11, 11.33s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 31.442076, 'eval_rouge-2': 7.156823999999999, 'eval_rouge-l': 23.246924000000003, 'eval_bleu-4': 0.03405216374744, 'eval_runtime': 64.2793, 'eval_samples_per_second': 0.778, 'eval_steps_per_second': 0.062, 'epoch': 0.07}\r\n",
" 50%|███████████████████▌ | 2000/4000 [21:01<17:54, 1.86it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:48<00:00, 12.97s/it]\u001B[A\r\n",
" \u001B[ASaving model checkpoint to ./output/checkpoint-2000\r\n",
"/media/zr/Data/Code/ChatGLM3/venv/lib/python3.10/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /media/zr/Data/Models/LLM/chatglm3-6b - will assume that the vocabulary was not modified.\r\n",
" warnings.warn(\r\n",
"{'loss': 3.3818, 'grad_norm': 8.677833557128906, 'learning_rate': 2.4875e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4928, 'grad_norm': 7.391153812408447, 'learning_rate': 2.4750000000000002e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.5547, 'grad_norm': 8.77245044708252, 'learning_rate': 2.4625000000000002e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4939, 'grad_norm': 8.10531997680664, 'learning_rate': 2.45e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3687, 'grad_norm': 8.14376449584961, 'learning_rate': 2.4375e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3307, 'grad_norm': 7.644017219543457, 'learning_rate': 2.425e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4414, 'grad_norm': 7.982100486755371, 'learning_rate': 2.4125e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4115, 'grad_norm': 8.171486854553223, 'learning_rate': 2.4e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.4326, 'grad_norm': 7.437331199645996, 'learning_rate': 2.3875e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3533, 'grad_norm': 7.70622444152832, 'learning_rate': 2.375e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.2926, 'grad_norm': 7.60914945602417, 'learning_rate': 2.3624999999999998e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.5812, 'grad_norm': 8.040843963623047, 'learning_rate': 2.35e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.2502, 'grad_norm': 7.3959574699401855, 'learning_rate': 2.3375000000000002e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3521, 'grad_norm': 8.238727569580078, 'learning_rate': 2.3250000000000003e-05, 'epoch': 0.07}\r\n",
"{'loss': 3.3969, 'grad_norm': 7.359251022338867, 'learning_rate': 2.3125000000000003e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.5178, 'grad_norm': 8.128018379211426, 'learning_rate': 2.3000000000000003e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.393, 'grad_norm': 7.082696914672852, 'learning_rate': 2.2875e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4172, 'grad_norm': 7.790773868560791, 'learning_rate': 2.275e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3604, 'grad_norm': 7.583011150360107, 'learning_rate': 2.2625e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4316, 'grad_norm': 7.347414970397949, 'learning_rate': 2.25e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4496, 'grad_norm': 6.759352207183838, 'learning_rate': 2.2375000000000002e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4145, 'grad_norm': 7.640699863433838, 'learning_rate': 2.2250000000000002e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4189, 'grad_norm': 8.391305923461914, 'learning_rate': 2.2125000000000002e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3705, 'grad_norm': 8.04839038848877, 'learning_rate': 2.2000000000000003e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.2355, 'grad_norm': 8.35435962677002, 'learning_rate': 2.1875e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3584, 'grad_norm': 7.815989017486572, 'learning_rate': 2.175e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4268, 'grad_norm': 8.53368854522705, 'learning_rate': 2.1625e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.467, 'grad_norm': 7.677575588226318, 'learning_rate': 2.15e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.2885, 'grad_norm': 8.361733436584473, 'learning_rate': 2.1375e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3535, 'grad_norm': 8.110257148742676, 'learning_rate': 2.125e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3191, 'grad_norm': 8.498170852661133, 'learning_rate': 2.1125000000000002e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3271, 'grad_norm': 8.709260940551758, 'learning_rate': 2.1e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3629, 'grad_norm': 9.01534366607666, 'learning_rate': 2.0875e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3635, 'grad_norm': 7.54719352722168, 'learning_rate': 2.075e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.2623, 'grad_norm': 8.59843635559082, 'learning_rate': 2.0625e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3803, 'grad_norm': 8.170056343078613, 'learning_rate': 2.05e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3506, 'grad_norm': 7.873594284057617, 'learning_rate': 2.0375e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4871, 'grad_norm': 8.418689727783203, 'learning_rate': 2.025e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.2262, 'grad_norm': 8.624137878417969, 'learning_rate': 2.0125e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4514, 'grad_norm': 7.584123611450195, 'learning_rate': 2e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.4514, 'grad_norm': 7.975276470184326, 'learning_rate': 1.9875000000000002e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.2789, 'grad_norm': 7.9726481437683105, 'learning_rate': 1.9750000000000002e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3652, 'grad_norm': 7.4362945556640625, 'learning_rate': 1.9625000000000003e-05, 'epoch': 0.08}\r\n",
"{'loss': 3.3795, 'grad_norm': 8.107170104980469, 'learning_rate': 1.9500000000000003e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.2727, 'grad_norm': 7.757025241851807, 'learning_rate': 1.9375e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3055, 'grad_norm': 7.5721869468688965, 'learning_rate': 1.925e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.2545, 'grad_norm': 8.496746063232422, 'learning_rate': 1.9125e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4332, 'grad_norm': 7.52405309677124, 'learning_rate': 1.9e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4711, 'grad_norm': 7.90508508682251, 'learning_rate': 1.8875e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.39, 'grad_norm': 9.309752464294434, 'learning_rate': 1.8750000000000002e-05, 'epoch': 0.09}\r\n",
" 62%|████████████████████████▍ | 2500/4000 [25:37<13:33, 1.84it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:03<00:03, 1.72s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:06<00:02, 2.25s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 31.633207999999996, 'eval_rouge-2': 6.800014, 'eval_rouge-l': 25.123896000000006, 'eval_bleu-4': 0.03327400496195634, 'eval_runtime': 25.5968, 'eval_samples_per_second': 1.953, 'eval_steps_per_second': 0.156, 'epoch': 0.09}\r\n",
" 62%|████████████████████████▍ | 2500/4000 [26:03<13:33, 1.84it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:22<00:00, 7.31s/it]\u001B[A\r\n",
"{'loss': 3.2988, 'grad_norm': 8.42829704284668, 'learning_rate': 1.8625000000000002e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3408, 'grad_norm': 9.460935592651367, 'learning_rate': 1.85e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.2467, 'grad_norm': 7.881652355194092, 'learning_rate': 1.8375e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3906, 'grad_norm': 8.49362564086914, 'learning_rate': 1.825e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3859, 'grad_norm': 7.6069016456604, 'learning_rate': 1.8125e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3982, 'grad_norm': 8.237305641174316, 'learning_rate': 1.8e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.465, 'grad_norm': 7.80671501159668, 'learning_rate': 1.7875e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4805, 'grad_norm': 8.655023574829102, 'learning_rate': 1.775e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3734, 'grad_norm': 8.358222961425781, 'learning_rate': 1.7625e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4732, 'grad_norm': 8.640260696411133, 'learning_rate': 1.75e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3471, 'grad_norm': 8.130788803100586, 'learning_rate': 1.7375e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4129, 'grad_norm': 7.604771614074707, 'learning_rate': 1.725e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.5184, 'grad_norm': 7.612947463989258, 'learning_rate': 1.7125000000000003e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4441, 'grad_norm': 8.518109321594238, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3992, 'grad_norm': 7.822119235992432, 'learning_rate': 1.6875000000000004e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3439, 'grad_norm': 7.961773872375488, 'learning_rate': 1.675e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4062, 'grad_norm': 8.931722640991211, 'learning_rate': 1.6625e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.2609, 'grad_norm': 7.5368194580078125, 'learning_rate': 1.65e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4715, 'grad_norm': 8.477120399475098, 'learning_rate': 1.6375e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4461, 'grad_norm': 9.24991512298584, 'learning_rate': 1.6250000000000002e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.4182, 'grad_norm': 8.294699668884277, 'learning_rate': 1.6125000000000002e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.2432, 'grad_norm': 7.574826717376709, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.09}\r\n",
"{'loss': 3.3834, 'grad_norm': 8.255449295043945, 'learning_rate': 1.5875e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.385, 'grad_norm': 8.229700088500977, 'learning_rate': 1.575e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.449, 'grad_norm': 8.934239387512207, 'learning_rate': 1.5625e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3947, 'grad_norm': 8.390064239501953, 'learning_rate': 1.55e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3486, 'grad_norm': 8.181641578674316, 'learning_rate': 1.5375e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2568, 'grad_norm': 8.498324394226074, 'learning_rate': 1.525e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2709, 'grad_norm': 7.9656147956848145, 'learning_rate': 1.5125e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2258, 'grad_norm': 7.652721405029297, 'learning_rate': 1.5e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.4379, 'grad_norm': 8.255173683166504, 'learning_rate': 1.4875e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3639, 'grad_norm': 7.929840564727783, 'learning_rate': 1.475e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3836, 'grad_norm': 8.210647583007812, 'learning_rate': 1.4625e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.4367, 'grad_norm': 8.759031295776367, 'learning_rate': 1.45e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.4047, 'grad_norm': 8.681133270263672, 'learning_rate': 1.4374999999999999e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.327, 'grad_norm': 8.468674659729004, 'learning_rate': 1.4249999999999999e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3654, 'grad_norm': 8.48736572265625, 'learning_rate': 1.4125e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.5008, 'grad_norm': 9.581798553466797, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2943, 'grad_norm': 8.112646102905273, 'learning_rate': 1.3875000000000002e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3182, 'grad_norm': 8.913463592529297, 'learning_rate': 1.3750000000000002e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2932, 'grad_norm': 7.881869792938232, 'learning_rate': 1.3625e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2365, 'grad_norm': 7.5258941650390625, 'learning_rate': 1.3500000000000001e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.3527, 'grad_norm': 9.253165245056152, 'learning_rate': 1.3375000000000002e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.248, 'grad_norm': 8.01251220703125, 'learning_rate': 1.3250000000000002e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.36, 'grad_norm': 8.332780838012695, 'learning_rate': 1.3125e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.2068, 'grad_norm': 9.181897163391113, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.4514, 'grad_norm': 8.965094566345215, 'learning_rate': 1.2875000000000001e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.424, 'grad_norm': 8.944855690002441, 'learning_rate': 1.2750000000000002e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.4562, 'grad_norm': 8.20882511138916, 'learning_rate': 1.2625e-05, 'epoch': 0.1}\r\n",
"{'loss': 3.358, 'grad_norm': 7.769922733306885, 'learning_rate': 1.25e-05, 'epoch': 0.1}\r\n",
" 75%|█████████████████████████████▎ | 3000/4000 [30:40<08:42, 1.91it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:02<00:02, 1.43s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:05<00:01, 1.94s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 33.007998, 'eval_rouge-2': 7.157356, 'eval_rouge-l': 25.306306000000003, 'eval_bleu-4': 0.0348571644891679, 'eval_runtime': 38.0831, 'eval_samples_per_second': 1.313, 'eval_steps_per_second': 0.105, 'epoch': 0.1}\r\n",
" 75%|█████████████████████████████▎ | 3000/4000 [31:18<08:42, 1.91it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:21<00:00, 7.25s/it]\u001B[A\r\n",
"{'loss': 3.4711, 'grad_norm': 8.417685508728027, 'learning_rate': 1.2375000000000001e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3418, 'grad_norm': 8.048948287963867, 'learning_rate': 1.225e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3564, 'grad_norm': 8.270435333251953, 'learning_rate': 1.2125e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.2293, 'grad_norm': 7.761234760284424, 'learning_rate': 1.2e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3873, 'grad_norm': 8.1546049118042, 'learning_rate': 1.1875e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.5338, 'grad_norm': 7.905092239379883, 'learning_rate': 1.175e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.2963, 'grad_norm': 8.120687484741211, 'learning_rate': 1.1625000000000001e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.292, 'grad_norm': 9.561246871948242, 'learning_rate': 1.1500000000000002e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.2029, 'grad_norm': 9.09880542755127, 'learning_rate': 1.1375e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3873, 'grad_norm': 7.879208087921143, 'learning_rate': 1.125e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3383, 'grad_norm': 8.732316970825195, 'learning_rate': 1.1125000000000001e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3205, 'grad_norm': 8.577627182006836, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3717, 'grad_norm': 9.737064361572266, 'learning_rate': 1.0875e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.2996, 'grad_norm': 8.619685173034668, 'learning_rate': 1.075e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.4496, 'grad_norm': 8.600975036621094, 'learning_rate': 1.0625e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.4277, 'grad_norm': 8.75851821899414, 'learning_rate': 1.05e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.4809, 'grad_norm': 7.5685930252075195, 'learning_rate': 1.0375e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.226, 'grad_norm': 8.321500778198242, 'learning_rate': 1.025e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.3586, 'grad_norm': 7.587204933166504, 'learning_rate': 1.0125e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.4166, 'grad_norm': 8.86058235168457, 'learning_rate': 1e-05, 'epoch': 0.11}\r\n",
"{'loss': 3.382, 'grad_norm': 9.254091262817383, 'learning_rate': 9.875000000000001e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.3961, 'grad_norm': 7.718448162078857, 'learning_rate': 9.750000000000002e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.4699, 'grad_norm': 8.792988777160645, 'learning_rate': 9.625e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.2145, 'grad_norm': 8.899701118469238, 'learning_rate': 9.5e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.4141, 'grad_norm': 8.802495956420898, 'learning_rate': 9.375000000000001e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.3627, 'grad_norm': 9.895890235900879, 'learning_rate': 9.25e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.4182, 'grad_norm': 8.153362274169922, 'learning_rate': 9.125e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.2916, 'grad_norm': 8.173482894897461, 'learning_rate': 9e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.2963, 'grad_norm': 9.929978370666504, 'learning_rate': 8.875e-06, 'epoch': 0.11}\r\n",
"{'loss': 3.4039, 'grad_norm': 7.541258335113525, 'learning_rate': 8.75e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3602, 'grad_norm': 7.881056785583496, 'learning_rate': 8.625e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2324, 'grad_norm': 8.763860702514648, 'learning_rate': 8.500000000000002e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.4018, 'grad_norm': 9.141348838806152, 'learning_rate': 8.375e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3771, 'grad_norm': 8.166316032409668, 'learning_rate': 8.25e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2783, 'grad_norm': 9.261619567871094, 'learning_rate': 8.125000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.4312, 'grad_norm': 8.153901100158691, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.327, 'grad_norm': 7.708031177520752, 'learning_rate': 7.875e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3779, 'grad_norm': 7.920627117156982, 'learning_rate': 7.75e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2857, 'grad_norm': 9.732666015625, 'learning_rate': 7.625e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3588, 'grad_norm': 8.037003517150879, 'learning_rate': 7.5e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2002, 'grad_norm': 8.716700553894043, 'learning_rate': 7.375e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2863, 'grad_norm': 9.12403678894043, 'learning_rate': 7.25e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3447, 'grad_norm': 8.44495677947998, 'learning_rate': 7.1249999999999995e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3088, 'grad_norm': 8.425846099853516, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3281, 'grad_norm': 8.53967571258545, 'learning_rate': 6.875000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3451, 'grad_norm': 9.039155960083008, 'learning_rate': 6.750000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2674, 'grad_norm': 9.248905181884766, 'learning_rate': 6.625000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2703, 'grad_norm': 10.257024765014648, 'learning_rate': 6.5000000000000004e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.4084, 'grad_norm': 8.447395324707031, 'learning_rate': 6.375000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.4488, 'grad_norm': 8.430671691894531, 'learning_rate': 6.25e-06, 'epoch': 0.12}\r\n",
" 88%|██████████████████████████████████▏ | 3500/4000 [35:52<04:30, 1.85it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:04<00:04, 2.18s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:06<00:02, 2.23s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 32.222722, 'eval_rouge-2': 6.6331180000000005, 'eval_rouge-l': 25.087382, 'eval_bleu-4': 0.03253227960558209, 'eval_runtime': 25.0679, 'eval_samples_per_second': 1.995, 'eval_steps_per_second': 0.16, 'epoch': 0.12}\r\n",
" 88%|██████████████████████████████████▏ | 3500/4000 [36:17<04:30, 1.85it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:08<00:00, 2.14s/it]\u001B[A\r\n",
"{'loss': 3.3912, 'grad_norm': 9.152791976928711, 'learning_rate': 6.125e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3229, 'grad_norm': 9.17188549041748, 'learning_rate': 6e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2846, 'grad_norm': 8.172340393066406, 'learning_rate': 5.875e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.308, 'grad_norm': 8.928167343139648, 'learning_rate': 5.750000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3578, 'grad_norm': 8.738048553466797, 'learning_rate': 5.625e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2117, 'grad_norm': 8.161530494689941, 'learning_rate': 5.500000000000001e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.3182, 'grad_norm': 7.672643184661865, 'learning_rate': 5.375e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.4324, 'grad_norm': 9.408201217651367, 'learning_rate': 5.25e-06, 'epoch': 0.12}\r\n",
"{'loss': 3.2418, 'grad_norm': 9.635400772094727, 'learning_rate': 5.125e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.1869, 'grad_norm': 8.71308708190918, 'learning_rate': 5e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2719, 'grad_norm': 10.24747085571289, 'learning_rate': 4.875000000000001e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.5238, 'grad_norm': 8.207618713378906, 'learning_rate': 4.75e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3982, 'grad_norm': 9.101743698120117, 'learning_rate': 4.625e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2068, 'grad_norm': 9.008282661437988, 'learning_rate': 4.5e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3084, 'grad_norm': 9.63040828704834, 'learning_rate': 4.375e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.1973, 'grad_norm': 8.8562593460083, 'learning_rate': 4.250000000000001e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.298, 'grad_norm': 8.217488288879395, 'learning_rate': 4.125e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3773, 'grad_norm': 8.624151229858398, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3955, 'grad_norm': 8.07646369934082, 'learning_rate': 3.875e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.4082, 'grad_norm': 9.692364692687988, 'learning_rate': 3.75e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3699, 'grad_norm': 9.671299934387207, 'learning_rate': 3.625e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.39, 'grad_norm': 9.423399925231934, 'learning_rate': 3.5000000000000004e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.168, 'grad_norm': 10.555978775024414, 'learning_rate': 3.3750000000000003e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.4062, 'grad_norm': 9.081645011901855, 'learning_rate': 3.2500000000000002e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2328, 'grad_norm': 8.238192558288574, 'learning_rate': 3.125e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2117, 'grad_norm': 8.344420433044434, 'learning_rate': 3e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2488, 'grad_norm': 9.779040336608887, 'learning_rate': 2.8750000000000004e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2828, 'grad_norm': 8.346026420593262, 'learning_rate': 2.7500000000000004e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.4674, 'grad_norm': 8.168132781982422, 'learning_rate': 2.625e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2598, 'grad_norm': 7.97592830657959, 'learning_rate': 2.5e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3447, 'grad_norm': 10.082160949707031, 'learning_rate': 2.375e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2311, 'grad_norm': 8.935636520385742, 'learning_rate': 2.25e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3922, 'grad_norm': 8.796125411987305, 'learning_rate': 2.1250000000000004e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.315, 'grad_norm': 8.807939529418945, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.2951, 'grad_norm': 8.721334457397461, 'learning_rate': 1.875e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.3289, 'grad_norm': 9.166098594665527, 'learning_rate': 1.7500000000000002e-06, 'epoch': 0.13}\r\n",
"{'loss': 3.46, 'grad_norm': 8.010759353637695, 'learning_rate': 1.6250000000000001e-06, 'epoch': 0.14}\r\n",
"{'loss': 3.4809, 'grad_norm': 8.220529556274414, 'learning_rate': 1.5e-06, 'epoch': 0.14}\r\n",
"{'loss': 3.4166, 'grad_norm': 8.10384750366211, 'learning_rate': 1.3750000000000002e-06, 'epoch': 0.14}\r\n",
"{'loss': 3.458, 'grad_norm': 8.7192964553833, 'learning_rate': 1.25e-06, 'epoch': 0.14}\r\n",
"{'loss': 3.2795, 'grad_norm': 8.834420204162598, 'learning_rate': 1.125e-06, 'epoch': 0.14}\r\n",
"{'loss': 3.3441, 'grad_norm': 9.3894681930542, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.14}\r\n",
"{'loss': 3.3844, 'grad_norm': 7.872992038726807, 'learning_rate': 8.750000000000001e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.5111, 'grad_norm': 8.390124320983887, 'learning_rate': 7.5e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.3422, 'grad_norm': 9.196588516235352, 'learning_rate': 6.25e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.2922, 'grad_norm': 8.946027755737305, 'learning_rate': 5.000000000000001e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.4168, 'grad_norm': 7.884989261627197, 'learning_rate': 3.75e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.4125, 'grad_norm': 9.072811126708984, 'learning_rate': 2.5000000000000004e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.4373, 'grad_norm': 8.543241500854492, 'learning_rate': 1.2500000000000002e-07, 'epoch': 0.14}\r\n",
"{'loss': 3.3844, 'grad_norm': 9.427127838134766, 'learning_rate': 0.0, 'epoch': 0.14}\r\n",
"100%|███████████████████████████████████████| 4000/4000 [40:55<00:00, 1.92it/s]***** Running Evaluation *****\r\n",
" Num examples = 50\r\n",
" Batch size = 16\r\n",
"\r\n",
" 0%| | 0/4 [00:00<?, ?it/s]\u001B[A\r\n",
" 50%|██████████████████████▌ | 2/4 [00:03<00:03, 1.96s/it]\u001B[A\r\n",
" 75%|█████████████████████████████████▊ | 3/4 [00:06<00:02, 2.33s/it]\u001B[A\r\n",
" \u001B[A\r\n",
"\u001B[A{'eval_rouge-1': 31.607680000000002, 'eval_rouge-2': 6.832874, 'eval_rouge-l': 25.068815999999998, 'eval_bleu-4': 0.03411200822704291, 'eval_runtime': 12.6342, 'eval_samples_per_second': 3.958, 'eval_steps_per_second': 0.317, 'epoch': 0.14}\r\n",
"100%|███████████████████████████████████████| 4000/4000 [41:08<00:00, 1.92it/s]\r\n",
"100%|█████████████████████████████████████████████| 4/4 [00:09<00:00, 2.33s/it]\u001B[A\r\n",
" \u001B[ASaving model checkpoint to ./output/checkpoint-4000\r\n",
"/media/zr/Data/Code/ChatGLM3/venv/lib/python3.10/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /media/zr/Data/Models/LLM/chatglm3-6b - will assume that the vocabulary was not modified.\r\n",
" warnings.warn(\r\n",
"\r\n",
"\r\n",
"Training completed. Do not forget to share your model on huggingface.co/models =)\r\n",
"\r\n",
"\r\n",
"{'train_runtime': 2468.7229, 'train_samples_per_second': 6.481, 'train_steps_per_second': 1.62, 'train_loss': 3.419384765625, 'epoch': 0.14}\r\n",
"100%|███████████████████████████████████████| 4000/4000 [41:08<00:00, 1.62it/s]\r\n",
"***** Running Prediction *****\r\n",
" Num examples = 1070\r\n",
" Batch size = 16\r\n",
"100%|███████████████████████████████████████████| 67/67 [12:42<00:00, 11.38s/it]\r\n"
]
}
],
"source": [
"!CUDA_VISIBLE_DEVICES=0 NCCL_P2P_DISABLE=\"1\" NCCL_IB_DISABLE=\"1\" python finetune_hf.py data/AdvertiseGen_fix /media/zr/Data/Models/LLM/chatglm3-6b configs/lora.yaml"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "17c87410a24d844f",
"outputId": "e347fc7d-875e-40c9-c682-3e064100476b",
"ExecuteTime": {
"end_time": "2024-04-14T06:23:41.282431Z",
"start_time": "2024-04-14T05:29:23.810692Z"
}
},
"id": "17c87410a24d844f"
},
{
"cell_type": "markdown",
"source": [
"## 3. 使用微调的数据集进行推理\n",
"在完成微调任务之后,我们可以查看到 `output` 文件夹下多了很多个`checkpoint-*`的文件夹,这些文件夹代表了训练的轮数。\n",
"我们选择最后一轮的微调权重,并使用inference进行导入。"
],
"metadata": {
"collapsed": false,
"id": "d9418f6c5c264601"
},
"id": "d9418f6c5c264601"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████████████| 7/7 [00:02<00:00, 2.45it/s]\r\n",
"Setting eos_token is not supported, use the default one.\r\n",
"Setting pad_token is not supported, use the default one.\r\n",
"Setting unk_token is not supported, use the default one.\r\n",
"这款连衣裙采用压褶的版型设计,不规则的木耳边拼接,修饰了腰线,使得身材更加修长,不规则的压褶设计,增加了层次感,不规则的压褶,修饰了腰线,拉长腿部比例,显瘦又性感,套头的设计,方便穿脱,不规则的压褶,增加层次感,视觉上拉长腿部比例,百褶的网纱拼接,增加了层次感,整体气质优雅。\r\n"
]
}
],
"source": [
"!CUDA_VISIBLE_DEVICES=0 NCCL_P2P_DISABLE=\"1\" NCCL_IB_DISABLE=\"1\" python inference_hf.py output/checkpoint-4000/ --prompt \"类型#裙*版型#显瘦*材质#网纱*风格#性感*裙型#百褶*裙下摆#压褶*裙长#连衣裙*裙衣门襟#拉链*裙衣门襟#套头*裙款式#拼接*裙款式#拉链*裙款式#木耳边*裙款式#抽褶*裙款式#不规则\""
],
"metadata": {
"id": "5060015c24e97ae",
"outputId": "d3f03d0d-46bf-4c74-9b00-dc0160da0e15",
"colab": {
"base_uri": "https://localhost:8080/"
},
"ExecuteTime": {
"end_time": "2024-04-14T06:23:52.725227Z",
"start_time": "2024-04-14T06:23:41.284552Z"
}
},
"id": "5060015c24e97ae"
},
{
"cell_type": "markdown",
"source": [
"## 4. 总结\n",
"到此位置,我们就完成了使用单张 GPU Lora 来微调 ChatGLM3-6B 模型,使其能生产出更好的广告。\n",
"在本章节中,你将会学会:\n",
"+ 如何使用模型进行 Lora 微调\n",
"+ 微调数据集的准备和对齐\n",
"+ 使用微调的模型进行推理"
],
"metadata": {
"collapsed": false,
"id": "18cd83087f096094"
},
"id": "18cd83087f096094"
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3 (ipykernel)"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
},
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "V100"
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 5
}
jieba>=0.42.1
ruamel_yaml>=0.18.6
rouge_chinese>=1.0.3
jupyter>=1.0.0
datasets>=2.18.0
peft>=0.10.0
deepspeed==0.13.1
mpi4py>=3.1.5
\ No newline at end of file
import ast
import json
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel, AutoConfig
from typing import List, Optional
class ChatGLM3(LLM):
max_token: int = 8192
do_sample: bool = True
temperature: float = 0.8
top_p = 0.8
tokenizer: object = None
model: object = None
history: List = []
has_search: bool = False
def __init__(self):
super().__init__()
@property
def _llm_type(self) -> str:
return "ChatGLM3"
def load_model(self, model_name_or_path=None):
model_config = AutoConfig.from_pretrained(
model_name_or_path,
trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(
model_name_or_path, config=model_config, trust_remote_code=True, device_map="auto").eval()
def _tool_history(self, prompt: str):
ans = []
tool_prompts = prompt.split(
"You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n")
tools_json = []
for tool_desc in tool_prompts:
name = tool_desc.split(":")[0]
description = tool_desc.split(", args:")[0].split(":")[1].strip()
parameters_str = tool_desc.split("args:")[1].strip()
parameters_dict = ast.literal_eval(parameters_str)
params_cleaned = {}
for param, details in parameters_dict.items():
params_cleaned[param] = {'description': details['description'], 'type': details['type']}
tools_json.append({
"name": name,
"description": description,
"parameters": params_cleaned
})
ans.append({
"role": "system",
"content": "Answer the following questions as best as you can. You have access to the following tools:",
"tools": tools_json
})
dialog_parts = prompt.split("Human: ")
for part in dialog_parts[1:]:
if "\nAI: " in part:
user_input, ai_response = part.split("\nAI: ")
ai_response = ai_response.split("\n")[0]
else:
user_input = part
ai_response = None
ans.append({"role": "user", "content": user_input.strip()})
if ai_response:
ans.append({"role": "assistant", "content": ai_response.strip()})
query = dialog_parts[-1].split("\n")[0]
return ans, query
def _extract_observation(self, prompt: str):
return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0]
self.history.append({
"role": "observation",
"content": return_json
})
return
def _extract_tool(self):
if len(self.history[-1]["metadata"]) > 0:
metadata = self.history[-1]["metadata"]
content = self.history[-1]["content"]
lines = content.split('\n')
for line in lines:
if 'tool_call(' in line and ')' in line and self.has_search is False:
# 获取括号内的字符串
params_str = line.split('tool_call(')[-1].split(')')[0]
# 解析参数对
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
action_json = {
"action": metadata,
"action_input": params
}
self.has_search = True
print("*****Action*****")
print(action_json)
print("*****Answer*****")
return f"""
Action:
```
{json.dumps(action_json, ensure_ascii=False)}
```"""
final_answer_json = {
"action": "Final Answer",
"action_input": self.history[-1]["content"]
}
self.has_search = False
return f"""
Action:
```
{json.dumps(final_answer_json, ensure_ascii=False)}
```"""
def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]):
if not self.has_search:
self.history, query = self._tool_history(prompt)
else:
self._extract_observation(prompt)
query = ""
_, self.history = self.model.chat(
self.tokenizer,
query,
history=self.history,
do_sample=self.do_sample,
max_length=self.max_token,
temperature=self.temperature,
)
response = self._extract_tool()
history.append((prompt, response))
return response
"""
This script demonstrates the use of the LangChain's StructuredChatAgent and AgentExecutor alongside various tools
The script utilizes the ChatGLM3 model, a large language model for understanding and generating human-like text.
The model is loaded from a specified path and integrated into the chat agent.
Tools:
- Calculator: Performs arithmetic calculations.
- Weather: Provides weather-related information based on input queries.
- DistanceConverter: Converts distances between meters, kilometers, and feet.
The agent operates in three modes:
1. Single Parameter without History: Uses Calculator to perform simple arithmetic.
2. Single Parameter with History: Uses Weather tool to answer queries about temperature, considering the
conversation history.
3. Multiple Parameters without History: Uses DistanceConverter to convert distances between specified units.
4. Single use Langchain Tool: Uses Arxiv tool to search for scientific articles.
Note:
The model calling tool fails, which may cause some errors or inability to execute. Try to reduce the temperature
parameters of the model, or reduce the number of tools, especially the third function.
The success rate of multi-parameter calling is low. The following errors may occur:
Required fields [type=missing, input_value={'distance': '30', 'unit': 'm', 'to': 'km'}, input_type=dict]
The model illusion in this case generates parameters that do not meet the requirements.
The top_p and temperature parameters of the model should be adjusted to better solve such problems.
Success example:
*****Action*****
{
'action': 'weather',
'action_input': {
'location': '厦门'
}
}
*****Answer*****
{
'input': '厦门比北京热吗?',
'chat_history': [HumanMessage(content='北京温度多少度'), AIMessage(content='北京现在12度')],
'output': '根据最新的天气数据,厦门今天的气温为18度,天气晴朗。而北京今天的气温为12度。所以,厦门比北京热。'
}
****************
"""
import os
from langchain import hub
from langchain.agents import AgentExecutor, create_structured_chat_agent, load_tools
from langchain_core.messages import AIMessage, HumanMessage
from ChatGLM3 import ChatGLM3
from tools.Calculator import Calculator
from tools.Weather import Weather
from tools.DistanceConversion import DistanceConverter
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
if __name__ == "__main__":
llm = ChatGLM3()
llm.load_model(MODEL_PATH)
prompt = hub.pull("hwchase17/structured-chat-agent")
# for single parameter without history
tools = [Calculator()]
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)
ans = agent_executor.invoke({"input": "34 * 34"})
print(ans)
# for singe parameter with history
tools = [Weather()]
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)
ans = agent_executor.invoke(
{
"input": "厦门比北京热吗?",
"chat_history": [
HumanMessage(content="北京温度多少度"),
AIMessage(content="北京现在12度"),
],
}
)
print(ans)
# for multiple parameters without history
tools = [DistanceConverter()]
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)
ans = agent_executor.invoke({"input": "how many meters in 30 km?"})
print(ans)
# for using langchain tools
tools = load_tools(["arxiv"], llm=llm)
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)
ans = agent_executor.invoke({"input": "Describe the paper about GLM 130B"})
print(ans)
import abc
import re
from typing import Type
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
class CalculatorInput(BaseModel):
calculation: str = Field(description="calculation to perform")
class Calculator(BaseTool, abc.ABC):
name = "Calculator"
description = "Useful for when you need to calculate math problems"
args_schema: Type[BaseModel] = CalculatorInput
def __init__(self):
super().__init__()
def parameter_validation(self, para: str):
"""
You can write your own parameter validation rules here,
you can refer to the code given here.
:param para:
:return:
"""
symbols = ["math", "sqrt", "log", "sin", "cos", "tan", "pi"]
for sym in symbols:
para = para.replace(sym, "")
patten = re.compile("[+*/\-%\d()=\s.]{3,}")
if re.findall(patten, para):
return True
def _run(self, calculation: str) -> str:
calculation = calculation.replace("^", "**")
if "sqrt" in calculation and "math" not in calculation:
calculation = calculation.replace("sqrt", "math.sqrt")
if "log" in calculation and "math" not in calculation:
calculation = calculation.replace("log", "math.log")
if "sin" in calculation and "math" not in calculation:
calculation = calculation.replace("sin", "math.sin")
if "cos" in calculation and "math" not in calculation:
calculation = calculation.replace("cos", "math.cos")
if "tan" in calculation and "math" not in calculation:
calculation = calculation.replace("tan", "math.tan")
if "pi" in calculation and "math" not in calculation:
calculation = calculation.replace("pi", "math.pi")
if "pI" in calculation and "math" not in calculation:
calculation = calculation.replace("pI", "math.pi")
if "PI" in calculation and "math" not in calculation:
calculation = calculation.replace("PI", "math.pi")
if "Pi" in calculation and "math" not in calculation:
calculation = calculation.replace("Pi", "math.pi")
return eval(calculation)
import abc
from typing import Type
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
class DistanceConversionInput(BaseModel):
distance: float = Field(description="The numerical value of the distance to convert")
unit: str = Field(description="The current unit of the distance (m, km, or feet)")
to_unit: str = Field(description="The target unit to convert the distance into (m, km, or feet)")
class DistanceConverter(BaseTool, abc.ABC):
name = "DistanceConverter"
description = "Converts distance between meters, kilometers, and feet"
args_schema: Type[BaseModel] = DistanceConversionInput
def __init__(self):
super().__init__()
def _run(self, distance: float, unit: str, to_unit: str) -> str:
unit_conversions = {
"m_to_km": 0.001,
"km_to_m": 1000,
"feet_to_m": 0.3048,
"m_to_feet": 3.28084,
"km_to_feet": 3280.84,
"feet_to_km": 0.0003048
}
if unit == to_unit:
return f"{distance} {unit} is equal to {distance} {to_unit}"
if unit == "km":
distance *= unit_conversions["km_to_m"]
elif unit == "feet":
distance *= unit_conversions["feet_to_m"]
if to_unit == "km":
converted_distance = distance * unit_conversions["m_to_km"]
elif to_unit == "feet":
converted_distance = distance * unit_conversions["m_to_feet"]
else:
converted_distance = distance # already in meters if this block is reached
return f"{distance} {unit} is equal to {converted_distance} {to_unit}"
import os
import requests
from typing import Type, Any
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
class WeatherInput(BaseModel):
location: str = Field(description="the location need to check the weather")
class Weather(BaseTool):
name = "weather"
description = "Use for searching weather at a specific location"
args_schema: Type[BaseModel] = WeatherInput
def __init__(self):
super().__init__()
def _run(self, location: str) -> dict[str, Any]:
api_key = os.environ["SENIVERSE_KEY"]
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
weather = {
"temperature": data["results"][0]["now"]["temperature"],
"description": data["results"][0]["now"]["text"],
}
return weather
else:
raise Exception(
f"Failed to retrieve weather: {response.status_code}")
LOCAL_MODEL_PATH=<your_path>
LOCAL_EMBEDDING_MODEL_PATH=<your_path>
\ No newline at end of file
"""
This script implements an API for the ChatGLM3-6B model,
formatted similarly to OpenAI's API (https://platform.openai.com/docs/api-reference/chat).
It's designed to be run as a web server using FastAPI and uvicorn,
making the ChatGLM3-6B model accessible through OpenAI Client.
Key Components and Features:
- Model and Tokenizer Setup: Configures the model and tokenizer paths and loads them.
- FastAPI Configuration: Sets up a FastAPI application with CORS middleware for handling cross-origin requests.
- API Endpoints:
- "/v1/models": Lists the available models, specifically ChatGLM3-6B.
- "/v1/chat/completions": Processes chat completion requests with options for streaming and regular responses.
- "/v1/embeddings": Processes Embedding request of a list of text inputs.
- Token Limit Caution: In the OpenAI API, 'max_tokens' is equivalent to HuggingFace's 'max_new_tokens', not 'max_length'.
For instance, setting 'max_tokens' to 8192 for a 6b model would result in an error due to the model's inability to output
that many tokens after accounting for the history and prompt tokens.
- Stream Handling and Custom Functions: Manages streaming responses and custom function calls within chat responses.
- Pydantic Models: Defines structured models for requests and responses, enhancing API documentation and type safety.
- Main Execution: Initializes the model and tokenizer, and starts the FastAPI app on the designated host and port.
Note:
This script doesn't include the setup for special tokens or multi-GPU support by default.
Users need to configure their special tokens and can enable multi-GPU support as per the provided instructions.
Embedding Models only support in One GPU.
Running this script requires 14-15GB of GPU memory. 2 GB for the embedding model and 12-13 GB for the FP16 ChatGLM3 LLM.
"""
import os
import time
import tiktoken
import torch
import uvicorn
import json
from fastapi import FastAPI, HTTPException, Response
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import List, Literal, Optional, Union
from loguru import logger
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModel
from utils import process_response, generate_chatglm3, generate_stream_chatglm3
from sentence_transformers import SentenceTransformer
from tools.schema import tool_class, tool_def, tool_param_start_with, tool_define_param_name
from sse_starlette.sse import EventSourceResponse
# Set up limit request time
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
# set LLM path
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
# set Embedding Model path
EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', 'BAAI/bge-m3')
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class FunctionCallResponse(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "function"]
content: str = None
name: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
## for Embedding
class EmbeddingRequest(BaseModel):
input: Union[List[str], str]
model: str
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class EmbeddingResponse(BaseModel):
data: list
model: str
object: str
usage: CompletionUsage
# for ChatCompletionRequest
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
repetition_penalty: Optional[float] = 1.1
agent: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "function_call"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length", "function_call"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str
id: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):
if isinstance(request.input, str):
embeddings = [embedding_model.encode(request.input)]
else:
embeddings = [embedding_model.encode(text) for text in request.input]
embeddings = [embedding.tolist() for embedding in embeddings]
def num_tokens_from_string(string: str) -> int:
"""
Returns the number of tokens in a text string.
use cl100k_base tokenizer
"""
encoding = tiktoken.get_encoding('cl100k_base')
num_tokens = len(encoding.encode(string))
return num_tokens
response = {
"data": [
{
"object": "embedding",
"embedding": embedding,
"index": index
}
for index, embedding in enumerate(embeddings)
],
"model": request.model,
"object": "list",
"usage": CompletionUsage(
prompt_tokens=sum(len(text.split()) for text in request.input),
completion_tokens=0,
total_tokens=sum(num_tokens_from_string(text) for text in request.input),
)
}
return response
@app.get("/v1/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(
id="chatglm3-6b"
)
return ModelList(
data=[model_card]
)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
gen_params = dict(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=request.stream,
repetition_penalty=request.repetition_penalty,
agent=request.agent
)
logger.debug(f"==== request ====\n{gen_params}")
gen_params["tools"] = tool_def if gen_params["agent"] else []
if request.stream:
# Use the stream mode to read the first few characters, if it is not a function call, direct stram output
predict_stream_generator = predict_stream(request.model, gen_params)
output = next(predict_stream_generator)
if not contains_custom_function(output, gen_params["tools"]):
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
# Obtain the result directly at one time and determine whether tools needs to be called.
logger.debug(f"First result output:\n{output}")
function_call = None
if output and request.tools:
try:
function_call = process_response(output, use_tool=True)
except:
logger.warning("Failed to parse tool call")
# CallFunction
if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call)
"""
In this demo, we did not register any tools.
You can use the tools that have been implemented in our `tools_using_demo` and implement your own streaming tool implementation here.
Similar to the following method:
"""
if tool_param_start_with in output:
tool = tool_class.get(function_call.name)
if tool:
this_tool_define_param_name = tool_define_param_name.get(function_call.name)
if this_tool_define_param_name:
tool_param = json.loads(function_call.arguments).get(this_tool_define_param_name)
if tool().parameter_validation(tool_param):
observation = str(tool().run(tool_param))
tool_response = observation
else:
tool_response = "Tool parameter values error, please tell the user about this situation."
else:
tool_response = "Tool parameter is not defined in tools schema, please tell the user about this situation."
else:
tool_response = "No available tools found, please tell the user about this situation."
else:
tool_response = "Tool parameter content error, please tell the user about this situation."
if not gen_params.get("messages"):
gen_params["messages"] = []
gen_params["messages"].append(ChatMessage(
role="assistant",
content=output,
))
gen_params["messages"].append(ChatMessage(
role="function",
name=function_call.name,
content=tool_response,
))
# Streaming output of results after function calls
generate = predict(request.model, gen_params)
return EventSourceResponse(generate, media_type="text/event-stream")
else:
# Handled to avoid exceptions in the above parsing function process.
generate = parse_output_text(request.model, output)
return EventSourceResponse(generate, media_type="text/event-stream")
# Here is the handling of stream = False
response = generate_chatglm3(model, tokenizer, gen_params)
# Remove the first newline character
if response["text"].startswith("\n"):
response["text"] = response["text"][1:]
response["text"] = response["text"].strip()
usage = UsageInfo()
function_call, finish_reason = None, "stop"
if request.tools:
try:
function_call = process_response(response["text"], use_tool=True)
except:
logger.warning("Failed to parse tool call, maybe the response is not a tool call or have been answered.")
if isinstance(function_call, dict):
finish_reason = "function_call"
function_call = FunctionCallResponse(**function_call)
message = ChatMessage(
role="assistant",
content=response["text"],
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
)
logger.debug(f"==== message ====\n{message}")
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason=finish_reason,
)
task_usage = UsageInfo.model_validate(response["usage"])
for usage_key, usage_value in task_usage.model_dump().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(
model=request.model,
id="", # for open_source model, id is empty
choices=[choice_data],
object="chat.completion",
usage=usage
)
async def predict(model_id: str, params: dict):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
previous_text = ""
for new_response in generate_stream_chatglm3(model, tokenizer, params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(previous_text):]
previous_text = decoded_unicode
finish_reason = new_response["finish_reason"]
if len(delta_text) == 0 and finish_reason != "function_call":
continue
function_call = None
if finish_reason == "function_call":
try:
function_call = process_response(decoded_unicode, use_tool=True)
except:
logger.warning(
"Failed to parse tool call, maybe the response is not a tool call or have been answered.")
if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call)
delta = DeltaMessage(
content=delta_text,
role="assistant",
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=delta,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
object="chat.completion.chunk"
)
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
object="chat.completion.chunk"
)
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]'
def predict_stream(model_id, gen_params):
"""
The function call is compatible with stream mode output.
The first seven characters are determined.
If not a function call, the stream output is directly generated.
Otherwise, the complete character content of the function call is returned.
:param model_id:
:param gen_params:
:return:
"""
output = ""
is_function_call = False
has_send_first_chunk = False
for new_response in generate_stream_chatglm3(model, tokenizer, gen_params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):]
output = decoded_unicode
# When it is not a function call and the character length is> 7,
# try to judge whether it is a function call according to the special function prefix
if not is_function_call and len(output) > 7:
# Determine whether a function is called
is_function_call = contains_custom_function(output, gen_params["tools"])
if is_function_call:
continue
# Non-function call, direct stream output
finish_reason = new_response["finish_reason"]
# Send an empty string first to avoid truncation by subsequent next() operations.
if not has_send_first_chunk:
message = DeltaMessage(
content="",
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
)
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
send_msg = delta_text if has_send_first_chunk else output
has_send_first_chunk = True
message = DeltaMessage(
content=send_msg,
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
)
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
if is_function_call:
yield output
else:
yield '[DONE]'
async def parse_output_text(model_id: str, value: str):
"""
Directly output the text content of value
:param model_id:
:param value:
:return:
"""
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=value),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]'
def contains_custom_function(value: str, tools: list) -> bool:
"""
Determine whether 'function_call' according to a special function prefix.
[Note] This is not a rigorous judgment method, only for reference.
:param value:
:param tools:
:return:
"""
for tool in tools:
if value and tool["name"] in value:
return True
if __name__ == "__main__":
# Load LLM
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
# load Embedding
embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
version: "3.6"
services:
glm3_api:
image: python:3.10.13-slim
restart: unless-stopped
working_dir: /glm3
container_name: glm3_api
env_file: ./.env
networks:
- v_glm3
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
environment:
- MODEL_PATH=/models/chatglm3-6b
- EMBEDDING_PATH=/models/bge-large-zh-v1.5
- TZ=Asia/Shanghai
- PYTHONDONTWRITEBYTECODE=1
- PYTHONUNBUFFERED=1
- DOCKER=True
ports:
- 8100:8000
volumes:
- ./:/glm3
- ${LOCAL_MODEL_PATH}:/models/chatglm3-6b
- ${LOCAL_EMBEDDING_MODEL_PATH}:/models/bge-large-zh-v1.5
command:
- sh
- -c
- |
sed -i s/deb.debian.org/mirrors.tencentyun.com/g /etc/apt/sources.list
sed -i s/security.debian.org/mirrors.tencentyun.com/g /etc/apt/sources.list
apt-get update
python -m pip install -i https://mirror.sjtu.edu.cn/pypi/web/simple --upgrade pip
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
python api_server.py
networks:
v_glm3:
driver: bridge
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment