Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
UMT5_pytorch
Commits
512350b9
Commit
512350b9
authored
Aug 22, 2024
by
wanglch
Browse files
Delete multi_dcu_test.py
parent
2950b694
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
160 deletions
+0
-160
multi_dcu_test.py
multi_dcu_test.py
+0
-160
No files found.
multi_dcu_test.py
deleted
100644 → 0
View file @
2950b694
import
torch
from
torch.utils.data
import
Dataset
,
DataLoader
from
transformers
import
AutoTokenizer
,
AutoModelForSeq2SeqLM
from
transformers
import
AdamW
,
get_scheduler
from
tqdm.auto
import
tqdm
from
rouge
import
Rouge
import
random
import
numpy
as
np
import
os
import
json
from
torch
import
nn
def
seed_everything
(
seed
=
1029
):
random
.
seed
(
seed
)
os
.
environ
[
'PYTHONHASHSEED'
]
=
str
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
class
LCSTS
(
Dataset
):
def
__init__
(
self
,
data_file
):
self
.
data
=
self
.
load_data
(
data_file
)
def
load_data
(
self
,
data_file
):
Data
=
{}
with
open
(
data_file
,
'rt'
,
encoding
=
'utf-8'
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
if
idx
>=
max_dataset_size
:
break
items
=
line
.
strip
().
split
(
'!=!'
)
assert
len
(
items
)
==
2
Data
[
idx
]
=
{
'title'
:
items
[
0
],
'content'
:
items
[
1
]
}
return
Data
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
idx
):
return
self
.
data
[
idx
]
def
collate_fn
(
batch_samples
):
batch_inputs
,
batch_targets
=
[],
[]
for
sample
in
batch_samples
:
batch_inputs
.
append
(
sample
[
'content'
])
batch_targets
.
append
(
sample
[
'title'
])
batch_data
=
tokenizer
(
batch_inputs
,
padding
=
True
,
max_length
=
max_input_length
,
truncation
=
True
,
return_tensors
=
"pt"
)
with
tokenizer
.
as_target_tokenizer
():
labels
=
tokenizer
(
batch_targets
,
padding
=
True
,
max_length
=
max_target_length
,
truncation
=
True
,
return_tensors
=
"pt"
)[
"input_ids"
]
batch_data
[
'decoder_input_ids'
]
=
decoder_input_ids
=
model
.
module
.
prepare_decoder_input_ids_from_labels
(
labels
)
end_token_index
=
torch
.
where
(
labels
==
tokenizer
.
eos_token_id
)[
1
]
for
idx
,
end_idx
in
enumerate
(
end_token_index
):
labels
[
idx
][
end_idx
+
1
:]
=
-
100
batch_data
[
'labels'
]
=
labels
return
batch_data
if
__name__
==
'__main__'
:
os
.
environ
[
"HIP_VISIBLE_DEVICES"
]
=
"4,5"
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
print
(
f
'Using
{
device
}
device'
)
seed_everything
(
5
)
max_dataset_size
=
200000
max_input_length
=
512
max_target_length
=
32
batch_size
=
16
learning_rate
=
1e-5
epoch_num
=
3
beam_size
=
4
no_repeat_ngram_size
=
2
test_data
=
LCSTS
(
'/umt5/data/lcsts_tsv/data3.tsv'
)
test_dataloader
=
DataLoader
(
test_data
,
batch_size
=
16
,
shuffle
=
False
,
collate_fn
=
collate_fn
)
model_checkpoint
=
"/umt5/utm5_base"
trained_model_weights
=
'/umt5/saves/train_dtk_weights/epoch_1_valid_rouge_23.4347_model_dtk_weights.bin'
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_checkpoint
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_checkpoint
)
model
=
model
.
to
(
device
)
# 检查是否有多个 GPU 可用
if
torch
.
cuda
.
device_count
()
>
1
:
print
(
"Let's use"
,
torch
.
cuda
.
device_count
(),
"GPUs!"
)
# 如果有多个 GPUs,使用 nn.DataParallel 包装模型
model
=
nn
.
DataParallel
(
model
)
model
.
load_state_dict
(
torch
.
load
(
trained_model_weights
))
model
.
eval
()
rouge
=
Rouge
()
with
torch
.
no_grad
():
print
(
'evaluating on test set...'
)
sources
,
preds
,
labels
=
[],
[],
[]
for
batch_data
in
tqdm
(
test_dataloader
):
batch_data
=
{
k
:
v
.
to
(
device
)
for
k
,
v
in
batch_data
.
items
()}
# 将数据移动到设备上
generated_tokens
=
model
.
module
.
generate
(
batch_data
[
"input_ids"
],
attention_mask
=
batch_data
[
"attention_mask"
],
max_length
=
max_target_length
,
num_beams
=
beam_size
,
no_repeat_ngram_size
=
no_repeat_ngram_size
,
).
cpu
().
numpy
()
if
isinstance
(
generated_tokens
,
tuple
):
generated_tokens
=
generated_tokens
[
0
]
label_tokens
=
batch_data
[
"labels"
].
cpu
().
numpy
()
decoded_sources
=
tokenizer
.
batch_decode
(
batch_data
[
"input_ids"
].
cpu
().
numpy
(),
skip_special_tokens
=
True
,
use_source_tokenizer
=
True
)
decoded_preds
=
tokenizer
.
batch_decode
(
generated_tokens
,
skip_special_tokens
=
True
)
label_tokens
=
np
.
where
(
label_tokens
!=
-
100
,
label_tokens
,
tokenizer
.
pad_token_id
)
decoded_labels
=
tokenizer
.
batch_decode
(
label_tokens
,
skip_special_tokens
=
True
)
sources
+=
[
source
.
strip
()
for
source
in
decoded_sources
]
preds
+=
[
pred
.
strip
()
for
pred
in
decoded_preds
]
labels
+=
[
label
.
strip
()
for
label
in
decoded_labels
]
scores
=
rouge
.
get_scores
(
hyps
=
[
' '
.
join
(
pred
)
for
pred
in
preds
],
refs
=
[
' '
.
join
(
label
)
for
label
in
labels
],
avg
=
True
)
rouges
=
{
key
:
value
[
'f'
]
*
100
for
key
,
value
in
scores
.
items
()}
rouges
[
'avg'
]
=
np
.
mean
(
list
(
rouges
.
values
()))
print
(
f
"Test Rouge1:
{
rouges
[
'rouge-1'
]:
>
0.2
f
}
Rouge2:
{
rouges
[
'rouge-2'
]:
>
0.2
f
}
RougeL:
{
rouges
[
'rouge-l'
]:
>
0.2
f
}
\n
"
)
results
=
[]
print
(
'saving predicted results...'
)
for
source
,
pred
,
label
in
zip
(
sources
,
preds
,
labels
):
results
.
append
({
"document"
:
source
,
"prediction"
:
pred
,
"summarization"
:
label
})
with
open
(
'test_data_pred.json'
,
'wt'
,
encoding
=
'utf-8'
)
as
f
:
for
exapmle_result
in
results
:
f
.
write
(
json
.
dumps
(
exapmle_result
,
ensure_ascii
=
False
)
+
'
\n
'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment