Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
53155b52
Unverified
Commit
53155b52
authored
Mar 27, 2023
by
Joao Gante
Committed by
GitHub
Mar 27, 2023
Browse files
Trainer: move Seq2SeqTrainer imports under the typing guard (#22401)
parent
0e708178
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
15 deletions
+18
-15
src/transformers/trainer_seq2seq.py
src/transformers/trainer_seq2seq.py
+18
-15
No files found.
src/transformers/trainer_seq2seq.py
View file @
53155b52
...
@@ -14,39 +14,42 @@
...
@@ -14,39 +14,42 @@
from
copy
import
deepcopy
from
copy
import
deepcopy
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
.data.data_collator
import
DataCollator
from
.deepspeed
import
is_deepspeed_zero3_enabled
from
.deepspeed
import
is_deepspeed_zero3_enabled
from
.generation.configuration_utils
import
GenerationConfig
from
.generation.configuration_utils
import
GenerationConfig
from
.modeling_utils
import
PreTrainedModel
from
.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.trainer
import
Trainer
from
.trainer
import
Trainer
from
.trainer_callback
import
TrainerCallback
from
.trainer_utils
import
EvalPrediction
,
PredictionOutput
from
.training_args
import
TrainingArguments
from
.utils
import
logging
from
.utils
import
logging
if
TYPE_CHECKING
:
from
.data.data_collator
import
DataCollator
from
.modeling_utils
import
PreTrainedModel
from
.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.trainer_callback
import
TrainerCallback
from
.trainer_utils
import
EvalPrediction
,
PredictionOutput
from
.training_args
import
TrainingArguments
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
Seq2SeqTrainer
(
Trainer
):
class
Seq2SeqTrainer
(
Trainer
):
def
__init__
(
def
__init__
(
self
,
self
,
model
:
Union
[
PreTrainedModel
,
nn
.
Module
]
=
None
,
model
:
Union
[
"
PreTrainedModel
"
,
nn
.
Module
]
=
None
,
args
:
TrainingArguments
=
None
,
args
:
"
TrainingArguments
"
=
None
,
data_collator
:
Optional
[
DataCollator
]
=
None
,
data_collator
:
Optional
[
"
DataCollator
"
]
=
None
,
train_dataset
:
Optional
[
Dataset
]
=
None
,
train_dataset
:
Optional
[
Dataset
]
=
None
,
eval_dataset
:
Optional
[
Union
[
Dataset
,
Dict
[
str
,
Dataset
]]]
=
None
,
eval_dataset
:
Optional
[
Union
[
Dataset
,
Dict
[
str
,
Dataset
]]]
=
None
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
,
tokenizer
:
Optional
[
"
PreTrainedTokenizerBase
"
]
=
None
,
model_init
:
Optional
[
Callable
[[],
PreTrainedModel
]]
=
None
,
model_init
:
Optional
[
Callable
[[],
"
PreTrainedModel
"
]]
=
None
,
compute_metrics
:
Optional
[
Callable
[[
EvalPrediction
],
Dict
]]
=
None
,
compute_metrics
:
Optional
[
Callable
[[
"
EvalPrediction
"
],
Dict
]]
=
None
,
callbacks
:
Optional
[
List
[
TrainerCallback
]]
=
None
,
callbacks
:
Optional
[
List
[
"
TrainerCallback
"
]]
=
None
,
optimizers
:
Tuple
[
torch
.
optim
.
Optimizer
,
torch
.
optim
.
lr_scheduler
.
LambdaLR
]
=
(
None
,
None
),
optimizers
:
Tuple
[
torch
.
optim
.
Optimizer
,
torch
.
optim
.
lr_scheduler
.
LambdaLR
]
=
(
None
,
None
),
preprocess_logits_for_metrics
:
Optional
[
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
]]
=
None
,
preprocess_logits_for_metrics
:
Optional
[
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
]]
=
None
,
):
):
...
@@ -161,7 +164,7 @@ class Seq2SeqTrainer(Trainer):
...
@@ -161,7 +164,7 @@ class Seq2SeqTrainer(Trainer):
ignore_keys
:
Optional
[
List
[
str
]]
=
None
,
ignore_keys
:
Optional
[
List
[
str
]]
=
None
,
metric_key_prefix
:
str
=
"test"
,
metric_key_prefix
:
str
=
"test"
,
**
gen_kwargs
,
**
gen_kwargs
,
)
->
PredictionOutput
:
)
->
"
PredictionOutput
"
:
"""
"""
Run prediction and returns predictions and potential metrics.
Run prediction and returns predictions and potential metrics.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment