"docs/basic_usage/openai_api_vision.ipynb" did not exist on "908dd7f9aae52a9c961c836d99e46ba6681fee42"
umt5_summary.py 2.28 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from torch import nn
import os

os.environ["HIP_VISIBLE_DEVICES"] = "4,5"

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

model_checkpoint = "/umt5/umt5_base"
trained_model_weights = '/umt5/saves/train_dtk_weights/epoch_1_valid_rouge_23.4347_model_dtk_weights.bin'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

# 检查是否有多个 GPU 可用
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # 如果有多个 GPUs,使用 nn.DataParallel 包装模型
    model = nn.DataParallel(model).to(device)

model.load_state_dict(torch.load(trained_model_weights))
model = model.to(device)

article_texts = [
"""
5年前的一场风电豪赌,让山东长星集团董事长朱玉国付出了沉重的代价,知情人士爆料称,朱玉国掌控的风电帝国已走到破产边缘。据了解,长星集团涉及多家银行贷款高达60余亿元,现滨州市、邹平县两级政府正在处理善后事宜。
""",
"""
央行今日将召集大型商业银行和股份制银行开会,以应对当前的债市风暴。消息人士表示,央行一方面旨在维稳银行间债券市场,另一方面很可能探讨以丙类户治理为重点的改革内容。此次债市风暴中,国家审计署扮演了至关重要的角色。
""",
"""
今年以来,多家券商都在“找婆家”。7月8日,齐鲁证券4亿股权在北京金融资产交易所挂牌转让,加上目前正在四大产权交易所挂牌转让的世纪证券、申银万国、云南证券等,至少4家券商股权亮相于各地产权交易所。
"""
]

input_ids = tokenizer(
    article_texts,
    padding=True, 
    return_tensors="pt",
    truncation=True,
    max_length=512
).to(device)

generated_tokens = model.module.generate(
    input_ids["input_ids"],
    attention_mask=input_ids["attention_mask"],
    max_length=32,
    no_repeat_ngram_size=2,
    num_beams=4
)
summarys = tokenizer.batch_decode(
    generated_tokens,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)

print('原文', article_texts)
print('umt5摘要结果:', summarys)