umt5_summary.py 2.36 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from torch import nn
import os

os.environ["HIP_VISIBLE_DEVICES"] = "4,5"

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

wanglch's avatar
wanglch committed
12
# 根据用户路径修改对应路径即可
wanglch's avatar
wanglch committed
13
model_checkpoint = "/umt5/umt5_base"
wanglch's avatar
wanglch committed
14
# 训练权重需要预训练后载入
wanglch's avatar
wanglch committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
trained_model_weights = '/umt5/saves/train_dtk_weights/epoch_1_valid_rouge_23.4347_model_dtk_weights.bin'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

# 检查是否有多个 GPU 可用
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # 如果有多个 GPUs,使用 nn.DataParallel 包装模型
    model = nn.DataParallel(model).to(device)

model.load_state_dict(torch.load(trained_model_weights))
model = model.to(device)

article_texts = [
"""
5年前的一场风电豪赌,让山东长星集团董事长朱玉国付出了沉重的代价,知情人士爆料称,朱玉国掌控的风电帝国已走到破产边缘。据了解,长星集团涉及多家银行贷款高达60余亿元,现滨州市、邹平县两级政府正在处理善后事宜。
""",
"""
央行今日将召集大型商业银行和股份制银行开会,以应对当前的债市风暴。消息人士表示,央行一方面旨在维稳银行间债券市场,另一方面很可能探讨以丙类户治理为重点的改革内容。此次债市风暴中,国家审计署扮演了至关重要的角色。
""",
"""
今年以来,多家券商都在“找婆家”。7月8日,齐鲁证券4亿股权在北京金融资产交易所挂牌转让,加上目前正在四大产权交易所挂牌转让的世纪证券、申银万国、云南证券等,至少4家券商股权亮相于各地产权交易所。
"""
]

input_ids = tokenizer(
    article_texts,
    padding=True, 
    return_tensors="pt",
    truncation=True,
    max_length=512
).to(device)

generated_tokens = model.module.generate(
    input_ids["input_ids"],
    attention_mask=input_ids["attention_mask"],
    max_length=32,
    no_repeat_ngram_size=2,
    num_beams=4
)
summarys = tokenizer.batch_decode(
    generated_tokens,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)

print('原文', article_texts)
print('umt5摘要结果:', summarys)