Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
908fa43b
Unverified
Commit
908fa43b
authored
Feb 27, 2020
by
srush
Committed by
GitHub
Feb 27, 2020
Browse files
Changes to NER examples for PLT and TPU (#3053)
* changes to allow for tpu training * black * tpu * tpu
parent
8bcb37bf
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
104 additions
and
95 deletions
+104
-95
examples/ner/run_pl.sh
examples/ner/run_pl.sh
+15
-1
examples/ner/run_pl_ner.py
examples/ner/run_pl_ner.py
+54
-70
examples/ner/transformer_base.py
examples/ner/transformer_base.py
+35
-24
No files found.
examples/ner/run_pl.sh
View file @
908fa43b
# Require pytorch-lightning=0.6
# Install newest ptl.
pip
install
-U
git+http://github.com/PyTorchLightning/pytorch-lightning/
curl
-L
'https://sites.google.com/site/germeval2014ner/data/NER-de-train.tsv?attredirects=0&d=1'
\
|
grep
-v
"^#"
|
cut
-f
2,3 |
tr
'\t'
' '
>
train.txt.tmp
curl
-L
'https://sites.google.com/site/germeval2014ner/data/NER-de-dev.tsv?attredirects=0&d=1'
\
|
grep
-v
"^#"
|
cut
-f
2,3 |
tr
'\t'
' '
>
dev.txt.tmp
curl
-L
'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attredirects=0&d=1'
\
|
grep
-v
"^#"
|
cut
-f
2,3 |
tr
'\t'
' '
>
test.txt.tmp
wget
"https://raw.githubusercontent.com/stefan-it/fine-tuned-berts-seq/master/scripts/preprocess.py"
export
MAX_LENGTH
=
128
export
MAX_LENGTH
=
128
export
BERT_MODEL
=
bert-base-multilingual-cased
export
BERT_MODEL
=
bert-base-multilingual-cased
python3 preprocess.py train.txt.tmp
$BERT_MODEL
$MAX_LENGTH
>
train.txt
python3 preprocess.py dev.txt.tmp
$BERT_MODEL
$MAX_LENGTH
>
dev.txt
python3 preprocess.py test.txt.tmp
$BERT_MODEL
$MAX_LENGTH
>
test.txt
cat
train.txt dev.txt test.txt |
cut
-d
" "
-f
2 |
grep
-v
"^$"
|
sort
|
uniq
>
labels.txt
export
OUTPUT_DIR
=
germeval-model
export
OUTPUT_DIR
=
germeval-model
export
BATCH_SIZE
=
32
export
BATCH_SIZE
=
32
export
NUM_EPOCHS
=
3
export
NUM_EPOCHS
=
3
...
...
examples/ner/run_pl_ner.py
View file @
908fa43b
...
@@ -7,8 +7,7 @@ import numpy as np
...
@@ -7,8 +7,7 @@ import numpy as np
import
torch
import
torch
from
seqeval.metrics
import
f1_score
,
precision_score
,
recall_score
from
seqeval.metrics
import
f1_score
,
precision_score
,
recall_score
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
from
torch.utils.data
import
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
from
torch.utils.data
import
DataLoader
,
TensorDataset
from
torch.utils.data.distributed
import
DistributedSampler
from
transformer_base
import
BaseTransformer
,
add_generic_args
,
generic_train
from
transformer_base
import
BaseTransformer
,
add_generic_args
,
generic_train
from
utils_ner
import
convert_examples_to_features
,
get_labels
,
read_examples_from_file
from
utils_ner
import
convert_examples_to_features
,
get_labels
,
read_examples_from_file
...
@@ -25,13 +24,14 @@ class NERTransformer(BaseTransformer):
...
@@ -25,13 +24,14 @@ class NERTransformer(BaseTransformer):
def
__init__
(
self
,
hparams
):
def
__init__
(
self
,
hparams
):
self
.
labels
=
get_labels
(
hparams
.
labels
)
self
.
labels
=
get_labels
(
hparams
.
labels
)
num_labels
=
len
(
self
.
labels
)
num_labels
=
len
(
self
.
labels
)
self
.
pad_token_label_id
=
CrossEntropyLoss
().
ignore_index
super
(
NERTransformer
,
self
).
__init__
(
hparams
,
num_labels
)
super
(
NERTransformer
,
self
).
__init__
(
hparams
,
num_labels
)
def
forward
(
self
,
**
inputs
):
def
forward
(
self
,
**
inputs
):
return
self
.
model
(
**
inputs
)
return
self
.
model
(
**
inputs
)
def
training_step
(
self
,
batch
,
batch_num
):
def
training_step
(
self
,
batch
,
batch_num
):
"Compute loss"
"Compute loss
and log.
"
inputs
=
{
"input_ids"
:
batch
[
0
],
"attention_mask"
:
batch
[
1
],
"labels"
:
batch
[
3
]}
inputs
=
{
"input_ids"
:
batch
[
0
],
"attention_mask"
:
batch
[
1
],
"labels"
:
batch
[
3
]}
if
self
.
hparams
.
model_type
!=
"distilbert"
:
if
self
.
hparams
.
model_type
!=
"distilbert"
:
inputs
[
"token_type_ids"
]
=
(
inputs
[
"token_type_ids"
]
=
(
...
@@ -40,25 +40,61 @@ class NERTransformer(BaseTransformer):
...
@@ -40,25 +40,61 @@ class NERTransformer(BaseTransformer):
outputs
=
self
.
forward
(
**
inputs
)
outputs
=
self
.
forward
(
**
inputs
)
loss
=
outputs
[
0
]
loss
=
outputs
[
0
]
tensorboard_logs
=
{
"loss"
:
loss
,
"rate"
:
self
.
lr_scheduler
.
get_last_lr
()[
-
1
]}
tensorboard_logs
=
{
"loss"
:
loss
,
"rate"
:
self
.
lr_scheduler
.
get_last_lr
()[
-
1
]}
return
{
"loss"
:
loss
,
"log"
:
tensorboard_logs
}
return
{
"loss"
:
loss
,
"log"
:
tensorboard_logs
}
def
_feature_file
(
self
,
mode
):
return
os
.
path
.
join
(
self
.
hparams
.
data_dir
,
"cached_{}_{}_{}"
.
format
(
mode
,
list
(
filter
(
None
,
self
.
hparams
.
model_name_or_path
.
split
(
"/"
))).
pop
(),
str
(
self
.
hparams
.
max_seq_length
),
),
)
def
prepare_data
(
self
):
"Called to initialize data. Use the call to construct features"
args
=
self
.
hparams
for
mode
in
[
"train"
,
"dev"
,
"test"
]:
cached_features_file
=
self
.
_feature_file
(
mode
)
if
not
os
.
path
.
exists
(
cached_features_file
):
logger
.
info
(
"Creating features from dataset file at %s"
,
args
.
data_dir
)
examples
=
read_examples_from_file
(
args
.
data_dir
,
mode
)
features
=
convert_examples_to_features
(
examples
,
self
.
labels
,
args
.
max_seq_length
,
self
.
tokenizer
,
cls_token_at_end
=
bool
(
args
.
model_type
in
[
"xlnet"
]),
cls_token
=
self
.
tokenizer
.
cls_token
,
cls_token_segment_id
=
2
if
args
.
model_type
in
[
"xlnet"
]
else
0
,
sep_token
=
self
.
tokenizer
.
sep_token
,
sep_token_extra
=
bool
(
args
.
model_type
in
[
"roberta"
]),
pad_on_left
=
bool
(
args
.
model_type
in
[
"xlnet"
]),
pad_token
=
self
.
tokenizer
.
convert_tokens_to_ids
([
self
.
tokenizer
.
pad_token
])[
0
],
pad_token_segment_id
=
4
if
args
.
model_type
in
[
"xlnet"
]
else
0
,
pad_token_label_id
=
self
.
pad_token_label_id
,
)
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
torch
.
save
(
features
,
cached_features_file
)
def
load_dataset
(
self
,
mode
,
batch_size
):
def
load_dataset
(
self
,
mode
,
batch_size
):
labels
=
get_labels
(
self
.
hparams
.
labels
)
"Load datasets. Called after prepare data."
self
.
pad_token_label_id
=
CrossEntropyLoss
().
ignore_index
cached_features_file
=
self
.
_feature_file
(
mode
)
dataset
=
self
.
load_and_cache_examples
(
labels
,
self
.
pad_token_label_id
,
mode
)
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
if
mode
==
"train"
:
features
=
torch
.
load
(
cached_features_file
)
if
self
.
hparams
.
n_gpu
>
1
:
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
features
],
dtype
=
torch
.
long
)
sampler
=
DistributedSampler
(
dataset
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
features
],
dtype
=
torch
.
long
)
else
:
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
features
],
dtype
=
torch
.
long
)
sampler
=
RandomSampler
(
dataset
)
all_label_ids
=
torch
.
tensor
([
f
.
label_ids
for
f
in
features
],
dtype
=
torch
.
long
)
else
:
return
DataLoader
(
sampler
=
SequentialSampler
(
dataset
)
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
),
batch_size
=
batch_size
dataloader
=
DataLoader
(
dataset
,
sampler
=
sampler
,
batch_size
=
batch_size
)
)
return
dataloader
def
validation_step
(
self
,
batch
,
batch_nb
):
def
validation_step
(
self
,
batch
,
batch_nb
):
"Compute validation"
inputs
=
{
"input_ids"
:
batch
[
0
],
"attention_mask"
:
batch
[
1
],
"labels"
:
batch
[
3
]}
inputs
=
{
"input_ids"
:
batch
[
0
],
"attention_mask"
:
batch
[
1
],
"labels"
:
batch
[
3
]}
if
self
.
hparams
.
model_type
!=
"distilbert"
:
if
self
.
hparams
.
model_type
!=
"distilbert"
:
inputs
[
"token_type_ids"
]
=
(
inputs
[
"token_type_ids"
]
=
(
...
@@ -68,11 +104,10 @@ class NERTransformer(BaseTransformer):
...
@@ -68,11 +104,10 @@ class NERTransformer(BaseTransformer):
tmp_eval_loss
,
logits
=
outputs
[:
2
]
tmp_eval_loss
,
logits
=
outputs
[:
2
]
preds
=
logits
.
detach
().
cpu
().
numpy
()
preds
=
logits
.
detach
().
cpu
().
numpy
()
out_label_ids
=
inputs
[
"labels"
].
detach
().
cpu
().
numpy
()
out_label_ids
=
inputs
[
"labels"
].
detach
().
cpu
().
numpy
()
return
{
"val_loss"
:
tmp_eval_loss
.
detach
().
cpu
(),
"pred"
:
preds
,
"target"
:
out_label_ids
}
return
{
"val_loss"
:
tmp_eval_loss
,
"pred"
:
preds
,
"target"
:
out_label_ids
}
def
_eval_end
(
self
,
outputs
):
def
_eval_end
(
self
,
outputs
):
"
Task specific validation
"
"
Evaluation called for both Val and Test
"
val_loss_mean
=
torch
.
stack
([
x
[
"val_loss"
]
for
x
in
outputs
]).
mean
()
val_loss_mean
=
torch
.
stack
([
x
[
"val_loss"
]
for
x
in
outputs
]).
mean
()
preds
=
np
.
concatenate
([
x
[
"pred"
]
for
x
in
outputs
],
axis
=
0
)
preds
=
np
.
concatenate
([
x
[
"pred"
]
for
x
in
outputs
],
axis
=
0
)
preds
=
np
.
argmax
(
preds
,
axis
=
2
)
preds
=
np
.
argmax
(
preds
,
axis
=
2
)
...
@@ -96,7 +131,6 @@ class NERTransformer(BaseTransformer):
...
@@ -96,7 +131,6 @@ class NERTransformer(BaseTransformer):
}
}
if
self
.
is_logger
():
if
self
.
is_logger
():
logger
.
info
(
self
.
proc_rank
)
logger
.
info
(
"***** Eval results *****"
)
logger
.
info
(
"***** Eval results *****"
)
for
key
in
sorted
(
results
.
keys
()):
for
key
in
sorted
(
results
.
keys
()):
logger
.
info
(
" %s = %s"
,
key
,
str
(
results
[
key
]))
logger
.
info
(
" %s = %s"
,
key
,
str
(
results
[
key
]))
...
@@ -140,56 +174,6 @@ class NERTransformer(BaseTransformer):
...
@@ -140,56 +174,6 @@ class NERTransformer(BaseTransformer):
)
)
return
ret
return
ret
def
load_and_cache_examples
(
self
,
labels
,
pad_token_label_id
,
mode
):
args
=
self
.
hparams
tokenizer
=
self
.
tokenizer
if
self
.
proc_rank
not
in
[
-
1
,
0
]
and
mode
==
"train"
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Load data features from cache or dataset file
cached_features_file
=
os
.
path
.
join
(
args
.
data_dir
,
"cached_{}_{}_{}"
.
format
(
mode
,
list
(
filter
(
None
,
args
.
model_name_or_path
.
split
(
"/"
))).
pop
(),
str
(
args
.
max_seq_length
)
),
)
if
os
.
path
.
exists
(
cached_features_file
)
and
not
args
.
overwrite_cache
:
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
features
=
torch
.
load
(
cached_features_file
)
else
:
logger
.
info
(
"Creating features from dataset file at %s"
,
args
.
data_dir
)
examples
=
read_examples_from_file
(
args
.
data_dir
,
mode
)
features
=
convert_examples_to_features
(
examples
,
labels
,
args
.
max_seq_length
,
tokenizer
,
cls_token_at_end
=
bool
(
args
.
model_type
in
[
"xlnet"
]),
cls_token
=
tokenizer
.
cls_token
,
cls_token_segment_id
=
2
if
args
.
model_type
in
[
"xlnet"
]
else
0
,
sep_token
=
tokenizer
.
sep_token
,
sep_token_extra
=
bool
(
args
.
model_type
in
[
"roberta"
]),
pad_on_left
=
bool
(
args
.
model_type
in
[
"xlnet"
]),
pad_token
=
tokenizer
.
convert_tokens_to_ids
([
tokenizer
.
pad_token
])[
0
],
pad_token_segment_id
=
4
if
args
.
model_type
in
[
"xlnet"
]
else
0
,
pad_token_label_id
=
pad_token_label_id
,
)
if
self
.
proc_rank
in
[
-
1
,
0
]:
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
torch
.
save
(
features
,
cached_features_file
)
if
self
.
proc_rank
==
0
and
mode
==
"train"
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Convert to Tensors and build dataset
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
features
],
dtype
=
torch
.
long
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_label_ids
=
torch
.
tensor
([
f
.
label_ids
for
f
in
features
],
dtype
=
torch
.
long
)
dataset
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
return
dataset
@
staticmethod
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
def
add_model_specific_args
(
parser
,
root_dir
):
# Add NER specific options
# Add NER specific options
...
...
examples/ner/transformer_base.py
View file @
908fa43b
import
logging
import
os
import
os
import
random
import
random
...
@@ -26,6 +27,9 @@ from transformers import (
...
@@ -26,6 +27,9 @@ from transformers import (
)
)
logger
=
logging
.
getLogger
(
__name__
)
ALL_MODELS
=
sum
(
ALL_MODELS
=
sum
(
(
(
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
...
@@ -77,20 +81,14 @@ class BaseTransformer(pl.LightningModule):
...
@@ -77,20 +81,14 @@ class BaseTransformer(pl.LightningModule):
cache_dir
=
self
.
hparams
.
cache_dir
if
self
.
hparams
.
cache_dir
else
None
,
cache_dir
=
self
.
hparams
.
cache_dir
if
self
.
hparams
.
cache_dir
else
None
,
)
)
self
.
config
,
self
.
tokenizer
,
self
.
model
=
config
,
tokenizer
,
model
self
.
config
,
self
.
tokenizer
,
self
.
model
=
config
,
tokenizer
,
model
self
.
proc_rank
=
-
1
def
is_logger
(
self
):
def
is_logger
(
self
):
return
self
.
proc_rank
<=
0
return
self
.
trainer
.
proc_rank
<=
0
def
configure_optimizers
(
self
):
def
configure_optimizers
(
self
):
"Prepare optimizer and schedule (linear warmup and decay)"
"Prepare optimizer and schedule (linear warmup and decay)"
model
=
self
.
model
t_total
=
(
model
=
self
.
model
len
(
self
.
train_dataloader
())
//
self
.
hparams
.
gradient_accumulation_steps
*
float
(
self
.
hparams
.
num_train_epochs
)
)
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
optimizer_grouped_parameters
=
[
optimizer_grouped_parameters
=
[
{
{
...
@@ -103,18 +101,16 @@ class BaseTransformer(pl.LightningModule):
...
@@ -103,18 +101,16 @@ class BaseTransformer(pl.LightningModule):
},
},
]
]
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
self
.
hparams
.
learning_rate
,
eps
=
self
.
hparams
.
adam_epsilon
)
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
self
.
hparams
.
learning_rate
,
eps
=
self
.
hparams
.
adam_epsilon
)
scheduler
=
get_linear_schedule_with_warmup
(
self
.
opt
=
optimizer
optimizer
,
num_warmup_steps
=
self
.
hparams
.
warmup_steps
,
num_training_steps
=
t_total
)
self
.
lr_scheduler
=
scheduler
return
[
optimizer
]
return
[
optimizer
]
def
optimizer_step
(
self
,
epoch
,
batch_idx
,
optimizer
,
optimizer_idx
,
second_order_closure
=
None
):
def
optimizer_step
(
self
,
epoch
,
batch_idx
,
optimizer
,
optimizer_idx
,
second_order_closure
=
None
):
if
self
.
trainer
.
use_tpu
:
# Step each time.
xm
.
optimizer_step
(
optimizer
)
else
:
optimizer
.
step
()
optimizer
.
step
()
self
.
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
self
.
lr_scheduler
.
step
()
def
get_tqdm_dict
(
self
):
def
get_tqdm_dict
(
self
):
tqdm_dict
=
{
"loss"
:
"{:.3f}"
.
format
(
self
.
trainer
.
avg_loss
),
"lr"
:
self
.
lr_scheduler
.
get_last_lr
()[
-
1
]}
tqdm_dict
=
{
"loss"
:
"{:.3f}"
.
format
(
self
.
trainer
.
avg_loss
),
"lr"
:
self
.
lr_scheduler
.
get_last_lr
()[
-
1
]}
...
@@ -127,22 +123,27 @@ class BaseTransformer(pl.LightningModule):
...
@@ -127,22 +123,27 @@ class BaseTransformer(pl.LightningModule):
def
test_end
(
self
,
outputs
):
def
test_end
(
self
,
outputs
):
return
self
.
validation_end
(
outputs
)
return
self
.
validation_end
(
outputs
)
@
pl
.
data_loader
def
train_dataloader
(
self
):
def
train_dataloader
(
self
):
return
self
.
load_dataset
(
"train"
,
self
.
hparams
.
train_batch_size
)
train_batch_size
=
self
.
hparams
.
train_batch_size
dataloader
=
self
.
load_dataset
(
"train"
,
train_batch_size
)
t_total
=
(
(
len
(
dataloader
.
dataset
)
//
(
train_batch_size
*
max
(
1
,
self
.
hparams
.
n_gpu
)))
//
self
.
hparams
.
gradient_accumulation_steps
*
float
(
self
.
hparams
.
num_train_epochs
)
)
scheduler
=
get_linear_schedule_with_warmup
(
self
.
opt
,
num_warmup_steps
=
self
.
hparams
.
warmup_steps
,
num_training_steps
=
t_total
)
self
.
lr_scheduler
=
scheduler
return
dataloader
@
pl
.
data_loader
def
val_dataloader
(
self
):
def
val_dataloader
(
self
):
return
self
.
load_dataset
(
"dev"
,
self
.
hparams
.
eval_batch_size
)
return
self
.
load_dataset
(
"dev"
,
self
.
hparams
.
eval_batch_size
)
@
pl
.
data_loader
def
test_dataloader
(
self
):
def
test_dataloader
(
self
):
return
self
.
load_dataset
(
"test"
,
self
.
hparams
.
eval_batch_size
)
return
self
.
load_dataset
(
"test"
,
self
.
hparams
.
eval_batch_size
)
def
init_ddp_connection
(
self
,
proc_rank
,
world_size
):
self
.
proc_rank
=
proc_rank
super
(
BaseTransformer
,
self
).
init_ddp_connection
(
proc_rank
,
world_size
)
@
staticmethod
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
def
add_model_specific_args
(
parser
,
root_dir
):
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -213,6 +214,7 @@ def add_generic_args(parser, root_dir):
...
@@ -213,6 +214,7 @@ def add_generic_args(parser, root_dir):
)
)
parser
.
add_argument
(
"--n_gpu"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--n_gpu"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--n_tpu_cores"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--do_train"
,
action
=
"store_true"
,
help
=
"Whether to run training."
)
parser
.
add_argument
(
"--do_train"
,
action
=
"store_true"
,
help
=
"Whether to run training."
)
parser
.
add_argument
(
"--do_predict"
,
action
=
"store_true"
,
help
=
"Whether to run predictions on the test set."
)
parser
.
add_argument
(
"--do_predict"
,
action
=
"store_true"
,
help
=
"Whether to run predictions on the test set."
)
...
@@ -252,13 +254,22 @@ def generic_train(model, args):
...
@@ -252,13 +254,22 @@ def generic_train(model, args):
accumulate_grad_batches
=
args
.
gradient_accumulation_steps
,
accumulate_grad_batches
=
args
.
gradient_accumulation_steps
,
gpus
=
args
.
n_gpu
,
gpus
=
args
.
n_gpu
,
max_epochs
=
args
.
num_train_epochs
,
max_epochs
=
args
.
num_train_epochs
,
early_stop_callback
=
False
,
gradient_clip_val
=
args
.
max_grad_norm
,
gradient_clip_val
=
args
.
max_grad_norm
,
checkpoint_callback
=
checkpoint_callback
,
checkpoint_callback
=
checkpoint_callback
,
)
)
if
args
.
fp16
:
if
args
.
fp16
:
train_params
[
"use_amp"
]
=
args
.
fp16
train_params
[
"use_amp"
]
=
args
.
fp16
train_params
[
"amp_level"
]
=
args
.
fp16_opt_level
train_params
[
"amp_level"
]
=
args
.
fp16_opt_level
if
args
.
n_tpu_cores
>
0
:
global
xm
import
torch_xla.core.xla_model
as
xm
train_params
[
"num_tpu_cores"
]
=
args
.
n_tpu_cores
train_params
[
"gpus"
]
=
0
if
args
.
n_gpu
>
1
:
if
args
.
n_gpu
>
1
:
train_params
[
"distributed_backend"
]
=
"ddp"
train_params
[
"distributed_backend"
]
=
"ddp"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment