Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1ae81e4a
Commit
1ae81e4a
authored
Aug 28, 2019
by
VictorSanh
Browse files
add dataset. distiller, utils
parent
5d29f8e9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
727 additions
and
0 deletions
+727
-0
examples/distillation/dataset.py
examples/distillation/dataset.py
+184
-0
examples/distillation/distiller.py
examples/distillation/distiller.py
+431
-0
examples/distillation/utils.py
examples/distillation/utils.py
+112
-0
No files found.
examples/distillation/dataset.py
0 → 100644
View file @
1ae81e4a
from
typing
import
List
import
math
from
itertools
import
chain
from
collections
import
Counter
import
numpy
as
np
import
torch
from
utils
import
logger
class
Dataset
:
def
__init__
(
self
,
params
,
data
):
self
.
params
=
params
self
.
tokens_per_batch
=
params
.
tokens_per_batch
self
.
batch_size
=
params
.
batch_size
self
.
shuffle
=
params
.
shuffle
self
.
group_by_size
=
params
.
group_by_size
self
.
token_ids
=
np
.
array
(
data
)
self
.
lengths
=
np
.
uint16
([
len
(
t
)
for
t
in
data
])
self
.
check
()
self
.
remove_long_sequences
()
self
.
remove_empty_sequences
()
self
.
check
()
self
.
print_statistics
()
def
__len__
(
self
):
return
len
(
self
.
lengths
)
def
check
(
self
):
"""
Some sanity checks
"""
assert
len
(
self
.
token_ids
)
==
len
(
self
.
lengths
)
def
remove_long_sequences
(
self
):
"""
Sequences that are too long are splitted by chunk of max_position_embeddings.
"""
indices
=
self
.
lengths
>=
self
.
params
.
max_position_embeddings
logger
.
info
(
f
'Splitting
{
sum
(
indices
)
}
too long sequences.'
)
def
divide_chunks
(
l
,
n
):
return
[
l
[
i
:
i
+
n
]
for
i
in
range
(
0
,
len
(
l
),
n
)]
new_tok_ids
=
[]
new_lengths
=
[]
cls_id
,
sep_id
=
self
.
params
.
special_tok_ids
[
'cls_token'
],
self
.
params
.
special_tok_ids
[
'sep_token'
]
max_len
=
self
.
params
.
max_position_embeddings
for
seq_
,
len_
in
zip
(
self
.
token_ids
,
self
.
lengths
):
if
len_
<=
max_len
:
new_tok_ids
.
append
(
seq_
)
new_lengths
.
append
(
len_
)
else
:
sub_seqs
=
[]
for
sub_s
in
divide_chunks
(
seq_
,
max_len
-
2
):
if
sub_s
[
0
]
!=
cls_id
:
sub_s
=
np
.
insert
(
sub_s
,
0
,
cls_id
)
if
sub_s
[
-
1
]
!=
sep_id
:
sub_s
=
np
.
insert
(
sub_s
,
len
(
sub_s
),
cls_id
)
assert
len
(
sub_s
)
<=
max_len
sub_seqs
.
append
(
sub_s
)
new_tok_ids
.
extend
(
sub_seqs
)
new_lengths
.
extend
([
len
(
l
)
for
l
in
sub_seqs
])
self
.
token_ids
=
np
.
array
(
new_tok_ids
)
self
.
lengths
=
np
.
array
(
new_lengths
)
def
remove_empty_sequences
(
self
):
"""
Too short sequences are simply removed. This could be tunedd.
"""
init_size
=
len
(
self
)
indices
=
self
.
lengths
>
5
self
.
token_ids
=
self
.
token_ids
[
indices
]
self
.
lengths
=
self
.
lengths
[
indices
]
new_size
=
len
(
self
)
logger
.
info
(
f
'Remove
{
init_size
-
new_size
}
too short (<=5 tokens) sequences.'
)
def
print_statistics
(
self
):
"""
Print some statistics on the corpus. Only the master process.
"""
if
not
self
.
params
.
is_master
:
return
logger
.
info
(
f
'
{
len
(
self
)
}
sequences'
)
# data_len = sum(self.lengths)
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
# unk_idx = self.params.special_tok_ids['unk_token']
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
def
select_data
(
self
,
a
:
int
,
b
:
int
):
"""
Select a subportion of the data.
"""
n_sequences
=
len
(
self
)
assert
0
<=
a
<
b
<=
n_sequences
,
ValueError
(
f
'`0 <= a < b <= n_sequences` is not met with a=
{
a
}
and b=
{
b
}
'
)
logger
.
info
(
f
'Selecting sequences from
{
a
}
to
{
b
}
(excluded).'
)
self
.
token_ids
=
self
.
token_ids
[
a
:
b
]
self
.
lengths
=
self
.
lengths
[
a
:
b
]
self
.
check
()
def
split
(
self
):
"""
Distributed training: split the data accross the processes.
"""
assert
self
.
params
.
n_gpu
>
1
logger
.
info
(
'Splitting the data accross the processuses.'
)
n_seq
=
len
(
self
)
n_seq_per_procesus
=
n_seq
//
self
.
params
.
world_size
a
=
n_seq_per_procesus
*
self
.
params
.
global_rank
b
=
a
+
n_seq_per_procesus
self
.
select_data
(
a
=
a
,
b
=
b
)
def
batch_sequences
(
self
,
token_ids
:
List
[
List
[
int
]],
lengths
:
List
[
int
]):
"""
Do the padding and transform into torch.tensor.
"""
assert
len
(
token_ids
)
==
len
(
lengths
)
# Max for paddings
max_seq_len_
=
max
(
lengths
)
# Pad token ids
pad_idx
=
self
.
params
.
special_tok_ids
[
'pad_token'
]
tk_
=
[
list
(
t
.
astype
(
int
))
+
[
pad_idx
]
*
(
max_seq_len_
-
len
(
t
))
for
t
in
token_ids
]
assert
len
(
tk_
)
==
len
(
token_ids
)
assert
all
(
len
(
t
)
==
max_seq_len_
for
t
in
tk_
)
tk_t
=
torch
.
tensor
(
tk_
)
# (bs, max_seq_len_)
lg_t
=
torch
.
tensor
(
lengths
.
astype
(
int
))
# (bs)
return
tk_t
,
lg_t
def
get_batches_iterator
(
self
,
batches
):
"""
Return an iterator over batches.
"""
for
sequences_ids
in
batches
:
token_ids
,
lengths
=
self
.
batch_sequences
(
self
.
token_ids
[
sequences_ids
],
self
.
lengths
[
sequences_ids
])
yield
(
token_ids
,
lengths
)
def
get_iterator
(
self
,
seed
:
int
=
None
):
"""
Return a data iterator.
"""
rng
=
np
.
random
.
RandomState
(
seed
)
n_sequences
=
len
(
self
)
indices
=
np
.
arange
(
n_sequences
)
if
self
.
group_by_size
:
indices
=
indices
[
np
.
argsort
(
self
.
lengths
[
indices
],
kind
=
'mergesort'
)]
if
self
.
tokens_per_batch
==
-
1
:
batches
=
np
.
array_split
(
indices
,
math
.
ceil
(
len
(
indices
)
*
1.
/
self
.
batch_size
))
else
:
assert
self
.
tokens_per_batch
>
0
batch_ids
=
np
.
cumsum
(
self
.
lengths
[
indices
])
//
self
.
tokens_per_batch
_
,
bounds
=
np
.
unique
(
batch_ids
,
return_index
=
True
)
batches
=
[
indices
[
bounds
[
i
]:
bounds
[
i
+
1
]]
for
i
in
range
(
len
(
bounds
)
-
1
)]
if
bounds
[
-
1
]
<
len
(
indices
):
batches
.
append
(
indices
[
bounds
[
-
1
]:])
if
self
.
shuffle
:
rng
.
shuffle
(
batches
)
assert
n_sequences
==
sum
([
len
(
x
)
for
x
in
batches
])
assert
self
.
lengths
[
indices
].
sum
()
==
sum
([
self
.
lengths
[
x
].
sum
()
for
x
in
batches
])
return
self
.
get_batches_iterator
(
batches
=
batches
)
examples/distillation/distiller.py
0 → 100644
View file @
1ae81e4a
import
os
import
math
from
tensorboardX
import
SummaryWriter
from
tqdm
import
trange
,
tqdm
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
pytorch_transformers
import
AdamW
,
WarmupLinearSchedule
from
utils
import
logger
from
dataset
import
Dataset
class
Distiller
:
def
__init__
(
self
,
params
:
dict
,
dataloader
:
Dataset
,
token_probs
:
torch
.
tensor
,
student
:
nn
.
Module
,
teacher
:
nn
.
Module
):
logger
.
info
(
'Initializing Distiller'
)
self
.
params
=
params
self
.
dump_path
=
params
.
dump_path
self
.
multi_gpu
=
params
.
multi_gpu
self
.
fp16
=
params
.
fp16
self
.
student
=
student
self
.
teacher
=
teacher
self
.
dataloader
=
dataloader
if
self
.
params
.
n_gpu
>
1
:
self
.
dataloader
.
split
()
self
.
get_iterator
(
seed
=
params
.
seed
)
self
.
temperature
=
params
.
temperature
assert
self
.
temperature
>
0.
self
.
alpha_ce
=
params
.
alpha_ce
self
.
alpha_mlm
=
params
.
alpha_mlm
self
.
alpha_mse
=
params
.
alpha_mse
assert
self
.
alpha_ce
>=
0.
assert
self
.
alpha_mlm
>=
0.
assert
self
.
alpha_mse
>=
0.
assert
self
.
alpha_ce
+
self
.
alpha_mlm
+
self
.
alpha_mse
>
0.
self
.
mlm_mask_prop
=
params
.
mlm_mask_prop
assert
0.0
<=
self
.
mlm_mask_prop
<=
1.0
assert
params
.
word_mask
+
params
.
word_keep
+
params
.
word_rand
==
1.0
self
.
pred_probs
=
torch
.
FloatTensor
([
params
.
word_mask
,
params
.
word_keep
,
params
.
word_rand
])
self
.
pred_probs
=
self
.
pred_probs
.
to
(
f
'cuda:
{
params
.
local_rank
}
'
)
if
params
.
n_gpu
>
0
else
self
.
pred_probs
self
.
token_probs
=
token_probs
.
to
(
f
'cuda:
{
params
.
local_rank
}
'
)
if
params
.
n_gpu
>
0
else
token_probs
if
self
.
fp16
:
self
.
pred_probs
=
self
.
pred_probs
.
half
()
self
.
token_probs
=
self
.
token_probs
.
half
()
self
.
epoch
=
0
self
.
n_iter
=
0
self
.
n_total_iter
=
0
self
.
n_sequences_epoch
=
0
self
.
total_loss_epoch
=
0
self
.
last_loss
=
0
self
.
last_loss_ce
=
0
self
.
last_loss_mlm
=
0
self
.
last_loss_mse
=
0
self
.
ce_loss_fct
=
nn
.
KLDivLoss
(
reduction
=
'batchmean'
)
self
.
mlm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
self
.
mse_loss_fct
=
nn
.
MSELoss
(
reduction
=
'sum'
)
logger
.
info
(
'--- Initializing model optimizer'
)
assert
params
.
gradient_accumulation_steps
>=
1
self
.
num_steps_epoch
=
int
(
len
(
self
.
dataloader
)
/
params
.
batch_size
)
+
1
num_train_optimization_steps
=
int
(
self
.
num_steps_epoch
/
params
.
gradient_accumulation_steps
*
params
.
n_epoch
)
+
1
warmup_steps
=
math
.
ceil
(
num_train_optimization_steps
*
params
.
warmup_prop
)
no_decay
=
[
'bias'
,
'LayerNorm.weight'
]
optimizer_grouped_parameters
=
[
{
'params'
:
[
p
for
n
,
p
in
student
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
no_decay
)
and
p
.
requires_grad
],
'weight_decay'
:
params
.
weight_decay
},
{
'params'
:
[
p
for
n
,
p
in
student
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_decay
)
and
p
.
requires_grad
],
'weight_decay'
:
0.0
}
]
logger
.
info
(
"------ Number of trainable parameters (student): %i"
%
sum
([
p
.
numel
()
for
p
in
self
.
student
.
parameters
()
if
p
.
requires_grad
]))
logger
.
info
(
"------ Number of parameters (student): %i"
%
sum
([
p
.
numel
()
for
p
in
self
.
student
.
parameters
()]))
self
.
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
params
.
learning_rate
,
eps
=
params
.
adam_epsilon
,
betas
=
(
0.9
,
0.98
))
self
.
scheduler
=
WarmupLinearSchedule
(
self
.
optimizer
,
warmup_steps
=
warmup_steps
,
t_total
=
num_train_optimization_steps
)
if
self
.
fp16
:
try
:
from
apex
import
amp
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
logger
.
info
(
f
"Using fp16 training:
{
self
.
params
.
fp16_opt_level
}
level"
)
self
.
student
,
self
.
optimizer
=
amp
.
initialize
(
self
.
student
,
self
.
optimizer
,
opt_level
=
self
.
params
.
fp16_opt_level
)
self
.
teacher
=
self
.
teacher
.
half
()
if
self
.
multi_gpu
:
if
self
.
fp16
:
from
apex.parallel
import
DistributedDataParallel
logger
.
info
(
"Using apex.parallel.DistributedDataParallel for distributed training."
)
self
.
student
=
DistributedDataParallel
(
self
.
student
)
else
:
from
torch.nn.parallel
import
DistributedDataParallel
logger
.
info
(
"Using nn.parallel.DistributedDataParallel for distributed training."
)
self
.
student
=
DistributedDataParallel
(
self
.
student
,
device_ids
=
[
params
.
local_rank
],
output_device
=
params
.
local_rank
)
self
.
is_master
=
params
.
is_master
if
self
.
is_master
:
logger
.
info
(
'--- Initializing Tensorboard'
)
self
.
tensorboard
=
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
self
.
dump_path
,
'log'
,
'train'
))
self
.
tensorboard
.
add_text
(
tag
=
'config'
,
text_string
=
str
(
self
.
params
),
global_step
=
0
)
def
get_iterator
(
self
,
seed
:
int
=
None
):
"""
Initialize the data iterator.
Each process has its own data iterator (iterating on his own random portion of the dataset).
Input:
------
seed: `int` - The random seed.
"""
logger
.
info
(
'--- Initializing Data Iterator'
)
self
.
data_iterator
=
self
.
dataloader
.
get_iterator
(
seed
=
seed
)
def
get_batch
(
self
):
"""
Call the data iterator to output a new batch.
If the data iterator went through the whole dataset, create a new iterator.
"""
assert
hasattr
(
self
,
'data_iterator'
)
try
:
x
=
next
(
self
.
data_iterator
)
except
StopIteration
:
logger
.
warning
(
'--- Went through the whole dataset. Creating new data iterator.'
)
self
.
data_iterator
=
self
.
dataloader
.
get_iterator
()
x
=
next
(
self
.
data_iterator
)
return
x
def
prepare_batch
(
self
,
batch
):
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
Input:
------
batch: `Tuple`
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
Output:
-------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
"""
token_ids
,
lengths
=
batch
token_ids
,
lengths
=
self
.
round_batch
(
x
=
token_ids
,
lengths
=
lengths
)
assert
token_ids
.
size
(
0
)
==
lengths
.
size
(
0
)
attn_mask
=
(
torch
.
arange
(
token_ids
.
size
(
1
),
dtype
=
torch
.
long
,
device
=
lengths
.
device
)
<
lengths
[:,
None
])
bs
,
max_seq_len
=
token_ids
.
size
()
mlm_labels
=
token_ids
.
new
(
token_ids
.
size
()).
copy_
(
token_ids
)
x_prob
=
self
.
token_probs
[
token_ids
.
flatten
()]
n_tgt
=
math
.
ceil
(
self
.
mlm_mask_prop
*
lengths
.
sum
().
item
())
tgt_ids
=
torch
.
multinomial
(
x_prob
/
x_prob
.
sum
(),
n_tgt
,
replacement
=
False
)
pred_mask
=
torch
.
zeros
(
bs
*
max_seq_len
,
dtype
=
torch
.
uint8
,
device
=
token_ids
.
device
)
pred_mask
[
tgt_ids
]
=
1
pred_mask
=
pred_mask
.
view
(
bs
,
max_seq_len
)
pred_mask
[
token_ids
==
self
.
params
.
special_tok_ids
[
'pad_token'
]]
=
0
# mask a number of words == 0 [8] (faster with fp16)
if
self
.
fp16
:
n1
=
pred_mask
.
sum
().
item
()
if
n1
>
8
:
pred_mask
=
pred_mask
.
view
(
-
1
)
n2
=
max
(
n1
%
8
,
8
*
(
n1
//
8
))
if
n2
!=
n1
:
pred_mask
[
torch
.
nonzero
(
pred_mask
).
view
(
-
1
)[:
n1
-
n2
]]
=
0
pred_mask
=
pred_mask
.
view
(
bs
,
max_seq_len
)
assert
pred_mask
.
sum
().
item
()
%
8
==
0
,
pred_mask
.
sum
().
item
()
_token_ids_real
=
token_ids
[
pred_mask
]
_token_ids_rand
=
_token_ids_real
.
clone
().
random_
(
self
.
params
.
vocab_size
)
_token_ids_mask
=
_token_ids_real
.
clone
().
fill_
(
self
.
params
.
special_tok_ids
[
'mask_token'
])
probs
=
torch
.
multinomial
(
self
.
pred_probs
,
len
(
_token_ids_real
),
replacement
=
True
)
_token_ids
=
_token_ids_mask
*
(
probs
==
0
).
long
()
+
_token_ids_real
*
(
probs
==
1
).
long
()
+
_token_ids_rand
*
(
probs
==
2
).
long
()
token_ids
=
token_ids
.
masked_scatter
(
pred_mask
,
_token_ids
)
mlm_labels
[
1
-
pred_mask
]
=
-
1
return
token_ids
,
attn_mask
,
mlm_labels
def
round_batch
(
self
,
x
:
torch
.
tensor
,
lengths
:
torch
.
tensor
):
"""
For float16 only.
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
Input:
------
x: `torch.tensor(bs, seq_length)` - The token ids.
lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.
Output:
-------
x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
"""
if
not
self
.
fp16
or
len
(
lengths
)
<
8
:
return
x
,
lengths
# number of sentences == 0 [8]
bs1
=
len
(
lengths
)
bs2
=
8
*
(
bs1
//
8
)
assert
bs2
>
0
and
bs2
%
8
==
0
if
bs1
!=
bs2
:
idx
=
torch
.
randperm
(
bs1
)[:
bs2
]
lengths
=
lengths
[
idx
]
slen
=
lengths
.
max
().
item
()
x
=
x
[
idx
,
:
slen
]
else
:
idx
=
None
# sequence length == 0 [8]
ml1
=
x
.
size
(
1
)
if
ml1
%
8
!=
0
:
pad
=
8
-
(
ml1
%
8
)
ml2
=
ml1
+
pad
pad_id
=
self
.
params
.
special_tok_ids
[
'pad_token'
]
padding_tensor
=
torch
.
zeros
(
bs2
,
pad
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
fill_
(
pad_id
)
x
=
torch
.
cat
([
x
,
padding_tensor
],
1
)
assert
x
.
size
()
==
(
bs2
,
ml2
)
assert
x
.
size
(
0
)
%
8
==
0
assert
x
.
size
(
1
)
%
8
==
0
return
x
,
lengths
def
train
(
self
):
"""
The real training loop.
"""
if
self
.
is_master
:
logger
.
info
(
'Starting training'
)
self
.
student
.
train
()
self
.
teacher
.
eval
()
for
_
in
range
(
self
.
params
.
n_epoch
):
if
self
.
is_master
:
logger
.
info
(
f
'--- Starting epoch
{
self
.
epoch
}
/
{
self
.
params
.
n_epoch
-
1
}
'
)
iter_bar
=
trange
(
self
.
num_steps_epoch
,
desc
=
"-Iter"
,
disable
=
self
.
params
.
local_rank
not
in
[
-
1
,
0
])
for
__
in
range
(
self
.
num_steps_epoch
):
batch
=
self
.
get_batch
()
if
self
.
params
.
n_gpu
>
0
:
batch
=
tuple
(
t
.
to
(
f
'cuda:
{
self
.
params
.
local_rank
}
'
)
for
t
in
batch
)
token_ids
,
attn_mask
,
mlm_labels
=
self
.
prepare_batch
(
batch
=
batch
)
self
.
step
(
input_ids
=
token_ids
,
attention_mask
=
attn_mask
,
mlm_labels
=
mlm_labels
)
iter_bar
.
update
()
iter_bar
.
set_postfix
({
'Last_loss'
:
f
'
{
self
.
last_loss
:.
2
f
}
'
,
'Avg_cum_loss'
:
f
'
{
self
.
total_loss_epoch
/
self
.
n_iter
:.
2
f
}
'
})
iter_bar
.
close
()
if
self
.
is_master
:
logger
.
info
(
f
'--- Ending epoch
{
self
.
epoch
}
/
{
self
.
params
.
n_epoch
-
1
}
'
)
self
.
end_epoch
()
if
self
.
is_master
:
logger
.
info
(
'Training is finished'
)
def
step
(
self
,
input_ids
:
torch
.
tensor
,
attention_mask
:
torch
.
tensor
,
mlm_labels
:
torch
.
tensor
):
"""
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
and possibly a parameter update (depending on the gradient accumulation).
Input:
------
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels.
"""
s_logits
=
self
.
student
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)[
0
]
# (bs, seq_length, voc_size)
with
torch
.
no_grad
():
t_logits
=
self
.
teacher
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)[
0
]
# (bs, seq_length, voc_size)
assert
s_logits
.
size
()
==
t_logits
.
size
()
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
if
self
.
params
.
restrict_ce_to_mask
:
mask
=
(
mlm_labels
>-
1
).
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
else
:
mask
=
attention_mask
.
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
s_logits_slct
=
torch
.
masked_select
(
s_logits
,
mask
)
# (bs * seq_length * voc_size) modulo the 1s in mask
s_logits_slct
=
s_logits_slct
.
view
(
-
1
,
s_logits
.
size
(
-
1
))
# (bs * seq_length, voc_size) modulo the 1s in mask
t_logits_slct
=
torch
.
masked_select
(
t_logits
,
mask
)
# (bs * seq_length * voc_size) modulo the 1s in mask
t_logits_slct
=
t_logits_slct
.
view
(
-
1
,
s_logits
.
size
(
-
1
))
# (bs * seq_length, voc_size) modulo the 1s in mask
assert
t_logits_slct
.
size
()
==
s_logits_slct
.
size
()
loss_ce
=
self
.
ce_loss_fct
(
F
.
log_softmax
(
s_logits_slct
/
self
.
temperature
,
dim
=-
1
),
F
.
softmax
(
t_logits_slct
/
self
.
temperature
,
dim
=-
1
))
*
(
self
.
temperature
)
**
2
loss
=
self
.
alpha_ce
*
loss_ce
if
self
.
alpha_mlm
>
0.
:
loss_mlm
=
self
.
mlm_loss_fct
(
s_logits
.
view
(
-
1
,
s_logits
.
size
(
-
1
)),
mlm_labels
.
view
(
-
1
))
loss
+=
self
.
alpha_mlm
*
loss_mlm
if
self
.
alpha_mse
>
0.
:
loss_mse
=
self
.
mse_loss_fct
(
s_logits_slct
,
t_logits_slct
)
/
s_logits_slct
.
size
(
0
)
# Reproducing batchmean reduction
loss
+=
self
.
alpha_mse
*
loss_mse
self
.
total_loss_epoch
+=
loss
.
item
()
self
.
last_loss
=
loss
.
item
()
self
.
last_loss_ce
=
loss_ce
.
item
()
if
self
.
alpha_mlm
>
0.
:
self
.
last_loss_mlm
=
loss_mlm
.
item
()
if
self
.
alpha_mse
>
0.
:
self
.
last_loss_mse
=
loss_mse
.
item
()
self
.
optimize
(
loss
)
self
.
n_sequences_epoch
+=
input_ids
.
size
(
0
)
def
optimize
(
self
,
loss
):
"""
Normalization on the loss (gradient accumulation or distributed training), followed by
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
Also update the metrics for tensorboard.
"""
# Check for NaN
if
(
loss
!=
loss
).
data
.
any
():
logger
.
error
(
'NaN detected'
)
exit
()
if
self
.
multi_gpu
:
loss
=
loss
.
mean
()
if
self
.
params
.
gradient_accumulation_steps
>
1
:
loss
=
loss
/
self
.
params
.
gradient_accumulation_steps
if
self
.
fp16
:
from
apex
import
amp
with
amp
.
scale_loss
(
loss
,
self
.
optimizer
)
as
scaled_loss
:
scaled_loss
.
backward
()
else
:
loss
.
backward
()
self
.
iter
()
if
self
.
n_iter
%
self
.
params
.
gradient_accumulation_steps
==
0
:
if
self
.
fp16
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
self
.
optimizer
),
self
.
params
.
max_grad_norm
)
else
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
self
.
student
.
parameters
(),
self
.
params
.
max_grad_norm
)
self
.
scheduler
.
step
()
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
def
iter
(
self
):
"""
Update global counts, write to tensorboard and save checkpoint.
"""
self
.
n_iter
+=
1
self
.
n_total_iter
+=
1
if
self
.
n_total_iter
%
self
.
params
.
log_interval
==
0
:
self
.
log_tensorboard
()
if
self
.
n_total_iter
%
self
.
params
.
checkpoint_interval
==
0
:
self
.
save_checkpoint
()
def
log_tensorboard
(
self
):
"""
Log into tensorboard. Only by the master process.
"""
if
not
self
.
is_master
:
return
for
param_name
,
param
in
self
.
student
.
named_parameters
():
self
.
tensorboard
.
add_scalar
(
tag
=
'parameter_mean/'
+
param_name
,
scalar_value
=
param
.
data
.
mean
(),
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
'parameter_std/'
+
param_name
,
scalar_value
=
param
.
data
.
std
(),
global_step
=
self
.
n_total_iter
)
if
param
.
grad
is
None
:
continue
self
.
tensorboard
.
add_scalar
(
tag
=
"grad_mean/"
+
param_name
,
scalar_value
=
param
.
grad
.
data
.
mean
(),
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"grad_std/"
+
param_name
,
scalar_value
=
param
.
grad
.
data
.
std
(),
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/cum_avg_loss_epoch"
,
scalar_value
=
self
.
total_loss_epoch
/
self
.
n_iter
,
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss"
,
scalar_value
=
self
.
last_loss
,
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_ce"
,
scalar_value
=
self
.
last_loss_ce
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_mlm
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_mlm"
,
scalar_value
=
self
.
last_loss_mlm
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_mse
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_mse"
,
scalar_value
=
self
.
last_loss_mse
,
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"learning_rate/lr"
,
scalar_value
=
self
.
scheduler
.
get_lr
()[
0
],
global_step
=
self
.
n_total_iter
)
def
end_epoch
(
self
):
"""
Finally arrived at the end of epoch (full pass on dataset).
Do some tensorboard logging and checkpoint saving.
"""
logger
.
info
(
f
'
{
self
.
n_sequences_epoch
}
sequences have been trained during this epoch.'
)
if
self
.
is_master
:
self
.
save_checkpoint
(
checkpoint_name
=
f
'model_epoch_
{
self
.
epoch
}
.pth'
)
self
.
tensorboard
.
add_scalar
(
tag
=
'epoch/loss'
,
scalar_value
=
self
.
total_loss_epoch
/
self
.
n_iter
,
global_step
=
self
.
epoch
)
self
.
epoch
+=
1
self
.
n_sequences_epoch
=
0
self
.
n_iter
=
0
self
.
total_loss_epoch
=
0
def
save_checkpoint
(
self
,
checkpoint_name
:
str
=
'checkpoint.pth'
):
"""
Save the current state. Only by the master process.
"""
if
not
self
.
is_master
:
return
mdl_to_save
=
self
.
student
.
module
if
hasattr
(
self
.
student
,
'module'
)
else
self
.
student
mdl_to_save
.
config
.
save_pretrained
(
self
.
dump_path
)
state_dict
=
mdl_to_save
.
state_dict
()
torch
.
save
(
state_dict
,
os
.
path
.
join
(
self
.
dump_path
,
checkpoint_name
))
examples/distillation/utils.py
0 → 100644
View file @
1ae81e4a
import
git
import
json
import
os
import
socket
import
torch
import
numpy
as
np
import
logging
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
def
git_log
(
folder_path
:
str
):
"""
Log commit info.
"""
repo
=
git
.
Repo
(
search_parent_directories
=
True
)
repo_infos
=
{
'repo_id'
:
str
(
repo
),
'repo_sha'
:
str
(
repo
.
head
.
object
.
hexsha
),
'repo_branch'
:
str
(
repo
.
active_branch
)
}
with
open
(
os
.
path
.
join
(
folder_path
,
'git_log.json'
),
'w'
)
as
f
:
json
.
dump
(
repo_infos
,
f
,
indent
=
4
)
def
init_gpu_params
(
params
):
"""
Handle single and multi-GPU / multi-node.
"""
if
params
.
n_gpu
<=
0
:
params
.
local_rank
=
0
params
.
master_port
=
-
1
params
.
is_master
=
True
params
.
multi_gpu
=
False
return
assert
torch
.
cuda
.
is_available
()
logger
.
info
(
'Initializing GPUs'
)
if
params
.
n_gpu
>
1
:
assert
params
.
local_rank
!=
-
1
params
.
world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
params
.
n_gpu_per_node
=
int
(
os
.
environ
[
'N_GPU_NODE'
])
params
.
global_rank
=
int
(
os
.
environ
[
'RANK'
])
# number of nodes / node ID
params
.
n_nodes
=
params
.
world_size
//
params
.
n_gpu_per_node
params
.
node_id
=
params
.
global_rank
//
params
.
n_gpu_per_node
params
.
multi_gpu
=
True
assert
params
.
n_nodes
==
int
(
os
.
environ
[
'N_NODES'
])
assert
params
.
node_id
==
int
(
os
.
environ
[
'NODE_RANK'
])
# local job (single GPU)
else
:
assert
params
.
local_rank
==
-
1
params
.
n_nodes
=
1
params
.
node_id
=
0
params
.
local_rank
=
0
params
.
global_rank
=
0
params
.
world_size
=
1
params
.
n_gpu_per_node
=
1
params
.
multi_gpu
=
False
# sanity checks
assert
params
.
n_nodes
>=
1
assert
0
<=
params
.
node_id
<
params
.
n_nodes
assert
0
<=
params
.
local_rank
<=
params
.
global_rank
<
params
.
world_size
assert
params
.
world_size
==
params
.
n_nodes
*
params
.
n_gpu_per_node
# define whether this is the master process / if we are in multi-node distributed mode
params
.
is_master
=
params
.
node_id
==
0
and
params
.
local_rank
==
0
params
.
multi_node
=
params
.
n_nodes
>
1
# summary
PREFIX
=
f
"--- Global rank:
{
params
.
global_rank
}
- "
logger
.
info
(
PREFIX
+
"Number of nodes: %i"
%
params
.
n_nodes
)
logger
.
info
(
PREFIX
+
"Node ID : %i"
%
params
.
node_id
)
logger
.
info
(
PREFIX
+
"Local rank : %i"
%
params
.
local_rank
)
logger
.
info
(
PREFIX
+
"World size : %i"
%
params
.
world_size
)
logger
.
info
(
PREFIX
+
"GPUs per node : %i"
%
params
.
n_gpu_per_node
)
logger
.
info
(
PREFIX
+
"Master : %s"
%
str
(
params
.
is_master
))
logger
.
info
(
PREFIX
+
"Multi-node : %s"
%
str
(
params
.
multi_node
))
logger
.
info
(
PREFIX
+
"Multi-GPU : %s"
%
str
(
params
.
multi_gpu
))
logger
.
info
(
PREFIX
+
"Hostname : %s"
%
socket
.
gethostname
())
# set GPU device
torch
.
cuda
.
set_device
(
params
.
local_rank
)
# initialize multi-GPU
if
params
.
multi_gpu
:
logger
.
info
(
"Initializing PyTorch distributed"
)
torch
.
distributed
.
init_process_group
(
init_method
=
'env://'
,
backend
=
'nccl'
,
)
def
set_seed
(
args
):
"""
Set the random seed.
"""
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
if
args
.
n_gpu
>
0
:
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment