Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
958c714e
Commit
958c714e
authored
Nov 08, 2020
by
Jiezhong Qiu
Browse files
add mem2mem attention
parent
44781ed2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
98 additions
and
15 deletions
+98
-15
pytorch/mem_transformer.py
pytorch/mem_transformer.py
+98
-15
No files found.
pytorch/mem_transformer.py
View file @
958c714e
...
...
@@ -66,8 +66,91 @@ class PositionwiseFF(nn.Module):
return
output
class
ExtendedMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
n_head
*
d_head
,
bias
=
False
)
self
.
kv_net
=
nn
.
Linear
(
d_model
,
2
*
n_head
*
d_head
,
bias
=
False
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if
mems
is
not
None
:
c
=
torch
.
cat
([
mems
,
h
],
0
)
mem_len
=
mems
.
size
(
0
)
else
:
c
=
h
mem_len
=
0
if
self
.
pre_lnorm
:
##### layer normalization
c
=
self
.
layer_norm
(
c
)
head_q
=
self
.
q_net
(
c
)
head_k
,
head_v
=
torch
.
chunk
(
self
.
kv_net
(
c
),
2
,
-
1
)
head_q
=
head_q
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_k
=
head_k
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_v
=
head_v
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [qlen x klen x bsz x n_head]
attn_score
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
head_q
,
head_k
))
attn_score
.
mul_
(
self
.
scale
)
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
.
dim
()
==
2
:
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
mem2other_attn
=
attn_mask
.
ones
(
mem_len
,
c
.
size
(
0
))
mem2other_attn
[:,
:
mem_len
]
=
0
attn_score
[:
mem_len
].
masked_fill_
(
mem2other_attn
,
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
attn_vec
=
attn_vec
[
mem_len
:]
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
output
=
h
+
attn_out
else
:
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
h
+
attn_out
)
return
output
class
MultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
...
...
@@ -380,7 +463,7 @@ class DecoderLayer(nn.Module):
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
...
...
@@ -398,7 +481,7 @@ class RelLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
...
@@ -417,7 +500,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
...
@@ -431,7 +514,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
class
AdaptiveEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
sample_softmax
=
False
):
super
(
AdaptiveEmbedding
,
self
).
__init__
()
...
...
@@ -469,7 +552,7 @@ class AdaptiveEmbedding(nn.Module):
else
:
param
=
next
(
self
.
parameters
())
inp_flat
=
inp
.
view
(
-
1
)
emb_flat
=
torch
.
zeros
([
inp_flat
.
size
(
0
),
self
.
d_proj
],
emb_flat
=
torch
.
zeros
([
inp_flat
.
size
(
0
),
self
.
d_proj
],
dtype
=
param
.
dtype
,
device
=
param
.
device
)
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
...
...
@@ -494,11 +577,11 @@ class AdaptiveEmbedding(nn.Module):
class
MemTransformerLM
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
n_layer
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
dropatt
,
tie_weight
=
True
,
d_embed
=
None
,
dropout
,
dropatt
,
tie_weight
=
True
,
d_embed
=
None
,
div_val
=
1
,
tie_projs
=
[
False
],
pre_lnorm
=
False
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
cutoffs
=
[],
adapt_inp
=
False
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
sample_softmax
=-
1
):
super
(
MemTransformerLM
,
self
).
__init__
()
self
.
n_token
=
n_token
...
...
@@ -509,7 +592,7 @@ class MemTransformerLM(nn.Module):
self
.
n_head
=
n_head
self
.
d_head
=
d_head
self
.
word_emb
=
AdaptiveEmbedding
(
n_token
,
d_embed
,
d_model
,
cutoffs
,
self
.
word_emb
=
AdaptiveEmbedding
(
n_token
,
d_embed
,
d_model
,
cutoffs
,
div_val
=
div_val
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
...
...
@@ -559,7 +642,7 @@ class MemTransformerLM(nn.Module):
# use adaptive softmax (including standard softmax)
else
:
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
n_token
,
d_embed
,
d_model
,
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
n_token
,
d_embed
,
d_model
,
cutoffs
,
div_val
=
div_val
)
if
tie_weight
:
...
...
@@ -661,7 +744,7 @@ class MemTransformerLM(nn.Module):
hids
=
[]
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
...
...
@@ -797,10 +880,10 @@ if __name__ == '__main__':
for
d_embed
in
[
200
,
100
]:
model
=
MemTransformerLM
(
args
.
n_token
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_model
,
args
.
d_head
,
args
.
d_inner
,
args
.
dropout
,
dropatt
=
args
.
dropout
,
tie_weight
=
True
,
d_embed
=
d_embed
,
div_val
=
div_val
,
dropatt
=
args
.
dropout
,
tie_weight
=
True
,
d_embed
=
d_embed
,
div_val
=
div_val
,
tie_projs
=
tie_projs
,
pre_lnorm
=
True
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
cutoffs
=
cutoffs
,
attn_type
=
0
).
to
(
device
)
print
(
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment