import torch from transformers import AutoTokenizer from transformers import AutoModelForSeq2SeqLM from torch import nn import os os.environ["HIP_VISIBLE_DEVICES"] = "4,5" device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f'Using {device} device') # 根据用户路径修改对应路径即可 model_checkpoint = "/umt5/umt5_base" # 训练权重需要预训练后载入 trained_model_weights = '/umt5/saves/train_dtk_weights/epoch_1_valid_rouge_23.4347_model_dtk_weights.bin' tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) # 检查是否有多个 GPU 可用 if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") # 如果有多个 GPUs,使用 nn.DataParallel 包装模型 model = nn.DataParallel(model).to(device) model.load_state_dict(torch.load(trained_model_weights)) model = model.to(device) article_texts = [ """ 5年前的一场风电豪赌,让山东长星集团董事长朱玉国付出了沉重的代价,知情人士爆料称,朱玉国掌控的风电帝国已走到破产边缘。据了解,长星集团涉及多家银行贷款高达60余亿元,现滨州市、邹平县两级政府正在处理善后事宜。 """, """ 央行今日将召集大型商业银行和股份制银行开会,以应对当前的债市风暴。消息人士表示,央行一方面旨在维稳银行间债券市场,另一方面很可能探讨以丙类户治理为重点的改革内容。此次债市风暴中,国家审计署扮演了至关重要的角色。 """, """ 今年以来,多家券商都在“找婆家”。7月8日,齐鲁证券4亿股权在北京金融资产交易所挂牌转让,加上目前正在四大产权交易所挂牌转让的世纪证券、申银万国、云南证券等,至少4家券商股权亮相于各地产权交易所。 """ ] input_ids = tokenizer( article_texts, padding=True, return_tensors="pt", truncation=True, max_length=512 ).to(device) generated_tokens = model.module.generate( input_ids["input_ids"], attention_mask=input_ids["attention_mask"], max_length=32, no_repeat_ngram_size=2, num_beams=4 ) summarys = tokenizer.batch_decode( generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False ) print('原文', article_texts) print('umt5摘要结果:', summarys)