Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
donut_pytorch
Commits
1e66b652
Commit
1e66b652
authored
Jul 31, 2023
by
Minseo Kang
Browse files
feat: update pl version supports
parent
681d9aa3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
22 deletions
+51
-22
donut/model.py
donut/model.py
+1
-0
lightning_module.py
lightning_module.py
+18
-17
train.py
train.py
+32
-5
No files found.
donut/model.py
View file @
1e66b652
...
...
@@ -69,6 +69,7 @@ class SwinEncoder(nn.Module):
num_heads
=
[
4
,
8
,
16
,
32
],
num_classes
=
0
,
)
self
.
model
.
norm
=
None
# weight init with swin
if
not
name_or_path
:
...
...
lightning_module.py
View file @
1e66b652
...
...
@@ -44,6 +44,8 @@ class DonutModelPLModule(pl.LightningModule):
# encoder_layer=[2,2,14,2], decoder_layer=4, ...
)
)
self
.
pytorch_lightning_version_is_1
=
int
(
pl
.
__version__
[
0
])
<
2
self
.
num_of_loaders
=
len
(
self
.
config
.
dataset_name_or_paths
)
def
training_step
(
self
,
batch
,
batch_idx
):
image_tensors
,
decoder_input_ids
,
decoder_labels
=
list
(),
list
(),
list
()
...
...
@@ -56,9 +58,16 @@ class DonutModelPLModule(pl.LightningModule):
decoder_labels
=
torch
.
cat
(
decoder_labels
)
loss
=
self
.
model
(
image_tensors
,
decoder_input_ids
,
decoder_labels
)[
0
]
self
.
log_dict
({
"train_loss"
:
loss
},
sync_dist
=
True
)
if
not
self
.
pytorch_lightning_version_is_1
:
self
.
log
(
'loss'
,
loss
,
prog_bar
=
True
)
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
,
dataset_idx
=
0
):
def
on_validation_epoch_start
(
self
)
->
None
:
super
().
on_validation_epoch_start
()
self
.
validation_step_outputs
=
[[]
for
_
in
range
(
self
.
num_of_loaders
)]
return
def
validation_step
(
self
,
batch
,
batch_idx
,
dataloader_idx
=
0
):
image_tensors
,
decoder_input_ids
,
prompt_end_idxs
,
answers
=
batch
decoder_prompts
=
pad_sequence
(
[
input_id
[:
end_idx
+
1
]
for
input_id
,
end_idx
in
zip
(
decoder_input_ids
,
prompt_end_idxs
)],
...
...
@@ -84,17 +93,16 @@ class DonutModelPLModule(pl.LightningModule):
self
.
print
(
f
" Answer:
{
answer
}
"
)
self
.
print
(
f
" Normed ED:
{
scores
[
0
]
}
"
)
self
.
validation_step_outputs
[
dataloader_idx
].
append
(
scores
)
return
scores
def
validation_epoch_end
(
self
,
validation_step_outputs
):
num_of_loaders
=
len
(
self
.
config
.
dataset_name_or_paths
)
if
num_of_loaders
==
1
:
validation_step_outputs
=
[
validation_step_outputs
]
assert
len
(
validation_step_outputs
)
==
num_of_loaders
cnt
=
[
0
]
*
num_of_loaders
total_metric
=
[
0
]
*
num_of_loaders
val_metric
=
[
0
]
*
num_of_loaders
for
i
,
results
in
enumerate
(
validation_step_outputs
):
def
on_validation_epoch_end
(
self
):
assert
len
(
self
.
validation_step_outputs
)
==
self
.
num_of_loaders
cnt
=
[
0
]
*
self
.
num_of_loaders
total_metric
=
[
0
]
*
self
.
num_of_loaders
val_metric
=
[
0
]
*
self
.
num_of_loaders
for
i
,
results
in
enumerate
(
self
.
validation_step_outputs
):
for
scores
in
results
:
cnt
[
i
]
+=
len
(
scores
)
total_metric
[
i
]
+=
np
.
sum
(
scores
)
...
...
@@ -136,13 +144,6 @@ class DonutModelPLModule(pl.LightningModule):
return
LambdaLR
(
optimizer
,
lr_lambda
)
def
get_progress_bar_dict
(
self
):
items
=
super
().
get_progress_bar_dict
()
items
.
pop
(
"v_num"
,
None
)
items
[
"exp_name"
]
=
f
"
{
self
.
config
.
get
(
'exp_name'
,
''
)
}
"
items
[
"exp_version"
]
=
f
"
{
self
.
config
.
get
(
'exp_version'
,
''
)
}
"
return
items
@
rank_zero_only
def
on_save_checkpoint
(
self
,
checkpoint
):
save_path
=
Path
(
self
.
config
.
result_path
)
/
self
.
config
.
exp_name
/
self
.
config
.
exp_version
...
...
train.py
View file @
1e66b652
...
...
@@ -51,8 +51,34 @@ def save_config_file(config, path):
print
(
f
"Config is saved at
{
save_path
}
"
)
class
ProgressBar
(
pl
.
callbacks
.
TQDMProgressBar
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
enable
=
True
self
.
config
=
config
def
disable
(
self
):
self
.
enable
=
False
def
get_metrics
(
self
,
trainer
,
model
):
items
=
super
().
get_metrics
(
trainer
,
model
)
items
.
pop
(
"v_num"
,
None
)
items
[
"exp_name"
]
=
f
"
{
self
.
config
.
get
(
'exp_name'
,
''
)
}
"
items
[
"exp_version"
]
=
f
"
{
self
.
config
.
get
(
'exp_version'
,
''
)
}
"
return
items
def
set_seed
(
seed
):
pytorch_lightning_version
=
int
(
pl
.
__version__
[
0
])
if
pytorch_lightning_version
<
2
:
pl
.
utilities
.
seed
.
seed_everything
(
seed
,
workers
=
True
)
else
:
import
lightning_fabric
lightning_fabric
.
utilities
.
seed
.
seed_everything
(
seed
,
workers
=
True
)
def
train
(
config
):
pl
.
utilities
.
seed
.
seed_everything
(
config
.
get
(
"seed"
,
42
)
,
workers
=
True
)
set_seed
(
config
.
get
(
"seed"
,
42
))
model_module
=
DonutModelPLModule
(
config
)
data_module
=
DonutDataPLModule
(
config
)
...
...
@@ -111,11 +137,12 @@ def train(config):
mode
=
"min"
,
)
bar
=
ProgressBar
(
config
)
custom_ckpt
=
CustomCheckpointIO
()
trainer
=
pl
.
Trainer
(
resume_from_checkpoint
=
config
.
get
(
"resume_from_checkpoint_path"
,
None
),
num_nodes
=
config
.
get
(
"num_nodes"
,
1
),
gpu
s
=
torch
.
cuda
.
device_count
(),
device
s
=
torch
.
cuda
.
device_count
(),
strategy
=
"ddp"
,
accelerator
=
"gpu"
,
plugins
=
custom_ckpt
,
...
...
@@ -127,10 +154,10 @@ def train(config):
precision
=
16
,
num_sanity_val_steps
=
0
,
logger
=
logger
,
callbacks
=
[
lr_callback
,
checkpoint_callback
],
callbacks
=
[
lr_callback
,
checkpoint_callback
,
bar
],
)
trainer
.
fit
(
model_module
,
data_module
)
trainer
.
fit
(
model_module
,
data_module
,
ckpt_path
=
config
.
get
(
"resume_from_checkpoint_path"
,
None
)
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment