import os import json import torch import logging import argparse from tqdm import tqdm import torch.distributed as dist from torch.utils.data import Dataset, DataLoader import wandb from accelerate import Accelerator from transformers import set_seed, get_cosine_schedule_with_warmup import shutil import json import traceback from jinja2 import Template from transformers import AutoModelForCausalLM, AutoTokenizer os.umask(0) logger = logging.getLogger(__name__) logging.basicConfig(level='INFO') class Train_dataset(torch.utils.data.Dataset): def __init__(self, config, tokenizer): self.config = config self.tokenizer = tokenizer with open(config.data_path) as f: self.data = json.load(f) newdata = [] for da in self.data: newdata.append(da) print('过滤掉',len(self.data),len(newdata)) self.data = newdata self.max_seq_len = self.config.max_seq_len self.debug = 0 # 如果从Base LLMs训练,选择 llama3-instruct作为模版 chat_template_llama3 = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" if not tokenizer.chat_template: tokenizer.chat_template = chat_template_llama3 self.template = Template(tokenizer.chat_template) def __getitem__(self, index): return self.data[index] def get_response(self,da): temp = '## Thinking\n\n{}\n\n## Final Response\n\n{}' return temp.format(da['Complex_CoT'],da['Response']) def get_prompt(self,da): q = da['Question'] a = self.get_response(da) assert q is not None and a is not None, f'q:{q} a:{a}' input = self.template.render(messages=[{"role": "user", "content": q},{"role": "assistant", "content": a}],bos_token=self.tokenizer.bos_token,add_generation_prompt=False) input_ids = self.tokenizer.encode(input,add_special_tokens= False) query = self.template.render(messages=[{"role": "user", "content": q}],bos_token=self.tokenizer.bos_token,add_generation_prompt=True) query_ids = self.tokenizer.encode(query,add_special_tokens= False) labels = [-100]*len(query_ids) + input_ids[len(query_ids):] assert len(labels) == len(input_ids) return {"input_ids": input_ids[-self.max_seq_len:], "labels": labels[-self.max_seq_len:]} def collate_fn(self, batch): data = [ self.get_prompt(da) for da in batch] input_ids = [item["input_ids"] for item in data] labels = [item["labels"] for item in data] max_len = max(len(x) for x in input_ids) max_len = min(max_len,self.max_seq_len) input_ids = [ item[:max_len] + [self.tokenizer.eos_token_id]*(max_len-len(item)) for item in input_ids] labels = [ item[:max_len] + [-100]*(max_len-len(item)) for item in labels] if self.debug < 3: print('input_ids',self.tokenizer.decode(input_ids[-1])) print('labels',self.tokenizer.decode([0 if x == -100 else x for x in labels[-1]])) self.debug += 1 return { "input_ids": torch.LongTensor(input_ids), "labels": torch.LongTensor(labels), } def __len__(self): return len(self.data) class SFTMetric: def __init__(self, device): self.n_step = 0 self.right = torch.Tensor([0]).to(device=device) self.total = torch.Tensor([0]).to(device=device) self.total_loss = torch.Tensor([0]).to(device=device) self.world_size = dist.get_world_size() def __call__(self, logits, labels, loss): return self.update(logits, labels, loss) def update(self, logits, labels, loss): self.n_step += 1 with torch.no_grad(): shift_preds = logits[..., :-1, :].argmax(dim=-1) shift_labels = labels[..., 1:] self.right += (shift_preds == shift_labels).masked_fill(shift_labels.eq(-100), 0).sum().item() self.total += (shift_labels != -100).sum().item() self.total_loss += loss.item() def get_metric(self, reset=True): dist.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM) dist.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM) dist.all_reduce(self.total_loss, op=torch.distributed.ReduceOp.SUM) acc = (self.right / self.total).item() loss = self.total_loss.item() / (self.world_size * self.n_step) if reset: self.n_step = 0 self.right.fill_(0) self.total.fill_(0) self.total_loss.fill_(0) return acc, loss def train(args): accelerator = Accelerator(mixed_precision='bf16', gradient_accumulation_steps=args.gradient_accumulation_steps) if accelerator.is_main_process: wandb.init(project = args.experiment_name, config=args, dir=args.log_dir, mode="offline") accelerator.print(f'args:\n{args}') accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_bsz_per_gpu accelerator.state.deepspeed_plugin.deepspeed_config['train_batch_size'] = args.train_bsz_per_gpu*dist.get_world_size()*accelerator.gradient_accumulation_steps tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True) # open gradient checkpointing model.gradient_checkpointing_enable() no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) train_dataset = Train_dataset(args, tokenizer) train_dataloader = DataLoader(train_dataset, batch_size=args.train_bsz_per_gpu, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn) num_training_steps = int(len(train_dataloader) * (args.n_epochs)) // accelerator.gradient_accumulation_steps // dist.get_world_size() lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_rates * num_training_steps), num_training_steps=num_training_steps) accelerator.print(f'gradient_accumulation_steps:{accelerator.gradient_accumulation_steps} data_path:{args.data_path} lr:{args.learning_rate} num_training_steps:{num_training_steps}') model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader) start_epoch = 0 start_step = 0 global_step = 0 metric = SFTMetric(device=torch.cuda.current_device()) def save_checkpoint(epoch, step, global_step): save_dir = os.path.join(args.output_dir, f"checkpoint-{epoch}-{global_step}") if accelerator.is_main_process: checkpoint_files = os.listdir(args.output_dir) checkpoint_files = [file for file in checkpoint_files if file.startswith("checkpoint-")] num_checkpoints = len(checkpoint_files) if args.max_ckpts>0: if num_checkpoints >= args.max_ckpts: checkpoint_files.sort(key=lambda x: os.path.getctime(os.path.join(args.output_dir, x))) oldest_checkpoint = checkpoint_files[0] shutil.rmtree(os.path.join(args.output_dir, oldest_checkpoint)) os.makedirs(save_dir, exist_ok=True) output_dir = os.path.join(save_dir, 'tfmr') if accelerator.state.deepspeed_plugin.zero_stage!=3: model.save_pretrained(output_dir,state_dict=accelerator.get_state_dict(model)) tokenizer.save_pretrained(output_dir) copy_files = [] for item in os.listdir(args.model_path): if os.path.exists(os.path.join(output_dir,item)): continue if item.startswith("pytorch_model") and item.endswith(".bin"): continue if item.endswith(".index.json") or item.endswith(".safetensors"): continue s = os.path.join(args.model_path, item) if os.path.isfile(s): shutil.copy(s, os.path.join(output_dir,item)) copy_files.append(item) print(f'huggingface model save in {output_dir}, copy file:{copy_files}') if accelerator.state.deepspeed_plugin.zero_stage==3: unwrap_model = accelerator.unwrap_model(model) unwrap_model.save_pretrained(os.path.join(save_dir, f'tfmr'),is_main_process=accelerator.is_main_process,save_function=accelerator.save,state_dict=accelerator.get_state_dict(model)) accelerator.wait_for_everyone() accelerator.save({"epoch": epoch, "step": step, "global_step": global_step}, os.path.join(save_dir, "training_state.pt")) accelerator.print(f'checkpoint checkpoint-{epoch}-{global_step} is saved...') accelerator.print(accelerator.deepspeed_config) model.train() for epoch in range(start_epoch, args.n_epochs): train_dataloader_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader)) if accelerator.is_main_process else enumerate(train_dataloader) for batch_cnt, batch in train_dataloader_iterator: if epoch==start_epoch and batch_cnt