import torch from torch.utils.data import Dataset, DataLoader from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from transformers import AdamW, get_scheduler from tqdm.auto import tqdm from rouge import Rouge import random import numpy as np import os import json from torch import nn def seed_everything(seed=1029): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True class LCSTS(Dataset): def __init__(self, data_file): self.data = self.load_data(data_file) def load_data(self, data_file): Data = {} with open(data_file, 'rt', encoding='utf-8') as f: for idx, line in enumerate(f): if idx >= max_dataset_size: break items = line.strip().split('!=!') assert len(items) == 2 Data[idx] = { 'title': items[0], 'content': items[1] } return Data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def collate_fn(batch_samples): batch_inputs, batch_targets = [], [] for sample in batch_samples: batch_inputs.append(sample['content']) batch_targets.append(sample['title']) batch_data = tokenizer( batch_inputs, padding=True, max_length=max_input_length, truncation=True, return_tensors="pt" ) with tokenizer.as_target_tokenizer(): labels = tokenizer( batch_targets, padding=True, max_length=max_target_length, truncation=True, return_tensors="pt" )["input_ids"] batch_data['decoder_input_ids'] = decoder_input_ids = model.module.prepare_decoder_input_ids_from_labels(labels) end_token_index = torch.where(labels == tokenizer.eos_token_id)[1] for idx, end_idx in enumerate(end_token_index): labels[idx][end_idx+1:] = -100 batch_data['labels'] = labels return batch_data if __name__=='__main__': os.environ["HIP_VISIBLE_DEVICES"] = "4,5" device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f'Using {device} device') seed_everything(5) max_dataset_size = 200000 max_input_length = 512 max_target_length = 32 batch_size = 16 learning_rate = 1e-5 epoch_num = 3 beam_size = 4 no_repeat_ngram_size = 2 test_data = LCSTS('/umt5/data/lcsts_tsv/data3.tsv') test_dataloader = DataLoader(test_data, batch_size=16, shuffle=False, collate_fn=collate_fn) model_checkpoint = "/umt5/utm5_base" trained_model_weights = '/umt5/saves/train_dtk_weights/epoch_1_valid_rouge_23.4347_model_dtk_weights.bin' tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) model = model.to(device) # 检查是否有多个 GPU 可用 if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") # 如果有多个 GPUs,使用 nn.DataParallel 包装模型 model = nn.DataParallel(model) model.load_state_dict(torch.load(trained_model_weights)) model.eval() rouge = Rouge() with torch.no_grad(): print('evaluating on test set...') sources, preds, labels = [], [], [] for batch_data in tqdm(test_dataloader): batch_data = {k: v.to(device) for k, v in batch_data.items()} # 将数据移动到设备上 generated_tokens = model.module.generate( batch_data["input_ids"], attention_mask=batch_data["attention_mask"], max_length=max_target_length, num_beams=beam_size, no_repeat_ngram_size=no_repeat_ngram_size, ).cpu().numpy() if isinstance(generated_tokens, tuple): generated_tokens = generated_tokens[0] label_tokens = batch_data["labels"].cpu().numpy() decoded_sources = tokenizer.batch_decode( batch_data["input_ids"].cpu().numpy(), skip_special_tokens=True, use_source_tokenizer=True ) decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id) decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True) sources += [source.strip() for source in decoded_sources] preds += [pred.strip() for pred in decoded_preds] labels += [label.strip() for label in decoded_labels] scores = rouge.get_scores( hyps=[' '.join(pred) for pred in preds], refs=[' '.join(label) for label in labels], avg=True ) rouges = {key: value['f'] * 100 for key, value in scores.items()} rouges['avg'] = np.mean(list(rouges.values())) print(f"Test Rouge1: {rouges['rouge-1']:>0.2f} Rouge2: {rouges['rouge-2']:>0.2f} RougeL: {rouges['rouge-l']:>0.2f}\n") results = [] print('saving predicted results...') for source, pred, label in zip(sources, preds, labels): results.append({ "document": source, "prediction": pred, "summarization": label }) with open('test_data_pred.json', 'wt', encoding='utf-8') as f: for exapmle_result in results: f.write(json.dumps(exapmle_result, ensure_ascii=False) + '\n')