Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
37d01e9c
"src/diffusers/utils/dummy_note_seq_objects.py" did not exist on "d7b692083c794b4047930cd84c17c0da3272510b"
Commit
37d01e9c
authored
Nov 18, 2020
by
Jiezhong Qiu
Browse files
multihead ffn
parent
cf8a61d8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
252 additions
and
62 deletions
+252
-62
pytorch/mem_transformer.py
pytorch/mem_transformer.py
+165
-32
pytorch/train.py
pytorch/train.py
+87
-30
No files found.
pytorch/mem_transformer.py
View file @
37d01e9c
...
@@ -30,6 +30,124 @@ class PositionalEmbedding(nn.Module):
...
@@ -30,6 +30,124 @@ class PositionalEmbedding(nn.Module):
else
:
else
:
return
pos_emb
[:,
None
,:]
return
pos_emb
[:,
None
,:]
class
MoEPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
top_k
=
64
):
super
(
MoEPositionwiseFF
,
self
).
__init__
()
print
(
"MoEPositionwiseFF"
)
self
.
top_k
=
top_k
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
gate
=
nn
.
Linear
(
d_model
,
d_inner
)
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_inner
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
ratio
=
top_k
/
d_inner
self
.
dropout_middle
=
nn
.
Dropout
(
dropout
*
ratio
)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp_Linear
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp_Linear
.
weight
.
data
.
transpose
(
0
,
1
)
self
.
b2
.
data
=
temp_Linear
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
relu_out
=
F
.
relu
(
gate_top_k_val
)
x
=
self
.
dropout_middle
(
relu_out
)
W2_select
=
self
.
W2
[
gate_top_k_idx
]
# [.. x top_k x d_model]
core_out
=
torch
.
einsum
(
'ijk,ijkd->ijd'
,
(
x
,
W2_select
))
+
self
.
b2
# [.. x d_model]
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
# return output, relu_out.detach()
class
MultiHeadPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_head
=
2
):
super
(
MultiHeadPositionwiseFF
,
self
).
__init__
()
print
(
"MultiHeadPositionwiseFF"
)
assert
d_model
%
n_head
==
0
self
.
n_head
=
n_head
d_head
=
d_model
/
n_head
self
.
d_head
=
d_head
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
k_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_inner
,
d_head
))
self
.
k_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_inner
))
self
.
v_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_head
,
d_inner
))
self
.
v_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_head
))
self
.
o_net
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
for
i
in
range
(
self
.
n_head
):
tmp
=
nn
.
Linear
(
self
.
d_head
,
self
.
d_inner
)
self
.
k_weight
.
data
[
i
]
=
tmp
.
weight
.
data
self
.
k_bias
.
data
[
i
]
=
tmp
.
bias
.
data
tmp
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_head
)
self
.
v_weight
.
data
[
i
]
=
tmp
.
weight
.
data
self
.
v_bias
.
data
[
i
]
=
tmp
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
head_q
=
self
.
q_net
(
inp
)
head_q
=
head_q
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [.. x n_head x d_head]
attn_score
=
torch
.
einsum
(
'ibnd,nhd->ibnh'
,
(
head_q
,
self
.
k_weight
))
+
self
.
k_bias
# [.. x n_head x d_inner]
attn_score
=
F
.
relu
(
attn_score
)
attn_score
=
self
.
dropout
(
attn_score
)
attn_vec
=
torch
.
einsum
(
'ibnh,ndh->ibnd'
,
(
attn_score
,
self
.
v_weight
))
+
self
.
v_bias
attn_vec
=
attn_vec
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
d_model
)
core_out
=
self
.
o_net
(
attn_vec
)
core_out
=
self
.
dropout
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
PositionwiseFF
(
nn
.
Module
):
class
PositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
...
@@ -69,7 +187,8 @@ class PositionwiseFF(nn.Module):
...
@@ -69,7 +187,8 @@ class PositionwiseFF(nn.Module):
##### residual connection + layer normalization
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
,
relu_out
.
detach
()
return
output
# return output, relu_out.detach()
class
ExtendedMultiHeadAttn
(
nn
.
Module
):
class
ExtendedMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
...
@@ -125,14 +244,14 @@ class ExtendedMultiHeadAttn(nn.Module):
...
@@ -125,14 +244,14 @@ class ExtendedMultiHeadAttn(nn.Module):
attn_score
.
mul_
(
self
.
scale
)
attn_score
.
mul_
(
self
.
scale
)
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
.
dim
()
==
2
:
if
attn_mask
.
dim
()
==
2
:
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[
None
,:,:,
None
]
.
bool
()
,
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
elif
attn_mask
.
dim
()
==
3
:
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[:,:,:,
None
]
.
bool
()
,
-
float
(
'inf'
))
mem2other_attn
=
attn_mask
.
new_ones
(
mem_len
,
c
.
size
(
0
))
mem2other_attn
=
attn_mask
.
new_ones
(
mem_len
,
c
.
size
(
0
))
mem2other_attn
[:,
:
mem_len
]
=
0
mem2other_attn
[:,
:
mem_len
]
=
0
attn_score
[:
mem_len
].
masked_fill_
(
mem2other_attn
[:,
:,
None
,
None
],
-
float
(
'inf'
))
attn_score
[:
mem_len
].
masked_fill_
(
mem2other_attn
[:,
:,
None
,
None
]
.
bool
()
,
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
...
@@ -211,9 +330,9 @@ class MultiHeadAttn(nn.Module):
...
@@ -211,9 +330,9 @@ class MultiHeadAttn(nn.Module):
attn_score
.
mul_
(
self
.
scale
)
attn_score
.
mul_
(
self
.
scale
)
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
.
dim
()
==
2
:
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
]
.
bool
()
,
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
]
.
bool
()
,
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
...
@@ -358,10 +477,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -358,10 +477,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
.
dim
()
==
2
:
if
attn_mask
.
dim
()
==
2
:
attn_score
=
attn_score
.
float
().
masked_fill
(
attn_score
=
attn_score
.
float
().
masked_fill
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
)).
type_as
(
attn_score
)
attn_mask
[
None
,:,:,
None
]
.
bool
()
,
-
float
(
'inf'
)).
type_as
(
attn_score
)
elif
attn_mask
.
dim
()
==
3
:
elif
attn_mask
.
dim
()
==
3
:
attn_score
=
attn_score
.
float
().
masked_fill
(
attn_score
=
attn_score
.
float
().
masked_fill
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
)).
type_as
(
attn_score
)
attn_mask
[:,:,:,
None
]
.
bool
()
,
-
float
(
'inf'
)).
type_as
(
attn_score
)
# [qlen x klen x bsz x n_head]
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
...
@@ -444,9 +563,9 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -444,9 +563,9 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
#### compute attention probability
#### compute attention probability
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
.
dim
()
==
2
:
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
]
.
bool
()
,
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
]
.
bool
()
,
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
...
@@ -478,16 +597,18 @@ class DecoderLayer(nn.Module):
...
@@ -478,16 +597,18 @@ class DecoderLayer(nn.Module):
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
MultiHead
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
,
relu_out
=
self
.
pos_ff
(
output
)
output
=
self
.
pos_ff
(
output
)
# output, relu_out = self.pos_ff(output)
return
output
,
relu_out
return
output
# return output, relu_out
class
RelLearnableDecoderLayer
(
nn
.
Module
):
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -496,7 +617,7 @@ class RelLearnableDecoderLayer(nn.Module):
...
@@ -496,7 +617,7 @@ class RelLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
MultiHead
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -504,9 +625,11 @@ class RelLearnableDecoderLayer(nn.Module):
...
@@ -504,9 +625,11 @@ class RelLearnableDecoderLayer(nn.Module):
output
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
output
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
dec_attn_mask
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
,
relu_out
=
self
.
pos_ff
(
output
)
output
=
self
.
pos_ff
(
output
)
# output, relu_out = self.pos_ff(output)
return
output
,
relu_out
return
output
# return output, relu_out
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -515,7 +638,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -515,7 +638,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
MultiHead
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -523,9 +646,11 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -523,9 +646,11 @@ class RelPartialLearnableDecoderLayer(nn.Module):
output
=
self
.
dec_attn
(
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
output
=
self
.
dec_attn
(
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
attn_mask
=
dec_attn_mask
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
,
relu_out
=
self
.
pos_ff
(
output
)
output
=
self
.
pos_ff
(
output
)
# output, relu_out = self.pos_ff(output)
return
output
,
relu_out
return
output
# return output, relu_out
class
AdaptiveEmbedding
(
nn
.
Module
):
class
AdaptiveEmbedding
(
nn
.
Module
):
...
@@ -758,7 +883,7 @@ class MemTransformerLM(nn.Module):
...
@@ -758,7 +883,7 @@ class MemTransformerLM(nn.Module):
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
hids
=
[]
hids
=
[]
relu_outs
=
[]
#
relu_outs = []
if
self
.
attn_type
==
0
:
# default
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
dtype
=
word_emb
.
dtype
)
...
@@ -772,10 +897,11 @@ class MemTransformerLM(nn.Module):
...
@@ -772,10 +897,11 @@ class MemTransformerLM(nn.Module):
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
,
relu_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
# core_out, relu_out = layer(core_out, pos_emb, self.r_w_bias,
core_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
#
relu_outs.append(relu_out)
elif
self
.
attn_type
==
1
:
# learnable
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
...
@@ -787,10 +913,11 @@ class MemTransformerLM(nn.Module):
...
@@ -787,10 +913,11 @@ class MemTransformerLM(nn.Module):
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
,
relu_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
# core_out, relu_out = layer(core_out, r_emb, self.r_w_bias[i],
core_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
#
relu_outs.append(relu_out)
elif
self
.
attn_type
==
2
:
# absolute
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
dtype
=
word_emb
.
dtype
)
...
@@ -805,10 +932,11 @@ class MemTransformerLM(nn.Module):
...
@@ -805,10 +932,11 @@ class MemTransformerLM(nn.Module):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
mems_i
+=
pos_emb
[:
mlen
]
core_out
,
relu_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
# core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
#
relu_outs.append(relu_out)
elif
self
.
attn_type
==
3
:
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
...
@@ -826,16 +954,18 @@ class MemTransformerLM(nn.Module):
...
@@ -826,16 +954,18 @@ class MemTransformerLM(nn.Module):
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
,
relu_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
# core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
relu_outs
.
append
(
relu_out
)
#
relu_outs.append(relu_out)
core_out
=
self
.
drop
(
core_out
)
core_out
=
self
.
drop
(
core_out
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
return
core_out
,
new_mems
,
relu_outs
return
core_out
,
new_mems
# return core_out, new_mems, relu_outs
def
forward
(
self
,
data
,
target
,
*
mems
):
def
forward
(
self
,
data
,
target
,
*
mems
):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
...
@@ -845,7 +975,8 @@ class MemTransformerLM(nn.Module):
...
@@ -845,7 +975,8 @@ class MemTransformerLM(nn.Module):
if
not
mems
:
mems
=
self
.
init_mems
()
if
not
mems
:
mems
=
self
.
init_mems
()
tgt_len
=
target
.
size
(
0
)
tgt_len
=
target
.
size
(
0
)
hidden
,
new_mems
,
relu_outs
=
self
.
_forward
(
data
,
mems
=
mems
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
# hidden, new_mems, relu_outs = self._forward(data, mems=mems)
# relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1)
# relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1)
...
@@ -860,9 +991,11 @@ class MemTransformerLM(nn.Module):
...
@@ -860,9 +991,11 @@ class MemTransformerLM(nn.Module):
loss
=
loss
.
view
(
tgt_len
,
-
1
)
loss
=
loss
.
view
(
tgt_len
,
-
1
)
if
new_mems
is
None
:
if
new_mems
is
None
:
return
[
relu_outs
,
loss
]
return
[
loss
]
# return [relu_outs, loss]
else
:
else
:
return
[
relu_outs
,
loss
]
+
new_mems
return
[
loss
]
+
new_mems
# return [relu_outs, loss] + new_mems
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
argparse
import
argparse
...
...
pytorch/train.py
View file @
37d01e9c
...
@@ -4,6 +4,7 @@ import time
...
@@ -4,6 +4,7 @@ import time
import
math
import
math
import
os
,
sys
import
os
,
sys
import
itertools
import
itertools
import
pathlib
import
numpy
as
np
import
numpy
as
np
...
@@ -411,7 +412,9 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
...
@@ -411,7 +412,9 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def
evaluate
(
eval_iter
):
def
evaluate
(
eval_iter
):
# Turn on evaluation mode which disables dropout.
# Turn on evaluation mode which disables dropout.
model
.
eval
()
model
.
eval
()
avg_nnzs
=
None
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
# If the model does not use memory at all, make the ext_len longer.
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
# Otherwise, make the mem_len longer and keep the ext_len the same.
...
@@ -430,22 +433,34 @@ def evaluate(eval_iter):
...
@@ -430,22 +433,34 @@ def evaluate(eval_iter):
if
args
.
max_eval_steps
>
0
and
i
>=
args
.
max_eval_steps
:
if
args
.
max_eval_steps
>
0
and
i
>=
args
.
max_eval_steps
:
break
break
ret
=
model
(
data
,
target
,
*
mems
)
ret
=
model
(
data
,
target
,
*
mems
)
#
loss, mems = ret[0], ret[1:]
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
relu_outs
,
loss
,
mems
=
ret
[
0
],
ret
[
1
],
ret
[
2
:]
#
relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss
=
loss
.
mean
()
loss
=
loss
.
mean
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_len
+=
seq_len
total_len
+=
seq_len
nnzs
=
[(
relu_out
>
0
).
sum
().
float
().
item
()
/
relu_out
.
numel
()
for
relu_out
in
relu_outs
]
if
avg_nnzs
is
None
:
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
avg_nnzs
=
[
AverageMeter
()
for
i
in
range
(
len
(
nnzs
))]
# if avg_nnzs is None:
for
i
in
range
(
len
(
nnzs
)):
# n_layer = len(acts)
avg_nnzs
[
i
].
update
(
nnzs
[
i
])
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
# Switch back to the training mode
# Switch back to the training mode
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
train
()
model
.
train
()
return
total_loss
/
total_len
,
avg_nnzs
return
total_loss
/
total_len
# return total_loss / total_len, avg_nnzs, act_hist, co_act_hist
def
train
():
def
train
():
...
@@ -457,7 +472,9 @@ def train():
...
@@ -457,7 +472,9 @@ def train():
else
:
else
:
mems
=
tuple
()
mems
=
tuple
()
avg_nnzs
=
None
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
...
@@ -469,8 +486,8 @@ def train():
...
@@ -469,8 +486,8 @@ def train():
data_i
=
data_chunks
[
i
].
contiguous
()
data_i
=
data_chunks
[
i
].
contiguous
()
target_i
=
target_chunks
[
i
].
contiguous
()
target_i
=
target_chunks
[
i
].
contiguous
()
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
#
loss, mems[i] = ret[0], ret[1:]
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
relu_outs
,
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
],
ret
[
2
:]
#
relu_outs, loss, mems[i] = ret[0], ret[1], ret[2:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
...
@@ -479,19 +496,29 @@ def train():
...
@@ -479,19 +496,29 @@ def train():
train_loss
+=
loss
.
float
().
item
()
train_loss
+=
loss
.
float
().
item
()
else
:
else
:
ret
=
para_model
(
data
,
target
,
*
mems
)
ret
=
para_model
(
data
,
target
,
*
mems
)
#
loss, mems = ret[0], ret[1:]
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
relu_outs
,
loss
,
mems
=
ret
[
0
],
ret
[
1
],
ret
[
2
:]
#
relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
else
:
else
:
loss
.
backward
()
loss
.
backward
()
train_loss
+=
loss
.
float
().
item
()
train_loss
+=
loss
.
float
().
item
()
nnzs
=
[(
relu_out
>
0
).
sum
().
float
().
item
()
/
relu_out
.
numel
()
for
relu_out
in
relu_outs
]
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
if
avg_nnzs
is
None
:
# # nnzs = [act.sum().item() / act.numel() for act in acts]
avg_nnzs
=
[
AverageMeter
()
for
i
in
range
(
len
(
nnzs
))]
# if avg_nnzs is None:
for
i
in
range
(
len
(
nnzs
)):
# n_layer = len(acts)
avg_nnzs
[
i
].
update
(
nnzs
[
i
])
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
clip_master_grads
(
args
.
clip
)
optimizer
.
clip_master_grads
(
args
.
clip
)
...
@@ -530,17 +557,39 @@ def train():
...
@@ -530,17 +557,39 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
else
:
else
:
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
final_avg_nnzs
=
[
avg_nnzs
[
i
].
avg
for
i
in
range
(
len
(
avg_nnzs
))]
#
final_avg_nnzs = [avg_nnzs[i].avg for i in range(len(avg_nnzs))]
for
i
in
range
(
len
(
avg_nnzs
)):
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
avg_nnzs
[
i
].
reset
()
#
sum(final_avg_nnzs)/len(final_avg_nnzs)*100,
log_str
+=
" | avg nnz %.2f | max nnz %.2f"
%
(
sum
(
final_avg_nnzs
)
/
len
(
final_avg_nnzs
)
*
100
,
max
(
final_avg_nnzs
)
*
100
)
#
max(final_avg_nnzs)*100
,
# )
logging
(
log_str
)
logging
(
log_str
)
# co_act_dir = pathlib.Path(logging.keywords['log_path']).parent.joinpath("co_act")
# co_act_dir.mkdir(parents=True, exist_ok=True)
# co_act_path = co_act_dir.joinpath('epoch_%d_train_step_%d.pt' % (epoch, train_step))
# torch.save(co_act_hist, co_act_path)
# for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum()
# prob, index = torch.topk(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1,
# prob[:64].sum().item(),
# prob[:128].sum().item(),
# prob[:256].sum().item(),
# prob[:512].sum().item(),
# prob[:1024].sum().item()
# )
# logging(log_str)
# act_hist[i] = 0.
# co_act_hist[i] = 0.
train_loss
=
0
train_loss
=
0
log_start_time
=
time
.
time
()
log_start_time
=
time
.
time
()
if
train_step
%
args
.
eval_interval
==
0
:
if
train_step
%
args
.
eval_interval
==
0
:
val_loss
,
eval_avg_nnzs
=
evaluate
(
va_iter
)
val_loss
=
evaluate
(
va_iter
)
# val_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(va_iter)
logging
(
'-'
*
100
)
logging
(
'-'
*
100
)
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
'| valid loss {:5.2f}'
.
format
(
'| valid loss {:5.2f}'
.
format
(
...
@@ -550,8 +599,11 @@ def train():
...
@@ -550,8 +599,11 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
else
:
else
:
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
final_eval_avg_nnzs
=
[
eval_avg_nnzs
[
i
].
avg
for
i
in
range
(
len
(
eval_avg_nnzs
))]
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
log_str
+=
" | mean nnz %.2f | max nnz %.2f"
%
(
sum
(
final_eval_avg_nnzs
)
/
len
(
final_eval_avg_nnzs
)
*
100
,
max
(
final_eval_avg_nnzs
)
*
100
)
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
logging
(
log_str
)
logging
(
log_str
)
logging
(
'-'
*
100
)
logging
(
'-'
*
100
)
# Save the model if the validation loss is the best we've seen so far.
# Save the model if the validation loss is the best we've seen so far.
...
@@ -600,7 +652,8 @@ with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
...
@@ -600,7 +652,8 @@ with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
para_model
=
model
.
to
(
device
)
para_model
=
model
.
to
(
device
)
# Run on test data.
# Run on test data.
test_loss
,
eval_avg_nnzs
=
evaluate
(
te_iter
)
test_loss
=
evaluate
(
te_iter
)
# test_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(te_iter)
logging
(
'='
*
100
)
logging
(
'='
*
100
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
...
@@ -609,6 +662,10 @@ else:
...
@@ -609,6 +662,10 @@ else:
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
test_loss
,
math
.
exp
(
test_loss
)))
test_loss
,
math
.
exp
(
test_loss
)))
final_eval_avg_nnzs
=
[
eval_avg_nnzs
[
i
].
avg
for
i
in
range
(
len
(
eval_avg_nnzs
))]
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
logging
(
" | mean nnz %.2f | max nnz %.2f"
%
(
sum
(
final_eval_avg_nnzs
)
/
len
(
final_eval_avg_nnzs
)
*
100
,
max
(
final_eval_avg_nnzs
)
*
100
))
# log_str = ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
# logging(log_str)
logging
(
'='
*
100
)
logging
(
'='
*
100
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment