import torch
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from torch import nn
import os

os.environ["HIP_VISIBLE_DEVICES"] = "0"

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

# 根据用户路径修改对应路径即可
model_checkpoint = "/umt5/umt5_base"
# 训练权重需要预训练后载入
trained_model_weights = '/umt5/saves/train_dtk_weights/epoch_1_valid_rouge_23.4347_model_dtk_weights.bin'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

model.load_state_dict(torch.load(trained_model_weights))
model = model.to(device)

article_texts = [
"""
5年前的一场风电豪赌，让山东长星集团董事长朱玉国付出了沉重的代价，知情人士爆料称，朱玉国掌控的风电帝国已走到破产边缘。据了解，长星集团涉及多家银行贷款高达60余亿元，现滨州市、邹平县两级政府正在处理善后事宜。
""",
"""
央行今日将召集大型商业银行和股份制银行开会，以应对当前的债市风暴。消息人士表示，央行一方面旨在维稳银行间债券市场，另一方面很可能探讨以丙类户治理为重点的改革内容。此次债市风暴中，国家审计署扮演了至关重要的角色。
""",
"""
今年以来，多家券商都在“找婆家”。7月8日，齐鲁证券4亿股权在北京金融资产交易所挂牌转让，加上目前正在四大产权交易所挂牌转让的世纪证券、申银万国、云南证券等，至少4家券商股权亮相于各地产权交易所。
"""
]

input_ids = tokenizer(
    article_texts,
    padding=True, 
    return_tensors="pt",
    truncation=True,
    max_length=512
).to(device)

generated_tokens = model.module.generate(
    input_ids["input_ids"],
    attention_mask=input_ids["attention_mask"],
    max_length=32,
    no_repeat_ngram_size=2,
    num_beams=4
)
summarys = tokenizer.batch_decode(
    generated_tokens,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)

print('原文', article_texts)
print('umt5摘要结果:', summarys)
