Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1ae81e4a
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "46509d1c19b9e69d75fb95d33d38dbac4f6f8858"
Commit
1ae81e4a
authored
Aug 28, 2019
by
VictorSanh
Browse files
add dataset. distiller, utils
parent
5d29f8e9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
727 additions
and
0 deletions
+727
-0
examples/distillation/dataset.py
examples/distillation/dataset.py
+184
-0
examples/distillation/distiller.py
examples/distillation/distiller.py
+431
-0
examples/distillation/utils.py
examples/distillation/utils.py
+112
-0
No files found.
examples/distillation/dataset.py
0 → 100644
View file @
1ae81e4a
from
typing
import
List
import
math
from
itertools
import
chain
from
collections
import
Counter
import
numpy
as
np
import
torch
from
utils
import
logger
class
Dataset
:
def
__init__
(
self
,
params
,
data
):
self
.
params
=
params
self
.
tokens_per_batch
=
params
.
tokens_per_batch
self
.
batch_size
=
params
.
batch_size
self
.
shuffle
=
params
.
shuffle
self
.
group_by_size
=
params
.
group_by_size
self
.
token_ids
=
np
.
array
(
data
)
self
.
lengths
=
np
.
uint16
([
len
(
t
)
for
t
in
data
])
self
.
check
()
self
.
remove_long_sequences
()
self
.
remove_empty_sequences
()
self
.
check
()
self
.
print_statistics
()
def
__len__
(
self
):
return
len
(
self
.
lengths
)
def
check
(
self
):
"""
Some sanity checks
"""
assert
len
(
self
.
token_ids
)
==
len
(
self
.
lengths
)
def
remove_long_sequences
(
self
):
"""
Sequences that are too long are splitted by chunk of max_position_embeddings.
"""
indices
=
self
.
lengths
>=
self
.
params
.
max_position_embeddings
logger
.
info
(
f
'Splitting
{
sum
(
indices
)
}
too long sequences.'
)
def
divide_chunks
(
l
,
n
):
return
[
l
[
i
:
i
+
n
]
for
i
in
range
(
0
,
len
(
l
),
n
)]
new_tok_ids
=
[]
new_lengths
=
[]
cls_id
,
sep_id
=
self
.
params
.
special_tok_ids
[
'cls_token'
],
self
.
params
.
special_tok_ids
[
'sep_token'
]
max_len
=
self
.
params
.
max_position_embeddings
for
seq_
,
len_
in
zip
(
self
.
token_ids
,
self
.
lengths
):
if
len_
<=
max_len
:
new_tok_ids
.
append
(
seq_
)
new_lengths
.
append
(
len_
)
else
:
sub_seqs
=
[]
for
sub_s
in
divide_chunks
(
seq_
,
max_len
-
2
):
if
sub_s
[
0
]
!=
cls_id
:
sub_s
=
np
.
insert
(
sub_s
,
0
,
cls_id
)
if
sub_s
[
-
1
]
!=
sep_id
:
sub_s
=
np
.
insert
(
sub_s
,
len
(
sub_s
),
cls_id
)
assert
len
(
sub_s
)
<=
max_len
sub_seqs
.
append
(
sub_s
)
new_tok_ids
.
extend
(
sub_seqs
)
new_lengths
.
extend
([
len
(
l
)
for
l
in
sub_seqs
])
self
.
token_ids
=
np
.
array
(
new_tok_ids
)
self
.
lengths
=
np
.
array
(
new_lengths
)
def
remove_empty_sequences
(
self
):
"""
Too short sequences are simply removed. This could be tunedd.
"""
init_size
=
len
(
self
)
indices
=
self
.
lengths
>
5
self
.
token_ids
=
self
.
token_ids
[
indices
]
self
.
lengths
=
self
.
lengths
[
indices
]
new_size
=
len
(
self
)
logger
.
info
(
f
'Remove
{
init_size
-
new_size
}
too short (<=5 tokens) sequences.'
)
def
print_statistics
(
self
):
"""
Print some statistics on the corpus. Only the master process.
"""
if
not
self
.
params
.
is_master
:
return
logger
.
info
(
f
'
{
len
(
self
)
}
sequences'
)
# data_len = sum(self.lengths)
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
# unk_idx = self.params.special_tok_ids['unk_token']
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
def
select_data
(
self
,
a
:
int
,
b
:
int
):
"""
Select a subportion of the data.
"""
n_sequences
=
len
(
self
)
assert
0
<=
a
<
b
<=
n_sequences
,
ValueError
(
f
'`0 <= a < b <= n_sequences` is not met with a=
{
a
}
and b=
{
b
}
'
)
logger
.
info
(
f
'Selecting sequences from
{
a
}
to
{
b
}
(excluded).'
)
self
.
token_ids
=
self
.
token_ids
[
a
:
b
]
self
.
lengths
=
self
.
lengths
[
a
:
b
]
self
.
check
()
def
split
(
self
):
"""
Distributed training: split the data accross the processes.
"""
assert
self
.
params
.
n_gpu
>
1
logger
.
info
(
'Splitting the data accross the processuses.'
)
n_seq
=
len
(
self
)
n_seq_per_procesus
=
n_seq
//
self
.
params
.
world_size
a
=
n_seq_per_procesus
*
self
.
params
.
global_rank
b
=
a
+
n_seq_per_procesus
self
.
select_data
(
a
=
a
,
b
=
b
)
def
batch_sequences
(
self
,
token_ids
:
List
[
List
[
int
]],
lengths
:
List
[
int
]):
"""
Do the padding and transform into torch.tensor.
"""
assert
len
(
token_ids
)
==
len
(
lengths
)
# Max for paddings
max_seq_len_
=
max
(
lengths
)
# Pad token ids
pad_idx
=
self
.
params
.
special_tok_ids
[
'pad_token'
]
tk_
=
[
list
(
t
.
astype
(
int
))
+
[
pad_idx
]
*
(
max_seq_len_
-
len
(
t
))
for
t
in
token_ids
]
assert
len
(
tk_
)
==
len
(
token_ids
)
assert
all
(
len
(
t
)
==
max_seq_len_
for
t
in
tk_
)
tk_t
=
torch
.
tensor
(
tk_
)
# (bs, max_seq_len_)
lg_t
=
torch
.
tensor
(
lengths
.
astype
(
int
))
# (bs)
return
tk_t
,
lg_t
def
get_batches_iterator
(
self
,
batches
):
"""
Return an iterator over batches.
"""
for
sequences_ids
in
batches
:
token_ids
,
lengths
=
self
.
batch_sequences
(
self
.
token_ids
[
sequences_ids
],
self
.
lengths
[
sequences_ids
])
yield
(
token_ids
,
lengths
)
def
get_iterator
(
self
,
seed
:
int
=
None
):
"""
Return a data iterator.
"""
rng
=
np
.
random
.
RandomState
(
seed
)
n_sequences
=
len
(
self
)
indices
=
np
.
arange
(
n_sequences
)
if
self
.
group_by_size
:
indices
=
indices
[
np
.
argsort
(
self
.
lengths
[
indices
],
kind
=
'mergesort'
)]
if
self
.
tokens_per_batch
==
-
1
:
batches
=
np
.
array_split
(
indices
,
math
.
ceil
(
len
(
indices
)
*
1.
/
self
.
batch_size
))
else
:
assert
self
.
tokens_per_batch
>
0
batch_ids
=
np
.
cumsum
(
self
.
lengths
[
indices
])
//
self
.
tokens_per_batch
_
,
bounds
=
np
.
unique
(
batch_ids
,
return_index
=
True
)
batches
=
[
indices
[
bounds
[
i
]:
bounds
[
i
+
1
]]
for
i
in
range
(
len
(
bounds
)
-
1
)]
if
bounds
[
-
1
]
<
len
(
indices
):
batches
.
append
(
indices
[
bounds
[
-
1
]:])
if
self
.
shuffle
:
rng
.
shuffle
(
batches
)
assert
n_sequences
==
sum
([
len
(
x
)
for
x
in
batches
])
assert
self
.
lengths
[
indices
].
sum
()
==
sum
([
self
.
lengths
[
x
].
sum
()
for
x
in
batches
])
return
self
.
get_batches_iterator
(
batches
=
batches
)
examples/distillation/distiller.py
0 → 100644
View file @
1ae81e4a
import
os
import
math
from
tensorboardX
import
SummaryWriter
from
tqdm
import
trange
,
tqdm
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
pytorch_transformers
import
AdamW
,
WarmupLinearSchedule
from
utils
import
logger
from
dataset
import
Dataset
class
Distiller
:
def
__init__
(
self
,
params
:
dict
,
dataloader
:
Dataset
,
token_probs
:
torch
.
tensor
,
student
:
nn
.
Module
,
teacher
:
nn
.
Module
):
logger
.
info
(
'Initializing Distiller'
)
self
.
params
=
params
self
.
dump_path
=
params
.
dump_path
self
.
multi_gpu
=
params
.
multi_gpu
self
.
fp16
=
params
.
fp16
self
.
student
=
student
self
.
teacher
=
teacher
self
.
dataloader
=
dataloader
if
self
.
params
.
n_gpu
>
1
:
self
.
dataloader
.
split
()
self
.
get_iterator
(
seed
=
params
.
seed
)
self
.
temperature
=
params
.
temperature
assert
self
.
temperature
>
0.
self
.
alpha_ce
=
params
.
alpha_ce
self
.
alpha_mlm
=
params
.
alpha_mlm
self
.
alpha_mse
=
params
.
alpha_mse
assert
self
.
alpha_ce
>=
0.
assert
self
.
alpha_mlm
>=
0.
assert
self
.
alpha_mse
>=
0.
assert
self
.
alpha_ce
+
self
.
alpha_mlm
+
self
.
alpha_mse
>
0.
self
.
mlm_mask_prop
=
params
.
mlm_mask_prop
assert
0.0
<=
self
.
mlm_mask_prop
<=
1.0
assert
params
.
word_mask
+
params
.
word_keep
+
params
.
word_rand
==
1.0
self
.
pred_probs
=
torch
.
FloatTensor
([
params
.
word_mask
,
params
.
word_keep
,
params
.
word_rand
])
self
.
pred_probs
=
self
.
pred_probs
.
to
(
f
'cuda:
{
params
.
local_rank
}
'
)
if
params
.
n_gpu
>
0
else
self
.
pred_probs
self
.
token_probs
=
token_probs
.
to
(
f
'cuda:
{
params
.
local_rank
}
'
)
if
params
.
n_gpu
>
0
else
token_probs
if
self
.
fp16
:
self
.
pred_probs
=
self
.
pred_probs
.
half
()
self
.
token_probs
=
self
.
token_probs
.
half
()
self
.
epoch
=
0
self
.
n_iter
=
0
self
.
n_total_iter
=
0
self
.
n_sequences_epoch
=
0
self
.
total_loss_epoch
=
0
self
.
last_loss
=
0
self
.
last_loss_ce
=
0
self
.
last_loss_mlm
=
0
self
.
last_loss_mse
=
0
self
.
ce_loss_fct
=
nn
.
KLDivLoss
(
reduction
=
'batchmean'
)
self
.
mlm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
self
.
mse_loss_fct
=
nn
.
MSELoss
(
reduction
=
'sum'
)
logger
.
info
(
'--- Initializing model optimizer'
)
assert
params
.
gradient_accumulation_steps
>=
1
self
.
num_steps_epoch
=
int
(
len
(
self
.
dataloader
)
/
params
.
batch_size
)
+
1
num_train_optimization_steps
=
int
(
self
.
num_steps_epoch
/
params
.
gradient_accumulation_steps
*
params
.
n_epoch
)
+
1
warmup_steps
=
math
.
ceil
(
num_train_optimization_steps
*
params
.
warmup_prop
)
no_decay
=
[
'bias'
,
'LayerNorm.weight'
]
optimizer_grouped_parameters
=
[
{
'params'
:
[
p
for
n
,
p
in
student
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
no_decay
)
and
p
.
requires_grad
],
'weight_decay'
:
params
.
weight_decay
},
{
'params'
:
[
p
for
n
,
p
in
student
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_decay
)
and
p
.
requires_grad
],
'weight_decay'
:
0.0
}
]
logger
.
info
(
"------ Number of trainable parameters (student): %i"
%
sum
([
p
.
numel
()
for
p
in
self
.
student
.
parameters
()
if
p
.
requires_grad
]))
logger
.
info
(
"------ Number of parameters (student): %i"
%
sum
([
p
.
numel
()
for
p
in
self
.
student
.
parameters
()]))
self
.
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
params
.
learning_rate
,
eps
=
params
.
adam_epsilon
,
betas
=
(
0.9
,
0.98
))
self
.
scheduler
=
WarmupLinearSchedule
(
self
.
optimizer
,
warmup_steps
=
warmup_steps
,
t_total
=
num_train_optimization_steps
)
if
self
.
fp16
:
try
:
from
apex
import
amp
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
logger
.
info
(
f
"Using fp16 training:
{
self
.
params
.
fp16_opt_level
}
level"
)
self
.
student
,
self
.
optimizer
=
amp
.
initialize
(
self
.
student
,
self
.
optimizer
,
opt_level
=
self
.
params
.
fp16_opt_level
)
self
.
teacher
=
self
.
teacher
.
half
()
if
self
.
multi_gpu
:
if
self
.
fp16
:
from
apex.parallel
import
DistributedDataParallel
logger
.
info
(
"Using apex.parallel.DistributedDataParallel for distributed training."
)
self
.
student
=
DistributedDataParallel
(
self
.
student
)
else
:
from
torch.nn.parallel
import
DistributedDataParallel
logger
.
info
(
"Using nn.parallel.DistributedDataParallel for distributed training."
)
self
.
student
=
DistributedDataParallel
(
self
.
student
,
device_ids
=
[
params
.
local_rank
],
output_device
=
params
.
local_rank
)
self
.
is_master
=
params
.
is_master
if
self
.
is_master
:
logger
.
info
(
'--- Initializing Tensorboard'
)
self
.
tensorboard
=
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
self
.
dump_path
,
'log'
,
'train'
))
self
.
tensorboard
.
add_text
(
tag
=
'config'
,
text_string
=
str
(
self
.
params
),
global_step
=
0
)
def
get_iterator
(
self
,
seed
:
int
=
None
):
"""
Initialize the data iterator.
Each process has its own data iterator (iterating on his own random portion of the dataset).
Input:
------
seed: `int` - The random seed.
"""
logger
.
info
(
'--- Initializing Data Iterator'
)
self
.
data_iterator
=
self
.
dataloader
.
get_iterator
(
seed
=
seed
)
def
get_batch
(
self
):
"""
Call the data iterator to output a new batch.
If the data iterator went through the whole dataset, create a new iterator.
"""
assert
hasattr
(
self
,
'data_iterator'
)
try
:
x
=
next
(
self
.
data_iterator
)
except
StopIteration
:
logger
.
warning
(
'--- Went through the whole dataset. Creating new data iterator.'
)
self
.
data_iterator
=
self
.
dataloader
.
get_iterator
()
x
=
next
(
self
.
data_iterator
)
return
x
def
prepare_batch
(
self
,
batch
):
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
Input:
------
batch: `Tuple`
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
Output:
-------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
"""
token_ids
,
lengths
=
batch
token_ids
,
lengths
=
self
.
round_batch
(
x
=
token_ids
,
lengths
=
lengths
)
assert
token_ids
.
size
(
0
)
==
lengths
.
size
(
0
)
attn_mask
=
(
torch
.
arange
(
token_ids
.
size
(
1
),
dtype
=
torch
.
long
,
device
=
lengths
.
device
)
<
lengths
[:,
None
])
bs
,
max_seq_len
=
token_ids
.
size
()
mlm_labels
=
token_ids
.
new
(
token_ids
.
size
()).
copy_
(
token_ids
)
x_prob
=
self
.
token_probs
[
token_ids
.
flatten
()]
n_tgt
=
math
.
ceil
(
self
.
mlm_mask_prop
*
lengths
.
sum
().
item
())
tgt_ids
=
torch
.
multinomial
(
x_prob
/
x_prob
.
sum
(),
n_tgt
,
replacement
=
False
)
pred_mask
=
torch
.
zeros
(
bs
*
max_seq_len
,
dtype
=
torch
.
uint8
,
device
=
token_ids
.
device
)
pred_mask
[
tgt_ids
]
=
1
pred_mask
=
pred_mask
.
view
(
bs
,
max_seq_len
)
pred_mask
[
token_ids
==
self
.
params
.
special_tok_ids
[
'pad_token'
]]
=
0
# mask a number of words == 0 [8] (faster with fp16)
if
self
.
fp16
:
n1
=
pred_mask
.
sum
().
item
()
if
n1
>
8
:
pred_mask
=
pred_mask
.
view
(
-
1
)
n2
=
max
(
n1
%
8
,
8
*
(
n1
//
8
))
if
n2
!=
n1
:
pred_mask
[
torch
.
nonzero
(
pred_mask
).
view
(
-
1
)[:
n1
-
n2
]]
=
0
pred_mask
=
pred_mask
.
view
(
bs
,
max_seq_len
)
assert
pred_mask
.
sum
().
item
()
%
8
==
0
,
pred_mask
.
sum
().
item
()
_token_ids_real
=
token_ids
[
pred_mask
]
_token_ids_rand
=
_token_ids_real
.
clone
().
random_
(
self
.
params
.
vocab_size
)
_token_ids_mask
=
_token_ids_real
.
clone
().
fill_
(
self
.
params
.
special_tok_ids
[
'mask_token'
])
probs
=
torch
.
multinomial
(
self
.
pred_probs
,
len
(
_token_ids_real
),
replacement
=
True
)
_token_ids
=
_token_ids_mask
*
(
probs
==
0
).
long
()
+
_token_ids_real
*
(
probs
==
1
).
long
()
+
_token_ids_rand
*
(
probs
==
2
).
long
()
token_ids
=
token_ids
.
masked_scatter
(
pred_mask
,
_token_ids
)
mlm_labels
[
1
-
pred_mask
]
=
-
1
return
token_ids
,
attn_mask
,
mlm_labels
def
round_batch
(
self
,
x
:
torch
.
tensor
,
lengths
:
torch
.
tensor
):
"""
For float16 only.
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
Input:
------
x: `torch.tensor(bs, seq_length)` - The token ids.
lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.
Output:
-------
x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
"""
if
not
self
.
fp16
or
len
(
lengths
)
<
8
:
return
x
,
lengths
# number of sentences == 0 [8]
bs1
=
len
(
lengths
)
bs2
=
8
*
(
bs1
//
8
)
assert
bs2
>
0
and
bs2
%
8
==
0
if
bs1
!=
bs2
:
idx
=
torch
.
randperm
(
bs1
)[:
bs2
]
lengths
=
lengths
[
idx
]
slen
=
lengths
.
max
().
item
()
x
=
x
[
idx
,
:
slen
]
else
:
idx
=
None
# sequence length == 0 [8]
ml1
=
x
.
size
(
1
)
if
ml1
%
8
!=
0
:
pad
=
8
-
(
ml1
%
8
)
ml2
=
ml1
+
pad
pad_id
=
self
.
params
.
special_tok_ids
[
'pad_token'
]
padding_tensor
=
torch
.
zeros
(
bs2
,
pad
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
fill_
(
pad_id
)
x
=
torch
.
cat
([
x
,
padding_tensor
],
1
)
assert
x
.
size
()
==
(
bs2
,
ml2
)
assert
x
.
size
(
0
)
%
8
==
0
assert
x
.
size
(
1
)
%
8
==
0
return
x
,
lengths
def
train
(
self
):
"""
The real training loop.
"""
if
self
.
is_master
:
logger
.
info
(
'Starting training'
)
self
.
student
.
train
()
self
.
teacher
.
eval
()
for
_
in
range
(
self
.
params
.
n_epoch
):
if
self
.
is_master
:
logger
.
info
(
f
'--- Starting epoch
{
self
.
epoch
}
/
{
self
.
params
.
n_epoch
-
1
}
'
)
iter_bar
=
trange
(
self
.
num_steps_epoch
,
desc
=
"-Iter"
,
disable
=
self
.
params
.
local_rank
not
in
[
-
1
,
0
])
for
__
in
range
(
self
.
num_steps_epoch
):
batch
=
self
.
get_batch
()
if
self
.
params
.
n_gpu
>
0
:
batch
=
tuple
(
t
.
to
(
f
'cuda:
{
self
.
params
.
local_rank
}
'
)
for
t
in
batch
)
token_ids
,
attn_mask
,
mlm_labels
=
self
.
prepare_batch
(
batch
=
batch
)
self
.
step
(
input_ids
=
token_ids
,
attention_mask
=
attn_mask
,
mlm_labels
=
mlm_labels
)
iter_bar
.
update
()
iter_bar
.
set_postfix
({
'Last_loss'
:
f
'
{
self
.
last_loss
:.
2
f
}
'
,
'Avg_cum_loss'
:
f
'
{
self
.
total_loss_epoch
/
self
.
n_iter
:.
2
f
}
'
})
iter_bar
.
close
()
if
self
.
is_master
:
logger
.
info
(
f
'--- Ending epoch
{
self
.
epoch
}
/
{
self
.
params
.
n_epoch
-
1
}
'
)
self
.
end_epoch
()
if
self
.
is_master
:
logger
.
info
(
'Training is finished'
)
def
step
(
self
,
input_ids
:
torch
.
tensor
,
attention_mask
:
torch
.
tensor
,
mlm_labels
:
torch
.
tensor
):
"""
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
and possibly a parameter update (depending on the gradient accumulation).
Input:
------
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels.
"""
s_logits
=
self
.
student
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)[
0
]
# (bs, seq_length, voc_size)
with
torch
.
no_grad
():
t_logits
=
self
.
teacher
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)[
0
]
# (bs, seq_length, voc_size)
assert
s_logits
.
size
()
==
t_logits
.
size
()
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
if
self
.
params
.
restrict_ce_to_mask
:
mask
=
(
mlm_labels
>-
1
).
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
else
:
mask
=
attention_mask
.
unsqueeze
(
-
1
).
expand_as
(
s_logits
)
# (bs, seq_lenth, voc_size)
s_logits_slct
=
torch
.
masked_select
(
s_logits
,
mask
)
# (bs * seq_length * voc_size) modulo the 1s in mask
s_logits_slct
=
s_logits_slct
.
view
(
-
1
,
s_logits
.
size
(
-
1
))
# (bs * seq_length, voc_size) modulo the 1s in mask
t_logits_slct
=
torch
.
masked_select
(
t_logits
,
mask
)
# (bs * seq_length * voc_size) modulo the 1s in mask
t_logits_slct
=
t_logits_slct
.
view
(
-
1
,
s_logits
.
size
(
-
1
))
# (bs * seq_length, voc_size) modulo the 1s in mask
assert
t_logits_slct
.
size
()
==
s_logits_slct
.
size
()
loss_ce
=
self
.
ce_loss_fct
(
F
.
log_softmax
(
s_logits_slct
/
self
.
temperature
,
dim
=-
1
),
F
.
softmax
(
t_logits_slct
/
self
.
temperature
,
dim
=-
1
))
*
(
self
.
temperature
)
**
2
loss
=
self
.
alpha_ce
*
loss_ce
if
self
.
alpha_mlm
>
0.
:
loss_mlm
=
self
.
mlm_loss_fct
(
s_logits
.
view
(
-
1
,
s_logits
.
size
(
-
1
)),
mlm_labels
.
view
(
-
1
))
loss
+=
self
.
alpha_mlm
*
loss_mlm
if
self
.
alpha_mse
>
0.
:
loss_mse
=
self
.
mse_loss_fct
(
s_logits_slct
,
t_logits_slct
)
/
s_logits_slct
.
size
(
0
)
# Reproducing batchmean reduction
loss
+=
self
.
alpha_mse
*
loss_mse
self
.
total_loss_epoch
+=
loss
.
item
()
self
.
last_loss
=
loss
.
item
()
self
.
last_loss_ce
=
loss_ce
.
item
()
if
self
.
alpha_mlm
>
0.
:
self
.
last_loss_mlm
=
loss_mlm
.
item
()
if
self
.
alpha_mse
>
0.
:
self
.
last_loss_mse
=
loss_mse
.
item
()
self
.
optimize
(
loss
)
self
.
n_sequences_epoch
+=
input_ids
.
size
(
0
)
def
optimize
(
self
,
loss
):
"""
Normalization on the loss (gradient accumulation or distributed training), followed by
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
Also update the metrics for tensorboard.
"""
# Check for NaN
if
(
loss
!=
loss
).
data
.
any
():
logger
.
error
(
'NaN detected'
)
exit
()
if
self
.
multi_gpu
:
loss
=
loss
.
mean
()
if
self
.
params
.
gradient_accumulation_steps
>
1
:
loss
=
loss
/
self
.
params
.
gradient_accumulation_steps
if
self
.
fp16
:
from
apex
import
amp
with
amp
.
scale_loss
(
loss
,
self
.
optimizer
)
as
scaled_loss
:
scaled_loss
.
backward
()
else
:
loss
.
backward
()
self
.
iter
()
if
self
.
n_iter
%
self
.
params
.
gradient_accumulation_steps
==
0
:
if
self
.
fp16
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
self
.
optimizer
),
self
.
params
.
max_grad_norm
)
else
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
self
.
student
.
parameters
(),
self
.
params
.
max_grad_norm
)
self
.
scheduler
.
step
()
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
def
iter
(
self
):
"""
Update global counts, write to tensorboard and save checkpoint.
"""
self
.
n_iter
+=
1
self
.
n_total_iter
+=
1
if
self
.
n_total_iter
%
self
.
params
.
log_interval
==
0
:
self
.
log_tensorboard
()
if
self
.
n_total_iter
%
self
.
params
.
checkpoint_interval
==
0
:
self
.
save_checkpoint
()
def
log_tensorboard
(
self
):
"""
Log into tensorboard. Only by the master process.
"""
if
not
self
.
is_master
:
return
for
param_name
,
param
in
self
.
student
.
named_parameters
():
self
.
tensorboard
.
add_scalar
(
tag
=
'parameter_mean/'
+
param_name
,
scalar_value
=
param
.
data
.
mean
(),
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
'parameter_std/'
+
param_name
,
scalar_value
=
param
.
data
.
std
(),
global_step
=
self
.
n_total_iter
)
if
param
.
grad
is
None
:
continue
self
.
tensorboard
.
add_scalar
(
tag
=
"grad_mean/"
+
param_name
,
scalar_value
=
param
.
grad
.
data
.
mean
(),
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"grad_std/"
+
param_name
,
scalar_value
=
param
.
grad
.
data
.
std
(),
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/cum_avg_loss_epoch"
,
scalar_value
=
self
.
total_loss_epoch
/
self
.
n_iter
,
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss"
,
scalar_value
=
self
.
last_loss
,
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_ce"
,
scalar_value
=
self
.
last_loss_ce
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_mlm
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_mlm"
,
scalar_value
=
self
.
last_loss_mlm
,
global_step
=
self
.
n_total_iter
)
if
self
.
alpha_mse
>
0.
:
self
.
tensorboard
.
add_scalar
(
tag
=
"losses/loss_mse"
,
scalar_value
=
self
.
last_loss_mse
,
global_step
=
self
.
n_total_iter
)
self
.
tensorboard
.
add_scalar
(
tag
=
"learning_rate/lr"
,
scalar_value
=
self
.
scheduler
.
get_lr
()[
0
],
global_step
=
self
.
n_total_iter
)
def
end_epoch
(
self
):
"""
Finally arrived at the end of epoch (full pass on dataset).
Do some tensorboard logging and checkpoint saving.
"""
logger
.
info
(
f
'
{
self
.
n_sequences_epoch
}
sequences have been trained during this epoch.'
)
if
self
.
is_master
:
self
.
save_checkpoint
(
checkpoint_name
=
f
'model_epoch_
{
self
.
epoch
}
.pth'
)
self
.
tensorboard
.
add_scalar
(
tag
=
'epoch/loss'
,
scalar_value
=
self
.
total_loss_epoch
/
self
.
n_iter
,
global_step
=
self
.
epoch
)
self
.
epoch
+=
1
self
.
n_sequences_epoch
=
0
self
.
n_iter
=
0
self
.
total_loss_epoch
=
0
def
save_checkpoint
(
self
,
checkpoint_name
:
str
=
'checkpoint.pth'
):
"""
Save the current state. Only by the master process.
"""
if
not
self
.
is_master
:
return
mdl_to_save
=
self
.
student
.
module
if
hasattr
(
self
.
student
,
'module'
)
else
self
.
student
mdl_to_save
.
config
.
save_pretrained
(
self
.
dump_path
)
state_dict
=
mdl_to_save
.
state_dict
()
torch
.
save
(
state_dict
,
os
.
path
.
join
(
self
.
dump_path
,
checkpoint_name
))
examples/distillation/utils.py
0 → 100644
View file @
1ae81e4a
import
git
import
json
import
os
import
socket
import
torch
import
numpy
as
np
import
logging
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
def
git_log
(
folder_path
:
str
):
"""
Log commit info.
"""
repo
=
git
.
Repo
(
search_parent_directories
=
True
)
repo_infos
=
{
'repo_id'
:
str
(
repo
),
'repo_sha'
:
str
(
repo
.
head
.
object
.
hexsha
),
'repo_branch'
:
str
(
repo
.
active_branch
)
}
with
open
(
os
.
path
.
join
(
folder_path
,
'git_log.json'
),
'w'
)
as
f
:
json
.
dump
(
repo_infos
,
f
,
indent
=
4
)
def
init_gpu_params
(
params
):
"""
Handle single and multi-GPU / multi-node.
"""
if
params
.
n_gpu
<=
0
:
params
.
local_rank
=
0
params
.
master_port
=
-
1
params
.
is_master
=
True
params
.
multi_gpu
=
False
return
assert
torch
.
cuda
.
is_available
()
logger
.
info
(
'Initializing GPUs'
)
if
params
.
n_gpu
>
1
:
assert
params
.
local_rank
!=
-
1
params
.
world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
params
.
n_gpu_per_node
=
int
(
os
.
environ
[
'N_GPU_NODE'
])
params
.
global_rank
=
int
(
os
.
environ
[
'RANK'
])
# number of nodes / node ID
params
.
n_nodes
=
params
.
world_size
//
params
.
n_gpu_per_node
params
.
node_id
=
params
.
global_rank
//
params
.
n_gpu_per_node
params
.
multi_gpu
=
True
assert
params
.
n_nodes
==
int
(
os
.
environ
[
'N_NODES'
])
assert
params
.
node_id
==
int
(
os
.
environ
[
'NODE_RANK'
])
# local job (single GPU)
else
:
assert
params
.
local_rank
==
-
1
params
.
n_nodes
=
1
params
.
node_id
=
0
params
.
local_rank
=
0
params
.
global_rank
=
0
params
.
world_size
=
1
params
.
n_gpu_per_node
=
1
params
.
multi_gpu
=
False
# sanity checks
assert
params
.
n_nodes
>=
1
assert
0
<=
params
.
node_id
<
params
.
n_nodes
assert
0
<=
params
.
local_rank
<=
params
.
global_rank
<
params
.
world_size
assert
params
.
world_size
==
params
.
n_nodes
*
params
.
n_gpu_per_node
# define whether this is the master process / if we are in multi-node distributed mode
params
.
is_master
=
params
.
node_id
==
0
and
params
.
local_rank
==
0
params
.
multi_node
=
params
.
n_nodes
>
1
# summary
PREFIX
=
f
"--- Global rank:
{
params
.
global_rank
}
- "
logger
.
info
(
PREFIX
+
"Number of nodes: %i"
%
params
.
n_nodes
)
logger
.
info
(
PREFIX
+
"Node ID : %i"
%
params
.
node_id
)
logger
.
info
(
PREFIX
+
"Local rank : %i"
%
params
.
local_rank
)
logger
.
info
(
PREFIX
+
"World size : %i"
%
params
.
world_size
)
logger
.
info
(
PREFIX
+
"GPUs per node : %i"
%
params
.
n_gpu_per_node
)
logger
.
info
(
PREFIX
+
"Master : %s"
%
str
(
params
.
is_master
))
logger
.
info
(
PREFIX
+
"Multi-node : %s"
%
str
(
params
.
multi_node
))
logger
.
info
(
PREFIX
+
"Multi-GPU : %s"
%
str
(
params
.
multi_gpu
))
logger
.
info
(
PREFIX
+
"Hostname : %s"
%
socket
.
gethostname
())
# set GPU device
torch
.
cuda
.
set_device
(
params
.
local_rank
)
# initialize multi-GPU
if
params
.
multi_gpu
:
logger
.
info
(
"Initializing PyTorch distributed"
)
torch
.
distributed
.
init_process_group
(
init_method
=
'env://'
,
backend
=
'nccl'
,
)
def
set_seed
(
args
):
"""
Set the random seed.
"""
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
if
args
.
n_gpu
>
0
:
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment