Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
cf8a61d8
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "8c281757a0daf3a8e92cbb4bded0e5e6b389a375"
Commit
cf8a61d8
authored
Nov 14, 2020
by
Jiezhong Qiu
Browse files
profile sparsity of relu output in ffn
parent
958c714e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
105 additions
and
31 deletions
+105
-31
pytorch/mem_transformer.py
pytorch/mem_transformer.py
+47
-25
pytorch/train.py
pytorch/train.py
+58
-6
No files found.
pytorch/mem_transformer.py
View file @
cf8a61d8
...
@@ -39,8 +39,11 @@ class PositionwiseFF(nn.Module):
...
@@ -39,8 +39,11 @@ class PositionwiseFF(nn.Module):
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
CoreNet
=
nn
.
Sequential
(
self
.
CoreNet_1
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
)
)
self
.
CoreNet_2
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
...
@@ -53,23 +56,26 @@ class PositionwiseFF(nn.Module):
...
@@ -53,23 +56,26 @@ class PositionwiseFF(nn.Module):
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
##### layer normalization + positionwise feed-forward
core_out
=
self
.
CoreNet
(
self
.
layer_norm
(
inp
))
relu_out
=
self
.
CoreNet_1
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
CoreNet_2
(
relu_out
)
##### residual connection
##### residual connection
output
=
core_out
+
inp
output
=
core_out
+
inp
else
:
else
:
##### positionwise feed-forward
##### positionwise feed-forward
core_out
=
self
.
CoreNet
(
inp
)
relu_out
=
self
.
CoreNet_1
(
inp
)
core_out
=
self
.
CoreNet_2
(
relu_out
)
##### residual connection + layer normalization
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
return
output
,
relu_out
.
detach
()
class
ExtendedMultiHeadAttn
(
nn
.
Module
):
class
ExtendedMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
):
pre_lnorm
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
super
(
ExtendedMultiHeadAttn
,
self
).
__init__
()
print
(
"ExtendedMultiHeadAttn"
)
self
.
n_head
=
n_head
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_model
=
d_model
...
@@ -81,7 +87,7 @@ class ExtendedMultiHeadAttn(nn.Module):
...
@@ -81,7 +87,7 @@ class ExtendedMultiHeadAttn(nn.Module):
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
*
2
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
...
@@ -89,6 +95,9 @@ class ExtendedMultiHeadAttn(nn.Module):
...
@@ -89,6 +95,9 @@ class ExtendedMultiHeadAttn(nn.Module):
self
.
pre_lnorm
=
pre_lnorm
self
.
pre_lnorm
=
pre_lnorm
# self.coeff = nn.Parameter(torch.Tensor(n_head, 2))
# nn.init.uniform_(self.coeff, a=-1, b=1)
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
):
##### multihead attention
##### multihead attention
# [hlen x bsz x n_head x d_head]
# [hlen x bsz x n_head x d_head]
...
@@ -121,9 +130,9 @@ class ExtendedMultiHeadAttn(nn.Module):
...
@@ -121,9 +130,9 @@ class ExtendedMultiHeadAttn(nn.Module):
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
mem2other_attn
=
attn_mask
.
ones
(
mem_len
,
c
.
size
(
0
))
mem2other_attn
=
attn_mask
.
new_
ones
(
mem_len
,
c
.
size
(
0
))
mem2other_attn
[:,
:
mem_len
]
=
0
mem2other_attn
[:,
:
mem_len
]
=
0
attn_score
[:
mem_len
].
masked_fill_
(
mem2other_attn
,
-
float
(
'inf'
))
attn_score
[:
mem_len
].
masked_fill_
(
mem2other_attn
[:,
:,
None
,
None
]
,
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
...
@@ -131,8 +140,13 @@ class ExtendedMultiHeadAttn(nn.Module):
...
@@ -131,8 +140,13 @@ class ExtendedMultiHeadAttn(nn.Module):
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec_quad
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
attn_vec
))
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
# [qlen x bsz x n_head x d_head x 2]
attn_vecs
=
torch
.
cat
([
attn_vec
.
unsqueeze
(
-
1
),
attn_vec_quad
.
unsqueeze
(
-
1
)],
dim
=-
1
)
# attn_vec = torch.einsum('ibndt,nt->ibnd', (attn_vecs, self.coeff))
attn_vec
=
attn_vecs
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
*
2
)
attn_vec
=
attn_vec
[
mem_len
:]
attn_vec
=
attn_vec
[
mem_len
:]
...
@@ -463,6 +477,7 @@ class DecoderLayer(nn.Module):
...
@@ -463,6 +477,7 @@ class DecoderLayer(nn.Module):
super
(
DecoderLayer
,
self
).
__init__
()
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
...
@@ -470,9 +485,9 @@ class DecoderLayer(nn.Module):
...
@@ -470,9 +485,9 @@ class DecoderLayer(nn.Module):
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
,
relu_out
=
self
.
pos_ff
(
output
)
return
output
return
output
,
relu_out
class
RelLearnableDecoderLayer
(
nn
.
Module
):
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -489,9 +504,9 @@ class RelLearnableDecoderLayer(nn.Module):
...
@@ -489,9 +504,9 @@ class RelLearnableDecoderLayer(nn.Module):
output
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
output
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
dec_attn_mask
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
,
relu_out
=
self
.
pos_ff
(
output
)
return
output
return
output
,
relu_out
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -508,9 +523,9 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -508,9 +523,9 @@ class RelPartialLearnableDecoderLayer(nn.Module):
output
=
self
.
dec_attn
(
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
output
=
self
.
dec_attn
(
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
attn_mask
=
dec_attn_mask
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
,
relu_out
=
self
.
pos_ff
(
output
)
return
output
return
output
,
relu_out
class
AdaptiveEmbedding
(
nn
.
Module
):
class
AdaptiveEmbedding
(
nn
.
Module
):
...
@@ -743,6 +758,7 @@ class MemTransformerLM(nn.Module):
...
@@ -743,6 +758,7 @@ class MemTransformerLM(nn.Module):
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
hids
=
[]
hids
=
[]
relu_outs
=
[]
if
self
.
attn_type
==
0
:
# default
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
dtype
=
word_emb
.
dtype
)
...
@@ -756,9 +772,10 @@ class MemTransformerLM(nn.Module):
...
@@ -756,9 +772,10 @@ class MemTransformerLM(nn.Module):
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
core_out
,
relu_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
elif
self
.
attn_type
==
1
:
# learnable
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
...
@@ -770,9 +787,10 @@ class MemTransformerLM(nn.Module):
...
@@ -770,9 +787,10 @@ class MemTransformerLM(nn.Module):
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
core_out
,
relu_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
elif
self
.
attn_type
==
2
:
# absolute
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
dtype
=
word_emb
.
dtype
)
...
@@ -787,9 +805,10 @@ class MemTransformerLM(nn.Module):
...
@@ -787,9 +805,10 @@ class MemTransformerLM(nn.Module):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
mems_i
+=
pos_emb
[:
mlen
]
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
core_out
,
relu_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
elif
self
.
attn_type
==
3
:
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
...
@@ -807,15 +826,16 @@ class MemTransformerLM(nn.Module):
...
@@ -807,15 +826,16 @@ class MemTransformerLM(nn.Module):
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
core_out
,
relu_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
core_out
=
self
.
drop
(
core_out
)
core_out
=
self
.
drop
(
core_out
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
return
core_out
,
new_mems
return
core_out
,
new_mems
,
relu_outs
def
forward
(
self
,
data
,
target
,
*
mems
):
def
forward
(
self
,
data
,
target
,
*
mems
):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
...
@@ -825,7 +845,9 @@ class MemTransformerLM(nn.Module):
...
@@ -825,7 +845,9 @@ class MemTransformerLM(nn.Module):
if
not
mems
:
mems
=
self
.
init_mems
()
if
not
mems
:
mems
=
self
.
init_mems
()
tgt_len
=
target
.
size
(
0
)
tgt_len
=
target
.
size
(
0
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
hidden
,
new_mems
,
relu_outs
=
self
.
_forward
(
data
,
mems
=
mems
)
# relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1)
pred_hid
=
hidden
[
-
tgt_len
:]
pred_hid
=
hidden
[
-
tgt_len
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
if
self
.
sample_softmax
>
0
and
self
.
training
:
...
@@ -838,9 +860,9 @@ class MemTransformerLM(nn.Module):
...
@@ -838,9 +860,9 @@ class MemTransformerLM(nn.Module):
loss
=
loss
.
view
(
tgt_len
,
-
1
)
loss
=
loss
.
view
(
tgt_len
,
-
1
)
if
new_mems
is
None
:
if
new_mems
is
None
:
return
[
loss
]
return
[
relu_outs
,
loss
]
else
:
else
:
return
[
loss
]
+
new_mems
return
[
relu_outs
,
loss
]
+
new_mems
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
argparse
import
argparse
...
...
pytorch/train.py
View file @
cf8a61d8
...
@@ -16,6 +16,31 @@ from mem_transformer import MemTransformerLM
...
@@ -16,6 +16,31 @@ from mem_transformer import MemTransformerLM
from
utils.exp_utils
import
create_exp_dir
from
utils.exp_utils
import
create_exp_dir
from
utils.data_parallel
import
BalancedDataParallel
from
utils.data_parallel
import
BalancedDataParallel
class
AverageMeter
(
object
):
"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
help
=
'location of the data corpus'
)
help
=
'location of the data corpus'
)
...
@@ -386,6 +411,7 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
...
@@ -386,6 +411,7 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def
evaluate
(
eval_iter
):
def
evaluate
(
eval_iter
):
# Turn on evaluation mode which disables dropout.
# Turn on evaluation mode which disables dropout.
model
.
eval
()
model
.
eval
()
avg_nnzs
=
None
# If the model does not use memory at all, make the ext_len longer.
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
# Otherwise, make the mem_len longer and keep the ext_len the same.
...
@@ -404,16 +430,22 @@ def evaluate(eval_iter):
...
@@ -404,16 +430,22 @@ def evaluate(eval_iter):
if
args
.
max_eval_steps
>
0
and
i
>=
args
.
max_eval_steps
:
if
args
.
max_eval_steps
>
0
and
i
>=
args
.
max_eval_steps
:
break
break
ret
=
model
(
data
,
target
,
*
mems
)
ret
=
model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# loss, mems = ret[0], ret[1:]
relu_outs
,
loss
,
mems
=
ret
[
0
],
ret
[
1
],
ret
[
2
:]
loss
=
loss
.
mean
()
loss
=
loss
.
mean
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_len
+=
seq_len
total_len
+=
seq_len
nnzs
=
[(
relu_out
>
0
).
sum
().
float
().
item
()
/
relu_out
.
numel
()
for
relu_out
in
relu_outs
]
if
avg_nnzs
is
None
:
avg_nnzs
=
[
AverageMeter
()
for
i
in
range
(
len
(
nnzs
))]
for
i
in
range
(
len
(
nnzs
)):
avg_nnzs
[
i
].
update
(
nnzs
[
i
])
# Switch back to the training mode
# Switch back to the training mode
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
train
()
model
.
train
()
return
total_loss
/
total_len
return
total_loss
/
total_len
,
avg_nnzs
def
train
():
def
train
():
...
@@ -424,6 +456,9 @@ def train():
...
@@ -424,6 +456,9 @@ def train():
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
else
:
else
:
mems
=
tuple
()
mems
=
tuple
()
avg_nnzs
=
None
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
model
.
zero_grad
()
model
.
zero_grad
()
...
@@ -434,7 +469,8 @@ def train():
...
@@ -434,7 +469,8 @@ def train():
data_i
=
data_chunks
[
i
].
contiguous
()
data_i
=
data_chunks
[
i
].
contiguous
()
target_i
=
target_chunks
[
i
].
contiguous
()
target_i
=
target_chunks
[
i
].
contiguous
()
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
# loss, mems[i] = ret[0], ret[1:]
relu_outs
,
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
],
ret
[
2
:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
...
@@ -443,13 +479,19 @@ def train():
...
@@ -443,13 +479,19 @@ def train():
train_loss
+=
loss
.
float
().
item
()
train_loss
+=
loss
.
float
().
item
()
else
:
else
:
ret
=
para_model
(
data
,
target
,
*
mems
)
ret
=
para_model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# loss, mems = ret[0], ret[1:]
relu_outs
,
loss
,
mems
=
ret
[
0
],
ret
[
1
],
ret
[
2
:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
else
:
else
:
loss
.
backward
()
loss
.
backward
()
train_loss
+=
loss
.
float
().
item
()
train_loss
+=
loss
.
float
().
item
()
nnzs
=
[(
relu_out
>
0
).
sum
().
float
().
item
()
/
relu_out
.
numel
()
for
relu_out
in
relu_outs
]
if
avg_nnzs
is
None
:
avg_nnzs
=
[
AverageMeter
()
for
i
in
range
(
len
(
nnzs
))]
for
i
in
range
(
len
(
nnzs
)):
avg_nnzs
[
i
].
update
(
nnzs
[
i
])
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
clip_master_grads
(
args
.
clip
)
optimizer
.
clip_master_grads
(
args
.
clip
)
...
@@ -488,12 +530,17 @@ def train():
...
@@ -488,12 +530,17 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
else
:
else
:
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
final_avg_nnzs
=
[
avg_nnzs
[
i
].
avg
for
i
in
range
(
len
(
avg_nnzs
))]
for
i
in
range
(
len
(
avg_nnzs
)):
avg_nnzs
[
i
].
reset
()
log_str
+=
" | avg nnz %.2f | max nnz %.2f"
%
(
sum
(
final_avg_nnzs
)
/
len
(
final_avg_nnzs
)
*
100
,
max
(
final_avg_nnzs
)
*
100
)
logging
(
log_str
)
logging
(
log_str
)
train_loss
=
0
train_loss
=
0
log_start_time
=
time
.
time
()
log_start_time
=
time
.
time
()
if
train_step
%
args
.
eval_interval
==
0
:
if
train_step
%
args
.
eval_interval
==
0
:
val_loss
=
evaluate
(
va_iter
)
val_loss
,
eval_avg_nnzs
=
evaluate
(
va_iter
)
logging
(
'-'
*
100
)
logging
(
'-'
*
100
)
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
'| valid loss {:5.2f}'
.
format
(
'| valid loss {:5.2f}'
.
format
(
...
@@ -503,6 +550,8 @@ def train():
...
@@ -503,6 +550,8 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
else
:
else
:
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
final_eval_avg_nnzs
=
[
eval_avg_nnzs
[
i
].
avg
for
i
in
range
(
len
(
eval_avg_nnzs
))]
log_str
+=
" | mean nnz %.2f | max nnz %.2f"
%
(
sum
(
final_eval_avg_nnzs
)
/
len
(
final_eval_avg_nnzs
)
*
100
,
max
(
final_eval_avg_nnzs
)
*
100
)
logging
(
log_str
)
logging
(
log_str
)
logging
(
'-'
*
100
)
logging
(
'-'
*
100
)
# Save the model if the validation loss is the best we've seen so far.
# Save the model if the validation loss is the best we've seen so far.
...
@@ -551,7 +600,7 @@ with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
...
@@ -551,7 +600,7 @@ with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
para_model
=
model
.
to
(
device
)
para_model
=
model
.
to
(
device
)
# Run on test data.
# Run on test data.
test_loss
=
evaluate
(
te_iter
)
test_loss
,
eval_avg_nnzs
=
evaluate
(
te_iter
)
logging
(
'='
*
100
)
logging
(
'='
*
100
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
...
@@ -559,4 +608,7 @@ if args.dataset in ['enwik8', 'text8']:
...
@@ -559,4 +608,7 @@ if args.dataset in ['enwik8', 'text8']:
else
:
else
:
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
test_loss
,
math
.
exp
(
test_loss
)))
test_loss
,
math
.
exp
(
test_loss
)))
final_eval_avg_nnzs
=
[
eval_avg_nnzs
[
i
].
avg
for
i
in
range
(
len
(
eval_avg_nnzs
))]
logging
(
" | mean nnz %.2f | max nnz %.2f"
%
(
sum
(
final_eval_avg_nnzs
)
/
len
(
final_eval_avg_nnzs
)
*
100
,
max
(
final_eval_avg_nnzs
)
*
100
))
logging
(
'='
*
100
)
logging
(
'='
*
100
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment