Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
cf8a61d8
Commit
cf8a61d8
authored
Nov 14, 2020
by
Jiezhong Qiu
Browse files
profile sparsity of relu output in ffn
parent
958c714e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
105 additions
and
31 deletions
+105
-31
pytorch/mem_transformer.py
pytorch/mem_transformer.py
+47
-25
pytorch/train.py
pytorch/train.py
+58
-6
No files found.
pytorch/mem_transformer.py
View file @
cf8a61d8
...
...
@@ -39,8 +39,11 @@ class PositionwiseFF(nn.Module):
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
CoreNet
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
self
.
CoreNet_1
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
)
)
self
.
CoreNet_2
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Dropout
(
dropout
),
...
...
@@ -53,23 +56,26 @@ class PositionwiseFF(nn.Module):
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
core_out
=
self
.
CoreNet
(
self
.
layer_norm
(
inp
))
relu_out
=
self
.
CoreNet_1
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
CoreNet_2
(
relu_out
)
##### residual connection
output
=
core_out
+
inp
else
:
##### positionwise feed-forward
core_out
=
self
.
CoreNet
(
inp
)
relu_out
=
self
.
CoreNet_1
(
inp
)
core_out
=
self
.
CoreNet_2
(
relu_out
)
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
return
output
,
relu_out
.
detach
()
class
ExtendedMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
super
(
ExtendedMultiHeadAttn
,
self
).
__init__
()
print
(
"ExtendedMultiHeadAttn"
)
self
.
n_head
=
n_head
self
.
d_model
=
d_model
...
...
@@ -81,7 +87,7 @@ class ExtendedMultiHeadAttn(nn.Module):
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
*
2
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
...
...
@@ -89,6 +95,9 @@ class ExtendedMultiHeadAttn(nn.Module):
self
.
pre_lnorm
=
pre_lnorm
# self.coeff = nn.Parameter(torch.Tensor(n_head, 2))
# nn.init.uniform_(self.coeff, a=-1, b=1)
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
):
##### multihead attention
# [hlen x bsz x n_head x d_head]
...
...
@@ -121,9 +130,9 @@ class ExtendedMultiHeadAttn(nn.Module):
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
mem2other_attn
=
attn_mask
.
ones
(
mem_len
,
c
.
size
(
0
))
mem2other_attn
=
attn_mask
.
new_
ones
(
mem_len
,
c
.
size
(
0
))
mem2other_attn
[:,
:
mem_len
]
=
0
attn_score
[:
mem_len
].
masked_fill_
(
mem2other_attn
,
-
float
(
'inf'
))
attn_score
[:
mem_len
].
masked_fill_
(
mem2other_attn
[:,
:,
None
,
None
]
,
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
...
...
@@ -131,8 +140,13 @@ class ExtendedMultiHeadAttn(nn.Module):
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
attn_vec_quad
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
attn_vec
))
# [qlen x bsz x n_head x d_head x 2]
attn_vecs
=
torch
.
cat
([
attn_vec
.
unsqueeze
(
-
1
),
attn_vec_quad
.
unsqueeze
(
-
1
)],
dim
=-
1
)
# attn_vec = torch.einsum('ibndt,nt->ibnd', (attn_vecs, self.coeff))
attn_vec
=
attn_vecs
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
*
2
)
attn_vec
=
attn_vec
[
mem_len
:]
...
...
@@ -463,6 +477,7 @@ class DecoderLayer(nn.Module):
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
...
...
@@ -470,9 +485,9 @@ class DecoderLayer(nn.Module):
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
,
relu_out
=
self
.
pos_ff
(
output
)
return
output
return
output
,
relu_out
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
...
@@ -489,9 +504,9 @@ class RelLearnableDecoderLayer(nn.Module):
output
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
,
relu_out
=
self
.
pos_ff
(
output
)
return
output
return
output
,
relu_out
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
...
@@ -508,9 +523,9 @@ class RelPartialLearnableDecoderLayer(nn.Module):
output
=
self
.
dec_attn
(
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
,
relu_out
=
self
.
pos_ff
(
output
)
return
output
return
output
,
relu_out
class
AdaptiveEmbedding
(
nn
.
Module
):
...
...
@@ -743,6 +758,7 @@ class MemTransformerLM(nn.Module):
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
hids
=
[]
relu_outs
=
[]
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
...
...
@@ -756,9 +772,10 @@ class MemTransformerLM(nn.Module):
hids
.
append
(
core_out
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
core_out
,
relu_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
hids
.
append
(
core_out
)
...
...
@@ -770,9 +787,10 @@ class MemTransformerLM(nn.Module):
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
core_out
,
relu_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
...
...
@@ -787,9 +805,10 @@ class MemTransformerLM(nn.Module):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
core_out
,
relu_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
...
...
@@ -807,15 +826,16 @@ class MemTransformerLM(nn.Module):
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
core_out
,
relu_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
core_out
=
self
.
drop
(
core_out
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
return
core_out
,
new_mems
return
core_out
,
new_mems
,
relu_outs
def
forward
(
self
,
data
,
target
,
*
mems
):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
...
...
@@ -825,7 +845,9 @@ class MemTransformerLM(nn.Module):
if
not
mems
:
mems
=
self
.
init_mems
()
tgt_len
=
target
.
size
(
0
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
hidden
,
new_mems
,
relu_outs
=
self
.
_forward
(
data
,
mems
=
mems
)
# relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1)
pred_hid
=
hidden
[
-
tgt_len
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
...
...
@@ -838,9 +860,9 @@ class MemTransformerLM(nn.Module):
loss
=
loss
.
view
(
tgt_len
,
-
1
)
if
new_mems
is
None
:
return
[
loss
]
return
[
relu_outs
,
loss
]
else
:
return
[
loss
]
+
new_mems
return
[
relu_outs
,
loss
]
+
new_mems
if
__name__
==
'__main__'
:
import
argparse
...
...
pytorch/train.py
View file @
cf8a61d8
...
...
@@ -16,6 +16,31 @@ from mem_transformer import MemTransformerLM
from
utils.exp_utils
import
create_exp_dir
from
utils.data_parallel
import
BalancedDataParallel
class
AverageMeter
(
object
):
"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
help
=
'location of the data corpus'
)
...
...
@@ -386,6 +411,7 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def
evaluate
(
eval_iter
):
# Turn on evaluation mode which disables dropout.
model
.
eval
()
avg_nnzs
=
None
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
...
...
@@ -404,16 +430,22 @@ def evaluate(eval_iter):
if
args
.
max_eval_steps
>
0
and
i
>=
args
.
max_eval_steps
:
break
ret
=
model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# loss, mems = ret[0], ret[1:]
relu_outs
,
loss
,
mems
=
ret
[
0
],
ret
[
1
],
ret
[
2
:]
loss
=
loss
.
mean
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_len
+=
seq_len
nnzs
=
[(
relu_out
>
0
).
sum
().
float
().
item
()
/
relu_out
.
numel
()
for
relu_out
in
relu_outs
]
if
avg_nnzs
is
None
:
avg_nnzs
=
[
AverageMeter
()
for
i
in
range
(
len
(
nnzs
))]
for
i
in
range
(
len
(
nnzs
)):
avg_nnzs
[
i
].
update
(
nnzs
[
i
])
# Switch back to the training mode
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
train
()
return
total_loss
/
total_len
return
total_loss
/
total_len
,
avg_nnzs
def
train
():
...
...
@@ -424,6 +456,9 @@ def train():
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
else
:
mems
=
tuple
()
avg_nnzs
=
None
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
model
.
zero_grad
()
...
...
@@ -434,7 +469,8 @@ def train():
data_i
=
data_chunks
[
i
].
contiguous
()
target_i
=
target_chunks
[
i
].
contiguous
()
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
# loss, mems[i] = ret[0], ret[1:]
relu_outs
,
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
],
ret
[
2
:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
...
...
@@ -443,13 +479,19 @@ def train():
train_loss
+=
loss
.
float
().
item
()
else
:
ret
=
para_model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# loss, mems = ret[0], ret[1:]
relu_outs
,
loss
,
mems
=
ret
[
0
],
ret
[
1
],
ret
[
2
:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
train_loss
+=
loss
.
float
().
item
()
nnzs
=
[(
relu_out
>
0
).
sum
().
float
().
item
()
/
relu_out
.
numel
()
for
relu_out
in
relu_outs
]
if
avg_nnzs
is
None
:
avg_nnzs
=
[
AverageMeter
()
for
i
in
range
(
len
(
nnzs
))]
for
i
in
range
(
len
(
nnzs
)):
avg_nnzs
[
i
].
update
(
nnzs
[
i
])
if
args
.
fp16
:
optimizer
.
clip_master_grads
(
args
.
clip
)
...
...
@@ -488,12 +530,17 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
else
:
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
final_avg_nnzs
=
[
avg_nnzs
[
i
].
avg
for
i
in
range
(
len
(
avg_nnzs
))]
for
i
in
range
(
len
(
avg_nnzs
)):
avg_nnzs
[
i
].
reset
()
log_str
+=
" | avg nnz %.2f | max nnz %.2f"
%
(
sum
(
final_avg_nnzs
)
/
len
(
final_avg_nnzs
)
*
100
,
max
(
final_avg_nnzs
)
*
100
)
logging
(
log_str
)
train_loss
=
0
log_start_time
=
time
.
time
()
if
train_step
%
args
.
eval_interval
==
0
:
val_loss
=
evaluate
(
va_iter
)
val_loss
,
eval_avg_nnzs
=
evaluate
(
va_iter
)
logging
(
'-'
*
100
)
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
'| valid loss {:5.2f}'
.
format
(
...
...
@@ -503,6 +550,8 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
else
:
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
final_eval_avg_nnzs
=
[
eval_avg_nnzs
[
i
].
avg
for
i
in
range
(
len
(
eval_avg_nnzs
))]
log_str
+=
" | mean nnz %.2f | max nnz %.2f"
%
(
sum
(
final_eval_avg_nnzs
)
/
len
(
final_eval_avg_nnzs
)
*
100
,
max
(
final_eval_avg_nnzs
)
*
100
)
logging
(
log_str
)
logging
(
'-'
*
100
)
# Save the model if the validation loss is the best we've seen so far.
...
...
@@ -551,7 +600,7 @@ with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
para_model
=
model
.
to
(
device
)
# Run on test data.
test_loss
=
evaluate
(
te_iter
)
test_loss
,
eval_avg_nnzs
=
evaluate
(
te_iter
)
logging
(
'='
*
100
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
...
...
@@ -559,4 +608,7 @@ if args.dataset in ['enwik8', 'text8']:
else
:
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
test_loss
,
math
.
exp
(
test_loss
)))
final_eval_avg_nnzs
=
[
eval_avg_nnzs
[
i
].
avg
for
i
in
range
(
len
(
eval_avg_nnzs
))]
logging
(
" | mean nnz %.2f | max nnz %.2f"
%
(
sum
(
final_eval_avg_nnzs
)
/
len
(
final_eval_avg_nnzs
)
*
100
,
max
(
final_eval_avg_nnzs
)
*
100
))
logging
(
'='
*
100
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment