Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f5891c38
Commit
f5891c38
authored
Oct 04, 2019
by
VictorSanh
Committed by
Victor SANH
Oct 04, 2019
Browse files
run_squad --> run_squad_w_distillation
parent
764a7923
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
589 additions
and
51 deletions
+589
-51
examples/distillation/run_squad_w_distillation.py
examples/distillation/run_squad_w_distillation.py
+585
-0
examples/run_squad.py
examples/run_squad.py
+4
-51
No files found.
examples/distillation/run_squad_w_distillation.py
0 → 100644
View file @
f5891c38
This diff is collapsed.
Click to expand it.
examples/run_squad.py
View file @
f5891c38
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)
with an optional step of distillation
."""
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
absolute_import
,
division
,
print_function
...
@@ -28,8 +28,6 @@ import torch
...
@@ -28,8 +28,6 @@ import torch
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
TensorDataset
)
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
import
torch.nn.functional
as
F
import
torch.nn
as
nn
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
...
@@ -75,7 +73,7 @@ def set_seed(args):
...
@@ -75,7 +73,7 @@ def set_seed(args):
def
to_list
(
tensor
):
def
to_list
(
tensor
):
return
tensor
.
detach
().
cpu
().
tolist
()
return
tensor
.
detach
().
cpu
().
tolist
()
def
train
(
args
,
train_dataset
,
model
,
tokenizer
,
teacher
=
None
):
def
train
(
args
,
train_dataset
,
model
,
tokenizer
):
""" Train the model """
""" Train the model """
if
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
tb_writer
=
SummaryWriter
()
tb_writer
=
SummaryWriter
()
...
@@ -134,8 +132,6 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -134,8 +132,6 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
for
step
,
batch
in
enumerate
(
epoch_iterator
):
for
step
,
batch
in
enumerate
(
epoch_iterator
):
model
.
train
()
model
.
train
()
if
teacher
is
not
None
:
teacher
.
eval
()
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
inputs
=
{
'input_ids'
:
batch
[
0
],
inputs
=
{
'input_ids'
:
batch
[
0
],
'attention_mask'
:
batch
[
1
],
'attention_mask'
:
batch
[
1
],
...
@@ -147,27 +143,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -147,27 +143,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
inputs
.
update
({
'cls_index'
:
batch
[
5
],
inputs
.
update
({
'cls_index'
:
batch
[
5
],
'p_mask'
:
batch
[
6
]})
'p_mask'
:
batch
[
6
]})
outputs
=
model
(
**
inputs
)
outputs
=
model
(
**
inputs
)
loss
,
start_logits_stu
,
end_logits_stu
=
outputs
loss
=
outputs
[
0
]
# model outputs are always tuple in transformers (see doc)
# Distillation loss
if
teacher
is
not
None
:
if
'token_type_ids'
not
in
inputs
:
inputs
[
'token_type_ids'
]
=
None
if
args
.
teacher_type
==
'xlm'
else
batch
[
2
]
with
torch
.
no_grad
():
start_logits_tea
,
end_logits_tea
=
teacher
(
input_ids
=
inputs
[
'input_ids'
],
token_type_ids
=
inputs
[
'token_type_ids'
],
attention_mask
=
inputs
[
'attention_mask'
])
assert
start_logits_tea
.
size
()
==
start_logits_stu
.
size
()
assert
end_logits_tea
.
size
()
==
end_logits_stu
.
size
()
loss_fct
=
nn
.
KLDivLoss
(
reduction
=
'batchmean'
)
loss_start
=
loss_fct
(
F
.
log_softmax
(
start_logits_stu
/
args
.
temperature
,
dim
=-
1
),
F
.
softmax
(
start_logits_tea
/
args
.
temperature
,
dim
=-
1
))
*
(
args
.
temperature
**
2
)
loss_end
=
loss_fct
(
F
.
log_softmax
(
end_logits_stu
/
args
.
temperature
,
dim
=-
1
),
F
.
softmax
(
end_logits_tea
/
args
.
temperature
,
dim
=-
1
))
*
(
args
.
temperature
**
2
)
loss_ce
=
(
loss_start
+
loss_end
)
/
2.
loss
=
args
.
alpha_ce
*
loss_ce
+
args
.
alpha_squad
*
loss
if
args
.
n_gpu
>
1
:
if
args
.
n_gpu
>
1
:
loss
=
loss
.
mean
()
# mean() to average on multi-gpu parallel (not distributed) training
loss
=
loss
.
mean
()
# mean() to average on multi-gpu parallel (not distributed) training
...
@@ -367,18 +343,6 @@ def main():
...
@@ -367,18 +343,6 @@ def main():
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model checkpoints and predictions will be written."
)
help
=
"The output directory where the model checkpoints and predictions will be written."
)
# Distillation parameters (optional)
parser
.
add_argument
(
'--teacher_type'
,
default
=
None
,
type
=
str
,
help
=
"Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation."
)
parser
.
add_argument
(
'--teacher_name_or_path'
,
default
=
None
,
type
=
str
,
help
=
"Path to the already SQuAD fine-tuned teacher model. Only for distillation."
)
parser
.
add_argument
(
'--alpha_ce'
,
default
=
0.5
,
type
=
float
,
help
=
"Distillation loss linear weight. Only for distillation."
)
parser
.
add_argument
(
'--alpha_squad'
,
default
=
0.5
,
type
=
float
,
help
=
"True SQuAD loss linear weight. Only for distillation."
)
parser
.
add_argument
(
'--temperature'
,
default
=
2.0
,
type
=
float
,
help
=
"Distillation temperature. Only for distillation."
)
## Other parameters
## Other parameters
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
)
help
=
"Pretrained config name or path if not the same as model_name"
)
...
@@ -506,17 +470,6 @@ def main():
...
@@ -506,17 +470,6 @@ def main():
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_name_or_path
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_name_or_path
,
do_lower_case
=
args
.
do_lower_case
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
,
from_tf
=
bool
(
'.ckpt'
in
args
.
model_name_or_path
),
config
=
config
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
,
from_tf
=
bool
(
'.ckpt'
in
args
.
model_name_or_path
),
config
=
config
)
if
args
.
teacher_type
is
not
None
:
assert
args
.
teacher_name_or_path
is
not
None
assert
args
.
alpha_ce
>
0.
assert
args
.
alpha_ce
+
args
.
alpha_squad
>
0.
assert
args
.
teacher_type
!=
'distilbert'
,
"We constraint teachers not to be of type DistilBERT."
teacher_config_class
,
teacher_model_class
,
_
=
MODEL_CLASSES
[
args
.
teacher_type
]
teacher_config
=
teacher_config_class
.
from_pretrained
(
args
.
teacher_name_or_path
)
teacher
=
teacher_model_class
.
from_pretrained
(
args
.
teacher_name_or_path
,
config
=
teacher_config
)
teacher
.
to
(
args
.
device
)
else
:
teacher
=
None
if
args
.
local_rank
==
0
:
if
args
.
local_rank
==
0
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
...
@@ -528,7 +481,7 @@ def main():
...
@@ -528,7 +481,7 @@ def main():
# Training
# Training
if
args
.
do_train
:
if
args
.
do_train
:
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
)
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
,
teacher
=
teacher
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment