Unverified Commit df5e9c53 authored by YeAnbang's avatar YeAnbang Committed by GitHub
Browse files

[ColossalChat] Update RLHF V2 (#5286)



* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>
parent 36c4bb28
{
"chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
]
}
# Examples
## Table of Contents
- [Examples](#examples)
- [Table of Contents](#table-of-contents)
- [Install Requirements](#install-requirements)
- [Get Start with ColossalRun](#get-start-with-colossalrun)
- [Training Configuration](#training-configuration)
- [RLHF Stage 1: Supervised Instruction Tuning](#rlhf-training-stage1---supervised-instructs-tuning)
- [Step 1: Data Collection](#step-1-data-collection)
- [Step 2: Preprocessing](#step-2-preprocessing)
- [Step 3: Training](#step-3-training)
- [RLHF Stage 2: Training Reward Model](#rlhf-training-stage2---training-reward-model)
- [Step 1: Data Collection](#step-1-data-collection-1)
- [Step 2: Preprocessing](#step-2-preprocessing-1)
- [Step 3: Training](#step-3-training-1)
- [Features and Tricks in RM Training](#features-and-tricks-in-rm-training)
- [RLHF Stage 3: Proximal Policy Optimization](#rlhf-training-stage3---proximal-policy-optimization)
- [Step 1: Data Collection](#step-1-data-collection-2)
- [Step 2: Preprocessing](#step-2-preprocessing-2)
- [Step 3: Training](#step-3-training-3)
- [PPO Training Results](#sample-training-results-using-default-script)
- [Reward](#reward)
- [KL Divergence](#approximate-kl-divergence)
- [Note on PPO Training](#note-on-ppo-training)
- [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization)
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
- [Hardware Requirements](#hardware-requirements)
- [Inference example](#inference-example)
- [Attention](#attention)
---
## Install requirements
```shell
pip install -r requirements.txt
```
## Get Start with ColossalRun
You can use colossalai run to launch multi-nodes training:
```
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
train.py --OTHER_CONFIGURATIONS
```
Here is a sample hostfile:
```
hostname1
hostname2
hostname3
hostname4
```
Make sure master node can access all nodes (including itself) by ssh without password. Here are some other arguments.
- nnodes: number of nodes used in the training
- nproc-per-node: specifies the number of processes to be launched per node
- rdzv-endpoint: address of the host node
### Training Configuration
This section gives a simple introduction on different training strategies that you can use and how to use them with our boosters and plugins to reduce training time and VRAM consumption. For more detail regarding training strategies, please refer to [here](https://colossalai.org/docs/concepts/paradigms_of_parallelism). For details regarding boosters and plugins, please refer to [here](https://colossalai.org/docs/basics/booster_plugins).
<details><summary><b>Gemini</b></summary>
This plugin implements Zero-3 with chunk-based and heterogeneous memory management. It can train large models without much loss in speed. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](https://colossalai.org/docs/features/zero_with_chunk).
Below shows how to use the gemini in SFT training.
```
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--save_interval 5000 \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--plugin gemini \
--batch_size 4 \
--max_epochs 1 \
--accumulation_steps 1 \ # the gradient accumulation has to be disabled
--lr 2e-5 \
--max_len 2048 \
--use_wandb
```
</details>
<details><summary><b>Gemini-Auto</b></summary>
This option use gemini and will automatically offload tensors with low priority to cpu. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](https://colossalai.org/docs/features/zero_with_chunk).
Below shows how to use the gemin-auto in SFT training.
```
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--save_interval 5000 \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--plugin gemini_auto \
--batch_size 4 \
--max_epochs 1 \
--accumulation_steps 1 \ # the gradient accumulation has to be disabled
--lr 2e-5 \
--max_len 2048 \
--use_wandb
```
</details>
</details>
<details><summary><b>Zero2</b></summary>
This option will distribute the optimizer parameters and the gradient to multiple GPUs and won't offload weights to cpu. It uses reduce and gather to synchronize gradients and weights. It does not support local gradient accumulation. Though you can accumulate gradient if you insist, it cannot reduce communication cost. That is to say, it's not a good idea to use Zero-2 with pipeline parallelism.
Below shows how to use the zero2 in SFT training.
```
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--save_interval 5000 \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--plugin zero2 \
--batch_size 4 \
--max_epochs 1 \
--accumulation_steps 4 \
--lr 2e-5 \
--max_len 2048 \
--use_wandb
```
</details>
<details><summary><b>Zero2CPU</b></summary>
This option will distribute the optimizer parameters and the gradient to multiple GPUs as well as offload parameters to cpu. It does not support local gradient accumulation. Though you can accumulate gradient if you insist, it cannot reduce communication cost.
Below shows how to use the zero2-cpu in SFT training.
```
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--save_interval 5000 \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--plugin zero2_cpu \
--batch_size 4 \
--max_epochs 1 \
--accumulation_steps 4 \
--lr 2e-5 \
--max_len 2048 \
--use_wandb
```
</details>
<details><summary><b>Tensor Parallelism</b></summary>
This option support Tensor Parallelism (TP). Note that if you want to use TP, zero and pipeline parallelism will be disabled. TP split large model weights/optimizer parameters/gradients into multiple small ones and distributes them to multiple GPUs, hence it is recommended to use TP when your model is large (e.g. 20B and above) or your training algorithm consumes a lot of memory (e.g. PPO).
Below shows how to use the TP in PPO training.
```
colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 30039 train_ppo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--rm_pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--rm_checkpoint_path $REWARD_MODEL_PATH \
--prompt_dataset ${prompt_dataset[@]} \
--pretrain_dataset ${ptx_dataset[@]} \
--ptx_batch_size 1 \
--ptx_coef 0.0 \
--plugin "zero2" \
--save_interval 200 \
--save_path $SAVE_DIR \
--num_episodes 2000 \
--num_collect_steps 4 \
--num_update_steps 1 \
--experience_batch_size 8 \
--train_batch_size 4 \
--accumulation_steps 8 \
--tp 4 \ # TP size, nproc_per_node must be divisible by it
--lr 9e-6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--weight_decay 0.01 \
--warmup_steps 100 \
--grad_checkpoint \
--use_wandb
```
</details>
<details><summary><b>Gradient Checkpointing</b></summary>
This option saves VRAM consumption by selectively recomputing some of the intermediate value on-the-fly during the backward pass, rather than storing them in memory.
To enable gradient checkpointing, add --grad_checkpoint to your training script.
```
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--save_interval 5000 \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--plugin zero2_cpu \
--batch_size 4 \
--max_epochs 1 \
--accumulation_steps 4 \
--lr 2e-5 \
--max_len 2048 \
--grad_checkpoint \ # This enables gradient checkpointing
--use_wandb
```
</details>
<details><summary><b>Flash Attention</b></summary>
Details about flash attention can be found in the paper: [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135).
To enable flash attention, add --use_flash_attn to your training script.
```
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--save_interval 5000 \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--plugin zero2_cpu \
--batch_size 4 \
--max_epochs 1 \
--accumulation_steps 4 \
--lr 2e-5 \
--max_len 2048 \
--use_flash_attn \ # This enables flash attention
--use_wandb
```
</details>
<details><summary><b>Low Rank Adaption</b></summary>
Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). It dramatically reduce the VRAM consumption at the cost of sacrifice model capability. It is suitable for training LLM with constrained resources.
To enable LoRA, set --lora_rank to a positive value (usually between 20 and 64).
```
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--save_interval 5000 \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--plugin zero2_cpu \
--batch_size 4 \
--max_epochs 1 \
--accumulation_steps 4 \
--lr 2e-5 \
--max_len 2048 \
--lora_rank 32 \ # This enables LoRA
--use_wandb
```
</details>
<details><summary><b>Other Training Arguments</b></summary>
- grad_clip: gradient larger than this value will be clipped.
- weight_decay: weight decay hyper-parameter.
- warmup_steps: number of warmup steps used in setting up the learning rate scheduler.
- pretrain: pretrain model path, weights will be loaded from this pretrained model unless checkpoint_path is provided.
- tokenizer_dir: specify where to load the tokenizer, if not provided, tokenizer will be loaded from pretrain model path.
- dataset: a list of strings, each is a path to a folder contains buffered dataset files in arrow format.
- checkpoint_path: if provided, will load weights from the checkpoint_path.
- config_file: path to store the training config file.
- save_dir: path to store the model checkpoints.
- max_length: input will be padded/truncate to max_length before feeding to the model.
- max_epochs: number of epoch to train.
- batch_size: training batch size.
- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some device may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility.
- save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes.
- merge_lora_weights: whether to merge lora weights before saving the model
- lr: the learning rate used in training.
- accumulation_steps: accumulate gradient every accumulation_steps.
- log_dir: path to store the log.
- use_wandb: if this flag is up, you can view logs on wandb.
</details>
### RLHF Training Stage1 - Supervised Instructs Tuning
Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of the RLHF training process, as it involves training a machine learning model using human-provided instructions to learn the initial behavior for the task at hand. Here's a detailed guide on how to SFT your LLM with ColossalChat:
#### Step 1: Data Collection
The first step in Stage 1 is to collect a dataset of human demonstrations of the following format.
```json
[
{"messages":
[
{
"from": "human",
"content": "what are some pranks with a pen i can do?"
},
{
"from": "assistant",
"content": "Are you looking for practical joke ideas?"
},
...
]
},
...
]
```
#### Step 2: Preprocessing
Once you have collected your SFT dataset, you will need to preprocess it. This involves four steps: data cleaning, data deduplication, formatting and tokenization. In this section, we will focus on formatting and tokenization.
In this code we provide a flexible way for users to set the conversation template for formatting chat data using Huggingface's newest feature--- chat template. Please follow the following steps to define your chat template and preprocess your data.
- Step 1: (Optional). Define your conversation template. You need to provide a conversation template config file similar to the config files under the ./config/conversation_template directory. This config should include the following fields.
```json
{
"chat_template": (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating,
"system_message": A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added,
"stop_ids": (Optional), A list of string indicating the end of assistant's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically,
}
```
On your first run of the data preparation script, you only need to define the "chat_template" (if you want to use custom chat template) and the "system message" (if you want to use a custom system message),
- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./examples/data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path.
- Step 3: (Optional) Check the correctness of the processed data. We provided an easy way for you to do a manual checking on the processed data by checking the "$SAVE_DIR/jsonl/part-XXXX.jsonl" files.
Finishing the above steps, you have converted the raw conversation to the designated chat format and tokenized the formatted conversation, calculate input_ids, labels, attention_masks and buffer those into binary dataset files under "$SAVE_DIR/arrow/part-XXXX" folders.
For example, our Colossal-LLaMA-2 format looks like,
```
<s> A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
Human: <s> what are some pranks with a pen i can do?</s> Assistant: <s> Are you looking for practical joke ideas?</s>
...
```
#### Step 3: Training
Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
### RLHF Training Stage2 - Training Reward Model
Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model.
#### Step 1: Data Collection
Below shows the preference dataset format used in training the reward model.
```json
[
{"context": [
{
"from": "human",
"content": "Introduce butterflies species in Oregon."
}
]
"chosen": [
{
"from": "assistant",
"content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..."
},
...
],
"rejected": [
{
"from": "assistant",
"content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..."
},
...
]
},
...
]
```
#### Step 2: Preprocessing
Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
#### Step 3: Training
You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
#### Features and Tricks in RM Training
- We recommend using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets for training the reward model.
- We support 2 kinds of loss function named `log_sig`(used by OpenAI) and `log_exp`(used by Anthropic).
- We log the training accuracy `train/acc`, `reward_chosen` and `reward_rejected` to monitor progress during training.
- We use cosine-reducing lr-scheduler for RM training.
- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution.
#### Note on Reward Model Training
Before you move on the next stage, please check the following list to ensure that your reward model is stable and robust. You can check the reward chart and the accuracy chart on wandb.
- The mean reward for chosen data is much higher than those for rejected data
- The accuracy is larger than 0.5 by a significant margin (usually should be greater than 0.6)
- Optional:check the reward is positive for chosen data vice versa
Your training reward curves should look similar to the following charts.
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/mean_reward_chart.png">
</p>
### RLHF Training Stage3 - Proximal Policy Optimization
In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimization (PPO), which is the most complex part of the training process:
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/stage-3.jpeg" width=800/>
</p>
#### Step 1: Data Collection
PPO uses two kind of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
```json
[
{"messages":
[
{
"from": "human",
"content": "what are some pranks with a pen i can do?"
}
...
]
},
]
```
The second dataset--- pretrained dataset is optional, provide it if you want to use the ptx loss introduced in the [InstructGPT paper](https://arxiv.org/abs/2203.02155). It follows the following format.
```json
[
{
"source": "", # system instruction
"Target": "Provide a list of the top 10 most popular mobile games in Asia\nThe top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
},
...
]
```
#### Step 2: Preprocessing
To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh)
You can use the SFT dataset you prepared in the SFT stage or prepare a new one from different source for the ptx dataset. The ptx data is used to calculate ptx loss, which stablize the training according to the [InstructGPT paper](https://arxiv.org/pdf/2203.02155.pdf).
#### Step 3: Training
You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
```bash
--pretrain $PRETRAINED_MODEL_PATH \
--rm_pretrain $PRETRAINED_MODEL_PATH \ # reward model architectural
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--rm_checkpoint_path $REWARD_MODEL_PATH \ # reward model checkpoint path
--prompt_dataset ${prompt_dataset[@]} \ # List of string, prompt dataset
--conversation_template_config $CONVERSATION_TEMPLATE_CONFIG_PATH \ # path to the conversation template config file
--pretrain_dataset ${ptx_dataset[@]} \ # List of string, the sft dataset
--ptx_batch_size 1 \ # batch size for calculate ptx loss
--ptx_coef 0.0 \ # none-zero if ptx loss is enable
--num_episodes 2000 \ # number of episodes to train
--num_collect_steps 1 \
--num_update_steps 1 \
--experience_batch_size 8 \
--train_batch_size 4 \
--accumulation_steps 2
```
Each episode has two phases, the collect phase and the update phase. During the collect phase, we will collect experiences (answers generated by actor), store those in ExperienceBuffer. Then data in ExperienceBuffer is used during the update phase to update parameter of actor and critic.
- Without tensor parallelism,
```
experience buffer size
= num_process * num_collect_steps * experience_batch_size
= train_batch_size * accumulation_steps * num_process
```
- With tensor parallelism,
```
num_tp_group = num_process / tp
experience buffer size
= num_tp_group * num_collect_steps * experience_batch_size
= train_batch_size * accumulation_steps * num_tp_group
```
### Sample Training Results Using Default Script
#### Reward
<p align="center">
<img width="700" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/reward.png">
</p>
### Note on PPO Training
#### Q1: My reward is negative
Answer: Check your reward model trained in stage 1. If the reward model only generate negative reward, we actually will expect a negative reward. However, even though the reward is negative, the reward should go up.
#### Q2: My actor loss is negative
Answer: This is normal for actor loss as PPO doesn't restrict the actor loss to be positive.
#### Q3: My reward doesn't go up (decreases)
Answer: The causes to this problem are two-fold. Check your reward model, make sure that it gives positive and strong reward for good cases and negative, strong reward for bad responses. You should also try different hyperparameter settings.
#### Q4: Generation is garbage
Answer: Yes, this happens and is well documented by other implementations. After training for too many episodes, the actor gradually deviate from its original state, which may leads to decrease in language modeling capabilities. A way to fix this is to add supervised loss during PPO. Set ptx_coef to a none-zero value (between 0 and 1), which balances PPO loss and sft loss.
## Alternative Option For RLHF: Direct Preference Optimization
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
### DPO Training Stage1 - Supervised Instructs Tuning
Please refer the [sft section](#dpo-training-stage1---supervised-instructs-tuning) in the PPO part.
### DPO Training Stage2 - DPO Training
#### Step 1: Data Collection & Preparation
For DPO training, you only need the preference dataset. Please follow the instruction in the [preference dataset preparation section](#rlhf-training-stage2---training-reward-model) to prepare the preference data for DPO training.
#### Step 2: Training
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
#### DPO Result
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/DPO.png">
</p>
## Hardware Requirements
For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model on a dummy dataset with 2048 sequence length and 512 layout length with different tp_size (equal to the number of GPUs). In this experiment, we use H800 GPU with 80GB VRAM.
| PPO | tp=8 | tp=4 |
|-------|---------------|---------------|
| bs=1 | 18485.19 MB | 42934.45 MB |
| bs=4 | 25585.65 MB | 42941.93 MB |
| bs=16 | 41408.28 MB | 56778.97 MB |
| bs=30 | 64047.42 MB | failed |
For DPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
- 1 H800 GPU
- zero2-cpu, batch size=2, VRAM Usage=49873.90 MB
- zero2-cpu, batch size=4, VRAM Usage=60998.22 MB
- 4 H800 GPUs
- zero2, batch size=4, VRAM Usage=67544.47 MB
## Inference example
We support different inference options, including int8 and int4 quantization.
For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
## Attention
The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
......@@ -24,7 +24,9 @@ def main(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
strategy = GeminiStrategy(
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Prepare dataset scripts
Usage:
- For SFT dataset preparation (SFT)
python prepare_dataset.py --type sft \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
- For prompt dataset preparation (PPO)
python prepare_dataset.py --type prompt \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
- For Preference dataset preparation (DPO and Reward model training)
python prepare_dataset.py --type preference \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
"""
import argparse
import json
import math
import os
import random
import time
from multiprocessing import cpu_count
from coati.dataset import setup_conversation_template, supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
from datasets import dataset_dict, load_dataset
from transformers import AutoTokenizer
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--type",
type=str,
required=True,
default=None,
choices=["sft", "prompt", "preference"],
help="Type of dataset, chose from 'sft', 'prompt', 'preference'.",
)
parser.add_argument(
"--data_input_dirs",
type=str,
required=True,
default=None,
help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.",
)
parser.add_argument(
"--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
)
parser.add_argument(
"--conversation_template_config",
type=str,
default="conversation_template_config",
help="Path \
to save conversation template config files.",
)
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
parser.add_argument(
"--data_jsonl_output_dir",
type=str,
default="jsonl_output",
help="Output directory of spliced dataset with jsonl format",
)
parser.add_argument(
"--data_arrow_output_dir",
type=str,
default="arrow_output",
help="Output directory of spliced dataset with arrow format",
)
parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
parser.add_argument(
"--num_samples_per_datafile",
type=int,
default=-1,
help="Number of samples to be generated from each data file. -1 denote all samples.",
)
args = parser.parse_args()
if args.num_spliced_dataset_bins >= 100000:
raise ValueError("Too many spliced divisions, must be smaller than 100000")
assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
assert not os.path.exists(
args.data_jsonl_output_dir
), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
assert not os.path.exists(
args.data_arrow_output_dir
), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
os.makedirs(args.data_jsonl_output_dir)
os.makedirs(args.data_arrow_output_dir)
# Prepare to all input datasets
input_data_paths = []
input_data_dirs = args.data_input_dirs.split(",")
for ds_dir in input_data_dirs:
ds_dir = os.path.abspath(ds_dir)
assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}"
ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")]
ds_paths = [os.path.join(ds_dir, name) for name in ds_files]
input_data_paths.extend(ds_paths)
# Prepare to data splitting.
train_splits = []
split_interval = math.ceil(100 / args.num_spliced_dataset_bins)
for i in range(0, 100, split_interval):
start = i
end = i + split_interval
if end > 100:
end = 100
train_splits.append(f"train[{start}%:{end}%]")
# Prepare the tokenizer.
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir, use_fast=False, trust_remote_code=True)
if os.path.exists(args.conversation_template_config):
chat_template_config = json.load(open(args.conversation_template_config, "r", encoding="utf8"))
else:
chat_template_config = {
"system_message": "A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
} # Use default system message
if args.type == "preference":
if "stop_ids" not in chat_template_config:
# Ask the user to define stop_ids for PPO training
dummy_messages = [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
{"role": "user", "content": "Who made you?"},
{"role": "assistant", "content": "I am a chatbot trained by Colossal-AI."},
]
dummy_prompt = tokenizer.apply_chat_template(dummy_messages, tokenize=False)
tokenized = tokenizer(dummy_prompt, add_special_tokens=False)["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(tokenized, skip_special_tokens=False)
corresponding_str = [tokenizer.convert_tokens_to_string([token]) for token in tokens]
token_id_mapping = [{"token": s, "id": tokenized[i]} for i, s in enumerate(corresponding_str)]
stop_ids = input(
"For PPO, we recommend to provide stop_ids for the properly stop the generation during roll out stage. "
"stop_ids are the ids of repetitive pattern that indicate the end of the assistant's response. "
"Here is an example of formatted prompt and token-id mapping, you can set stop_ids by entering a list "
"of integers, separate by space, press `Enter` to end. Or you can press `Enter` without input if you are "
"not using PPO or you prefer to not set the stop_ids, in that case, stop_ids will be set to tokenizer.eos_token_id. "
f"\nPrompt:\n{dummy_prompt}\nToken-id Mapping:\n{token_id_mapping}\nstop_ids:"
)
if stop_ids == "":
chat_template_config["stop_ids"] = [tokenizer.eos_token_id]
else:
try:
chat_template_config["stop_ids"] = [int(s) for s in stop_ids.split()]
except ValueError:
raise ValueError("Invalid input, please provide a list of integers.")
else:
# Set stop_ids to eos_token_id for other dataset types if not exist
if "stop_ids" not in chat_template_config:
chat_template_config["stop_ids"] = [tokenizer.eos_token_id]
conversation_template = setup_conversation_template(
tokenizer, chat_template_config=chat_template_config, save_path=args.conversation_template_config
)
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
try:
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer.pad_token = tokenizer.eos_token
except AttributeError as e:
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
logger.warning(
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
)
list_dataset = load_dataset(
path="json",
data_files=input_data_paths,
cache_dir=os.path.join(args.data_cache_dir, "raw"),
keep_in_memory=False,
split=train_splits,
num_proc=cpu_count(),
)
if args.type == "sft":
preparation_function = supervised_tokenize_sft
elif args.type == "prompt":
preparation_function = tokenize_prompt_dataset
elif args.type == "preference":
preparation_function = tokenize_rlhf
else:
raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']")
for index, dataset in enumerate(list_dataset):
assert isinstance(dataset, dataset_dict.Dataset)
if len(dataset) == 0:
# Hack: Skip empty dataset. If dataset contains less than num_of_rank samples, some rank may have empty dataset and leads to error
continue
if args.num_samples_per_datafile > 0:
# limit the number of samples in each dataset
dataset = dataset.select(
random.sample(range(len(dataset)), min(args.num_samples_per_datafile, len(dataset)))
)
logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
dataset = dataset.map(
function=preparation_function,
fn_kwargs={
"tokenizer": tokenizer,
"conversation_template": conversation_template,
"max_length": args.max_length,
},
keep_in_memory=False,
num_proc=min(len(dataset), cpu_count()),
)
dataset = dataset.filter(
lambda data: data["chosen_input_ids" if args.type == "preference" else "input_ids"] is not None
)
# Save each jsonl spliced dataset.
output_index = "0" * (5 - len(str(index))) + str(index)
output_name = f"part-{output_index}"
output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl")
st = time.time()
with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer:
count = 0
for data_point in dataset:
if count % 500 == 0:
logger.info(f"processing {count} spliced data points for {fp_writer.name}")
count += 1
fp_writer.write(json.dumps(data_point, ensure_ascii=False) + "\n")
logger.info(
f"Current file {fp_writer.name}; "
f"Data size: {len(dataset)}; "
f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
)
# Save each arrow spliced dataset
output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name)
logger.info(f"Start to save {output_arrow_path}")
dataset = load_dataset(
path="json",
data_files=[output_jsonl_path],
cache_dir=os.path.join(args.data_cache_dir, "tokenized"),
keep_in_memory=False,
num_proc=cpu_count(),
split="train",
)
dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(dataset), cpu_count()))
if __name__ == "__main__":
main()
SAVE_DIR=""
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type preference \
--data_input_dirs "PATH/TO/PREFERENCE/DATA" \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow
SAVE_DIR=""
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type prompt \
--data_input_dirs /PATH/TO/PROMPT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow
SAVE_DIR=""
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type sft \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
"""
command line IO utils for chatbot
"""
import abc
import re
from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.history import InMemoryHistory
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
class ChatIO(abc.ABC):
@abc.abstractmethod
def prompt_for_input(self, role: str) -> str:
"""Prompt for input from a role."""
@abc.abstractmethod
def prompt_for_output(self, role: str):
"""Prompt for output from a role."""
@abc.abstractmethod
def stream_output(self, output_stream):
"""Stream output."""
class SimpleChatIO(ChatIO):
def prompt_for_input(self, role) -> str:
return input(f"{role}: ")
def prompt_for_output(self, role: str):
print(f"{role}: ", end="", flush=True)
def stream_output(self, output_stream):
pre = 0
for outputs in output_stream:
outputs = outputs.strip()
outputs = outputs.split(" ")
now = len(outputs) - 1
if now > pre:
print(" ".join(outputs[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(outputs[pre:]), flush=True)
return " ".join(outputs)
class RichChatIO(ChatIO):
def __init__(self):
self._prompt_session = PromptSession(history=InMemoryHistory())
self._completer = WordCompleter(words=["!exit", "!reset"], pattern=re.compile("$"))
self._console = Console()
def prompt_for_input(self, role) -> str:
self._console.print(f"[bold]{role}:")
prompt_input = self._prompt_session.prompt(
completer=self._completer,
multiline=False,
auto_suggest=AutoSuggestFromHistory(),
key_bindings=None,
)
self._console.print()
return prompt_input
def prompt_for_output(self, role: str) -> str:
self._console.print(f"[bold]{role}:")
def stream_output(self, output_stream):
"""Stream output from a role."""
# Create a Live context for updating the console output
with Live(console=self._console, refresh_per_second=60) as live:
# Read lines from the stream
for outputs in output_stream:
accumulated_text = outputs
if not accumulated_text:
continue
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
# in rich. The chatbots output treat "\n" as a new line for
# better compatibility with real-world text. However, rendering
# in markdown would break the format. It is because standard markdown
# treat a single "\n" in normal text as a space.
# Our workaround is adding two spaces at the end of each line.
# This is not a perfect solution, as it would
# introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not
# care about trailing spaces.
lines = []
for line in accumulated_text.splitlines():
lines.append(line)
if line.startswith("```"):
# Code block marker - do not add trailing spaces, as it would
# break the syntax highlighting
lines.append("\n")
else:
lines.append(" \n")
markdown = Markdown("".join(lines))
# Update the Live console output
live.update(markdown)
self._console.print()
return outputs
class DummyChatIO(ChatIO):
"""
Dummy ChatIO class for testing
"""
def __init__(self):
self.roles = []
self._console = Console()
def prompt_for_input(self, role) -> str:
self.roles.append(role)
if len(self.roles) == 1:
ret = "Hello"
elif len(self.roles) == 2:
ret = "What's the value of 1+1?"
else:
ret = "exit"
self._console.print(f"[bold]{role}:{ret}")
return ret
def prompt_for_output(self, role: str) -> str:
self._console.print(f"[bold]{role}:")
def stream_output(self, output_stream):
"""Stream output from a role."""
# Create a Live context for updating the console output
with Live(console=self._console, refresh_per_second=60) as live:
# Read lines from the stream
for outputs in output_stream:
accumulated_text = outputs
if not accumulated_text:
continue
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
# in rich. The chatbots output treat "\n" as a new line for
# better compatibility with real-world text. However, rendering
# in markdown would break the format. It is because standard markdown
# treat a single "\n" in normal text as a space.
# Our workaround is adding two spaces at the end of each line.
# This is not a perfect solution, as it would
# introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not
# care about trailing spaces.
lines = []
for line in accumulated_text.splitlines():
lines.append(line)
if line.startswith("```"):
# Code block marker - do not add trailing spaces, as it would
# break the syntax highlighting
lines.append("\n")
else:
lines.append(" \n")
markdown = Markdown("".join(lines))
# Update the Live console output
live.update(markdown)
self._console.print()
return outputs
simple_io = SimpleChatIO()
rich_io = RichChatIO()
dummy_io = DummyChatIO()
import argparse
import json
import os
from typing import Dict
import torch
from chatio import dummy_io, rich_io, simple_io
from coati.dataset.conversation import setup_conversation_template
from coati.models import generate_streaming
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def get_gpu_memory(max_gpus=None):
"""
Get the available memory for each GPU.
Args:
max_gpus (int, optional): The maximum number of GPUs to consider. Defaults to None.
Returns:
list: A list of available memory for each GPU.
"""
gpu_memory = []
num_gpus = torch.cuda.device_count() if max_gpus is None else min(max_gpus, torch.cuda.device_count())
for gpu_id in range(num_gpus):
# Code to get GPU memory goes here
with torch.cuda.device(gpu_id):
device = torch.cuda.current_device()
gpu_properties = torch.cuda.get_device_properties(device)
total_memory = gpu_properties.total_memory / (1024**3)
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
available_memory = total_memory - allocated_memory
gpu_memory.append(available_memory)
return gpu_memory
def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs):
"""
Load the model and tokenizer from the specified paths and move the model to the specified device.
Args:
model_path (str): The path to the pre-trained model.
tokenizer_path (str): The path to the pre-trained tokenizer.
device (str, optional): The device to move the model to. Defaults to "cuda".
**kwargs: Additional keyword arguments to be passed to the `AutoModelForCausalLM.from_pretrained` function.
Returns:
tuple: A tuple containing the loaded model and tokenizer.
"""
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
model.to(device)
return model, tokenizer
def _set_default_generate_kwargs(model: PreTrainedModel) -> Dict:
"""
Set default keyword arguments for generation based on the given model.
Args:
model (PreTrainedModel): The model used for generation.
Returns:
Dict: A dictionary containing the default keyword arguments for generation.
"""
unwrapped_model = model
new_kwargs = {}
# Use huggingface models method directly
if hasattr(unwrapped_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation
if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation
return new_kwargs
def generation_wrapper(*args, **kwargs):
input_ids = args[1]
tokenizer = args[2]
for output in generate_streaming(*args, **kwargs):
yield tokenizer.batch_decode(output[:, input_ids.size(1) :], skip_special_tokens=True)[0]
def main(args):
conversation_template_config = json.load(open(args.conversation_template_config, "r", encoding="utf8"))
max_new_tokens = args.max_new_tokens
model_max_length = args.model_max_length
model, tokenizer = load_model_and_tokenizer(
args.model_path, args.tokenizer_path or args.model_path, local_files_only=True
)
assert max_new_tokens <= model_max_length
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
try:
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer.pad_token = tokenizer.eos_token
except AttributeError as e:
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
tokenizer.padding_side = "left"
model_kwargs = {
"max_new_tokens": max_new_tokens,
# 'early_stopping': True,
# 'top_k': -1,
# 'top_p': 1.0,
# 'temperature': 1.0,
# 'temperature':0.1,
}
round = 1
conv = setup_conversation_template(tokenizer, conversation_template_config, args.conversation_template_config)
while True:
if args.io == "simple":
chat_io = simple_io
elif args.io == "rich":
chat_io = rich_io
elif args.io == "dummy":
chat_io = dummy_io
else:
raise ValueError(f"Unknown io type: {args.io}")
# raw_text = print(">>> Human:", end=" ")
inp = chat_io.prompt_for_input("user")
if not inp:
print("prompt should not be empty!")
continue
if inp.strip() == "clear":
conv.clear()
os.system("clear")
continue
if inp.strip() == "exit":
print("End of chat.")
break
query_text = inp.strip()
conv.append_message("user", query_text)
chat_io.prompt_for_output("assistant")
prompt = conv.get_prompt(add_generation_prompt=True)
print(prompt + "<end_of_prompt>")
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(
torch.cuda.current_device()
)
default_generate_kwargs = _set_default_generate_kwargs(model)
model_kwargs.update(default_generate_kwargs)
output_stream = generation_wrapper(
model,
input_ids,
tokenizer,
max_length=model_max_length,
temperature=0.7,
early_stopping=True,
stop_token_ids=conversation_template_config["stop_ids"],
**model_kwargs,
)
# print(f">>> Assistant:", end=" ")
outputs = chat_io.stream_output(output_stream)
conv.append_message("assistant", outputs.strip())
with open("round.txt", mode="a", encoding="utf-8") as f:
f.write("\n\n" + "=" * 10 + "\n")
f.write(f"round {round}:\n{conv.save_prompt()}\n\n")
f.write("=" * 10 + "\n")
# print(f">>> Assistant:", end=" ")
round += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default=None)
parser.add_argument("--tokenizer_path", type=str, default=None)
parser.add_argument("--conversation_template_config", type=str, default=None)
parser.add_argument("--model_max_length", type=int, default=2048)
parser.add_argument("--max_new_tokens", type=int, default=512)
parser.add_argument("--io", type=str, default="rich", choices=["simple", "rich", "dummy"])
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment