Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bb9c5ead
Commit
bb9c5ead
authored
Oct 02, 2019
by
VictorSanh
Committed by
Victor SANH
Oct 03, 2019
Browse files
update distiller
parent
a12ab0a8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
118 additions
and
71 deletions
+118
-71
examples/distillation/distiller.py
examples/distillation/distiller.py
+118
-71
No files found.
examples/distillation/distiller.py
View file @
bb9c5ead
...
@@ -12,8 +12,8 @@
...
@@ -12,8 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" The distiller to distil
DistilBERT
""" The distiller to distil
the student.
a
dapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
A
dapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
"""
import
os
import
os
import
math
import
math
...
@@ -28,16 +28,19 @@ import torch
...
@@ -28,16 +28,19 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.optim
import
AdamW
from
torch.optim
import
AdamW
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data
import
RandomSampler
,
BatchSampler
,
DataLoader
from
transformers
import
WarmupLinearSchedule
from
transformers
import
WarmupLinearSchedule
from
utils
import
logger
from
utils
import
logger
from
dataset
import
Dataset
from
lm_seqs_dataset
import
LmSeqsDataset
from
grouped_batch_sampler
import
GroupedBatchSampler
,
create_lengths_groups
class
Distiller
:
class
Distiller
:
def
__init__
(
self
,
def
__init__
(
self
,
params
:
dict
,
params
:
dict
,
data
loader
:
Dataset
,
data
set
:
LmSeqs
Dataset
,
token_probs
:
torch
.
tensor
,
token_probs
:
torch
.
tensor
,
student
:
nn
.
Module
,
student
:
nn
.
Module
,
teacher
:
nn
.
Module
):
teacher
:
nn
.
Module
):
...
@@ -50,33 +53,47 @@ class Distiller:
...
@@ -50,33 +53,47 @@ class Distiller:
self
.
student
=
student
self
.
student
=
student
self
.
teacher
=
teacher
self
.
teacher
=
teacher
self
.
dataloader
=
dataloader
self
.
student_config
=
student
.
config
if
self
.
params
.
n_gpu
>
1
:
self
.
vocab_size
=
student
.
config
.
vocab_size
self
.
dataloader
.
split
()
self
.
get_iterator
(
seed
=
params
.
seed
)
if
params
.
n_gpu
<=
1
:
sampler
=
RandomSampler
(
dataset
)
else
:
sampler
=
DistributedSampler
(
dataset
)
if
params
.
group_by_size
:
groups
=
create_lengths_groups
(
lengths
=
dataset
.
lengths
,
k
=
params
.
max_model_input_size
)
sampler
=
GroupedBatchSampler
(
sampler
=
sampler
,
group_ids
=
groups
,
batch_size
=
params
.
batch_size
)
else
:
sampler
=
BatchSampler
(
sampler
=
sampler
,
batch_size
=
params
.
batch_size
,
drop_last
=
False
)
self
.
dataloader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
sampler
,
collate_fn
=
dataset
.
batch_sequences
)
self
.
temperature
=
params
.
temperature
self
.
temperature
=
params
.
temperature
assert
self
.
temperature
>
0.
assert
self
.
temperature
>
0.
self
.
alpha_ce
=
params
.
alpha_ce
self
.
alpha_ce
=
params
.
alpha_ce
self
.
alpha_mlm
=
params
.
alpha_mlm
self
.
alpha_mlm
=
params
.
alpha_mlm
self
.
alpha_clm
=
params
.
alpha_clm
self
.
alpha_mse
=
params
.
alpha_mse
self
.
alpha_mse
=
params
.
alpha_mse
self
.
alpha_cos
=
params
.
alpha_cos
self
.
alpha_cos
=
params
.
alpha_cos
assert
self
.
alpha_ce
>=
0.
assert
self
.
alpha_
mlm
>
=
0.
self
.
mlm
=
params
.
mlm
assert
self
.
alpha_mse
>=
0.
if
self
.
mlm
:
assert
self
.
alpha_cos
>=
0.
logger
.
info
(
f
'Using MLM loss for LM step.'
)
assert
self
.
alpha_ce
+
self
.
alpha_mlm
+
self
.
alpha_mse
+
self
.
alpha_cos
>
0.
self
.
mlm_mask_prop
=
params
.
mlm_mask_prop
assert
0.0
<=
self
.
mlm_mask_prop
<=
1.0
self
.
mlm_mask_pro
p
=
params
.
mlm_mask_prop
assert
params
.
word_mask
+
params
.
word_kee
p
+
params
.
word_rand
==
1.0
assert
0.0
<=
self
.
mlm_mask_prop
<=
1.0
self
.
pred_probs
=
torch
.
FloatTensor
([
params
.
word_mask
,
params
.
word_keep
,
params
.
word_rand
])
assert
params
.
word_mask
+
params
.
word_keep
+
params
.
word_rand
==
1.0
self
.
pred_probs
=
self
.
pred_probs
.
to
(
f
'cuda:
{
params
.
local_rank
}
'
)
if
params
.
n_gpu
>
0
else
self
.
pred_probs
self
.
pred
_probs
=
to
rch
.
FloatTensor
([
params
.
word_mask
,
params
.
word_keep
,
params
.
word_rand
])
self
.
token
_probs
=
to
ken_probs
.
to
(
f
'cuda:
{
params
.
local_rank
}
'
)
if
params
.
n_gpu
>
0
else
token_probs
self
.
pred_probs
=
self
.
pred_probs
.
to
(
f
'cuda:
{
params
.
local_rank
}
'
)
if
params
.
n_gpu
>
0
else
self
.
pred_probs
if
self
.
fp16
:
self
.
token
_probs
=
token_probs
.
to
(
f
'cuda:
{
params
.
local_rank
}
'
)
if
params
.
n_gpu
>
0
else
token_probs
self
.
pred
_probs
=
self
.
pred_probs
.
half
()
if
self
.
fp16
:
self
.
token_probs
=
self
.
token_probs
.
half
()
self
.
pred_probs
=
self
.
pred_probs
.
half
()
else
:
self
.
token_probs
=
self
.
token_probs
.
half
(
)
logger
.
info
(
f
'Using CLM loss for LM step.'
)
self
.
epoch
=
0
self
.
epoch
=
0
self
.
n_iter
=
0
self
.
n_iter
=
0
...
@@ -86,12 +103,13 @@ class Distiller:
...
@@ -86,12 +103,13 @@ class Distiller:
self
.
last_loss
=
0
self
.
last_loss
=
0
self
.
last_loss_ce
=
0
self
.
last_loss_ce
=
0
self
.
last_loss_mlm
=
0
self
.
last_loss_mlm
=
0
self
.
last_loss_clm
=
0
if
self
.
alpha_mse
>
0.
:
self
.
last_loss_mse
=
0
if
self
.
alpha_mse
>
0.
:
self
.
last_loss_mse
=
0
if
self
.
alpha_cos
>
0.
:
self
.
last_loss_cos
=
0
if
self
.
alpha_cos
>
0.
:
self
.
last_loss_cos
=
0
self
.
last_log
=
0
self
.
last_log
=
0
self
.
ce_loss_fct
=
nn
.
KLDivLoss
(
reduction
=
'batchmean'
)
self
.
ce_loss_fct
=
nn
.
KLDivLoss
(
reduction
=
'batchmean'
)
self
.
m
lm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
self
.
lm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
if
self
.
alpha_mse
>
0.
:
if
self
.
alpha_mse
>
0.
:
self
.
mse_loss_fct
=
nn
.
MSELoss
(
reduction
=
'sum'
)
self
.
mse_loss_fct
=
nn
.
MSELoss
(
reduction
=
'sum'
)
if
self
.
alpha_cos
>
0.
:
if
self
.
alpha_cos
>
0.
:
...
@@ -99,7 +117,7 @@ class Distiller:
...
@@ -99,7 +117,7 @@ class Distiller:
logger
.
info
(
'--- Initializing model optimizer'
)
logger
.
info
(
'--- Initializing model optimizer'
)
assert
params
.
gradient_accumulation_steps
>=
1
assert
params
.
gradient_accumulation_steps
>=
1
self
.
num_steps_epoch
=
int
(
len
(
self
.
dataloader
)
/
params
.
batch_size
)
+
1
self
.
num_steps_epoch
=
len
(
self
.
dataloader
)
num_train_optimization_steps
=
int
(
self
.
num_steps_epoch
/
params
.
gradient_accumulation_steps
*
params
.
n_epoch
)
+
1
num_train_optimization_steps
=
int
(
self
.
num_steps_epoch
/
params
.
gradient_accumulation_steps
*
params
.
n_epoch
)
+
1
no_decay
=
[
'bias'
,
'LayerNorm.weight'
]
no_decay
=
[
'bias'
,
'LayerNorm.weight'
]
...
@@ -140,43 +158,18 @@ class Distiller:
...
@@ -140,43 +158,18 @@ class Distiller:
logger
.
info
(
"Using nn.parallel.DistributedDataParallel for distributed training."
)
logger
.
info
(
"Using nn.parallel.DistributedDataParallel for distributed training."
)
self
.
student
=
DistributedDataParallel
(
self
.
student
,
self
.
student
=
DistributedDataParallel
(
self
.
student
,
device_ids
=
[
params
.
local_rank
],
device_ids
=
[
params
.
local_rank
],
output_device
=
params
.
local_rank
)
output_device
=
params
.
local_rank
,
find_unused_parameters
=
True
)
self
.
is_master
=
params
.
is_master
self
.
is_master
=
params
.
is_master
if
self
.
is_master
:
if
self
.
is_master
:
logger
.
info
(
'--- Initializing Tensorboard'
)
logger
.
info
(
'--- Initializing Tensorboard'
)
self
.
tensorboard
=
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
self
.
dump_path
,
'log'
,
'train'
))
self
.
tensorboard
=
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
self
.
dump_path
,
'log'
,
'train'
))
self
.
tensorboard
.
add_text
(
tag
=
'config'
,
text_string
=
str
(
self
.
params
),
global_step
=
0
)
self
.
tensorboard
.
add_text
(
tag
=
'config/training'
,
text_string
=
str
(
self
.
params
),
global_step
=
0
)
self
.
tensorboard
.
add_text
(
tag
=
'config/student'
,
text_string
=
str
(
self
.
student_config
),
global_step
=
0
)
def
get_iterator
(
self
,
seed
:
int
=
None
):
"""
Initialize the data iterator.
Each process has its own data iterator (iterating on his own random portion of the dataset).
Input:
def
prepare_batch_mlm
(
self
,
------
batch
):
seed: `int` - The random seed.
"""
logger
.
info
(
'--- Initializing Data Iterator'
)
self
.
data_iterator
=
self
.
dataloader
.
get_iterator
(
seed
=
seed
)
def
get_batch
(
self
):
"""
Call the data iterator to output a new batch.
If the data iterator went through the whole dataset, create a new iterator.
"""
assert
hasattr
(
self
,
'data_iterator'
)
try
:
x
=
next
(
self
.
data_iterator
)
except
StopIteration
:
logger
.
warning
(
'--- Went through the whole dataset. Creating new data iterator.'
)
self
.
data_iterator
=
self
.
dataloader
.
get_iterator
()
x
=
next
(
self
.
data_iterator
)
return
x
def
prepare_batch
(
self
,
batch
):
"""
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
...
@@ -222,7 +215,7 @@ class Distiller:
...
@@ -222,7 +215,7 @@ class Distiller:
assert
pred_mask
.
sum
().
item
()
%
8
==
0
,
pred_mask
.
sum
().
item
()
assert
pred_mask
.
sum
().
item
()
%
8
==
0
,
pred_mask
.
sum
().
item
()
_token_ids_real
=
token_ids
[
pred_mask
]
_token_ids_real
=
token_ids
[
pred_mask
]
_token_ids_rand
=
_token_ids_real
.
clone
().
random_
(
self
.
params
.
vocab_size
)
_token_ids_rand
=
_token_ids_real
.
clone
().
random_
(
self
.
vocab_size
)
_token_ids_mask
=
_token_ids_real
.
clone
().
fill_
(
self
.
params
.
special_tok_ids
[
'mask_token'
])
_token_ids_mask
=
_token_ids_real
.
clone
().
fill_
(
self
.
params
.
special_tok_ids
[
'mask_token'
])
probs
=
torch
.
multinomial
(
self
.
pred_probs
,
len
(
_token_ids_real
),
replacement
=
True
)
probs
=
torch
.
multinomial
(
self
.
pred_probs
,
len
(
_token_ids_real
),
replacement
=
True
)
_token_ids
=
_token_ids_mask
*
(
probs
==
0
).
long
()
+
_token_ids_real
*
(
probs
==
1
).
long
()
+
_token_ids_rand
*
(
probs
==
2
).
long
()
_token_ids
=
_token_ids_mask
*
(
probs
==
0
).
long
()
+
_token_ids_real
*
(
probs
==
1
).
long
()
+
_token_ids_rand
*
(
probs
==
2
).
long
()
...
@@ -230,8 +223,41 @@ class Distiller:
...
@@ -230,8 +223,41 @@ class Distiller:
mlm_labels
[
~
pred_mask
]
=
-
1
# previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
mlm_labels
[
~
pred_mask
]
=
-
1
# previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert
0
<=
token_ids
.
min
()
<=
token_ids
.
max
()
<
self
.
vocab_size
return
token_ids
,
attn_mask
,
mlm_labels
return
token_ids
,
attn_mask
,
mlm_labels
def
prepare_batch_clm
(
self
,
batch
):
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.
Input:
------
batch: `Tuple`
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
Output:
-------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict.
"""
token_ids
,
lengths
=
batch
token_ids
,
lengths
=
self
.
round_batch
(
x
=
token_ids
,
lengths
=
lengths
)
assert
token_ids
.
size
(
0
)
==
lengths
.
size
(
0
)
attn_mask
=
(
torch
.
arange
(
token_ids
.
size
(
1
),
dtype
=
torch
.
long
,
device
=
lengths
.
device
)
<
lengths
[:,
None
])
clm_labels
=
token_ids
.
new
(
token_ids
.
size
()).
copy_
(
token_ids
)
clm_labels
[
~
attn_mask
]
=
-
1
# previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert
0
<=
token_ids
.
min
()
<=
token_ids
.
max
()
<
self
.
vocab_size
return
token_ids
,
attn_mask
,
clm_labels
def
round_batch
(
self
,
def
round_batch
(
self
,
x
:
torch
.
tensor
,
x
:
torch
.
tensor
,
lengths
:
torch
.
tensor
):
lengths
:
torch
.
tensor
):
...
@@ -269,7 +295,10 @@ class Distiller:
...
@@ -269,7 +295,10 @@ class Distiller:
if
ml1
%
8
!=
0
:
if
ml1
%
8
!=
0
:
pad
=
8
-
(
ml1
%
8
)
pad
=
8
-
(
ml1
%
8
)
ml2
=
ml1
+
pad
ml2
=
ml1
+
pad
pad_id
=
self
.
params
.
special_tok_ids
[
'pad_token'
]
if
self
.
mlm
:
pad_id
=
self
.
params
.
special_tok_ids
[
'pad_token'
]
else
:
pad_id
=
self
.
params
.
special_tok_ids
[
'unk_token'
]
padding_tensor
=
torch
.
zeros
(
bs2
,
pad
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
fill_
(
pad_id
)
padding_tensor
=
torch
.
zeros
(
bs2
,
pad
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
fill_
(
pad_id
)
x
=
torch
.
cat
([
x
,
padding_tensor
],
1
)
x
=
torch
.
cat
([
x
,
padding_tensor
],
1
)
assert
x
.
size
()
==
(
bs2
,
ml2
)
assert
x
.
size
()
==
(
bs2
,
ml2
)
...
@@ -292,14 +321,16 @@ class Distiller:
...
@@ -292,14 +321,16 @@ class Distiller:
if
self
.
multi_gpu
:
if
self
.
multi_gpu
:
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
iter_bar
=
trange
(
self
.
num_steps_epoch
,
desc
=
"-Iter"
,
disable
=
self
.
params
.
local_rank
not
in
[
-
1
,
0
])
iter_bar
=
tqdm
(
self
.
dataloader
,
desc
=
"-Iter"
,
disable
=
self
.
params
.
local_rank
not
in
[
-
1
,
0
])
for
__
in
range
(
self
.
num_steps_epoch
):
for
batch
in
iter_bar
:
batch
=
self
.
get_batch
()
if
self
.
params
.
n_gpu
>
0
:
if
self
.
params
.
n_gpu
>
0
:
batch
=
tuple
(
t
.
to
(
f
'cuda:
{
self
.
params
.
local_rank
}
'
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
f
'cuda:
{
self
.
params
.
local_rank
}
'
)
for
t
in
batch
)
token_ids
,
attn_mask
,
mlm_labels
=
self
.
prepare_batch
(
batch
=
batch
)
self
.
step
(
input_ids
=
token_ids
,
attention_mask
=
attn_mask
,
mlm_labels
=
mlm_labels
)
if
self
.
mlm
:
token_ids
,
attn_mask
,
lm_labels
=
self
.
prepare_batch_mlm
(
batch
=
batch
)
else
:
token_ids
,
attn_mask
,
lm_labels
=
self
.
prepare_batch_clm
(
batch
=
batch
)
self
.
step
(
input_ids
=
token_ids
,
attention_mask
=
attn_mask
,
lm_labels
=
lm_labels
)
iter_bar
.
update
()
iter_bar
.
update
()
iter_bar
.
set_postfix
({
'Last_loss'
:
f
'
{
self
.
last_loss
:.
2
f
}
'
,
iter_bar
.
set_postfix
({
'Last_loss'
:
f
'
{
self
.
last_loss
:.
2
f
}
'
,
...
@@ -317,7 +348,7 @@ class Distiller:
...
@@ -317,7 +348,7 @@ class Distiller:
def
step
(
self
,
def
step
(
self
,
input_ids
:
torch
.
tensor
,
input_ids
:
torch
.
tensor
,
attention_mask
:
torch
.
tensor
,
attention_mask
:
torch
.
tensor
,
m
lm_labels
:
torch
.
tensor
):
lm_labels
:
torch
.
tensor
):
"""
"""
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
and possibly a parameter update (depending on the gradient accumulation).
and possibly a parameter update (depending on the gradient accumulation).
...
@@ -326,17 +357,22 @@ class Distiller:
...
@@ -326,17 +357,22 @@ class Distiller:
------
------
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
m
lm_labels: `torch.tensor(bs, seq_length)` - The
masked
language modeling labels.
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels
(mlm labels for MLM and clm labels for CLM)
.
"""
"""
s_logits
,
s_hidden_states
=
self
.
student
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# (bs, seq_length, voc_size)
if
self
.
mlm
:
with
torch
.
no_grad
():
s_logits
,
s_hidden_states
=
self
.
student
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# (bs, seq_length, voc_size)
t_logits
,
t_hidden_states
=
self
.
teacher
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# (bs, seq_length, voc_size)
with
torch
.
no_grad
():
t_logits
,
t_hidden_states
=
self
.
teacher
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# (bs, seq_length, voc_size)
else
:
s_logits
,
_
,
s_hidden_states
=
self
.
student
(
input_ids
=
input_ids
,
attention_mask
=
None
)
# (bs, seq_length, voc_size)
with
torch
.
no_grad
():
t_logits
,
_
,
t_hidden_states
=
self
.
teacher
(
input_ids
=
input_ids
,
attention_mask
=
None
)
# (bs, seq_length, voc_size)
assert
s_logits
.
size
()
==
t_logits
.
size
()
assert
s_logits
.
size
()
==
t_logits
.
size
()
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
if
self
.
params
.
restrict_ce_to_mask
:
if
self
.
params
.
restrict_ce_to_mask
:
mask
=
(
m
lm_labels
>-
1
).
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
mask
=
(
lm_labels
>-
1
).
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
else
:
else
:
mask
=
attention_mask
.
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
mask
=
attention_mask
.
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
s_logits_slct
=
torch
.
masked_select
(
s_logits
,
mask
)
# (bs * seq_length * voc_size) modulo the 1s in mask
s_logits_slct
=
torch
.
masked_select
(
s_logits
,
mask
)
# (bs * seq_length * voc_size) modulo the 1s in mask
...
@@ -348,13 +384,20 @@ class Distiller:
...
@@ -348,13 +384,20 @@ class Distiller:
loss_ce
=
self
.
ce_loss_fct
(
F
.
log_softmax
(
s_logits_slct
/
self
.
temperature
,
dim
=-
1
),
loss_ce
=
self
.
ce_loss_fct
(
F
.
log_softmax
(
s_logits_slct
/
self
.
temperature
,
dim
=-
1
),
F
.
softmax
(
t_logits_slct
/
self
.
temperature
,
dim
=-
1
))
*
(
self
.
temperature
)
**
2
F
.
softmax
(
t_logits_slct
/
self
.
temperature
,
dim
=-
1
))
*
(
self
.
temperature
)
**
2
loss
=
self
.
alpha_ce
*
loss_ce
loss
=
self
.
alpha_ce
*
loss_ce
if
self
.
alpha_mlm
>
0.
:
if
self
.
alpha_mlm
>
0.
:
loss_mlm
=
self
.
m
lm_loss_fct
(
s_logits
.
view
(
-
1
,
s_logits
.
size
(
-
1
)),
m
lm_labels
.
view
(
-
1
))
loss_mlm
=
self
.
lm_loss_fct
(
s_logits
.
view
(
-
1
,
s_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
))
loss
+=
self
.
alpha_mlm
*
loss_mlm
loss
+=
self
.
alpha_mlm
*
loss_mlm
if
self
.
alpha_clm
>
0.
:
shift_logits
=
s_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_clm
=
self
.
lm_loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
loss
+=
self
.
alpha_clm
*
loss_clm
if
self
.
alpha_mse
>
0.
:
if
self
.
alpha_mse
>
0.
:
loss_mse
=
self
.
mse_loss_fct
(
s_logits_slct
,
t_logits_slct
)
/
s_logits_slct
.
size
(
0
)
# Reproducing batchmean reduction
loss_mse
=
self
.
mse_loss_fct
(
s_logits_slct
,
t_logits_slct
)
/
s_logits_slct
.
size
(
0
)
# Reproducing batchmean reduction
loss
+=
self
.
alpha_mse
*
loss_mse
loss
+=
self
.
alpha_mse
*
loss_mse
if
self
.
alpha_cos
>
0.
:
if
self
.
alpha_cos
>
0.
:
s_hidden_states
=
s_hidden_states
[
-
1
]
# (bs, seq_length, dim)
s_hidden_states
=
s_hidden_states
[
-
1
]
# (bs, seq_length, dim)
t_hidden_states
=
t_hidden_states
[
-
1
]
# (bs, seq_length, dim)
t_hidden_states
=
t_hidden_states
[
-
1
]
# (bs, seq_length, dim)
...
@@ -376,6 +419,8 @@ class Distiller:
...
@@ -376,6 +419,8 @@ class Distiller:
self
.
last_loss_ce
=
loss_ce
.
item
()
self
.
last_loss_ce
=
loss_ce
.
item
()
if
self
.
alpha_mlm
>
0.
:
if
self
.
alpha_mlm
>
0.
:
self
.
last_loss_mlm
=
loss_mlm
.
item
()
self
.
last_loss_mlm
=
loss_mlm
.
item
()
if
self
.
alpha_clm
>
0.
:
self
.
last_loss_clm
=
loss_clm
.
item
()
if
self
.
alpha_mse
>
0.
:
if
self
.
alpha_mse
>
0.
:
self
.
last_loss_mse
=
loss_mse
.
item
()
self
.
last_loss_mse
=
loss_mse
.
item
()
if
self
.
alpha_cos
>
0.
:
if
self
.
alpha_cos
>
0.
:
...
@@ -452,6 +497,8 @@ class Distiller:
...
@@ -452,6 +497,8 @@ class Distiller:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_ce"
,
scalar_value
=
self
.
last_loss_ce
,
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_ce"
,
scalar_value
=
self
.
last_loss_ce
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_mlm
>
0.
:
if
self
.
alpha_mlm
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_mlm"
,
scalar_value
=
self
.
last_loss_mlm
,
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_mlm"
,
scalar_value
=
self
.
last_loss_mlm
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_clm
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_clm"
,
scalar_value
=
self
.
last_loss_clm
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_mse
>
0.
:
if
self
.
alpha_mse
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_mse"
,
scalar_value
=
self
.
last_loss_mse
,
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_mse"
,
scalar_value
=
self
.
last_loss_mse
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_cos
>
0.
:
if
self
.
alpha_cos
>
0.
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment