Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
764a7923
Commit
764a7923
authored
Sep 27, 2019
by
VictorSanh
Committed by
Victor SANH
Oct 04, 2019
Browse files
add distillation+finetuning option in run_squad
parent
bb464289
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
7 deletions
+57
-7
examples/run_squad.py
examples/run_squad.py
+57
-7
No files found.
examples/run_squad.py
View file @
764a7923
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" Finetuning the library models for question-answering on SQuAD (Bert, XLM, XLNet)."""
""" Finetuning the library models for question-answering on SQuAD (
DistilBERT,
Bert, XLM, XLNet)
with an optional step of distillation
."""
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
absolute_import
,
division
,
print_function
...
@@ -28,6 +28,8 @@ import torch
...
@@ -28,6 +28,8 @@ import torch
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
TensorDataset
)
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
import
torch.nn.functional
as
F
import
torch.nn
as
nn
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
...
@@ -73,7 +75,7 @@ def set_seed(args):
...
@@ -73,7 +75,7 @@ def set_seed(args):
def
to_list
(
tensor
):
def
to_list
(
tensor
):
return
tensor
.
detach
().
cpu
().
tolist
()
return
tensor
.
detach
().
cpu
().
tolist
()
def
train
(
args
,
train_dataset
,
model
,
tokenizer
):
def
train
(
args
,
train_dataset
,
model
,
tokenizer
,
teacher
=
None
):
""" Train the model """
""" Train the model """
if
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
tb_writer
=
SummaryWriter
()
tb_writer
=
SummaryWriter
()
...
@@ -132,17 +134,40 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -132,17 +134,40 @@ def train(args, train_dataset, model, tokenizer):
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
for
step
,
batch
in
enumerate
(
epoch_iterator
):
for
step
,
batch
in
enumerate
(
epoch_iterator
):
model
.
train
()
model
.
train
()
if
teacher
is
not
None
:
teacher
.
eval
()
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
inputs
=
{
'input_ids'
:
batch
[
0
],
inputs
=
{
'input_ids'
:
batch
[
0
],
'attention_mask'
:
batch
[
1
],
'attention_mask'
:
batch
[
1
],
'token_type_ids'
:
None
if
args
.
model_type
==
'xlm'
else
batch
[
2
],
'start_positions'
:
batch
[
3
],
'start_positions'
:
batch
[
3
],
'end_positions'
:
batch
[
4
]}
'end_positions'
:
batch
[
4
]}
if
args
.
model_type
!=
'distilbert'
:
inputs
[
'token_type_ids'
]
=
None
if
args
.
model_type
==
'xlm'
else
batch
[
2
]
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
inputs
.
update
({
'cls_index'
:
batch
[
5
],
inputs
.
update
({
'cls_index'
:
batch
[
5
],
'p_mask'
:
batch
[
6
]})
'p_mask'
:
batch
[
6
]})
outputs
=
model
(
**
inputs
)
outputs
=
model
(
**
inputs
)
loss
=
outputs
[
0
]
# model outputs are always tuple in transformers (see doc)
loss
,
start_logits_stu
,
end_logits_stu
=
outputs
# Distillation loss
if
teacher
is
not
None
:
if
'token_type_ids'
not
in
inputs
:
inputs
[
'token_type_ids'
]
=
None
if
args
.
teacher_type
==
'xlm'
else
batch
[
2
]
with
torch
.
no_grad
():
start_logits_tea
,
end_logits_tea
=
teacher
(
input_ids
=
inputs
[
'input_ids'
],
token_type_ids
=
inputs
[
'token_type_ids'
],
attention_mask
=
inputs
[
'attention_mask'
])
assert
start_logits_tea
.
size
()
==
start_logits_stu
.
size
()
assert
end_logits_tea
.
size
()
==
end_logits_stu
.
size
()
loss_fct
=
nn
.
KLDivLoss
(
reduction
=
'batchmean'
)
loss_start
=
loss_fct
(
F
.
log_softmax
(
start_logits_stu
/
args
.
temperature
,
dim
=-
1
),
F
.
softmax
(
start_logits_tea
/
args
.
temperature
,
dim
=-
1
))
*
(
args
.
temperature
**
2
)
loss_end
=
loss_fct
(
F
.
log_softmax
(
end_logits_stu
/
args
.
temperature
,
dim
=-
1
),
F
.
softmax
(
end_logits_tea
/
args
.
temperature
,
dim
=-
1
))
*
(
args
.
temperature
**
2
)
loss_ce
=
(
loss_start
+
loss_end
)
/
2.
loss
=
args
.
alpha_ce
*
loss_ce
+
args
.
alpha_squad
*
loss
if
args
.
n_gpu
>
1
:
if
args
.
n_gpu
>
1
:
loss
=
loss
.
mean
()
# mean() to average on multi-gpu parallel (not distributed) training
loss
=
loss
.
mean
()
# mean() to average on multi-gpu parallel (not distributed) training
...
@@ -218,9 +243,10 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -218,9 +243,10 @@ def evaluate(args, model, tokenizer, prefix=""):
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
inputs
=
{
'input_ids'
:
batch
[
0
],
inputs
=
{
'input_ids'
:
batch
[
0
],
'attention_mask'
:
batch
[
1
],
'attention_mask'
:
batch
[
1
]
'token_type_ids'
:
None
if
args
.
model_type
==
'xlm'
else
batch
[
2
]
# XLM don't use segment_ids
}
}
if
args
.
model_type
!=
'distilbert'
:
inputs
[
'token_type_ids'
]
=
None
if
args
.
model_type
==
'xlm'
else
batch
[
2
]
# XLM don't use segment_ids
example_indices
=
batch
[
3
]
example_indices
=
batch
[
3
]
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
inputs
.
update
({
'cls_index'
:
batch
[
4
],
inputs
.
update
({
'cls_index'
:
batch
[
4
],
...
@@ -341,6 +367,18 @@ def main():
...
@@ -341,6 +367,18 @@ def main():
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model checkpoints and predictions will be written."
)
help
=
"The output directory where the model checkpoints and predictions will be written."
)
# Distillation parameters (optional)
parser
.
add_argument
(
'--teacher_type'
,
default
=
None
,
type
=
str
,
help
=
"Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation."
)
parser
.
add_argument
(
'--teacher_name_or_path'
,
default
=
None
,
type
=
str
,
help
=
"Path to the already SQuAD fine-tuned teacher model. Only for distillation."
)
parser
.
add_argument
(
'--alpha_ce'
,
default
=
0.5
,
type
=
float
,
help
=
"Distillation loss linear weight. Only for distillation."
)
parser
.
add_argument
(
'--alpha_squad'
,
default
=
0.5
,
type
=
float
,
help
=
"True SQuAD loss linear weight. Only for distillation."
)
parser
.
add_argument
(
'--temperature'
,
default
=
2.0
,
type
=
float
,
help
=
"Distillation temperature. Only for distillation."
)
## Other parameters
## Other parameters
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
)
help
=
"Pretrained config name or path if not the same as model_name"
)
...
@@ -468,6 +506,18 @@ def main():
...
@@ -468,6 +506,18 @@ def main():
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_name_or_path
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_name_or_path
,
do_lower_case
=
args
.
do_lower_case
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
,
from_tf
=
bool
(
'.ckpt'
in
args
.
model_name_or_path
),
config
=
config
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
,
from_tf
=
bool
(
'.ckpt'
in
args
.
model_name_or_path
),
config
=
config
)
if
args
.
teacher_type
is
not
None
:
assert
args
.
teacher_name_or_path
is
not
None
assert
args
.
alpha_ce
>
0.
assert
args
.
alpha_ce
+
args
.
alpha_squad
>
0.
assert
args
.
teacher_type
!=
'distilbert'
,
"We constraint teachers not to be of type DistilBERT."
teacher_config_class
,
teacher_model_class
,
_
=
MODEL_CLASSES
[
args
.
teacher_type
]
teacher_config
=
teacher_config_class
.
from_pretrained
(
args
.
teacher_name_or_path
)
teacher
=
teacher_model_class
.
from_pretrained
(
args
.
teacher_name_or_path
,
config
=
teacher_config
)
teacher
.
to
(
args
.
device
)
else
:
teacher
=
None
if
args
.
local_rank
==
0
:
if
args
.
local_rank
==
0
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
...
@@ -478,7 +528,7 @@ def main():
...
@@ -478,7 +528,7 @@ def main():
# Training
# Training
if
args
.
do_train
:
if
args
.
do_train
:
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
)
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
,
teacher
=
teacher
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment