import argparse from loguru import logger import os from os.path import join import torch from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training import bitsandbytes as bnb from component.collator import PretrainCollator, SFTDataCollator from component.argument import CustomizedArguments from component.template import template_dict from component.dataset import ( UnifiedSFTDataset, ChatGLM2SFTDataset, ChatGLM3SFTDataset, UnifiedDPODataset ) from transformers import ( set_seed, HfArgumentParser, TrainingArguments, AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, Trainer, AddedToken ) import importlib if importlib.util.find_spec('unsloth') is not None: from unsloth import FastLanguageModel from datasets import load_dataset, concatenate_datasets import datasets from itertools import chain from tqdm import tqdm import json from trl import DPOTrainer, get_kbit_device_map import torch.nn as nn os.environ['TOKENIZERS_PARALLELISM'] = 'false' def setup_everything(): parser = argparse.ArgumentParser() # parser.add_argument("--train_args_file", type=str, default='train_args/pretrain/full/bloom-1b1-pretrain-full.json', help="") parser.add_argument("--train_args_file", type=str, default='train_args/sft/qlora/qwen-7b-sft-qlora.json', help="") parser.add_argument("--local_rank", type=int, help="") args = parser.parse_args() train_args_file = args.train_args_file # 读取训练的参数配置 parser = HfArgumentParser((CustomizedArguments, TrainingArguments)) # 解析得到自定义参数,以及自带参数 args, training_args = parser.parse_json_file(json_file=train_args_file) # 创建输出目录 if not os.path.exists(training_args.output_dir): os.makedirs(training_args.output_dir) logger.add(join(training_args.output_dir, 'train.log')) logger.info("train_args:{}".format(training_args)) # 加载训练配置文件 with open(train_args_file, "r") as f: train_args = json.load(f) # 保存训练参数到输出目录 with open(join(training_args.output_dir, 'train_args.json'), "w") as f: json.dump(train_args, f, indent=4) # 设置随机种子 set_seed(training_args.seed) # check some setting assert args.task_type in ['pretrain', 'sft', 'dpo'], "task_type should be in ['pretrain', 'sft', 'dpo']" assert args.train_mode in ['full', 'lora', 'qlora'], "task_type should be in ['full', 'lora', 'qlora']" assert sum([training_args.fp16, training_args.bf16]) == 1, "only one of fp16 and bf16 can be True" # assert not (args.task_type == 'dpo' and args.use_unsloth), 'We have not tested Unsloth during DPO yet. Please set use_unsloth=False when task_type=dpo' return args, training_args def find_all_linear_names(model, train_mode): """ 找出所有全连接层,为所有全连接添加adapter """ assert train_mode in ['lora', 'qlora'] cls = bnb.nn.Linear4bit if train_mode == 'qlora' else nn.Linear lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') lora_module_names = list(lora_module_names) logger.info(f'LoRA target module names: {lora_module_names}') return lora_module_names def load_pretrain_dataset(training_args, args, tokenizer): """ 多线程预处理预训练数据 """ def tokenize_function(examples): output = tokenizer(examples["text"]) output = {'input_ids': output.input_ids} return output def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= max_seq_length: total_length = (total_length // max_seq_length) * max_seq_length # Split by chunks of max_len. result = { k: [t[i: i + max_seq_length] for i in range(0, total_length, max_seq_length)] for k, t in concatenated_examples.items() } return result data_path = args.train_file max_seq_length = args.max_seq_length # 创建缓存路径 cache_dir = join(data_path, 'cache') os.makedirs(cache_dir, exist_ok=True) logger.info('Pretraining data path: {}'.format(data_path)) # 扫描所有jsonl文件 logger.info('Scanning all the training file...') files = [] for root, dir_names, file_names in os.walk(data_path): for file_name in file_names: file = join(root, file_name) if file_name.endswith('.jsonl'): files.append(file) logger.info(f'Total num of training file: {len(files)}') # 预处理所有文本,将其id化,并且进行packing操作 with training_args.main_process_first(desc="dataset map tokenization and grouping"): pretrain_dataset = [] # 汇总所有dataset for idx, file in enumerate(tqdm(files)): logger.info(f'Loading file: {file}') file_name = os.path.basename(file) file_name = file_name.replace('.jsonl', '') cache_path = os.path.join(cache_dir, file_name) os.makedirs(cache_path, exist_ok=True) try: processed_dataset = datasets.load_from_disk(cache_path, keep_in_memory=False) logger.info(f'Finished loading datasets-{file_name} from cache') except Exception: tmp_cache_path = join(cache_path, 'tmp') # 临时缓存目录,会被自动删除 logger.info(f'There is no cache of file {file_name}, start preprocessing...') raw_dataset = load_dataset("json", data_files=file, cache_dir=tmp_cache_path, keep_in_memory=False) tokenized_dataset = raw_dataset.map( tokenize_function, batched=True, num_proc=args.tokenize_num_workers, remove_columns="text", load_from_cache_file=True, keep_in_memory=False, cache_file_names={k: os.path.join(tmp_cache_path, 'tokenized.arrow') for k in raw_dataset}, desc="Running tokenizer on dataset", ) grouped_datasets = tokenized_dataset.map( group_texts, batched=True, num_proc=args.tokenize_num_workers, load_from_cache_file=True, keep_in_memory=False, cache_file_names={k: os.path.join(tmp_cache_path, 'grouped.arrow') for k in tokenized_dataset}, desc=f"Grouping texts in chunks of {max_seq_length}", ) processed_dataset = grouped_datasets processed_dataset.save_to_disk(cache_path) # 删除临时目录 # shutil.rmtree(tmp_cache_path) logger.info(f"Training number of {file_name}: {len(processed_dataset['train'])}") if idx == 0: pretrain_dataset = processed_dataset['train'] else: assert pretrain_dataset.features.type == processed_dataset["train"].features.type pretrain_dataset = concatenate_datasets([pretrain_dataset, processed_dataset["train"]]) logger.info(f"Total training number: {len(pretrain_dataset)}") return pretrain_dataset def load_tokenizer(args): config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True) # 加载tokenzier tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, trust_remote_code=True, # llama不支持fast use_fast=False if config.model_type == 'llama' or config.model_type == 'internlm2' else True ) # 部分模型的base与chat版本的tokenizer存在差异 if 'internlm2' in args.model_name_or_path.lower(): tokenizer._added_tokens_encoder.update({'<|im_start|>': 92543}) tokenizer._added_tokens_encoder.update({'<|im_end|>': 92542}) tokenizer._added_tokens_decoder.update({92543: AddedToken('<|im_start|>')}) tokenizer._added_tokens_decoder.update({92542: AddedToken('<|im_end|>')}) tokenizer.add_special_tokens({'additional_special_tokens': ['<|im_start|>', '<|im_end|>']}) elif 'orion' in args.model_name_or_path.lower(): tokenizer.add_special_tokens({'bos_token': '', 'eos_token': ''}) elif 'gemma' in args.model_name_or_path.lower(): tokenizer.add_special_tokens({'additional_special_tokens': ['', '']}) if tokenizer.__class__.__name__ == 'QWenTokenizer': tokenizer.pad_token_id = tokenizer.eod_id tokenizer.bos_token_id = tokenizer.eod_id tokenizer.eos_token_id = tokenizer.eod_id if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token assert tokenizer.pad_token_id is not None, "pad_token_id should not be None" assert tokenizer.eos_token_id is not None, "eos_token_id should not be None" logger.info(f'vocab_size of tokenizer: {tokenizer.vocab_size}') return tokenizer def load_unsloth_model(args, training_args): model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.model_name_or_path, max_seq_length=args.max_seq_length, dtype=None, trust_remote_code=True, load_in_4bit=True if args.train_mode == 'qlora' else False, ) if args.train_mode in ['lora', 'qlora']: logger.info('Initializing PEFT Model...') target_modules = find_all_linear_names(model, args.train_mode) model = FastLanguageModel.get_peft_model( model, r=args.lora_rank, target_modules=target_modules, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", use_gradient_checkpointing=True, random_state=training_args.seed, max_seq_length=args.max_seq_length, ) logger.info(f'target_modules: {target_modules}') return { 'model': model, 'ref_model': None, 'peft_config': None } def load_model(args, training_args): """ 加载模型 """ assert training_args.bf16 or training_args.fp16, 'bf16 or fp16 should be True' logger.info(f'Loading model from base model: {args.model_name_or_path}') logger.info(f'Train model with {args.train_mode}') # init model kwargs # todo add flash attention # attn_implementation = None torch_dtype = torch.float16 if training_args.fp16 else torch.bfloat16 if args.train_mode == 'qlora': quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16 if training_args.fp16 else torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, ) else: quantization_config = None model_kwargs = dict( trust_remote_code=True, # attn_implementation=attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs) # moe模型,需要考虑负载均衡的loss if 'output_router_logits' in model.config.to_dict(): logger.info('set output_router_logits as True') model.config.output_router_logits = True # QLoRA: casts all the non int8 modules to full precision (fp32) for stability if args.train_mode == 'qlora' and args.task_type in ['pretrain', 'sft']: model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) # LoRA: Enables the gradients for the input embeddings if args.train_mode == 'lora' and args.task_type in ['pretrain', 'sft']: # For backward compatibility if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) # init peft_config if args.train_mode == 'full': peft_config = None else: # 找到所有需要插入adapter的全连接层 target_modules = find_all_linear_names(model, args.train_mode) peft_config = LoraConfig( r=args.lora_rank, lora_alpha=args.lora_alpha, target_modules=target_modules, lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM", ) # init peft model if args.train_mode in ['lora', 'qlora'] and args.task_type in ['pretrain', 'sft']: model = get_peft_model(model, peft_config) logger.info(f'memory footprint of model: {model.get_memory_footprint() / (1024 * 1024 * 1024)} GB') model.print_trainable_parameters() # init ref_model if args.task_type == 'dpo': ref_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs) if args.train_mode == 'full' else None # pretrain和sft,不需要ref_model else: ref_model = None # 计算模型参数量 total = sum(p.numel() for p in model.parameters()) logger.info("Total model params: %.2fM" % (total / 1e6)) return { 'model': model, 'ref_model': ref_model, 'peft_config': peft_config } def load_sft_dataset(args, tokenizer): if args.template_name not in template_dict.keys(): raise Exception(f"template_name doesn't exist, all template_name: {template_dict.keys()}") template = template_dict[args.template_name] if 'chatglm2' in args.model_name_or_path.lower(): logger.info('Loading data with ChatGLM2SFTDataset') train_dataset = ChatGLM2SFTDataset(args.train_file, tokenizer, args.max_seq_length, template) elif 'chatglm3' in args.model_name_or_path.lower(): logger.info('Loading data with ChatGLM3SFTDataset') train_dataset = ChatGLM3SFTDataset(args.train_file, tokenizer, args.max_seq_length, template) else: logger.info('Loading data with UnifiedSFTDataset') train_dataset = UnifiedSFTDataset(args.train_file, tokenizer, args.max_seq_length, template) return train_dataset def load_dpo_dataset(args, tokenizer): if args.template_name not in template_dict.keys(): raise Exception(f"template_name doesn't exist, all template_name: {template_dict.keys()}") template = template_dict[args.template_name] train_dataset = UnifiedDPODataset(args.train_file, tokenizer, args.max_seq_length, args.max_prompt_length, template) return train_dataset def init_components(args, training_args): """ 初始化各个组件 """ training_args.ddp_find_unused_parameters = False logger.info('Initializing components...') # 加载tokenizer tokenizer = load_tokenizer(args) # 加载model if args.use_unsloth: components = load_unsloth_model(args, training_args) else: components = load_model(args, training_args) model = components['model'] ref_model = components['ref_model'] peft_config = components['peft_config'] # 初始化dataset和collator if args.task_type == 'pretrain': logger.info('Train model with pretrain task') train_dataset = load_pretrain_dataset(training_args, args, tokenizer) data_collator = PretrainCollator(tokenizer, args.max_seq_length) elif args.task_type == 'sft': logger.info('Train model with sft task') train_dataset = load_sft_dataset(args, tokenizer) data_collator = SFTDataCollator(tokenizer, args.max_seq_length) else: logger.info('Train model with dpo task') train_dataset = load_dpo_dataset(args, tokenizer) data_collator = None # dpo if args.task_type == 'dpo': trainer = DPOTrainer( model, ref_model, args=training_args, beta=args.beta, train_dataset=train_dataset, data_collator=data_collator, tokenizer=tokenizer, peft_config=peft_config ) # pretrain or sft else: trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, tokenizer=tokenizer, data_collator=data_collator, ) return trainer def main(): # 进行一些配置和检查 args, training_args = setup_everything() # 加载各种组件 trainer = init_components(args, training_args) # 开始训练 logger.info("*** starting training ***") train_result = trainer.train() # 保存最好的checkpoint final_save_path = join(training_args.output_dir) trainer.save_model(final_save_path) # Saves the tokenizer too # 保存训练指标 metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() if __name__ == "__main__": main()