Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bb9c5ead
Commit
bb9c5ead
authored
Oct 02, 2019
by
VictorSanh
Committed by
Victor SANH
Oct 03, 2019
Browse files
update distiller
parent
a12ab0a8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
118 additions
and
71 deletions
+118
-71
examples/distillation/distiller.py
examples/distillation/distiller.py
+118
-71
No files found.
examples/distillation/distiller.py
View file @
bb9c5ead
...
...
@@ -12,8 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" The distiller to distil
DistilBERT
a
dapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
""" The distiller to distil
the student.
A
dapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
import
os
import
math
...
...
@@ -28,16 +28,19 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.optim
import
AdamW
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data
import
RandomSampler
,
BatchSampler
,
DataLoader
from
transformers
import
WarmupLinearSchedule
from
utils
import
logger
from
dataset
import
Dataset
from
lm_seqs_dataset
import
LmSeqsDataset
from
grouped_batch_sampler
import
GroupedBatchSampler
,
create_lengths_groups
class
Distiller
:
def
__init__
(
self
,
params
:
dict
,
data
loader
:
Dataset
,
data
set
:
LmSeqs
Dataset
,
token_probs
:
torch
.
tensor
,
student
:
nn
.
Module
,
teacher
:
nn
.
Module
):
...
...
@@ -50,24 +53,36 @@ class Distiller:
self
.
student
=
student
self
.
teacher
=
teacher
self
.
dataloader
=
dataloader
if
self
.
params
.
n_gpu
>
1
:
self
.
dataloader
.
split
()
self
.
get_iterator
(
seed
=
params
.
seed
)
self
.
student_config
=
student
.
config
self
.
vocab_size
=
student
.
config
.
vocab_size
if
params
.
n_gpu
<=
1
:
sampler
=
RandomSampler
(
dataset
)
else
:
sampler
=
DistributedSampler
(
dataset
)
if
params
.
group_by_size
:
groups
=
create_lengths_groups
(
lengths
=
dataset
.
lengths
,
k
=
params
.
max_model_input_size
)
sampler
=
GroupedBatchSampler
(
sampler
=
sampler
,
group_ids
=
groups
,
batch_size
=
params
.
batch_size
)
else
:
sampler
=
BatchSampler
(
sampler
=
sampler
,
batch_size
=
params
.
batch_size
,
drop_last
=
False
)
self
.
dataloader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
sampler
,
collate_fn
=
dataset
.
batch_sequences
)
self
.
temperature
=
params
.
temperature
assert
self
.
temperature
>
0.
self
.
alpha_ce
=
params
.
alpha_ce
self
.
alpha_mlm
=
params
.
alpha_mlm
self
.
alpha_clm
=
params
.
alpha_clm
self
.
alpha_mse
=
params
.
alpha_mse
self
.
alpha_cos
=
params
.
alpha_cos
assert
self
.
alpha_ce
>=
0.
assert
self
.
alpha_mlm
>=
0.
assert
self
.
alpha_mse
>=
0.
assert
self
.
alpha_cos
>=
0.
assert
self
.
alpha_ce
+
self
.
alpha_mlm
+
self
.
alpha_mse
+
self
.
alpha_cos
>
0.
self
.
mlm
=
params
.
mlm
if
self
.
mlm
:
logger
.
info
(
f
'Using MLM loss for LM step.'
)
self
.
mlm_mask_prop
=
params
.
mlm_mask_prop
assert
0.0
<=
self
.
mlm_mask_prop
<=
1.0
assert
params
.
word_mask
+
params
.
word_keep
+
params
.
word_rand
==
1.0
...
...
@@ -77,6 +92,8 @@ class Distiller:
if
self
.
fp16
:
self
.
pred_probs
=
self
.
pred_probs
.
half
()
self
.
token_probs
=
self
.
token_probs
.
half
()
else
:
logger
.
info
(
f
'Using CLM loss for LM step.'
)
self
.
epoch
=
0
self
.
n_iter
=
0
...
...
@@ -86,12 +103,13 @@ class Distiller:
self
.
last_loss
=
0
self
.
last_loss_ce
=
0
self
.
last_loss_mlm
=
0
self
.
last_loss_clm
=
0
if
self
.
alpha_mse
>
0.
:
self
.
last_loss_mse
=
0
if
self
.
alpha_cos
>
0.
:
self
.
last_loss_cos
=
0
self
.
last_log
=
0
self
.
ce_loss_fct
=
nn
.
KLDivLoss
(
reduction
=
'batchmean'
)
self
.
m
lm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
self
.
lm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
if
self
.
alpha_mse
>
0.
:
self
.
mse_loss_fct
=
nn
.
MSELoss
(
reduction
=
'sum'
)
if
self
.
alpha_cos
>
0.
:
...
...
@@ -99,7 +117,7 @@ class Distiller:
logger
.
info
(
'--- Initializing model optimizer'
)
assert
params
.
gradient_accumulation_steps
>=
1
self
.
num_steps_epoch
=
int
(
len
(
self
.
dataloader
)
/
params
.
batch_size
)
+
1
self
.
num_steps_epoch
=
len
(
self
.
dataloader
)
num_train_optimization_steps
=
int
(
self
.
num_steps_epoch
/
params
.
gradient_accumulation_steps
*
params
.
n_epoch
)
+
1
no_decay
=
[
'bias'
,
'LayerNorm.weight'
]
...
...
@@ -140,42 +158,17 @@ class Distiller:
logger
.
info
(
"Using nn.parallel.DistributedDataParallel for distributed training."
)
self
.
student
=
DistributedDataParallel
(
self
.
student
,
device_ids
=
[
params
.
local_rank
],
output_device
=
params
.
local_rank
)
output_device
=
params
.
local_rank
,
find_unused_parameters
=
True
)
self
.
is_master
=
params
.
is_master
if
self
.
is_master
:
logger
.
info
(
'--- Initializing Tensorboard'
)
self
.
tensorboard
=
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
self
.
dump_path
,
'log'
,
'train'
))
self
.
tensorboard
.
add_text
(
tag
=
'config'
,
text_string
=
str
(
self
.
params
),
global_step
=
0
)
def
get_iterator
(
self
,
seed
:
int
=
None
):
"""
Initialize the data iterator.
Each process has its own data iterator (iterating on his own random portion of the dataset).
Input:
------
seed: `int` - The random seed.
"""
logger
.
info
(
'--- Initializing Data Iterator'
)
self
.
data_iterator
=
self
.
dataloader
.
get_iterator
(
seed
=
seed
)
self
.
tensorboard
.
add_text
(
tag
=
'config/training'
,
text_string
=
str
(
self
.
params
),
global_step
=
0
)
self
.
tensorboard
.
add_text
(
tag
=
'config/student'
,
text_string
=
str
(
self
.
student_config
),
global_step
=
0
)
def
get_batch
(
self
):
"""
Call the data iterator to output a new batch.
If the data iterator went through the whole dataset, create a new iterator.
"""
assert
hasattr
(
self
,
'data_iterator'
)
try
:
x
=
next
(
self
.
data_iterator
)
except
StopIteration
:
logger
.
warning
(
'--- Went through the whole dataset. Creating new data iterator.'
)
self
.
data_iterator
=
self
.
dataloader
.
get_iterator
()
x
=
next
(
self
.
data_iterator
)
return
x
def
prepare_batch
(
self
,
def
prepare_batch_mlm
(
self
,
batch
):
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
...
...
@@ -222,7 +215,7 @@ class Distiller:
assert
pred_mask
.
sum
().
item
()
%
8
==
0
,
pred_mask
.
sum
().
item
()
_token_ids_real
=
token_ids
[
pred_mask
]
_token_ids_rand
=
_token_ids_real
.
clone
().
random_
(
self
.
params
.
vocab_size
)
_token_ids_rand
=
_token_ids_real
.
clone
().
random_
(
self
.
vocab_size
)
_token_ids_mask
=
_token_ids_real
.
clone
().
fill_
(
self
.
params
.
special_tok_ids
[
'mask_token'
])
probs
=
torch
.
multinomial
(
self
.
pred_probs
,
len
(
_token_ids_real
),
replacement
=
True
)
_token_ids
=
_token_ids_mask
*
(
probs
==
0
).
long
()
+
_token_ids_real
*
(
probs
==
1
).
long
()
+
_token_ids_rand
*
(
probs
==
2
).
long
()
...
...
@@ -230,8 +223,41 @@ class Distiller:
mlm_labels
[
~
pred_mask
]
=
-
1
# previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert
0
<=
token_ids
.
min
()
<=
token_ids
.
max
()
<
self
.
vocab_size
return
token_ids
,
attn_mask
,
mlm_labels
def
prepare_batch_clm
(
self
,
batch
):
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.
Input:
------
batch: `Tuple`
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
Output:
-------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict.
"""
token_ids
,
lengths
=
batch
token_ids
,
lengths
=
self
.
round_batch
(
x
=
token_ids
,
lengths
=
lengths
)
assert
token_ids
.
size
(
0
)
==
lengths
.
size
(
0
)
attn_mask
=
(
torch
.
arange
(
token_ids
.
size
(
1
),
dtype
=
torch
.
long
,
device
=
lengths
.
device
)
<
lengths
[:,
None
])
clm_labels
=
token_ids
.
new
(
token_ids
.
size
()).
copy_
(
token_ids
)
clm_labels
[
~
attn_mask
]
=
-
1
# previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert
0
<=
token_ids
.
min
()
<=
token_ids
.
max
()
<
self
.
vocab_size
return
token_ids
,
attn_mask
,
clm_labels
def
round_batch
(
self
,
x
:
torch
.
tensor
,
lengths
:
torch
.
tensor
):
...
...
@@ -269,7 +295,10 @@ class Distiller:
if
ml1
%
8
!=
0
:
pad
=
8
-
(
ml1
%
8
)
ml2
=
ml1
+
pad
if
self
.
mlm
:
pad_id
=
self
.
params
.
special_tok_ids
[
'pad_token'
]
else
:
pad_id
=
self
.
params
.
special_tok_ids
[
'unk_token'
]
padding_tensor
=
torch
.
zeros
(
bs2
,
pad
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
fill_
(
pad_id
)
x
=
torch
.
cat
([
x
,
padding_tensor
],
1
)
assert
x
.
size
()
==
(
bs2
,
ml2
)
...
...
@@ -292,14 +321,16 @@ class Distiller:
if
self
.
multi_gpu
:
torch
.
distributed
.
barrier
()
iter_bar
=
trange
(
self
.
num_steps_epoch
,
desc
=
"-Iter"
,
disable
=
self
.
params
.
local_rank
not
in
[
-
1
,
0
])
for
__
in
range
(
self
.
num_steps_epoch
):
batch
=
self
.
get_batch
()
iter_bar
=
tqdm
(
self
.
dataloader
,
desc
=
"-Iter"
,
disable
=
self
.
params
.
local_rank
not
in
[
-
1
,
0
])
for
batch
in
iter_bar
:
if
self
.
params
.
n_gpu
>
0
:
batch
=
tuple
(
t
.
to
(
f
'cuda:
{
self
.
params
.
local_rank
}
'
)
for
t
in
batch
)
token_ids
,
attn_mask
,
mlm_labels
=
self
.
prepare_batch
(
batch
=
batch
)
self
.
step
(
input_ids
=
token_ids
,
attention_mask
=
attn_mask
,
mlm_labels
=
mlm_labels
)
if
self
.
mlm
:
token_ids
,
attn_mask
,
lm_labels
=
self
.
prepare_batch_mlm
(
batch
=
batch
)
else
:
token_ids
,
attn_mask
,
lm_labels
=
self
.
prepare_batch_clm
(
batch
=
batch
)
self
.
step
(
input_ids
=
token_ids
,
attention_mask
=
attn_mask
,
lm_labels
=
lm_labels
)
iter_bar
.
update
()
iter_bar
.
set_postfix
({
'Last_loss'
:
f
'
{
self
.
last_loss
:.
2
f
}
'
,
...
...
@@ -317,7 +348,7 @@ class Distiller:
def
step
(
self
,
input_ids
:
torch
.
tensor
,
attention_mask
:
torch
.
tensor
,
m
lm_labels
:
torch
.
tensor
):
lm_labels
:
torch
.
tensor
):
"""
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
and possibly a parameter update (depending on the gradient accumulation).
...
...
@@ -326,17 +357,22 @@ class Distiller:
------
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
m
lm_labels: `torch.tensor(bs, seq_length)` - The
masked
language modeling labels.
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels
(mlm labels for MLM and clm labels for CLM)
.
"""
if
self
.
mlm
:
s_logits
,
s_hidden_states
=
self
.
student
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# (bs, seq_length, voc_size)
with
torch
.
no_grad
():
t_logits
,
t_hidden_states
=
self
.
teacher
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# (bs, seq_length, voc_size)
else
:
s_logits
,
_
,
s_hidden_states
=
self
.
student
(
input_ids
=
input_ids
,
attention_mask
=
None
)
# (bs, seq_length, voc_size)
with
torch
.
no_grad
():
t_logits
,
_
,
t_hidden_states
=
self
.
teacher
(
input_ids
=
input_ids
,
attention_mask
=
None
)
# (bs, seq_length, voc_size)
assert
s_logits
.
size
()
==
t_logits
.
size
()
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
if
self
.
params
.
restrict_ce_to_mask
:
mask
=
(
m
lm_labels
>-
1
).
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
mask
=
(
lm_labels
>-
1
).
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
else
:
mask
=
attention_mask
.
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
s_logits_slct
=
torch
.
masked_select
(
s_logits
,
mask
)
# (bs * seq_length * voc_size) modulo the 1s in mask
...
...
@@ -348,13 +384,20 @@ class Distiller:
loss_ce
=
self
.
ce_loss_fct
(
F
.
log_softmax
(
s_logits_slct
/
self
.
temperature
,
dim
=-
1
),
F
.
softmax
(
t_logits_slct
/
self
.
temperature
,
dim
=-
1
))
*
(
self
.
temperature
)
**
2
loss
=
self
.
alpha_ce
*
loss_ce
if
self
.
alpha_mlm
>
0.
:
loss_mlm
=
self
.
m
lm_loss_fct
(
s_logits
.
view
(
-
1
,
s_logits
.
size
(
-
1
)),
m
lm_labels
.
view
(
-
1
))
loss_mlm
=
self
.
lm_loss_fct
(
s_logits
.
view
(
-
1
,
s_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
))
loss
+=
self
.
alpha_mlm
*
loss_mlm
if
self
.
alpha_clm
>
0.
:
shift_logits
=
s_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_clm
=
self
.
lm_loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
loss
+=
self
.
alpha_clm
*
loss_clm
if
self
.
alpha_mse
>
0.
:
loss_mse
=
self
.
mse_loss_fct
(
s_logits_slct
,
t_logits_slct
)
/
s_logits_slct
.
size
(
0
)
# Reproducing batchmean reduction
loss
+=
self
.
alpha_mse
*
loss_mse
if
self
.
alpha_cos
>
0.
:
s_hidden_states
=
s_hidden_states
[
-
1
]
# (bs, seq_length, dim)
t_hidden_states
=
t_hidden_states
[
-
1
]
# (bs, seq_length, dim)
...
...
@@ -376,6 +419,8 @@ class Distiller:
self
.
last_loss_ce
=
loss_ce
.
item
()
if
self
.
alpha_mlm
>
0.
:
self
.
last_loss_mlm
=
loss_mlm
.
item
()
if
self
.
alpha_clm
>
0.
:
self
.
last_loss_clm
=
loss_clm
.
item
()
if
self
.
alpha_mse
>
0.
:
self
.
last_loss_mse
=
loss_mse
.
item
()
if
self
.
alpha_cos
>
0.
:
...
...
@@ -452,6 +497,8 @@ class Distiller:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_ce"
,
scalar_value
=
self
.
last_loss_ce
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_mlm
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_mlm"
,
scalar_value
=
self
.
last_loss_mlm
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_clm
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_clm"
,
scalar_value
=
self
.
last_loss_clm
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_mse
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_mse"
,
scalar_value
=
self
.
last_loss_mse
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_cos
>
0.
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment