Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
889d3bfd
Unverified
Commit
889d3bfd
authored
Feb 20, 2020
by
srush
Committed by
GitHub
Feb 20, 2020
Browse files
default arg fix (#2937)
parent
ea8eba35
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
4 deletions
+10
-4
examples/ner/transformer_base.py
examples/ner/transformer_base.py
+10
-4
No files found.
examples/ner/transformer_base.py
View file @
889d3bfd
...
...
@@ -248,16 +248,22 @@ def generic_train(model, args):
filepath
=
args
.
output_dir
,
prefix
=
"checkpoint"
,
monitor
=
"val_loss"
,
mode
=
"min"
,
save_top_k
=
5
)
train
er
=
pl
.
Trainer
(
train
_params
=
dict
(
accumulate_grad_batches
=
args
.
gradient_accumulation_steps
,
gpus
=
args
.
n_gpu
,
max_epochs
=
args
.
num_train_epochs
,
use_amp
=
args
.
fp16
,
amp_level
=
args
.
fp16_opt_level
,
distributed_backend
=
"ddp"
,
gradient_clip_val
=
args
.
max_grad_norm
,
checkpoint_callback
=
checkpoint_callback
,
)
if
args
.
fp16
:
train_params
[
"use_amp"
]
=
args
.
fp16
train_params
[
"amp_level"
]
=
args
.
fp16_opt_level
if
args
.
n_gpu
>
1
:
train_params
[
"distributed_backend"
]
=
"ddp"
trainer
=
pl
.
Trainer
(
**
train_params
)
if
args
.
do_train
:
trainer
.
fit
(
model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment