Commit 0938ae70 authored by zhaoying1's avatar zhaoying1
Browse files

fix save method of adapter_model.bin

parent 1b73554f
...@@ -3,4 +3,4 @@ COPY requirements.txt requirements.txt ...@@ -3,4 +3,4 @@ COPY requirements.txt requirements.txt
RUN source /opt/dtk-23.04/env.sh RUN source /opt/dtk-23.04/env.sh
RUN cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && echo 'Asia/Shanghai' >/etc/timezone RUN cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && echo 'Asia/Shanghai' >/etc/timezone
ENV LANG C.UTF-8 ENV LANG C.UTF-8
RUN pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com RUN pip install -r requirements.txt --no-dependencies -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
...@@ -23,9 +23,9 @@ docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk ...@@ -23,9 +23,9 @@ docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk
``` ```
安装docker中没有的依赖: 安装docker中没有的依赖:
``` ```
pip install transformers==4.28.0 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com pip install transformers==4.31.0 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
pip install accelerate==0.22.0 --no-dependencies -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
pip install datasets accelerate peft trl tiktoken jieba rouge-chinese nltk gradio matplotlib uvicore fastapi sse-starlette pip install datasets peft trl tiktoken jieba rouge-chinese nltk gradio matplotlib uvicore fastapi sse-starlette
``` ```
...@@ -51,9 +51,11 @@ conda create -n chatglm python=3.8 ...@@ -51,9 +51,11 @@ conda create -n chatglm python=3.8
3. 其它依赖库参照requirements.txt安装: 3. 其它依赖库参照requirements.txt安装:
``` ```
pip install -r requirements.txt pip install -r requirements.txt --no-dependencies -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
``` ```
说明:若在accelerate、transformers等库中遇到对deepspeed0.9.3的依赖,请注释掉相应的version check代码,目前暂未对deepspeed0.9.3进行适配,deepspeed0.9.2即可使用。
## 数据集 ## 数据集
输入数据为放置在项目[data](.data)目录下的 json 文件,用--dataset选项指定(参考下面示例),多个输入文件用`,`分隔。json 文件示例格式和字段说明如下: 输入数据为放置在项目[data](.data)目录下的 json 文件,用--dataset选项指定(参考下面示例),多个输入文件用`,`分隔。json 文件示例格式和字段说明如下:
...@@ -79,6 +81,8 @@ json 文件中存储一个列表,列表的每个元素是一个sample。其中 ...@@ -79,6 +81,8 @@ json 文件中存储一个列表,列表的每个元素是一个sample。其中
``` ```
数据集的使用方法请参考 [data/README.md](data/README_zh.md) 文件。 数据集的使用方法请参考 [data/README.md](data/README_zh.md) 文件。
注意:请配置[./src/llmtunerhparams/data_args.py](src/llmtuner/hparams/data_args.py)中L38的dataset_dir路径;
## 模型下载 ## 模型下载
Hugging Face模型下载地址: Hugging Face模型下载地址:
......
{ {
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"zero_allow_untested_optimizer": true, "zero_allow_untested_optimizer": true,
"fp16": { "fp16": {
"enabled": "auto", "enabled": "auto",
......
{ {
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"zero_allow_untested_optimizer": true, "zero_allow_untested_optimizer": true,
"fp16": { "fp16": {
"enabled": "auto", "enabled": "auto",
......
#!/bin/bash #!/bin/bash
export HSA_FORCE_FINE_GRAIN_PCIE=1
export MIOPEN_FIND_MODE=3 export MIOPEN_FIND_MODE=3
export MIOPEN_COMPILE_PARALLEL_LEVEL=1 export GPU_MAX_HW_QUEUES=16
export NCCL_PLUGIN_P2P=ucx
export RCCL_NCHANNELS=2
export NCCL_SOCKET_IFNAME=ib0
export NCCL_P2P_LEVEL=5
lrank=$OMPI_COMM_WORLD_LOCAL_RANK lrank=$OMPI_COMM_WORLD_LOCAL_RANK
echo "LRANK===============================$lrank" comm_rank=$OMPI_COMM_WORLD_RANK
RANK=$OMPI_COMM_WORLD_RANK comm_size=$OMPI_COMM_WORLD_SIZE
WORLD_SIZE=$OMPI_COMM_WORLD_SIZE export LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
export RANK=$comm_rank
export NCCL_IB_HCA=mlx5_0 #0号网卡 export WORLD_SIZE=$comm_size
export MASTER_ADDR=$1
export MASTER_PORT=29500
export NCCL_IB_HCA=mlx5
export NCCL_SOCKET_IFNAME=ib0
export HIP_DIRECT_DISPATCH=0
APP="python3 ../src/train_bash.py --stage sft \ APP="python3 ../src/train_bash.py --stage sft \
--model_name_or_path ../../baichun-7b \ --model_name_or_path ../../baichuan-13b-base \
--do_train \ --do_train \
--template default \ --template default \
--dataset alpaca_gpt4_en,alpaca_gpt4_zh,codealpaca \ --dataset alpaca_gpt4_en \
--finetuning_type lora \ --finetuning_type lora \
--lora_rank 16 \ --lora_rank 16 \
--lora_target W_pack,o_proj,gate_proj,down_proj,up_proj \ --lora_target W_pack,o_proj,gate_proj,down_proj,up_proj \
--output_dir output/baichuan-7b-lora-2-3 \ --output_dir out/baichuan-7b-lora-test7 \
--per_device_train_batch_size 8 \ --per_device_train_batch_size 1 \
--per_device_eval_batch_size 8 \ --per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \ --gradient_accumulation_steps 1 \
--preprocessing_num_workers 16 \ --preprocessing_num_workers 8 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
--save_steps 2000 \ --save_steps 2 \
--eval_steps 2 \
--learning_rate 1e-4 \ --learning_rate 1e-4 \
--max_grad_norm 0.5 \ --max_grad_norm 0.5 \
--num_train_epochs 1.0 \ --num_train_epochs 1.0 \
--val_size 0.001 \
--evaluation_strategy steps \
--load_best_model_at_end \
--plot_loss \ --plot_loss \
--fp16 \ --fp16 \
--deepspeed deepspeed.json --deepspeed deepspeed.json
......
#!/bin/bash #!/bin/bash
#SBATCH -p kshdnormal #SBATCH -p kshdexclu11
#SBATCH -N 32 #SBATCH -N 4
#SBATCH --cpus-per-task=1 #SBATCH --cpus-per-task=1
#SBATCH --ntasks-per-node=32 #SBATCH --ntasks-per-node=32
#SBATCH --gres=dcu:4 #SBATCH --gres=dcu:4
#SBATCH -J baichuan #SBATCH -J baichuan
#SBATCH -o logs-7B/baichuan-lora-%j.out #SBATCH -o logs-13B/baichuan-lora-%j.out
#SBATCH -e logs-7B/baichuan-lora-%j.err #SBATCH -e logs-13B/baichuan-lora-%j.err
ulimit -u 200000 #SBATCH --exclusive
ulimit -s unlimited
export HIP_VISIBLE_DEVICES=0,1,2,3
export MIOPEN_FIND_MODE=3
export MIOPEN_DEBUG_CONV_IMPLICIT_GEMM=0
export MIOPEN_USER_DB_PATH=/tmp/miopen-udb
export MIOPEN_CUSTOM_CACHE_DIR=/tmp/miopen-cache
export NCCL_SOCKET_IFNAME=ib0
export HSA_FORCE_FINE_GRAIN_PCIE=1 export HSA_FORCE_FINE_GRAIN_PCIE=1
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
export NCCL_IB_HCA=mlx5
export NCCL_DEBUG=INFO export NCCL_DEBUG=INFO
export MIOPEN_FIND_MODE=3
export HSA_FORCE_FINE_GRAIN_PCIE=1
export MIOPEN_COMPILE_PARALLEL_LEVEL=1 export MIOPEN_COMPILE_PARALLEL_LEVEL=1
export NCCL_PLUGIN_P2P=ucx export NCCL_PLUGIN_P2P=ucx
export NCCL_SOCKET_IFNAME=ib0
export NCCL_P2P_LEVEL=5 export NCCL_P2P_LEVEL=5
echo "START TIME: $(date)" echo "START TIME: $(date)"
hostfile=./hostfile/$SLURM_JOB_ID
nodes=($(scontrol show hostnames $SLURM_JOB_NODELIST ))
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
echo Node IP: $head_node_ip
echo headnode: $head_node
NODE_RANK=$SLURM_NODEID
hostfile=./hostfile/$SLURM_JOB_ID #获取节点号
scontrol show hostnames $SLURM_JOB_NODELIST > ${hostfile} scontrol show hostnames $SLURM_JOB_NODELIST > ${hostfile}
rm `pwd`/hostfile-dl -f rm `pwd`/hostfile-dl -f
for i in `cat $hostfile` for i in `cat $hostfile`
do do
echo ${i} slots=4 >> `pwd`/hostfile/hostfile-dl-$SLURM_JOB_ID echo ${i} slots=4 >> `pwd`/hostfile/hostfile-dl-$SLURM_JOB_ID #节点号
done done
np=$(cat $hostfile|sort|uniq |wc -l) np=$(cat $hostfile|sort|uniq |wc -l) #节点去重
np=$(($np*4)) np=$(($np*4))
nodename=$(cat $hostfile |sed -n "1p") nodename=$(cat $hostfile |sed -n "1p") #读取每行节点 第一个是主节点
dist_url=`echo $nodename | awk '{print $1}'` dist_url=`echo $nodename | awk '{print $1}'`
mpirun -np $np --allow-run-as-root --hostfile hostfile/hostfile-dl-$SLURM_JOB_ID --bind-to none `pwd`/run-7b-sft-lora-single.sh $dist_url $np mpirun -np $np --allow-run-as-root --hostfile hostfile/hostfile-dl-$SLURM_JOB_ID --bind-to none `pwd`/run-7b-single-lora.sh $dist_url
import torch import torch
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template_and_fix_tokenizer from llmtuner.extras.template import get_template_and_fix_tokenizer
...@@ -14,7 +14,6 @@ class ChatModel: ...@@ -14,7 +14,6 @@ class ChatModel:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.model = self.model.eval() # enable evaluation mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.system_prompt = data_args.system_prompt self.system_prompt = data_args.system_prompt
...@@ -41,26 +40,30 @@ class ChatModel: ...@@ -41,26 +40,30 @@ class ChatModel:
max_length = input_kwargs.pop("max_length", None) max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None)
gen_kwargs = self.generating_args.to_dict() generating_args = self.generating_args.to_dict()
gen_kwargs.update(dict( generating_args.update(dict(
input_ids=input_ids, do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"], temperature=temperature or generating_args["temperature"],
temperature=temperature or gen_kwargs["temperature"], top_p=top_p or generating_args["top_p"],
top_p=top_p or gen_kwargs["top_p"], top_k=top_k or generating_args["top_k"],
top_k=top_k or gen_kwargs["top_k"], repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
eos_token_id=list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)), pad_token_id=self.tokenizer.pad_token_id
pad_token_id=self.tokenizer.pad_token_id,
logits_processor=get_logits_processor()
)) ))
if max_length: if max_length:
gen_kwargs.pop("max_new_tokens", None) generating_args.pop("max_new_tokens", None)
gen_kwargs["max_length"] = max_length generating_args["max_length"] = max_length
if max_new_tokens: if max_new_tokens:
gen_kwargs.pop("max_length", None) generating_args.pop("max_length", None)
gen_kwargs["max_new_tokens"] = max_new_tokens generating_args["max_new_tokens"] = max_new_tokens
gen_kwargs = dict(
inputs=input_ids,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor()
)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
......
...@@ -31,11 +31,15 @@ def preprocess_dataset( ...@@ -31,11 +31,15 @@ def preprocess_dataset(
yield query, response, history, system yield query, response, history, system
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `X1 X2 X3 ...` (without <eos>) # build grouped texts with format `X1 X2 X3 ...`
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding):
kwargs = dict(allowed_special="all") kwargs = dict(allowed_special="all") # for tiktoken tokenizer (Qwen)
else: else:
kwargs = dict(add_special_tokens=False) kwargs = dict(add_special_tokens=True)
if hasattr(tokenizer, "add_bos_token") and hasattr(tokenizer, "add_eos_token"):
setattr(tokenizer, "add_bos_token", True) # for LLaMA tokenizer
setattr(tokenizer, "add_eos_token", True)
tokenized_examples = tokenizer(examples["prompt"], **kwargs) tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
...@@ -59,7 +63,9 @@ def preprocess_dataset( ...@@ -59,7 +63,9 @@ def preprocess_dataset(
for query, response, history, system in construct_example(examples): for query, response, history, system in construct_example(examples):
input_ids, labels = [], [] input_ids, labels = [], []
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system): for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length: if len(target_ids) > data_args.max_target_length:
...@@ -68,8 +74,17 @@ def preprocess_dataset( ...@@ -68,8 +74,17 @@ def preprocess_dataset(
if len(input_ids) + len(source_ids) + len(target_ids) > max_length: if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
break break
if turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids input_ids += source_ids + target_ids
labels += [IGNORE_INDEX] * len(source_ids) + target_ids labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
...@@ -89,6 +104,9 @@ def preprocess_dataset( ...@@ -89,6 +104,9 @@ def preprocess_dataset(
if len(target_ids) > data_args.max_target_length: if len(target_ids) > data_args.max_target_length:
target_ids = target_ids[:data_args.max_target_length] target_ids = target_ids[:data_args.max_target_length]
if template.efficient_eos:
target_ids += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(source_ids) model_inputs["input_ids"].append(source_ids)
model_inputs["attention_mask"].append([1] * len(source_ids)) model_inputs["attention_mask"].append([1] * len(source_ids))
model_inputs["labels"].append(target_ids) model_inputs["labels"].append(target_ids)
...@@ -109,6 +127,10 @@ def preprocess_dataset( ...@@ -109,6 +127,10 @@ def preprocess_dataset(
if len(rejected_ids) > data_args.max_target_length: if len(rejected_ids) > data_args.max_target_length:
rejected_ids = rejected_ids[:data_args.max_target_length] rejected_ids = rejected_ids[:data_args.max_target_length]
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
model_inputs["prompt_ids"].append(prompt_ids) model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids) model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids) model_inputs["rejected_ids"].append(rejected_ids)
......
...@@ -5,7 +5,9 @@ from typing import TYPE_CHECKING ...@@ -5,7 +5,9 @@ from typing import TYPE_CHECKING
from datetime import timedelta from datetime import timedelta
from transformers import TrainerCallback from transformers import TrainerCallback
from transformers.trainer_utils import has_length from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from transformers.training_args import TrainingArguments
from llmtuner.extras.constants import LOG_FILE_NAME from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
...@@ -17,6 +19,24 @@ if TYPE_CHECKING: ...@@ -17,6 +19,24 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
class SavePeftModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
return control
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
r"""
Event called at the end of training.
"""
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
return control
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
def __init__(self, runner=None): def __init__(self, runner=None):
......
...@@ -2,28 +2,16 @@ IGNORE_INDEX = -100 ...@@ -2,28 +2,16 @@ IGNORE_INDEX = -100
LOG_FILE_NAME = "trainer_log.jsonl" LOG_FILE_NAME = "trainer_log.jsonl"
VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
STAGES = [ TRAINING_STAGES = {
"SFT", "Supervised Fine-Tuning": "sft",
"Reward Modeling",
"PPO",
"DPO",
"Pre-Training"
]
DATASET_STAGE_MAP = {
"SFT": "sft",
"Pre-Training": "pt",
"Reward Modeling": "rm", "Reward Modeling": "rm",
"PPO": "sft", "PPO": "ppo",
"DPO": "rm" "DPO": "dpo",
"Pre-Training": "pt"
} }
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
...@@ -54,11 +42,16 @@ SUPPORTED_MODELS = { ...@@ -54,11 +42,16 @@ SUPPORTED_MODELS = {
"Baichuan-7B": "baichuan-inc/Baichuan-7B", "Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base", "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat", "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
"Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
"InternLM-7B": "internlm/internlm-7b", "InternLM-7B": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b", "InternLM-7B-Chat": "internlm/internlm-chat-7b",
"Qwen-7B": "Qwen/Qwen-7B", "Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"XVERSE-13B": "xverse/XVERSE-13B", "XVERSE-13B": "xverse/XVERSE-13B",
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b" "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
} }
...@@ -70,6 +63,7 @@ DEFAULT_MODULE = { ...@@ -70,6 +63,7 @@ DEFAULT_MODULE = {
"BLOOMZ": "query_key_value", "BLOOMZ": "query_key_value",
"Falcon": "query_key_value", "Falcon": "query_key_value",
"Baichuan": "W_pack", "Baichuan": "W_pack",
"Baichuan2": "W_pack",
"InternLM": "q_proj,v_proj", "InternLM": "q_proj,v_proj",
"Qwen": "c_attn", "Qwen": "c_attn",
"XVERSE": "q_proj,v_proj", "XVERSE": "q_proj,v_proj",
...@@ -80,7 +74,9 @@ DEFAULT_TEMPLATE = { ...@@ -80,7 +74,9 @@ DEFAULT_TEMPLATE = {
"LLaMA2": "llama2", "LLaMA2": "llama2",
"ChineseLLaMA2": "llama2_zh", "ChineseLLaMA2": "llama2_zh",
"Baichuan": "baichuan", "Baichuan": "baichuan",
"Baichuan2": "baichuan2",
"InternLM": "intern", "InternLM": "intern",
"Qwen": "chatml", "Qwen": "chatml",
"XVERSE": "xverse",
"ChatGLM2": "chatglm2" "ChatGLM2": "chatglm2"
} }
import gc
import torch import torch
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
...@@ -28,12 +27,6 @@ class AverageMeter: ...@@ -28,12 +27,6 @@ class AverageMeter:
self.avg = self.sum / self.count self.avg = self.sum / self.count
def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r""" r"""
Returns the number of trainable parameters and number of all parameters in the model. Returns the number of trainable parameters and number of all parameters in the model.
...@@ -56,48 +49,17 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: ...@@ -56,48 +49,17 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param return trainable_params, all_param
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32 def get_logits_processor() -> LogitsProcessorList:
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 logits_processor = LogitsProcessorList()
def prepare_model_for_training( logits_processor.append(InfNanRemoveLogitsProcessor())
model: "PreTrainedModel", return logits_processor
finetuning_type: str,
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32)
if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_layer_name):
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype
class CastOutputToFloat(torch.nn.Sequential):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32)
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
return model
def torch_gc() -> None: def torch_gc() -> None:
r""" r"""
Collects GPU memory. Collects GPU memory.
""" """
gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
......
# coding=utf-8
# Modified from:
# [1] https://huggingface.co/Birchlabs/flash_llama/blob/main/modeling_flash_llama.py
# [2] https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py
# [3] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# With fix from Alex Birch: https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/17
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from transformers.models.llama.configuration_llama import LlamaConfig
try:
from flash_attn.flash_attn_interface import (
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
)
from flash_attn.bert_padding import unpad_input, pad_input
flash_attn_v2_installed = True
print('>>>> Flash Attention installed')
except ImportError:
flash_attn_v2_installed = False
raise ImportError('Please install Flash Attention: `pip install flash-attn --no-build-isolation`')
try:
from flash_attn.layers.rotary import apply_rotary_emb_func
flash_rope_installed = True
print('>>>> Flash RoPE installed')
except ImportError:
flash_rope_installed = False
raise ImportError('Please install RoPE kernels: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`')
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
def rmsnorm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
return (weight * hidden_states).to(input_dtype)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.register_buffer(
"variance_epsilon",
torch.tensor(eps),
persistent=False,
)
def forward(self, hidden_states):
return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
class FlashRotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None,
scaling_factor=1.0, pos_idx_in_fp32=True, device=None):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
scaling_factor: RotaryEmbedding extended with linear scaling.
"""
super().__init__()
self.dim = dim
self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = self._compute_inv_freq(device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
self.scaling_factor = scaling_factor
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
/ (1.4 * dim) if scale_base is not None else None)
self.register_buffer("scale", scale)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _compute_inv_freq(self, device=None):
return 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
dtype=torch.float32) / self.dim))
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
t /= self.scaling_factor
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if self.inv_freq.dtype != torch.float32:
inv_freq = self.inv_freq.to(torch.float32)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
t /= self.scaling_factor
inv_freq = self.inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2) / self.scale_base)
scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
"""
q: (batch, seqlen, nheads, headdim)
k: (batch, seqlen, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
"""
self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype)
if self.scale is None:
return apply_rotary_emb_func(
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved, True # inplace=True
), apply_rotary_emb_func(
k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self.interleaved, True # inplace=True
)
else:
assert False
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, slen, _, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, :, :, None, :].expand(batch, slen, 2, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(batch, slen, 2, num_key_value_heads * n_rep, head_dim)
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.register_buffer(
"norm_factor",
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False,
)
if self.config.rope_scaling is None:
scaling_factor = 1
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
assert scaling_type == 'linear'
self.rotary_emb = FlashRotaryEmbedding(
self.head_dim, base=10000, interleaved=False, scaling_factor=scaling_factor,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
is_padded_inputs: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, h_size = hidden_states.size()
has_layer_past = past_key_value is not None
if has_layer_past:
past_kv = past_key_value[0]
past_len = past_key_value[1]
else:
past_len = 0
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = q.view(bsz, q_len, self.num_heads, self.head_dim)
k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
q, k = self.rotary_emb(q, k, past_len)
kv = torch.stack([k, v], 2)
kv = repeat_kv(kv, self.num_key_value_groups)
# Cache QKV values
if has_layer_past:
new_len = past_len+q.size(1)
if new_len > past_kv.size(1):
past_kv = torch.cat([past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1)
past_kv[:, past_len:new_len] = kv
kv = past_kv[:, :new_len]
else:
past_kv = kv
past_key_value = (past_kv, past_len+q.size(1)) if use_cache else None
if is_padded_inputs:
# varlen, ignore padding tokens, efficient for large batch with many paddings
logger.warning_once("padded")
assert attention_mask is not None
unpadded_kv, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(kv, attention_mask)
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):])
attn_outputs = flash_attn_varlen_kvpacked_func(
unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k,
dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
causal=(not has_layer_past), return_attn_probs=output_attentions
)
attn_output = attn_outputs[0] if output_attentions else attn_outputs
attn_output = pad_input(
attn_output, indices_q, bsz, q_len
).reshape(bsz, q_len, h_size)
attn_weights = attn_outputs[2] if output_attentions else None
else:
# no padding tokens, more efficient
attn_outputs = flash_attn_kvpacked_func(
q, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions)
attn_output = attn_outputs[0] if output_attentions else attn_outputs
attn_output = attn_output.reshape(bsz, q_len, h_size)
attn_weights = attn_outputs[2] if output_attentions else None
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
is_padded_inputs: Optional[bool] = False,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
is_padded_inputs=is_padded_inputs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
LLAMA_START_DOCSTRING, LLAMA_INPUTS_DOCSTRING = "", ""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
config_class = LlamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlamaModel):
module.gradient_checkpointing = value
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
is_padded_inputs: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
position_ids = None
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
is_padded_inputs
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
is_padded_inputs=is_padded_inputs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
is_padded_inputs: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_padded_inputs = ((attention_mask is not None) and (not attention_mask.all().item()))
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: "CausalLMOutputWithPast" = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_padded_inputs=is_padded_inputs,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"is_padded_inputs": ((attention_mask is not None) and (not attention_mask.all().item()))
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
import os import os
import torch import torch
from typing import Dict from transformers.trainer import WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import load_sharded_checkpoint
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor] = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if os.path.exists(weights_file):
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
else:
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
return False
return True
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME) vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if not os.path.exists(valuehead_file): if not os.path.exists(vhead_file):
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir)) logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
return False return False
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu") vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"]) model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"]) model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])) model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True return True
...@@ -20,6 +20,7 @@ class Template: ...@@ -20,6 +20,7 @@ class Template:
sep: List[Union[str, Dict[str, str]]] sep: List[Union[str, Dict[str, str]]]
stop_words: List[str] stop_words: List[str]
use_history: bool use_history: bool
efficient_eos: bool
def encode_oneturn( def encode_oneturn(
self, self,
...@@ -74,19 +75,19 @@ class Template: ...@@ -74,19 +75,19 @@ class Template:
self, self,
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
if ( if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
tokenizer.bos_token_id is not None
and getattr(tokenizer, "add_bos_token", True)
): # baichuan-13b has no bos token
bos_ids = [tokenizer.bos_token_id] bos_ids = [tokenizer.bos_token_id]
else: else: # baichuan, qwen and gpt2 models have no bos token
bos_ids = [] # bos token is optional bos_ids = []
if tokenizer.eos_token_id is not None: if tokenizer.eos_token_id is None:
eos_ids = [tokenizer.eos_token_id]
else:
raise ValueError("EOS token is required.") raise ValueError("EOS token is required.")
if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
eos_ids = []
else:
eos_ids = [tokenizer.eos_token_id]
return bos_ids, eos_ids return bos_ids, eos_ids
def _encode( def _encode(
...@@ -137,6 +138,8 @@ class Template: ...@@ -137,6 +138,8 @@ class Template:
token_ids = [] token_ids = []
for elem in context: for elem in context:
if isinstance(elem, str): if isinstance(elem, str):
if len(elem) == 0:
continue
elem = elem.replace("{{system}}", system, 1) if system is not None else elem elem = elem.replace("{{system}}", system, 1) if system is not None else elem
elem = elem.replace("{{query}}", query, 1) if query is not None else elem elem = elem.replace("{{query}}", query, 1) if query is not None else elem
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
...@@ -184,7 +187,8 @@ def register_template( ...@@ -184,7 +187,8 @@ def register_template(
system: str, system: str,
sep: List[Union[str, Dict[str, str]]], sep: List[Union[str, Dict[str, str]]],
stop_words: Optional[List[str]] = [], stop_words: Optional[List[str]] = [],
use_history: Optional[bool] = True use_history: Optional[bool] = True,
efficient_eos: Optional[bool] = False
) -> None: ) -> None:
template_class = Llama2Template if "llama2" in name else Template template_class = Llama2Template if "llama2" in name else Template
templates[name] = template_class( templates[name] = template_class(
...@@ -193,7 +197,8 @@ def register_template( ...@@ -193,7 +197,8 @@ def register_template(
system=system, system=system,
sep=sep, sep=sep,
stop_words=stop_words, stop_words=stop_words,
use_history=use_history use_history=use_history,
efficient_eos=efficient_eos
) )
...@@ -201,31 +206,21 @@ def get_template_and_fix_tokenizer( ...@@ -201,31 +206,21 @@ def get_template_and_fix_tokenizer(
name: str, name: str,
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
) -> Template: ) -> Template:
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
additional_special_tokens = template.stop_words
if len(template.stop_words): # inplace method
if tokenizer.eos_token_id is not None:
additional_special_tokens.append(tokenizer.eos_token)
tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token
additional_special_tokens.pop(0)
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if tokenizer.eos_token_id is None: if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>" tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token)) logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None: tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token)) logger.info("Add pad token: {}".format(tokenizer.pad_token))
if name is None:
return None
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
dict(additional_special_tokens=additional_special_tokens), dict(additional_special_tokens=template.stop_words),
replace_additional_special_tokens=False replace_additional_special_tokens=False
) )
return template return template
...@@ -464,18 +459,18 @@ register_template( ...@@ -464,18 +459,18 @@ register_template(
], ],
system="", system="",
sep=[ sep=[
{"token": "<eoa>"},
"\n" "\n"
], ],
stop_words=[ stop_words=[
"</s>", # internlm cannot replace eos token
"<eoa>" "<eoa>"
] ],
efficient_eos=True
) )
r""" r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
Used for training and inference of the fine-tuned models.
""" """
register_template( register_template(
name="baichuan", name="baichuan",
...@@ -485,33 +480,31 @@ register_template( ...@@ -485,33 +480,31 @@ register_template(
prompt=[ prompt=[
{"token": "<reserved_102>"}, # user token {"token": "<reserved_102>"}, # user token
"{{query}}", "{{query}}",
{"token": "<reserved_103>"} # assistant token {"token": "<reserved_103>"} # assistant token
], ],
system="", system="",
sep=[], sep=[],
stop_words=[] efficient_eos=True
) )
r""" r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
Used for inference of the original model. https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
""" """
register_template( register_template(
name="baichuan_eval", name="baichuan2",
prefix=[ prefix=[
"{{system}}", "{{system}}"
{"token": "<reserved_102>"} # user token
], ],
prompt=[ prompt=[
{"token": "<reserved_106>"}, # user token
"{{query}}", "{{query}}",
{"token": "<reserved_103>"} # assistant token {"token": "<reserved_107>"} # assistant token
], ],
system="", system="",
sep=[], sep=[],
stop_words=[ efficient_eos=True
"<reserved_102>" # user token
]
) )
...@@ -524,7 +517,6 @@ register_template( ...@@ -524,7 +517,6 @@ register_template(
prefix=[ prefix=[
{"token": "<|system|>"}, {"token": "<|system|>"},
"\n{{system}}", "\n{{system}}",
{"token": "<|end|>"}
], ],
prompt=[ prompt=[
{"token": "<|user|>"}, {"token": "<|user|>"},
...@@ -535,11 +527,13 @@ register_template( ...@@ -535,11 +527,13 @@ register_template(
], ],
system="", system="",
sep=[ sep=[
{"token": "<|end|>"},
"\n" "\n"
], ],
stop_words=[ stop_words=[
"<|end|>" "<|end|>"
] ],
efficient_eos=True
) )
...@@ -550,8 +544,7 @@ register_template( ...@@ -550,8 +544,7 @@ register_template(
name="chatml", name="chatml",
prefix=[ prefix=[
{"token": "<|im_start|>"}, {"token": "<|im_start|>"},
"system\n{{system}}", "system\n{{system}}"
{"token": "<|im_end|>"}
], ],
prompt=[ prompt=[
{"token": "<|im_start|>"}, {"token": "<|im_start|>"},
...@@ -563,11 +556,13 @@ register_template( ...@@ -563,11 +556,13 @@ register_template(
], ],
system="You are a helpful assistant.", system="You are a helpful assistant.",
sep=[ sep=[
{"token": "<|im_end|>"},
"\n" "\n"
], ],
stop_words=[ stop_words=[
"<|im_end|>" "<|im_end|>"
] ],
efficient_eos=True
) )
...@@ -587,7 +582,8 @@ register_template( ...@@ -587,7 +582,8 @@ register_template(
system="", system="",
sep=[ sep=[
"\n\n" "\n\n"
] ],
efficient_eos=True
) )
......
...@@ -11,24 +11,23 @@ class DatasetAttr: ...@@ -11,24 +11,23 @@ class DatasetAttr:
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None dataset_sha1: Optional[str] = None
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
stage: Optional[str] = None ranking: Optional[bool] = False
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
def __post_init__(self):
self.prompt = "instruction"
self.query = "input"
self.response = "output"
self.history = None
@dataclass @dataclass
class DataArguments: class DataArguments:
r""" r"""
Arguments pertaining to what data we are going to input our model for training and evaluation. Arguments pertaining to what data we are going to input our model for training and evaluation.
""" """
template: str = field( template: Optional[str] = field(
default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."} metadata={"help": "Which template to use for constructing prompts in training and inference."}
) )
dataset: Optional[str] = field( dataset: Optional[str] = field(
...@@ -36,7 +35,7 @@ class DataArguments: ...@@ -36,7 +35,7 @@ class DataArguments:
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
) )
dataset_dir: Optional[str] = field( dataset_dir: Optional[str] = field(
default="data", default="/public/home/zhaoying1/work/Baichuan-13B-main/LLaMA-Efficient-Tuning-remove-pe/data",
metadata={"help": "The name of the folder containing datasets."} metadata={"help": "The name of the folder containing datasets."}
) )
split: Optional[str] = field( split: Optional[str] = field(
...@@ -48,7 +47,7 @@ class DataArguments: ...@@ -48,7 +47,7 @@ class DataArguments:
metadata={"help": "Enable streaming mode."} metadata={"help": "Enable streaming mode."}
) )
buffer_size: Optional[int] = field( buffer_size: Optional[int] = field(
default=16384, default=1024,
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."} metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
) )
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
...@@ -114,21 +113,14 @@ class DataArguments: ...@@ -114,21 +113,14 @@ class DataArguments:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
if "hf_hub_url" in dataset_info[name]: if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
"hf_hub",
dataset_name=dataset_info[name]["hf_hub_url"],
stage=dataset_info[name].get("stage", None))
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
"script",
dataset_name=dataset_info[name]["script_url"],
stage=dataset_info[name].get("stage", None))
else: else:
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr(
"file", "file",
dataset_name=dataset_info[name]["file_name"], dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None), dataset_sha1=dataset_info[name].get("file_sha1", None)
stage=dataset_info[name].get("stage", None)
) )
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
...@@ -137,5 +129,6 @@ class DataArguments: ...@@ -137,5 +129,6 @@ class DataArguments:
dataset_attr.response = dataset_info[name]["columns"].get("response", None) dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None) dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.system_prompt = prompt_list[i] dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr) self.dataset_list.append(dataset_attr)
...@@ -16,7 +16,7 @@ class ModelArguments: ...@@ -16,7 +16,7 @@ class ModelArguments:
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
) )
use_fast_tokenizer: Optional[bool] = field( use_fast_tokenizer: Optional[bool] = field(
default=False, default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
) )
use_auth_token: Optional[bool] = field( use_auth_token: Optional[bool] = field(
...@@ -27,10 +27,6 @@ class ModelArguments: ...@@ -27,10 +27,6 @@ class ModelArguments:
default="main", default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
) )
padding_side: Optional[Literal["left", "right"]] = field(
default="left",
metadata={"help": "The side on which the model should have padding applied."}
)
quantization_bit: Optional[int] = field( quantization_bit: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the model."} metadata={"help": "The number of bits to quantize the model."}
...@@ -47,6 +43,10 @@ class ModelArguments: ...@@ -47,6 +43,10 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Adopt scaled rotary positional embeddings."} metadata={"help": "Adopt scaled rotary positional embeddings."}
) )
flash_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enable flash attention for faster training."}
)
checkpoint_dir: Optional[str] = field( checkpoint_dir: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment