Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
cf8a61d8
Commit
cf8a61d8
authored
Nov 14, 2020
by
Jiezhong Qiu
Browse files
profile sparsity of relu output in ffn
parent
958c714e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
105 additions
and
31 deletions
+105
-31
pytorch/mem_transformer.py
pytorch/mem_transformer.py
+47
-25
pytorch/train.py
pytorch/train.py
+58
-6
No files found.
pytorch/mem_transformer.py
View file @
cf8a61d8
...
@@ -39,8 +39,11 @@ class PositionwiseFF(nn.Module):
...
@@ -39,8 +39,11 @@ class PositionwiseFF(nn.Module):
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
CoreNet
=
nn
.
Sequential
(
self
.
CoreNet_1
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
)
)
self
.
CoreNet_2
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
...
@@ -53,23 +56,26 @@ class PositionwiseFF(nn.Module):
...
@@ -53,23 +56,26 @@ class PositionwiseFF(nn.Module):
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
##### layer normalization + positionwise feed-forward
core_out
=
self
.
CoreNet
(
self
.
layer_norm
(
inp
))
relu_out
=
self
.
CoreNet_1
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
CoreNet_2
(
relu_out
)
##### residual connection
##### residual connection
output
=
core_out
+
inp
output
=
core_out
+
inp
else
:
else
:
##### positionwise feed-forward
##### positionwise feed-forward
core_out
=
self
.
CoreNet
(
inp
)
relu_out
=
self
.
CoreNet_1
(
inp
)
core_out
=
self
.
CoreNet_2
(
relu_out
)
##### residual connection + layer normalization
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
return
output
,
relu_out
.
detach
()
class
ExtendedMultiHeadAttn
(
nn
.
Module
):
class
ExtendedMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
):
pre_lnorm
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
super
(
ExtendedMultiHeadAttn
,
self
).
__init__
()
print
(
"ExtendedMultiHeadAttn"
)
self
.
n_head
=
n_head
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_model
=
d_model
...
@@ -81,7 +87,7 @@ class ExtendedMultiHeadAttn(nn.Module):
...
@@ -81,7 +87,7 @@ class ExtendedMultiHeadAttn(nn.Module):
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
*
2
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
...
@@ -89,6 +95,9 @@ class ExtendedMultiHeadAttn(nn.Module):
...
@@ -89,6 +95,9 @@ class ExtendedMultiHeadAttn(nn.Module):
self
.
pre_lnorm
=
pre_lnorm
self
.
pre_lnorm
=
pre_lnorm
# self.coeff = nn.Parameter(torch.Tensor(n_head, 2))
# nn.init.uniform_(self.coeff, a=-1, b=1)
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
):
##### multihead attention
##### multihead attention
# [hlen x bsz x n_head x d_head]
# [hlen x bsz x n_head x d_head]
...
@@ -121,9 +130,9 @@ class ExtendedMultiHeadAttn(nn.Module):
...
@@ -121,9 +130,9 @@ class ExtendedMultiHeadAttn(nn.Module):
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
mem2other_attn
=
attn_mask
.
ones
(
mem_len
,
c
.
size
(
0
))
mem2other_attn
=
attn_mask
.
new_
ones
(
mem_len
,
c
.
size
(
0
))
mem2other_attn
[:,
:
mem_len
]
=
0
mem2other_attn
[:,
:
mem_len
]
=
0
attn_score
[:
mem_len
].
masked_fill_
(
mem2other_attn
,
-
float
(
'inf'
))
attn_score
[:
mem_len
].
masked_fill_
(
mem2other_attn
[:,
:,
None
,
None
]
,
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
...
@@ -131,8 +140,13 @@ class ExtendedMultiHeadAttn(nn.Module):
...
@@ -131,8 +140,13 @@ class ExtendedMultiHeadAttn(nn.Module):
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec_quad
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
attn_vec
))
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
# [qlen x bsz x n_head x d_head x 2]
attn_vecs
=
torch
.
cat
([
attn_vec
.
unsqueeze
(
-
1
),
attn_vec_quad
.
unsqueeze
(
-
1
)],
dim
=-
1
)
# attn_vec = torch.einsum('ibndt,nt->ibnd', (attn_vecs, self.coeff))
attn_vec
=
attn_vecs
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
*
2
)
attn_vec
=
attn_vec
[
mem_len
:]
attn_vec
=
attn_vec
[
mem_len
:]
...
@@ -463,6 +477,7 @@ class DecoderLayer(nn.Module):
...
@@ -463,6 +477,7 @@ class DecoderLayer(nn.Module):
super
(
DecoderLayer
,
self
).
__init__
()
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
...
@@ -470,9 +485,9 @@ class DecoderLayer(nn.Module):
...
@@ -470,9 +485,9 @@ class DecoderLayer(nn.Module):
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
,
relu_out
=
self
.
pos_ff
(
output
)
return
output
return
output
,
relu_out
class
RelLearnableDecoderLayer
(
nn
.
Module
):
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -489,9 +504,9 @@ class RelLearnableDecoderLayer(nn.Module):
...
@@ -489,9 +504,9 @@ class RelLearnableDecoderLayer(nn.Module):
output
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
output
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
dec_attn_mask
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
,
relu_out
=
self
.
pos_ff
(
output
)
return
output
return
output
,
relu_out
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -508,9 +523,9 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -508,9 +523,9 @@ class RelPartialLearnableDecoderLayer(nn.Module):
output
=
self
.
dec_attn
(
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
output
=
self
.
dec_attn
(
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
attn_mask
=
dec_attn_mask
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
,
relu_out
=
self
.
pos_ff
(
output
)
return
output
return
output
,
relu_out
class
AdaptiveEmbedding
(
nn
.
Module
):
class
AdaptiveEmbedding
(
nn
.
Module
):
...
@@ -743,6 +758,7 @@ class MemTransformerLM(nn.Module):
...
@@ -743,6 +758,7 @@ class MemTransformerLM(nn.Module):
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
hids
=
[]
hids
=
[]
relu_outs
=
[]
if
self
.
attn_type
==
0
:
# default
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
dtype
=
word_emb
.
dtype
)
...
@@ -756,9 +772,10 @@ class MemTransformerLM(nn.Module):
...
@@ -756,9 +772,10 @@ class MemTransformerLM(nn.Module):
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
core_out
,
relu_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
elif
self
.
attn_type
==
1
:
# learnable
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
...
@@ -770,9 +787,10 @@ class MemTransformerLM(nn.Module):
...
@@ -770,9 +787,10 @@ class MemTransformerLM(nn.Module):
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
core_out
,
relu_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
elif
self
.
attn_type
==
2
:
# absolute
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
dtype
=
word_emb
.
dtype
)
...
@@ -787,9 +805,10 @@ class MemTransformerLM(nn.Module):
...
@@ -787,9 +805,10 @@ class MemTransformerLM(nn.Module):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
mems_i
+=
pos_emb
[:
mlen
]
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
core_out
,
relu_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
elif
self
.
attn_type
==
3
:
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
...
@@ -807,15 +826,16 @@ class MemTransformerLM(nn.Module):
...
@@ -807,15 +826,16 @@ class MemTransformerLM(nn.Module):
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
core_out
,
relu_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
core_out
=
self
.
drop
(
core_out
)
core_out
=
self
.
drop
(
core_out
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
return
core_out
,
new_mems
return
core_out
,
new_mems
,
relu_outs
def
forward
(
self
,
data
,
target
,
*
mems
):
def
forward
(
self
,
data
,
target
,
*
mems
):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
...
@@ -825,7 +845,9 @@ class MemTransformerLM(nn.Module):
...
@@ -825,7 +845,9 @@ class MemTransformerLM(nn.Module):
if
not
mems
:
mems
=
self
.
init_mems
()
if
not
mems
:
mems
=
self
.
init_mems
()
tgt_len
=
target
.
size
(
0
)
tgt_len
=
target
.
size
(
0
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
hidden
,
new_mems
,
relu_outs
=
self
.
_forward
(
data
,
mems
=
mems
)
# relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1)
pred_hid
=
hidden
[
-
tgt_len
:]
pred_hid
=
hidden
[
-
tgt_len
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
if
self
.
sample_softmax
>
0
and
self
.
training
:
...
@@ -838,9 +860,9 @@ class MemTransformerLM(nn.Module):
...
@@ -838,9 +860,9 @@ class MemTransformerLM(nn.Module):
loss
=
loss
.
view
(
tgt_len
,
-
1
)
loss
=
loss
.
view
(
tgt_len
,
-
1
)
if
new_mems
is
None
:
if
new_mems
is
None
:
return
[
loss
]
return
[
relu_outs
,
loss
]
else
:
else
:
return
[
loss
]
+
new_mems
return
[
relu_outs
,
loss
]
+
new_mems
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
argparse
import
argparse
...
...
pytorch/train.py
View file @
cf8a61d8
...
@@ -16,6 +16,31 @@ from mem_transformer import MemTransformerLM
...
@@ -16,6 +16,31 @@ from mem_transformer import MemTransformerLM
from
utils.exp_utils
import
create_exp_dir
from
utils.exp_utils
import
create_exp_dir
from
utils.data_parallel
import
BalancedDataParallel
from
utils.data_parallel
import
BalancedDataParallel
class
AverageMeter
(
object
):
"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
help
=
'location of the data corpus'
)
help
=
'location of the data corpus'
)
...
@@ -386,6 +411,7 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
...
@@ -386,6 +411,7 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def
evaluate
(
eval_iter
):
def
evaluate
(
eval_iter
):
# Turn on evaluation mode which disables dropout.
# Turn on evaluation mode which disables dropout.
model
.
eval
()
model
.
eval
()
avg_nnzs
=
None
# If the model does not use memory at all, make the ext_len longer.
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
# Otherwise, make the mem_len longer and keep the ext_len the same.
...
@@ -404,16 +430,22 @@ def evaluate(eval_iter):
...
@@ -404,16 +430,22 @@ def evaluate(eval_iter):
if
args
.
max_eval_steps
>
0
and
i
>=
args
.
max_eval_steps
:
if
args
.
max_eval_steps
>
0
and
i
>=
args
.
max_eval_steps
:
break
break
ret
=
model
(
data
,
target
,
*
mems
)
ret
=
model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# loss, mems = ret[0], ret[1:]
relu_outs
,
loss
,
mems
=
ret
[
0
],
ret
[
1
],
ret
[
2
:]
loss
=
loss
.
mean
()
loss
=
loss
.
mean
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_len
+=
seq_len
total_len
+=
seq_len
nnzs
=
[(
relu_out
>
0
).
sum
().
float
().
item
()
/
relu_out
.
numel
()
for
relu_out
in
relu_outs
]
if
avg_nnzs
is
None
:
avg_nnzs
=
[
AverageMeter
()
for
i
in
range
(
len
(
nnzs
))]
for
i
in
range
(
len
(
nnzs
)):
avg_nnzs
[
i
].
update
(
nnzs
[
i
])
# Switch back to the training mode
# Switch back to the training mode
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
train
()
model
.
train
()
return
total_loss
/
total_len
return
total_loss
/
total_len
,
avg_nnzs
def
train
():
def
train
():
...
@@ -424,6 +456,9 @@ def train():
...
@@ -424,6 +456,9 @@ def train():
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
else
:
else
:
mems
=
tuple
()
mems
=
tuple
()
avg_nnzs
=
None
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
model
.
zero_grad
()
model
.
zero_grad
()
...
@@ -434,7 +469,8 @@ def train():
...
@@ -434,7 +469,8 @@ def train():
data_i
=
data_chunks
[
i
].
contiguous
()
data_i
=
data_chunks
[
i
].
contiguous
()
target_i
=
target_chunks
[
i
].
contiguous
()
target_i
=
target_chunks
[
i
].
contiguous
()
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
# loss, mems[i] = ret[0], ret[1:]
relu_outs
,
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
],
ret
[
2
:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
...
@@ -443,13 +479,19 @@ def train():
...
@@ -443,13 +479,19 @@ def train():
train_loss
+=
loss
.
float
().
item
()
train_loss
+=
loss
.
float
().
item
()
else
:
else
:
ret
=
para_model
(
data
,
target
,
*
mems
)
ret
=
para_model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# loss, mems = ret[0], ret[1:]
relu_outs
,
loss
,
mems
=
ret
[
0
],
ret
[
1
],
ret
[
2
:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
else
:
else
:
loss
.
backward
()
loss
.
backward
()
train_loss
+=
loss
.
float
().
item
()
train_loss
+=
loss
.
float
().
item
()
nnzs
=
[(
relu_out
>
0
).
sum
().
float
().
item
()
/
relu_out
.
numel
()
for
relu_out
in
relu_outs
]
if
avg_nnzs
is
None
:
avg_nnzs
=
[
AverageMeter
()
for
i
in
range
(
len
(
nnzs
))]
for
i
in
range
(
len
(
nnzs
)):
avg_nnzs
[
i
].
update
(
nnzs
[
i
])
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
clip_master_grads
(
args
.
clip
)
optimizer
.
clip_master_grads
(
args
.
clip
)
...
@@ -488,12 +530,17 @@ def train():
...
@@ -488,12 +530,17 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
else
:
else
:
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
final_avg_nnzs
=
[
avg_nnzs
[
i
].
avg
for
i
in
range
(
len
(
avg_nnzs
))]
for
i
in
range
(
len
(
avg_nnzs
)):
avg_nnzs
[
i
].
reset
()
log_str
+=
" | avg nnz %.2f | max nnz %.2f"
%
(
sum
(
final_avg_nnzs
)
/
len
(
final_avg_nnzs
)
*
100
,
max
(
final_avg_nnzs
)
*
100
)
logging
(
log_str
)
logging
(
log_str
)
train_loss
=
0
train_loss
=
0
log_start_time
=
time
.
time
()
log_start_time
=
time
.
time
()
if
train_step
%
args
.
eval_interval
==
0
:
if
train_step
%
args
.
eval_interval
==
0
:
val_loss
=
evaluate
(
va_iter
)
val_loss
,
eval_avg_nnzs
=
evaluate
(
va_iter
)
logging
(
'-'
*
100
)
logging
(
'-'
*
100
)
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
'| valid loss {:5.2f}'
.
format
(
'| valid loss {:5.2f}'
.
format
(
...
@@ -503,6 +550,8 @@ def train():
...
@@ -503,6 +550,8 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
else
:
else
:
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
final_eval_avg_nnzs
=
[
eval_avg_nnzs
[
i
].
avg
for
i
in
range
(
len
(
eval_avg_nnzs
))]
log_str
+=
" | mean nnz %.2f | max nnz %.2f"
%
(
sum
(
final_eval_avg_nnzs
)
/
len
(
final_eval_avg_nnzs
)
*
100
,
max
(
final_eval_avg_nnzs
)
*
100
)
logging
(
log_str
)
logging
(
log_str
)
logging
(
'-'
*
100
)
logging
(
'-'
*
100
)
# Save the model if the validation loss is the best we've seen so far.
# Save the model if the validation loss is the best we've seen so far.
...
@@ -551,7 +600,7 @@ with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
...
@@ -551,7 +600,7 @@ with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
para_model
=
model
.
to
(
device
)
para_model
=
model
.
to
(
device
)
# Run on test data.
# Run on test data.
test_loss
=
evaluate
(
te_iter
)
test_loss
,
eval_avg_nnzs
=
evaluate
(
te_iter
)
logging
(
'='
*
100
)
logging
(
'='
*
100
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
...
@@ -559,4 +608,7 @@ if args.dataset in ['enwik8', 'text8']:
...
@@ -559,4 +608,7 @@ if args.dataset in ['enwik8', 'text8']:
else
:
else
:
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
test_loss
,
math
.
exp
(
test_loss
)))
test_loss
,
math
.
exp
(
test_loss
)))
final_eval_avg_nnzs
=
[
eval_avg_nnzs
[
i
].
avg
for
i
in
range
(
len
(
eval_avg_nnzs
))]
logging
(
" | mean nnz %.2f | max nnz %.2f"
%
(
sum
(
final_eval_avg_nnzs
)
/
len
(
final_eval_avg_nnzs
)
*
100
,
max
(
final_eval_avg_nnzs
)
*
100
))
logging
(
'='
*
100
)
logging
(
'='
*
100
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment