Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f5891c38
Commit
f5891c38
authored
Oct 04, 2019
by
VictorSanh
Committed by
Victor SANH
Oct 04, 2019
Browse files
run_squad --> run_squad_w_distillation
parent
764a7923
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
589 additions
and
51 deletions
+589
-51
examples/distillation/run_squad_w_distillation.py
examples/distillation/run_squad_w_distillation.py
+585
-0
examples/run_squad.py
examples/run_squad.py
+4
-51
No files found.
examples/distillation/run_squad_w_distillation.py
0 → 100644
View file @
f5891c38
This diff is collapsed.
Click to expand it.
examples/run_squad.py
View file @
f5891c38
...
...
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)
with an optional step of distillation
."""
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from
__future__
import
absolute_import
,
division
,
print_function
...
...
@@ -28,8 +28,6 @@ import torch
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
from
torch.utils.data.distributed
import
DistributedSampler
import
torch.nn.functional
as
F
import
torch.nn
as
nn
from
tqdm
import
tqdm
,
trange
from
tensorboardX
import
SummaryWriter
...
...
@@ -75,7 +73,7 @@ def set_seed(args):
def
to_list
(
tensor
):
return
tensor
.
detach
().
cpu
().
tolist
()
def
train
(
args
,
train_dataset
,
model
,
tokenizer
,
teacher
=
None
):
def
train
(
args
,
train_dataset
,
model
,
tokenizer
):
""" Train the model """
if
args
.
local_rank
in
[
-
1
,
0
]:
tb_writer
=
SummaryWriter
()
...
...
@@ -134,8 +132,6 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
for
step
,
batch
in
enumerate
(
epoch_iterator
):
model
.
train
()
if
teacher
is
not
None
:
teacher
.
eval
()
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
inputs
=
{
'input_ids'
:
batch
[
0
],
'attention_mask'
:
batch
[
1
],
...
...
@@ -147,27 +143,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
inputs
.
update
({
'cls_index'
:
batch
[
5
],
'p_mask'
:
batch
[
6
]})
outputs
=
model
(
**
inputs
)
loss
,
start_logits_stu
,
end_logits_stu
=
outputs
# Distillation loss
if
teacher
is
not
None
:
if
'token_type_ids'
not
in
inputs
:
inputs
[
'token_type_ids'
]
=
None
if
args
.
teacher_type
==
'xlm'
else
batch
[
2
]
with
torch
.
no_grad
():
start_logits_tea
,
end_logits_tea
=
teacher
(
input_ids
=
inputs
[
'input_ids'
],
token_type_ids
=
inputs
[
'token_type_ids'
],
attention_mask
=
inputs
[
'attention_mask'
])
assert
start_logits_tea
.
size
()
==
start_logits_stu
.
size
()
assert
end_logits_tea
.
size
()
==
end_logits_stu
.
size
()
loss_fct
=
nn
.
KLDivLoss
(
reduction
=
'batchmean'
)
loss_start
=
loss_fct
(
F
.
log_softmax
(
start_logits_stu
/
args
.
temperature
,
dim
=-
1
),
F
.
softmax
(
start_logits_tea
/
args
.
temperature
,
dim
=-
1
))
*
(
args
.
temperature
**
2
)
loss_end
=
loss_fct
(
F
.
log_softmax
(
end_logits_stu
/
args
.
temperature
,
dim
=-
1
),
F
.
softmax
(
end_logits_tea
/
args
.
temperature
,
dim
=-
1
))
*
(
args
.
temperature
**
2
)
loss_ce
=
(
loss_start
+
loss_end
)
/
2.
loss
=
args
.
alpha_ce
*
loss_ce
+
args
.
alpha_squad
*
loss
loss
=
outputs
[
0
]
# model outputs are always tuple in transformers (see doc)
if
args
.
n_gpu
>
1
:
loss
=
loss
.
mean
()
# mean() to average on multi-gpu parallel (not distributed) training
...
...
@@ -367,18 +343,6 @@ def main():
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model checkpoints and predictions will be written."
)
# Distillation parameters (optional)
parser
.
add_argument
(
'--teacher_type'
,
default
=
None
,
type
=
str
,
help
=
"Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation."
)
parser
.
add_argument
(
'--teacher_name_or_path'
,
default
=
None
,
type
=
str
,
help
=
"Path to the already SQuAD fine-tuned teacher model. Only for distillation."
)
parser
.
add_argument
(
'--alpha_ce'
,
default
=
0.5
,
type
=
float
,
help
=
"Distillation loss linear weight. Only for distillation."
)
parser
.
add_argument
(
'--alpha_squad'
,
default
=
0.5
,
type
=
float
,
help
=
"True SQuAD loss linear weight. Only for distillation."
)
parser
.
add_argument
(
'--temperature'
,
default
=
2.0
,
type
=
float
,
help
=
"Distillation temperature. Only for distillation."
)
## Other parameters
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
)
...
...
@@ -506,17 +470,6 @@ def main():
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_name_or_path
,
do_lower_case
=
args
.
do_lower_case
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
,
from_tf
=
bool
(
'.ckpt'
in
args
.
model_name_or_path
),
config
=
config
)
if
args
.
teacher_type
is
not
None
:
assert
args
.
teacher_name_or_path
is
not
None
assert
args
.
alpha_ce
>
0.
assert
args
.
alpha_ce
+
args
.
alpha_squad
>
0.
assert
args
.
teacher_type
!=
'distilbert'
,
"We constraint teachers not to be of type DistilBERT."
teacher_config_class
,
teacher_model_class
,
_
=
MODEL_CLASSES
[
args
.
teacher_type
]
teacher_config
=
teacher_config_class
.
from_pretrained
(
args
.
teacher_name_or_path
)
teacher
=
teacher_model_class
.
from_pretrained
(
args
.
teacher_name_or_path
,
config
=
teacher_config
)
teacher
.
to
(
args
.
device
)
else
:
teacher
=
None
if
args
.
local_rank
==
0
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
...
...
@@ -528,7 +481,7 @@ def main():
# Training
if
args
.
do_train
:
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
,
teacher
=
teacher
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment