Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bb9c5ead
"docs/source/en/vscode:/vscode.git/clone" did not exist on "9946dcf8db300c22dc8bed660e62f1ba5aa85bfd"
Commit
bb9c5ead
authored
Oct 02, 2019
by
VictorSanh
Committed by
Victor SANH
Oct 03, 2019
Browse files
update distiller
parent
a12ab0a8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
118 additions
and
71 deletions
+118
-71
examples/distillation/distiller.py
examples/distillation/distiller.py
+118
-71
No files found.
examples/distillation/distiller.py
View file @
bb9c5ead
...
...
@@ -12,8 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" The distiller to distil
DistilBERT
a
dapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
""" The distiller to distil
the student.
A
dapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
import
os
import
math
...
...
@@ -28,16 +28,19 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.optim
import
AdamW
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data
import
RandomSampler
,
BatchSampler
,
DataLoader
from
transformers
import
WarmupLinearSchedule
from
utils
import
logger
from
dataset
import
Dataset
from
lm_seqs_dataset
import
LmSeqsDataset
from
grouped_batch_sampler
import
GroupedBatchSampler
,
create_lengths_groups
class
Distiller
:
def
__init__
(
self
,
params
:
dict
,
data
loader
:
Dataset
,
data
set
:
LmSeqs
Dataset
,
token_probs
:
torch
.
tensor
,
student
:
nn
.
Module
,
teacher
:
nn
.
Module
):
...
...
@@ -50,33 +53,47 @@ class Distiller:
self
.
student
=
student
self
.
teacher
=
teacher
self
.
dataloader
=
dataloader
if
self
.
params
.
n_gpu
>
1
:
self
.
dataloader
.
split
()
self
.
get_iterator
(
seed
=
params
.
seed
)
self
.
student_config
=
student
.
config
self
.
vocab_size
=
student
.
config
.
vocab_size
if
params
.
n_gpu
<=
1
:
sampler
=
RandomSampler
(
dataset
)
else
:
sampler
=
DistributedSampler
(
dataset
)
if
params
.
group_by_size
:
groups
=
create_lengths_groups
(
lengths
=
dataset
.
lengths
,
k
=
params
.
max_model_input_size
)
sampler
=
GroupedBatchSampler
(
sampler
=
sampler
,
group_ids
=
groups
,
batch_size
=
params
.
batch_size
)
else
:
sampler
=
BatchSampler
(
sampler
=
sampler
,
batch_size
=
params
.
batch_size
,
drop_last
=
False
)
self
.
dataloader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
sampler
,
collate_fn
=
dataset
.
batch_sequences
)
self
.
temperature
=
params
.
temperature
assert
self
.
temperature
>
0.
self
.
alpha_ce
=
params
.
alpha_ce
self
.
alpha_mlm
=
params
.
alpha_mlm
self
.
alpha_clm
=
params
.
alpha_clm
self
.
alpha_mse
=
params
.
alpha_mse
self
.
alpha_cos
=
params
.
alpha_cos
assert
self
.
alpha_ce
>=
0.
assert
self
.
alpha_
mlm
>
=
0.
assert
self
.
alpha_mse
>=
0.
assert
self
.
alpha_cos
>=
0.
assert
self
.
alpha_ce
+
self
.
alpha_mlm
+
self
.
alpha_mse
+
self
.
alpha_cos
>
0.
self
.
mlm_mask_pro
p
=
params
.
mlm_mask_prop
assert
0.0
<=
self
.
mlm_mask_prop
<=
1.0
assert
params
.
word_mask
+
params
.
word_keep
+
params
.
word_rand
==
1.0
self
.
pred
_probs
=
to
rch
.
FloatTensor
([
params
.
word_mask
,
params
.
word_keep
,
params
.
word_rand
])
self
.
pred_probs
=
self
.
pred_probs
.
to
(
f
'cuda:
{
params
.
local_rank
}
'
)
if
params
.
n_gpu
>
0
else
self
.
pred_probs
self
.
token
_probs
=
token_probs
.
to
(
f
'cuda:
{
params
.
local_rank
}
'
)
if
params
.
n_gpu
>
0
else
token_probs
if
self
.
fp16
:
self
.
pred_probs
=
self
.
pred_probs
.
half
()
self
.
token_probs
=
self
.
token_probs
.
half
(
)
self
.
mlm
=
params
.
mlm
if
self
.
mlm
:
logger
.
info
(
f
'Using MLM loss for LM step.'
)
self
.
mlm_mask_prop
=
params
.
mlm_mask_prop
assert
0.0
<=
self
.
mlm_mask_prop
<=
1.0
assert
params
.
word_mask
+
params
.
word_kee
p
+
params
.
word_rand
==
1.0
self
.
pred_probs
=
torch
.
FloatTensor
([
params
.
word_mask
,
params
.
word_keep
,
params
.
word_rand
])
self
.
pred_probs
=
self
.
pred_probs
.
to
(
f
'cuda:
{
params
.
local_rank
}
'
)
if
params
.
n_gpu
>
0
else
self
.
pred_probs
self
.
token
_probs
=
to
ken_probs
.
to
(
f
'cuda:
{
params
.
local_rank
}
'
)
if
params
.
n_gpu
>
0
else
token_probs
if
self
.
fp16
:
self
.
pred
_probs
=
self
.
pred_probs
.
half
()
self
.
token_probs
=
self
.
token_probs
.
half
()
else
:
logger
.
info
(
f
'Using CLM loss for LM step.'
)
self
.
epoch
=
0
self
.
n_iter
=
0
...
...
@@ -86,12 +103,13 @@ class Distiller:
self
.
last_loss
=
0
self
.
last_loss_ce
=
0
self
.
last_loss_mlm
=
0
self
.
last_loss_clm
=
0
if
self
.
alpha_mse
>
0.
:
self
.
last_loss_mse
=
0
if
self
.
alpha_cos
>
0.
:
self
.
last_loss_cos
=
0
self
.
last_log
=
0
self
.
ce_loss_fct
=
nn
.
KLDivLoss
(
reduction
=
'batchmean'
)
self
.
m
lm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
self
.
lm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
if
self
.
alpha_mse
>
0.
:
self
.
mse_loss_fct
=
nn
.
MSELoss
(
reduction
=
'sum'
)
if
self
.
alpha_cos
>
0.
:
...
...
@@ -99,7 +117,7 @@ class Distiller:
logger
.
info
(
'--- Initializing model optimizer'
)
assert
params
.
gradient_accumulation_steps
>=
1
self
.
num_steps_epoch
=
int
(
len
(
self
.
dataloader
)
/
params
.
batch_size
)
+
1
self
.
num_steps_epoch
=
len
(
self
.
dataloader
)
num_train_optimization_steps
=
int
(
self
.
num_steps_epoch
/
params
.
gradient_accumulation_steps
*
params
.
n_epoch
)
+
1
no_decay
=
[
'bias'
,
'LayerNorm.weight'
]
...
...
@@ -140,43 +158,18 @@ class Distiller:
logger
.
info
(
"Using nn.parallel.DistributedDataParallel for distributed training."
)
self
.
student
=
DistributedDataParallel
(
self
.
student
,
device_ids
=
[
params
.
local_rank
],
output_device
=
params
.
local_rank
)
output_device
=
params
.
local_rank
,
find_unused_parameters
=
True
)
self
.
is_master
=
params
.
is_master
if
self
.
is_master
:
logger
.
info
(
'--- Initializing Tensorboard'
)
self
.
tensorboard
=
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
self
.
dump_path
,
'log'
,
'train'
))
self
.
tensorboard
.
add_text
(
tag
=
'config'
,
text_string
=
str
(
self
.
params
),
global_step
=
0
)
def
get_iterator
(
self
,
seed
:
int
=
None
):
"""
Initialize the data iterator.
Each process has its own data iterator (iterating on his own random portion of the dataset).
self
.
tensorboard
.
add_text
(
tag
=
'config/training'
,
text_string
=
str
(
self
.
params
),
global_step
=
0
)
self
.
tensorboard
.
add_text
(
tag
=
'config/student'
,
text_string
=
str
(
self
.
student_config
),
global_step
=
0
)
Input:
------
seed: `int` - The random seed.
"""
logger
.
info
(
'--- Initializing Data Iterator'
)
self
.
data_iterator
=
self
.
dataloader
.
get_iterator
(
seed
=
seed
)
def
get_batch
(
self
):
"""
Call the data iterator to output a new batch.
If the data iterator went through the whole dataset, create a new iterator.
"""
assert
hasattr
(
self
,
'data_iterator'
)
try
:
x
=
next
(
self
.
data_iterator
)
except
StopIteration
:
logger
.
warning
(
'--- Went through the whole dataset. Creating new data iterator.'
)
self
.
data_iterator
=
self
.
dataloader
.
get_iterator
()
x
=
next
(
self
.
data_iterator
)
return
x
def
prepare_batch
(
self
,
batch
):
def
prepare_batch_mlm
(
self
,
batch
):
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
...
...
@@ -222,7 +215,7 @@ class Distiller:
assert
pred_mask
.
sum
().
item
()
%
8
==
0
,
pred_mask
.
sum
().
item
()
_token_ids_real
=
token_ids
[
pred_mask
]
_token_ids_rand
=
_token_ids_real
.
clone
().
random_
(
self
.
params
.
vocab_size
)
_token_ids_rand
=
_token_ids_real
.
clone
().
random_
(
self
.
vocab_size
)
_token_ids_mask
=
_token_ids_real
.
clone
().
fill_
(
self
.
params
.
special_tok_ids
[
'mask_token'
])
probs
=
torch
.
multinomial
(
self
.
pred_probs
,
len
(
_token_ids_real
),
replacement
=
True
)
_token_ids
=
_token_ids_mask
*
(
probs
==
0
).
long
()
+
_token_ids_real
*
(
probs
==
1
).
long
()
+
_token_ids_rand
*
(
probs
==
2
).
long
()
...
...
@@ -230,8 +223,41 @@ class Distiller:
mlm_labels
[
~
pred_mask
]
=
-
1
# previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert
0
<=
token_ids
.
min
()
<=
token_ids
.
max
()
<
self
.
vocab_size
return
token_ids
,
attn_mask
,
mlm_labels
def
prepare_batch_clm
(
self
,
batch
):
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.
Input:
------
batch: `Tuple`
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
Output:
-------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict.
"""
token_ids
,
lengths
=
batch
token_ids
,
lengths
=
self
.
round_batch
(
x
=
token_ids
,
lengths
=
lengths
)
assert
token_ids
.
size
(
0
)
==
lengths
.
size
(
0
)
attn_mask
=
(
torch
.
arange
(
token_ids
.
size
(
1
),
dtype
=
torch
.
long
,
device
=
lengths
.
device
)
<
lengths
[:,
None
])
clm_labels
=
token_ids
.
new
(
token_ids
.
size
()).
copy_
(
token_ids
)
clm_labels
[
~
attn_mask
]
=
-
1
# previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert
0
<=
token_ids
.
min
()
<=
token_ids
.
max
()
<
self
.
vocab_size
return
token_ids
,
attn_mask
,
clm_labels
def
round_batch
(
self
,
x
:
torch
.
tensor
,
lengths
:
torch
.
tensor
):
...
...
@@ -269,7 +295,10 @@ class Distiller:
if
ml1
%
8
!=
0
:
pad
=
8
-
(
ml1
%
8
)
ml2
=
ml1
+
pad
pad_id
=
self
.
params
.
special_tok_ids
[
'pad_token'
]
if
self
.
mlm
:
pad_id
=
self
.
params
.
special_tok_ids
[
'pad_token'
]
else
:
pad_id
=
self
.
params
.
special_tok_ids
[
'unk_token'
]
padding_tensor
=
torch
.
zeros
(
bs2
,
pad
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
fill_
(
pad_id
)
x
=
torch
.
cat
([
x
,
padding_tensor
],
1
)
assert
x
.
size
()
==
(
bs2
,
ml2
)
...
...
@@ -292,14 +321,16 @@ class Distiller:
if
self
.
multi_gpu
:
torch
.
distributed
.
barrier
()
iter_bar
=
trange
(
self
.
num_steps_epoch
,
desc
=
"-Iter"
,
disable
=
self
.
params
.
local_rank
not
in
[
-
1
,
0
])
for
__
in
range
(
self
.
num_steps_epoch
):
batch
=
self
.
get_batch
()
iter_bar
=
tqdm
(
self
.
dataloader
,
desc
=
"-Iter"
,
disable
=
self
.
params
.
local_rank
not
in
[
-
1
,
0
])
for
batch
in
iter_bar
:
if
self
.
params
.
n_gpu
>
0
:
batch
=
tuple
(
t
.
to
(
f
'cuda:
{
self
.
params
.
local_rank
}
'
)
for
t
in
batch
)
token_ids
,
attn_mask
,
mlm_labels
=
self
.
prepare_batch
(
batch
=
batch
)
self
.
step
(
input_ids
=
token_ids
,
attention_mask
=
attn_mask
,
mlm_labels
=
mlm_labels
)
if
self
.
mlm
:
token_ids
,
attn_mask
,
lm_labels
=
self
.
prepare_batch_mlm
(
batch
=
batch
)
else
:
token_ids
,
attn_mask
,
lm_labels
=
self
.
prepare_batch_clm
(
batch
=
batch
)
self
.
step
(
input_ids
=
token_ids
,
attention_mask
=
attn_mask
,
lm_labels
=
lm_labels
)
iter_bar
.
update
()
iter_bar
.
set_postfix
({
'Last_loss'
:
f
'
{
self
.
last_loss
:.
2
f
}
'
,
...
...
@@ -317,7 +348,7 @@ class Distiller:
def
step
(
self
,
input_ids
:
torch
.
tensor
,
attention_mask
:
torch
.
tensor
,
m
lm_labels
:
torch
.
tensor
):
lm_labels
:
torch
.
tensor
):
"""
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
and possibly a parameter update (depending on the gradient accumulation).
...
...
@@ -326,17 +357,22 @@ class Distiller:
------
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
m
lm_labels: `torch.tensor(bs, seq_length)` - The
masked
language modeling labels.
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels
(mlm labels for MLM and clm labels for CLM)
.
"""
s_logits
,
s_hidden_states
=
self
.
student
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# (bs, seq_length, voc_size)
with
torch
.
no_grad
():
t_logits
,
t_hidden_states
=
self
.
teacher
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# (bs, seq_length, voc_size)
if
self
.
mlm
:
s_logits
,
s_hidden_states
=
self
.
student
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# (bs, seq_length, voc_size)
with
torch
.
no_grad
():
t_logits
,
t_hidden_states
=
self
.
teacher
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# (bs, seq_length, voc_size)
else
:
s_logits
,
_
,
s_hidden_states
=
self
.
student
(
input_ids
=
input_ids
,
attention_mask
=
None
)
# (bs, seq_length, voc_size)
with
torch
.
no_grad
():
t_logits
,
_
,
t_hidden_states
=
self
.
teacher
(
input_ids
=
input_ids
,
attention_mask
=
None
)
# (bs, seq_length, voc_size)
assert
s_logits
.
size
()
==
t_logits
.
size
()
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
if
self
.
params
.
restrict_ce_to_mask
:
mask
=
(
m
lm_labels
>-
1
).
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
mask
=
(
lm_labels
>-
1
).
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
else
:
mask
=
attention_mask
.
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
s_logits_slct
=
torch
.
masked_select
(
s_logits
,
mask
)
# (bs * seq_length * voc_size) modulo the 1s in mask
...
...
@@ -348,13 +384,20 @@ class Distiller:
loss_ce
=
self
.
ce_loss_fct
(
F
.
log_softmax
(
s_logits_slct
/
self
.
temperature
,
dim
=-
1
),
F
.
softmax
(
t_logits_slct
/
self
.
temperature
,
dim
=-
1
))
*
(
self
.
temperature
)
**
2
loss
=
self
.
alpha_ce
*
loss_ce
if
self
.
alpha_mlm
>
0.
:
loss_mlm
=
self
.
m
lm_loss_fct
(
s_logits
.
view
(
-
1
,
s_logits
.
size
(
-
1
)),
m
lm_labels
.
view
(
-
1
))
loss_mlm
=
self
.
lm_loss_fct
(
s_logits
.
view
(
-
1
,
s_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
))
loss
+=
self
.
alpha_mlm
*
loss_mlm
if
self
.
alpha_clm
>
0.
:
shift_logits
=
s_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_clm
=
self
.
lm_loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
loss
+=
self
.
alpha_clm
*
loss_clm
if
self
.
alpha_mse
>
0.
:
loss_mse
=
self
.
mse_loss_fct
(
s_logits_slct
,
t_logits_slct
)
/
s_logits_slct
.
size
(
0
)
# Reproducing batchmean reduction
loss
+=
self
.
alpha_mse
*
loss_mse
if
self
.
alpha_cos
>
0.
:
s_hidden_states
=
s_hidden_states
[
-
1
]
# (bs, seq_length, dim)
t_hidden_states
=
t_hidden_states
[
-
1
]
# (bs, seq_length, dim)
...
...
@@ -376,6 +419,8 @@ class Distiller:
self
.
last_loss_ce
=
loss_ce
.
item
()
if
self
.
alpha_mlm
>
0.
:
self
.
last_loss_mlm
=
loss_mlm
.
item
()
if
self
.
alpha_clm
>
0.
:
self
.
last_loss_clm
=
loss_clm
.
item
()
if
self
.
alpha_mse
>
0.
:
self
.
last_loss_mse
=
loss_mse
.
item
()
if
self
.
alpha_cos
>
0.
:
...
...
@@ -452,6 +497,8 @@ class Distiller:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_ce"
,
scalar_value
=
self
.
last_loss_ce
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_mlm
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_mlm"
,
scalar_value
=
self
.
last_loss_mlm
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_clm
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_clm"
,
scalar_value
=
self
.
last_loss_clm
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_mse
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_mse"
,
scalar_value
=
self
.
last_loss_mse
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_cos
>
0.
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment