Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
donut_pytorch
Commits
7e451193
Unverified
Commit
7e451193
authored
Aug 04, 2022
by
Geewook Kim
Committed by
GitHub
Aug 04, 2022
Browse files
feat: add categorical special tokens (optional), related to #10
parent
dd12dae5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
0 deletions
+12
-0
train.py
train.py
+12
-0
No files found.
train.py
View file @
7e451193
...
...
@@ -61,6 +61,18 @@ def train(config):
datasets
=
{
"train"
:
[],
"validation"
:
[]}
for
i
,
dataset_name_or_path
in
enumerate
(
config
.
dataset_name_or_paths
):
task_name
=
os
.
path
.
basename
(
dataset_name_or_path
)
# e.g., cord-v2, docvqa, rvlcdip, ...
# add categorical special tokens (optional)
if
task_name
==
"rvlcdip"
:
model_module
.
model
.
decoder
.
add_special_tokens
([
"<advertisement/>"
,
"<budget/>"
,
"<email/>"
,
"<file_folder/>"
,
"<form/>"
,
"<handwritten/>"
,
"<invoice/>"
,
"<letter/>"
,
"<memo/>"
,
"<news_article/>"
,
"<presentation/>"
,
"<questionnaire/>"
,
"<resume/>"
,
"<scientific_publication/>"
,
"<scientific_report/>"
,
"<specification/>"
])
if
task_name
==
"docvqa"
:
model_module
.
model
.
decoder
.
add_special_tokens
([
"<yes/>"
,
"<no/>"
])
for
split
in
[
"train"
,
"validation"
]:
datasets
[
split
].
append
(
DonutDataset
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment