"server/vscode:/vscode.git/clone" did not exist on "12590fdccebb34f39fb85b7dae29b80fade2b6b0"
Commit 9d518ec9 authored by Rayyyyy's avatar Rayyyyy
Browse files

Updata

parent 190f5704
# GLM-4-9b # GLM-4-9b
## 论文 ## 论文
暂无
## 模型结构 ## 模型结构
基于transformer结构 基于transformer结构
...@@ -8,8 +9,10 @@ ...@@ -8,8 +9,10 @@
</div> </div>
## 算法原理 ## 算法原理
GLM-4-9B是智谱AI推出的最新一代预训练模型GLM-4系列中的开源版本,在语义、数学、推理、代码和知识等多方面的数据集测评中,GLM-4-9B及其人类偏好对齐的版本GLM-4-9B-Chat均表现出超越Llama-3-8B的卓越性能 GLM-4-9B是智谱AI推出的最新一代预训练模型GLM-4系列中的开源版本,在语义、数学、推理、代码和知识等多方面的数据集测评中,GLM-4-9B及其人类偏好对齐的版本GLM-4-9B-Chat均表现出超越Llama-3-8B的卓越性能。以多模态模型GLM-4V-9B为例,这一模型采用了与CogVLM2相似的架构设计,能够处理高达1120 x 1120分辨率的输入,并通过降采样技术有效减少了token的开销。为了减小部署与计算开销,GLM-4V-9B没有引入额外的视觉专家模块,采用了直接混合文本和图片数据的方式进行训练,在保持文本性能的同时提升多模态能力。
<div align=center>
<img src="./doc/multi-mode.png" witdh=500 height=700/>
</div>
## 环境配置 ## 环境配置
-v 路径、docker_name和imageID根据实际情况修改 -v 路径、docker_name和imageID根据实际情况修改
...@@ -43,10 +46,9 @@ pip install -r requirements.txt ...@@ -43,10 +46,9 @@ pip install -r requirements.txt
DTK软件栈:dtk24.04 DTK软件栈:dtk24.04
python:python3.10 python:python3.10
torch:2.1 torch:2.1
deepspeed: 0.12.3+gita724046.abi0.dtk2404.torch2.1.0 deepspeed: 0.12.3
``` ```
**Tips**:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应
Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应
2、其他非特殊库直接按照下面步骤进行安装 2、其他非特殊库直接按照下面步骤进行安装
...@@ -95,12 +97,12 @@ python gen_messages_data.py --data_path /path/to/AdvertiseGen ...@@ -95,12 +97,12 @@ python gen_messages_data.py --data_path /path/to/AdvertiseGen
角色必须存在并且 `content` 字段为空。 角色必须存在并且 `content` 字段为空。
## 训练 ## 训练
1. 进入`finetune_demo`目录下,执行 1. 进入`finetune_demo`目录下,首先安装所需环境信息
```bash ```bash
pip install -r requirements.txt pip install -r requirements.txt
``` ```
2. 配置文件位于configs目录下,包括以下文件: 2. 配置文件位于[configs](./finetune_demo/configs/)目录下,包括以下文件:
- `deepspeed配置文件`[ds_zereo_2](./finetune_demo/configs/ds_zereo_2.json)[ds_zereo_3](./finetune_demo/configs/ds_zereo_3.json) - `deepspeed配置文件`[ds_zereo_2](./finetune_demo/configs/ds_zereo_2.json)[ds_zereo_3](./finetune_demo/configs/ds_zereo_3.json)
- `lora.yaml/ ptuning_v2.yaml / sft.yaml`: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。部分重要参数解释如下: - `lora.yaml/ ptuning_v2.yaml / sft.yaml`: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。部分重要参数解释如下:
+ data_config 部分 + data_config 部分
...@@ -139,7 +141,10 @@ pip install -r requirements.txt ...@@ -139,7 +141,10 @@ pip install -r requirements.txt
+ num_attention_heads: 2: P-TuningV2 的注意力头数(不要改动)。 + num_attention_heads: 2: P-TuningV2 的注意力头数(不要改动)。
+ token_dim: 256: P-TuningV2 的 token 维度(不要改动)。 + token_dim: 256: P-TuningV2 的 token 维度(不要改动)。
3. `data/AdvertiseGen/saves/``.jsonl`数据地址,`THUDM/glm-4-9b-chat`为模型地址,`configs/lora.yaml`为配置文件地址,以上参数均可根据自身数据地址进行替换。 3. 脚本中主要参数解释, 以下参数均可根据自身数据地址进行替换:
+ `data/AdvertiseGen/saves/`: `.jsonl`数据地址
+ `../checkpoints/glm-4-9b-chat/`: 模型地址
+ `configs/lora.yaml`: 配置文件地址
### 单机单卡 ### 单机单卡
```shell ```shell
...@@ -174,7 +179,10 @@ python finetune.py ../data/AdvertiseGen/saves/ ../checkpoints/glm-4-9b-chat/ con ...@@ -174,7 +179,10 @@ python finetune.py ../data/AdvertiseGen/saves/ ../checkpoints/glm-4-9b-chat/ con
- --device: 当前默认"cuda" - --device: 当前默认"cuda"
- --query: 待测输入语句,当前默认"你好" - --query: 待测输入语句,当前默认"你好"
``` ```bash
pip install -U huggingface_hub hf_transfer
export HF_ENDPOINT=https://hf-mirror.com/
python inference.py python inference.py
``` ```
...@@ -189,6 +197,12 @@ python trans_cli_vision_demo.py --model_name_or_path ../checkpoints/GLM-4V-9B ...@@ -189,6 +197,12 @@ python trans_cli_vision_demo.py --model_name_or_path ../checkpoints/GLM-4V-9B
python trans_web_demo.py --model_name_or_path ../checkpoints/GLM-4-9B-Chat python trans_web_demo.py --model_name_or_path ../checkpoints/GLM-4-9B-Chat
``` ```
### 验证微调后的模型
您可以在 `finetune_demo/inference.py` 中使用微调后的模型,仅需要一行代码就能简单的进行测试。
```shell
python inference.py your_finetune_path
```
## result ## result
<div align=center> <div align=center>
<img src="./doc/result.png" width=1500 heigh=400/> <img src="./doc/result.png" width=1500 heigh=400/>
...@@ -196,10 +210,11 @@ python trans_web_demo.py --model_name_or_path ../checkpoints/GLM-4-9B-Chat ...@@ -196,10 +210,11 @@ python trans_web_demo.py --model_name_or_path ../checkpoints/GLM-4-9B-Chat
### 精度 ### 精度
数据集:AdvertiseGen 数据集:AdvertiseGen
| device | iter | loss |
| device | iters | train_loss |
| :------: | :------: | :------: | | :------: | :------: | :------: |
| A800 | 1000 | 3.0781 | | A800 | 1000 | 3.0219 |
| K100 | 1000 | 3.0734 | | K100 | 1000 | 3.0205 |
## 应用场景 ## 应用场景
### 算法类别 ### 算法类别
......
'''based on transformers'''
import torch import torch
import argparse import argparse
......
...@@ -38,7 +38,6 @@ def load_json_infos(file_path): ...@@ -38,7 +38,6 @@ def load_json_infos(file_path):
if __name__ == "__main__": if __name__ == "__main__":
files = ['train.json', 'dev.json'] files = ['train.json', 'dev.json']
for file in files: for file in files:
file_path = os.path.join(args.data_path, file) file_path = os.path.join(args.data_path, file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment