Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
donut_pytorch
Commits
3e809211
Commit
3e809211
authored
Sep 19, 2022
by
SamSamhuns
Browse files
Change model_kwarg to encoder_outputs
parent
392ed80c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
donut/model.py
donut/model.py
+2
-2
No files found.
donut/model.py
View file @
3e809211
...
@@ -206,7 +206,7 @@ class BARTDecoder(nn.Module):
...
@@ -206,7 +206,7 @@ class BARTDecoder(nn.Module):
if
newly_added_num
>
0
:
if
newly_added_num
>
0
:
self
.
model
.
resize_token_embeddings
(
len
(
self
.
tokenizer
))
self
.
model
.
resize_token_embeddings
(
len
(
self
.
tokenizer
))
def
prepare_inputs_for_inference
(
self
,
input_ids
:
torch
.
Tensor
,
past
=
None
,
use_cache
:
bool
=
None
,
**
model_kwargs
):
def
prepare_inputs_for_inference
(
self
,
input_ids
:
torch
.
Tensor
,
past
=
None
,
use_cache
:
bool
=
None
,
encoder_outputs
:
torch
.
Tensor
=
None
):
"""
"""
Args:
Args:
input_ids: (batch_size, sequence_lenth)
input_ids: (batch_size, sequence_lenth)
...
@@ -223,7 +223,7 @@ class BARTDecoder(nn.Module):
...
@@ -223,7 +223,7 @@ class BARTDecoder(nn.Module):
"attention_mask"
:
attention_mask
,
"attention_mask"
:
attention_mask
,
"past_key_values"
:
past
,
"past_key_values"
:
past
,
"use_cache"
:
use_cache
,
"use_cache"
:
use_cache
,
"encoder_hidden_states"
:
model_kwargs
[
"
encoder_outputs
"
]
.
last_hidden_state
,
"encoder_hidden_states"
:
encoder_outputs
.
last_hidden_state
,
}
}
return
output
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment