Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
03b2a725
Unverified
Commit
03b2a725
authored
Feb 25, 2021
by
Rick Ho
Committed by
GitHub
Feb 25, 2021
Browse files
Merge pull request #6 from xfmr-xl
Test Transformer-XL
parents
e86dea53
0a942e3f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
88 additions
and
620 deletions
+88
-620
.gitignore
.gitignore
+3
-0
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+59
-502
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
+1
-0
examples/transformer-xl/train.py
examples/transformer-xl/train.py
+9
-110
examples/transformer-xl/utils/proj_adaptive_softmax.py
examples/transformer-xl/utils/proj_adaptive_softmax.py
+10
-6
fmoe/transformer.py
fmoe/transformer.py
+6
-2
No files found.
.gitignore
View file @
03b2a725
...
@@ -10,3 +10,6 @@ a.out
...
@@ -10,3 +10,6 @@ a.out
build
build
*swp
*swp
logs
logs
examples/transformer-xl/data
examples/data
examples/transformer-xl/LM-TFM-enwik8
examples/transformer-xl/mem_transformer.py
View file @
03b2a725
...
@@ -7,10 +7,9 @@ import numpy as np
...
@@ -7,10 +7,9 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
# import torch_sparse
sys
.
path
.
append
(
'utils'
)
sys
.
path
.
append
(
'utils'
)
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
,
Projection
from
log_uniform_sampler
import
LogUniformSampler
,
sample_logits
from
log_uniform_sampler
import
LogUniformSampler
,
sample_logits
class
PositionalEmbedding
(
nn
.
Module
):
class
PositionalEmbedding
(
nn
.
Module
):
...
@@ -32,361 +31,16 @@ class PositionalEmbedding(nn.Module):
...
@@ -32,361 +31,16 @@ class PositionalEmbedding(nn.Module):
return
pos_emb
[:,
None
,:]
return
pos_emb
[:,
None
,:]
# A baseline naive slow implementation
class
MoEPositionwiseFFRaw
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
top_k
=
64
):
super
(
MoEPositionwiseFFRaw
,
self
).
__init__
()
print
(
"MoEPositionwiseFF"
)
self
.
top_k
=
top_k
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
gate
=
nn
.
Linear
(
d_model
,
d_inner
)
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_inner
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
ratio
=
top_k
/
d_inner
self
.
dropout_middle
=
nn
.
Dropout
(
dropout
*
ratio
)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp_Linear
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp_Linear
.
weight
.
data
.
transpose
(
0
,
1
)
self
.
b2
.
data
=
temp_Linear
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
relu_out
=
F
.
relu
(
gate_top_k_val
)
x
=
self
.
dropout_middle
(
relu_out
)
W2_select
=
self
.
W2
[
gate_top_k_idx
]
# [.. x top_k x d_model]
core_out
=
torch
.
einsum
(
'ijk,ijkd->ijd'
,
(
x
,
W2_select
))
+
self
.
b2
# [.. x d_model]
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
def
my_topk
(
x
,
k
,
inplace
=
True
):
y
=
x
if
inplace
else
x
.
clone
()
top1_val
,
top1_idx
=
torch
.
max
(
y
,
dim
=-
1
)
top1_val
=
top1_val
.
unsqueeze
(
-
1
)
top1_idx
=
top1_idx
.
unsqueeze
(
-
1
)
if
k
==
1
:
return
top1_val
,
top1_idx
y
.
scatter_
(
-
1
,
top1_idx
,
value
=
float
(
'-inf'
))
top2_val
,
top2_idx
=
torch
.
max
(
y
,
dim
=-
1
)
top2_val
=
top2_val
.
unsqueeze
(
-
1
)
top2_idx
=
top2_idx
.
unsqueeze
(
-
1
)
top_val
=
torch
.
cat
((
top1_val
,
top2_val
),
dim
=-
1
)
top_idx
=
torch
.
cat
((
top1_idx
,
top2_idx
),
dim
=-
1
)
return
top_val
,
top_idx
class
MultiHeadHierarchicalMoEPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_block
=
16
,
top_block
=
2
):
super
(
MultiHeadHierarchicalMoEPositionwiseFF
,
self
).
__init__
()
print
(
"MultiHeadHierarchicalMoEPositionwiseFF"
)
assert
d_inner
%
n_block
==
0
assert
top_block
in
[
1
,
2
]
self
.
top_block
=
top_block
self
.
n_block
=
n_block
d_block
=
d_inner
//
n_block
self
.
d_block
=
d_block
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
block_net_W
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
,
top_block
,
n_block
))
self
.
block_net_b
=
nn
.
Parameter
(
torch
.
Tensor
(
top_block
,
n_block
))
self
.
W1
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
,
d_model
))
self
.
b1
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
))
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
ratio
=
top_block
/
n_block
self
.
dropout_middle
=
nn
.
Dropout
(
dropout
*
ratio
)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
# self.scale = 1 / (d_model ** 0.5)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp
=
nn
.
Linear
(
self
.
d_model
,
self
.
d_inner
)
self
.
W1
.
data
=
temp
.
weight
.
data
.
view
(
self
.
n_block
,
self
.
d_block
,
self
.
d_model
)
self
.
b1
.
data
=
temp
.
bias
.
data
.
view
(
self
.
n_block
,
self
.
d_block
)
temp
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
().
view
(
self
.
n_block
,
self
.
d_block
,
self
.
d_model
)
self
.
b2
.
data
=
temp
.
bias
.
data
for
i
in
range
(
self
.
top_block
):
temp
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_block
)
self
.
block_net_W
.
data
[:,
i
]
=
temp
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
()
self
.
block_net_b
.
data
[
i
]
=
temp
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
block
=
torch
.
einsum
(
"ibd,dan->iban"
,
(
inp
,
self
.
block_net_W
))
+
self
.
block_net_b
# [.. x top_block x n_block ]
block_val
,
block_idx
=
my_topk
(
block
,
k
=
1
,
inplace
=
True
)
# block_val, block_idx = torch.topk(block, k=1, dim=-1, largest=True, sorted=False) # [.. x top_k x 1]
block_val
=
block_val
.
squeeze
(
-
1
)
block_idx
=
block_idx
.
squeeze
(
-
1
)
gate
=
F
.
softmax
(
block_val
,
dim
=-
1
)
W1_block
=
self
.
W1
[
block_idx
]
# [.. x top_k x d_block x d_model]
b1_block
=
self
.
b1
[
block_idx
]
# [.. x top_k x d_block]
x
=
torch
.
einsum
(
'ibd,ibnhd->ibnh'
,
(
inp
,
W1_block
))
+
b1_block
# [.. x top_k x d_block]
# x = x + block_val.unsqueeze(-1) # somehow like residual
x
=
x
*
gate
.
unsqueeze
(
-
1
)
relu_out
=
F
.
relu
(
x
)
relu_out
=
self
.
dropout_middle
(
relu_out
)
W2_block
=
self
.
W2
[
block_idx
]
# [.. x top_k x d_model]
core_out
=
torch
.
einsum
(
'ibnh,ibnhd->ibd'
,
(
x
,
W2_block
))
+
self
.
b2
# [.. x d_model]
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
HierarchicalMoEPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_block
=
16
,
top_block
=
2
):
super
(
HierarchicalMoEPositionwiseFF
,
self
).
__init__
()
print
(
"HierarchicalMoEPositionwiseFF"
)
assert
d_inner
%
n_block
==
0
assert
top_block
in
[
1
,
2
]
self
.
top_block
=
top_block
self
.
n_block
=
n_block
d_block
=
d_inner
//
n_block
self
.
d_block
=
d_block
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
block_net
=
nn
.
Linear
(
d_model
,
n_block
,
bias
=
True
)
self
.
W1
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
,
d_model
))
self
.
b1
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
))
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
ratio
=
top_block
/
n_block
self
.
dropout_middle
=
nn
.
Dropout
(
dropout
*
ratio
)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
# self.scale = 1 / (d_model ** 0.5)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp
=
nn
.
Linear
(
self
.
d_model
,
self
.
d_inner
)
self
.
W1
.
data
=
temp
.
weight
.
data
.
view
(
self
.
n_block
,
self
.
d_block
,
self
.
d_model
)
self
.
b1
.
data
=
temp
.
bias
.
data
.
view
(
self
.
n_block
,
self
.
d_block
)
temp
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
().
view
(
self
.
n_block
,
self
.
d_block
,
self
.
d_model
)
self
.
b2
.
data
=
temp
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
block
=
self
.
block_net
(
inp
)
# block_val, block_idx = my_topk(block, k=self.top_block)
block_val
,
block_idx
=
torch
.
topk
(
block
,
k
=
self
.
top_block
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate
=
F
.
softmax
(
block_val
,
dim
=-
1
)
W1_block
=
self
.
W1
[
block_idx
]
# [.. x top_k x d_block x d_model]
b1_block
=
self
.
b1
[
block_idx
]
# [.. x top_k x d_block]
x
=
torch
.
einsum
(
'ibd,ibnhd->ibnh'
,
(
inp
,
W1_block
))
+
b1_block
# [.. x top_k x d_block]
# x = x + block_val.unsqueeze(-1) # somehow like residual
x
=
x
*
gate
.
unsqueeze
(
-
1
)
relu_out
=
F
.
relu
(
x
)
relu_out
=
self
.
dropout_middle
(
relu_out
)
W2_block
=
self
.
W2
[
block_idx
]
# [.. x top_k x d_model]
core_out
=
torch
.
einsum
(
'ibnh,ibnhd->ibd'
,
(
x
,
W2_block
))
+
self
.
b2
# [.. x d_model]
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
SparsePositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
SparsePositionwiseFF
,
self
).
__init__
()
print
(
"SparsePositionwiseFF"
)
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
CoreNet_1
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Dropout
(
dropout
)
)
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_inner
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp_Linear
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp_Linear
.
weight
.
data
.
transpose
(
0
,
1
)
self
.
b2
.
data
=
temp_Linear
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
relu_out
=
self
.
CoreNet_1
(
inp
).
view
(
-
1
,
self
.
d_inner
)
sparse_relu_out
=
torch_sparse
.
SparseTensor
.
from_dense
(
relu_out
)
core_out
=
torch_sparse
.
matmul
(
sparse_relu_out
,
self
.
W2
)
+
self
.
b2
core_out
=
core_out
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
d_model
)
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
MultiHeadPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_head
=
2
):
super
(
MultiHeadPositionwiseFF
,
self
).
__init__
()
print
(
"MultiHeadPositionwiseFF"
)
assert
d_model
%
n_head
==
0
self
.
n_head
=
n_head
d_head
=
d_model
//
n_head
self
.
d_head
=
d_head
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
k_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_inner
,
d_head
))
self
.
k_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_inner
))
self
.
v_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_head
,
d_inner
))
self
.
v_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_head
))
#self.o_net = nn.Linear(d_model, d_model)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
for
i
in
range
(
self
.
n_head
):
tmp
=
nn
.
Linear
(
self
.
d_head
,
self
.
d_inner
)
self
.
k_weight
.
data
[
i
]
=
tmp
.
weight
.
data
self
.
k_bias
.
data
[
i
]
=
tmp
.
bias
.
data
tmp
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_head
)
self
.
v_weight
.
data
[
i
]
=
tmp
.
weight
.
data
self
.
v_bias
.
data
[
i
]
=
tmp
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
head_q
=
self
.
q_net
(
inp
)
head_q
=
head_q
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [.. x n_head x d_head]
attn_score
=
torch
.
einsum
(
'ibnd,nhd->ibnh'
,
(
head_q
,
self
.
k_weight
))
+
self
.
k_bias
# [.. x n_head x d_inner]
attn_score
=
F
.
relu
(
attn_score
)
attn_score
=
self
.
dropout
(
attn_score
)
attn_vec
=
torch
.
einsum
(
'ibnh,ndh->ibnd'
,
(
attn_score
,
self
.
v_weight
))
+
self
.
v_bias
attn_vec
=
attn_vec
.
contiguous
().
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
d_model
)
# core_out = self.o_net(attn_vec)
core_out
=
self
.
dropout
(
attn_vec
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
PositionwiseFF
(
nn
.
Module
):
class
PositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
use_softmax
=
True
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
PositionwiseFF
,
self
).
__init__
()
super
(
PositionwiseFF
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
CoreNet_1
=
nn
.
Sequential
(
self
.
CoreNet
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Softmax
(
dim
=-
1
)
if
use_softmax
else
nn
.
ReLU
(
inplace
=
True
)
)
self
.
CoreNet_2
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
...
@@ -399,113 +53,18 @@ class PositionwiseFF(nn.Module):
...
@@ -399,113 +53,18 @@ class PositionwiseFF(nn.Module):
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
##### layer normalization + positionwise feed-forward
relu_out
=
self
.
CoreNet_1
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
CoreNet
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
CoreNet_2
(
relu_out
)
##### residual connection
##### residual connection
output
=
core_out
+
inp
output
=
core_out
+
inp
else
:
else
:
##### positionwise feed-forward
##### positionwise feed-forward
relu_out
=
self
.
CoreNet_1
(
inp
)
core_out
=
self
.
CoreNet
(
inp
)
core_out
=
self
.
CoreNet_2
(
relu_out
)
##### residual connection + layer normalization
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
return
output
# return output, relu_out.detach()
class
ExtendedMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
):
super
(
ExtendedMultiHeadAttn
,
self
).
__init__
()
print
(
"ExtendedMultiHeadAttn"
)
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
n_head
*
d_head
,
bias
=
False
)
self
.
kv_net
=
nn
.
Linear
(
d_model
,
2
*
n_head
*
d_head
,
bias
=
False
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
*
2
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
# self.coeff = nn.Parameter(torch.Tensor(n_head, 2))
# nn.init.uniform_(self.coeff, a=-1, b=1)
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if
mems
is
not
None
:
c
=
torch
.
cat
([
mems
,
h
],
0
)
mem_len
=
mems
.
size
(
0
)
else
:
c
=
h
mem_len
=
0
if
self
.
pre_lnorm
:
##### layer normalization
c
=
self
.
layer_norm
(
c
)
head_q
=
self
.
q_net
(
c
)
head_k
,
head_v
=
torch
.
chunk
(
self
.
kv_net
(
c
),
2
,
-
1
)
head_q
=
head_q
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_k
=
head_k
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_v
=
head_v
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [qlen x klen x bsz x n_head]
attn_score
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
head_q
,
head_k
))
attn_score
.
mul_
(
self
.
scale
)
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
.
dim
()
==
2
:
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[
None
,:,:,
None
].
bool
(),
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[:,:,:,
None
].
bool
(),
-
float
(
'inf'
))
mem2other_attn
=
attn_mask
.
new_ones
(
mem_len
,
c
.
size
(
0
))
mem2other_attn
[:,
:
mem_len
]
=
0
attn_score
[:
mem_len
].
masked_fill_
(
mem2other_attn
[:,
:,
None
,
None
].
bool
(),
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec_quad
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
attn_vec
))
# [qlen x bsz x n_head x d_head x 2]
attn_vecs
=
torch
.
cat
([
attn_vec
.
unsqueeze
(
-
1
),
attn_vec_quad
.
unsqueeze
(
-
1
)],
dim
=-
1
)
# attn_vec = torch.einsum('ibndt,nt->ibnd', (attn_vecs, self.coeff))
attn_vec
=
attn_vecs
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
*
2
)
attn_vec
=
attn_vec
[
mem_len
:]
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
output
=
h
+
attn_out
else
:
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
h
+
attn_out
)
return
output
class
MultiHeadAttn
(
nn
.
Module
):
class
MultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
...
@@ -583,7 +142,8 @@ class MultiHeadAttn(nn.Module):
...
@@ -583,7 +142,8 @@ class MultiHeadAttn(nn.Module):
class
RelMultiHeadAttn
(
nn
.
Module
):
class
RelMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
):
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
moe
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
super
(
RelMultiHeadAttn
,
self
).
__init__
()
self
.
n_head
=
n_head
self
.
n_head
=
n_head
...
@@ -816,42 +376,41 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -816,42 +376,41 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
return
output
return
output
from
fmoe
import
FMoETransformerMLP
from
fmoe
import
FMoETransformerMLP
class
CustomizedMoEPositionwiseFF
(
FMoETransformerMLP
):
class
CustomizedMoEPositionwiseFF
(
FMoETransformerMLP
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
def
activation
(
x
):
activation
=
nn
.
Sequential
(
return
self
.
dropout
(
F
.
relu
(
x
))
nn
.
Dropout
(
dropout
),
super
().
__init__
(
num_expert
=
8
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
nn
.
ReLU
()
pre_lnorm
=
pre_lnorm
,
activation
=
activation
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
d_model
,
dtype
=
torch
.
float32
)
)
)
super
().
__init__
(
num_expert
=
moe_num_expert
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
top_k
=
moe_top_k
,
do_lnorm
=
True
,
pre_lnorm
=
pre_lnorm
,
activation
=
activation
,
dropout
=
dropout
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
super
().
forward
(
x
)
x
=
super
().
forward
(
x
)
return
x
+
self
.
bias
return
x
class
DecoderLayer
(
nn
.
Module
):
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
DecoderLayer
,
self
).
__init__
()
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
if
kwargs
.
get
(
'moe'
)
is
False
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
=
self
.
pos_ff
(
output
)
# output, relu_out = self.pos_ff(output)
return
output
return
output
# return output, relu_out
class
RelLearnableDecoderLayer
(
nn
.
Module
):
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -860,8 +419,15 @@ class RelLearnableDecoderLayer(nn.Module):
...
@@ -860,8 +419,15 @@ class RelLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
if
kwargs
.
get
(
'moe'
)
is
False
:
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -869,10 +435,8 @@ class RelLearnableDecoderLayer(nn.Module):
...
@@ -869,10 +435,8 @@ class RelLearnableDecoderLayer(nn.Module):
attn_mask
=
dec_attn_mask
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
=
self
.
pos_ff
(
output
)
# output, relu_out = self.pos_ff(output)
return
output
return
output
# return output, relu_out
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -881,8 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -881,8 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
if
kwargs
.
get
(
'moe'
)
is
False
:
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -890,10 +461,8 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -890,10 +461,8 @@ class RelPartialLearnableDecoderLayer(nn.Module):
attn_mask
=
dec_attn_mask
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
=
self
.
pos_ff
(
output
)
# output, relu_out = self.pos_ff(output)
return
output
return
output
# return output, relu_out
class
AdaptiveEmbedding
(
nn
.
Module
):
class
AdaptiveEmbedding
(
nn
.
Module
):
...
@@ -913,25 +482,26 @@ class AdaptiveEmbedding(nn.Module):
...
@@ -913,25 +482,26 @@ class AdaptiveEmbedding(nn.Module):
self
.
cutoff_ends
=
[
0
]
+
self
.
cutoffs
self
.
cutoff_ends
=
[
0
]
+
self
.
cutoffs
self
.
emb_layers
=
nn
.
ModuleList
()
self
.
emb_layers
=
nn
.
ModuleList
()
self
.
emb_projs
=
nn
.
ParameterList
()
self
.
emb_projs
=
nn
.
ModuleList
()
if
div_val
==
1
:
if
div_val
==
1
:
self
.
emb_layers
.
append
(
self
.
emb_layers
.
append
(
nn
.
Embedding
(
n_token
,
d_embed
,
sparse
=
sample_softmax
>
0
)
nn
.
Embedding
(
n_token
,
d_embed
,
sparse
=
sample_softmax
>
0
)
)
)
if
d_proj
!=
d_embed
:
if
d_proj
!=
d_embed
:
self
.
emb_projs
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
d_proj
,
d_embed
))
)
self
.
emb_projs
.
append
(
Projection
(
d_proj
,
d_embed
))
else
:
else
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
d_emb_i
=
d_embed
//
(
div_val
**
i
)
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
emb_layers
.
append
(
nn
.
Embedding
(
r_idx
-
l_idx
,
d_emb_i
))
self
.
emb_layers
.
append
(
nn
.
Embedding
(
r_idx
-
l_idx
,
d_emb_i
))
self
.
emb_projs
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
d_proj
,
d_emb_i
))
)
self
.
emb_projs
.
append
(
Projectio
(
d_proj
,
d_emb_i
))
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
if
self
.
div_val
==
1
:
if
self
.
div_val
==
1
:
embed
=
self
.
emb_layers
[
0
](
inp
)
embed
=
self
.
emb_layers
[
0
](
inp
)
if
self
.
d_proj
!=
self
.
d_embed
:
if
self
.
d_proj
!=
self
.
d_embed
:
embed
=
F
.
linear
(
embed
,
self
.
emb_projs
[
0
])
embed
=
F
.
linear
(
embed
,
self
.
emb_projs
[
0
]
.
weight
)
else
:
else
:
param
=
next
(
self
.
parameters
())
param
=
next
(
self
.
parameters
())
inp_flat
=
inp
.
view
(
-
1
)
inp_flat
=
inp
.
view
(
-
1
)
...
@@ -948,7 +518,7 @@ class AdaptiveEmbedding(nn.Module):
...
@@ -948,7 +518,7 @@ class AdaptiveEmbedding(nn.Module):
inp_i
=
inp_flat
.
index_select
(
0
,
indices_i
)
-
l_idx
inp_i
=
inp_flat
.
index_select
(
0
,
indices_i
)
-
l_idx
emb_i
=
self
.
emb_layers
[
i
](
inp_i
)
emb_i
=
self
.
emb_layers
[
i
](
inp_i
)
emb_i
=
F
.
linear
(
emb_i
,
self
.
emb_projs
[
i
])
emb_i
=
F
.
linear
(
emb_i
,
self
.
emb_projs
[
i
]
.
weight
)
emb_flat
.
index_copy_
(
0
,
indices_i
,
emb_i
)
emb_flat
.
index_copy_
(
0
,
indices_i
,
emb_i
)
...
@@ -965,7 +535,7 @@ class MemTransformerLM(nn.Module):
...
@@ -965,7 +535,7 @@ class MemTransformerLM(nn.Module):
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
cutoffs
=
[],
adapt_inp
=
False
,
cutoffs
=
[],
adapt_inp
=
False
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
sample_softmax
=-
1
):
sample_softmax
=-
1
,
moe
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
MemTransformerLM
,
self
).
__init__
()
super
(
MemTransformerLM
,
self
).
__init__
()
self
.
n_token
=
n_token
self
.
n_token
=
n_token
...
@@ -996,7 +566,8 @@ class MemTransformerLM(nn.Module):
...
@@ -996,7 +566,8 @@ class MemTransformerLM(nn.Module):
RelPartialLearnableDecoderLayer
(
RelPartialLearnableDecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
)
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
)
elif
attn_type
==
1
:
# learnable embeddings
elif
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
n_layer
):
for
i
in
range
(
n_layer
):
...
@@ -1004,14 +575,16 @@ class MemTransformerLM(nn.Module):
...
@@ -1004,14 +575,16 @@ class MemTransformerLM(nn.Module):
RelLearnableDecoderLayer
(
RelLearnableDecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
)
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
)
elif
attn_type
in
[
2
,
3
]:
# absolute embeddings
elif
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
n_layer
):
for
i
in
range
(
n_layer
):
self
.
layers
.
append
(
self
.
layers
.
append
(
DecoderLayer
(
DecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
)
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
)
self
.
sample_softmax
=
sample_softmax
self
.
sample_softmax
=
sample_softmax
...
@@ -1035,9 +608,9 @@ class MemTransformerLM(nn.Module):
...
@@ -1035,9 +608,9 @@ class MemTransformerLM(nn.Module):
if
tie_projs
:
if
tie_projs
:
for
i
,
tie_proj
in
enumerate
(
tie_projs
):
for
i
,
tie_proj
in
enumerate
(
tie_projs
):
if
tie_proj
and
div_val
==
1
and
d_model
!=
d_embed
:
if
tie_proj
and
div_val
==
1
and
d_model
!=
d_embed
:
self
.
crit
.
out_projs
[
i
]
=
self
.
word_emb
.
emb_projs
[
0
]
self
.
crit
.
out_projs
[
i
]
.
weight
=
self
.
word_emb
.
emb_projs
[
0
]
.
weight
elif
tie_proj
and
div_val
!=
1
:
elif
tie_proj
and
div_val
!=
1
:
self
.
crit
.
out_projs
[
i
]
=
self
.
word_emb
.
emb_projs
[
i
]
self
.
crit
.
out_projs
[
i
]
.
weight
=
self
.
word_emb
.
emb_projs
[
i
]
.
weight
self
.
same_length
=
same_length
self
.
same_length
=
same_length
self
.
clamp_len
=
clamp_len
self
.
clamp_len
=
clamp_len
...
@@ -1070,12 +643,11 @@ class MemTransformerLM(nn.Module):
...
@@ -1070,12 +643,11 @@ class MemTransformerLM(nn.Module):
self
.
mem_len
=
mem_len
self
.
mem_len
=
mem_len
self
.
ext_len
=
ext_len
self
.
ext_len
=
ext_len
def
init_mems
(
self
):
def
init_mems
(
self
,
x
):
if
self
.
mem_len
>
0
:
if
self
.
mem_len
>
0
:
mems
=
[]
mems
=
[]
param
=
next
(
self
.
parameters
())
for
i
in
range
(
self
.
n_layer
+
1
):
for
i
in
range
(
self
.
n_layer
+
1
):
empty
=
torch
.
empty
(
0
,
dtype
=
param
.
dtype
,
device
=
param
.
device
)
empty
=
torch
.
empty
(
0
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
mems
.
append
(
empty
)
mems
.
append
(
empty
)
return
mems
return
mems
...
@@ -1126,7 +698,6 @@ class MemTransformerLM(nn.Module):
...
@@ -1126,7 +698,6 @@ class MemTransformerLM(nn.Module):
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
hids
=
[]
hids
=
[]
# relu_outs = []
if
self
.
attn_type
==
0
:
# default
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
dtype
=
word_emb
.
dtype
)
...
@@ -1140,11 +711,9 @@ class MemTransformerLM(nn.Module):
...
@@ -1140,11 +711,9 @@ class MemTransformerLM(nn.Module):
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
# core_out, relu_out = layer(core_out, pos_emb, self.r_w_bias,
core_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
core_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
# relu_outs.append(relu_out)
elif
self
.
attn_type
==
1
:
# learnable
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
...
@@ -1156,11 +725,9 @@ class MemTransformerLM(nn.Module):
...
@@ -1156,11 +725,9 @@ class MemTransformerLM(nn.Module):
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
# core_out, relu_out = layer(core_out, r_emb, self.r_w_bias[i],
core_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
core_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
# relu_outs.append(relu_out)
elif
self
.
attn_type
==
2
:
# absolute
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
dtype
=
word_emb
.
dtype
)
...
@@ -1175,11 +742,9 @@ class MemTransformerLM(nn.Module):
...
@@ -1175,11 +742,9 @@ class MemTransformerLM(nn.Module):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
mems_i
+=
pos_emb
[:
mlen
]
# core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
# relu_outs.append(relu_out)
elif
self
.
attn_type
==
3
:
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
...
@@ -1197,31 +762,25 @@ class MemTransformerLM(nn.Module):
...
@@ -1197,31 +762,25 @@ class MemTransformerLM(nn.Module):
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
# core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
# relu_outs.append(relu_out)
core_out
=
self
.
drop
(
core_out
)
core_out
=
self
.
drop
(
core_out
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
return
core_out
,
new_mems
return
core_out
,
new_mems
# return core_out, new_mems, relu_outs
def
forward
(
self
,
data
,
target
,
*
mems
):
def
forward
(
self
,
data
,
target
,
*
mems
):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
# them together.
if
not
mems
:
mems
=
self
.
init_mems
()
if
not
mems
:
mems
=
self
.
init_mems
(
data
)
tgt_len
=
target
.
size
(
0
)
tgt_len
=
target
.
size
(
0
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
# hidden, new_mems, relu_outs = self._forward(data, mems=mems)
# relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1)
pred_hid
=
hidden
[
-
tgt_len
:]
pred_hid
=
hidden
[
-
tgt_len
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
if
self
.
sample_softmax
>
0
and
self
.
training
:
...
@@ -1235,10 +794,8 @@ class MemTransformerLM(nn.Module):
...
@@ -1235,10 +794,8 @@ class MemTransformerLM(nn.Module):
if
new_mems
is
None
:
if
new_mems
is
None
:
return
[
loss
]
return
[
loss
]
# return [relu_outs, loss]
else
:
else
:
return
[
loss
]
+
new_mems
return
[
loss
]
+
new_mems
# return [relu_outs, loss] + new_mems
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
argparse
import
argparse
...
...
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
View file @
03b2a725
...
@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
...
@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
--batch_size
22
\
--batch_size
22
\
--multi_gpu
\
--multi_gpu
\
--gpu0_bsz
4
\
--gpu0_bsz
4
\
--moe
--moe-num-expert
64
--moe-top-k
2
\
${
@
:2
}
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
echo
'Run evaluation...'
...
...
examples/transformer-xl/train.py
View file @
03b2a725
...
@@ -4,7 +4,6 @@ import time
...
@@ -4,7 +4,6 @@ import time
import
math
import
math
import
os
,
sys
import
os
,
sys
import
itertools
import
itertools
import
pathlib
import
numpy
as
np
import
numpy
as
np
...
@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM
...
@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM
from
utils.exp_utils
import
create_exp_dir
from
utils.exp_utils
import
create_exp_dir
from
utils.data_parallel
import
BalancedDataParallel
from
utils.data_parallel
import
BalancedDataParallel
class
AverageMeter
(
object
):
"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
help
=
'location of the data corpus'
)
help
=
'location of the data corpus'
)
...
@@ -167,8 +141,15 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
...
@@ -167,8 +141,15 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
help
=
'Use dynamic loss scaling. If supplied, this argument'
help
=
'Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.'
)
' supersedes --static-loss-scale.'
)
parser
.
add_argument
(
'--moe'
,
action
=
'store_true'
,
help
=
'replace position-wise ffn with moe position-wise ffn'
)
parser
.
add_argument
(
'--moe-num-expert'
,
type
=
int
,
default
=
64
,
help
=
'number of experts in MoE'
)
parser
.
add_argument
(
'--moe-top-k'
,
type
=
int
,
default
=
2
,
help
=
'top_k experts in hard gate of moe'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
tied
=
not
args
.
not_tied
args
.
tied
=
not
args
.
not_tied
assert
args
.
moe_num_expert
>=
args
.
moe_top_k
,
"must have moe-num-expert >= moe-top_k"
if
args
.
d_embed
<
0
:
if
args
.
d_embed
<
0
:
args
.
d_embed
=
args
.
d_model
args
.
d_embed
=
args
.
d_model
...
@@ -305,7 +286,8 @@ else:
...
@@ -305,7 +286,8 @@ else:
tie_projs
=
tie_projs
,
pre_lnorm
=
args
.
pre_lnorm
,
tgt_len
=
args
.
tgt_len
,
tie_projs
=
tie_projs
,
pre_lnorm
=
args
.
pre_lnorm
,
tgt_len
=
args
.
tgt_len
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
)
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
,
moe
=
args
.
moe
,
moe_num_expert
=
args
.
moe_num_expert
,
moe_top_k
=
args
.
moe_top_k
)
model
.
apply
(
weights_init
)
model
.
apply
(
weights_init
)
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
...
@@ -412,9 +394,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
...
@@ -412,9 +394,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def
evaluate
(
eval_iter
):
def
evaluate
(
eval_iter
):
# Turn on evaluation mode which disables dropout.
# Turn on evaluation mode which disables dropout.
model
.
eval
()
model
.
eval
()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
# If the model does not use memory at all, make the ext_len longer.
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
# Otherwise, make the mem_len longer and keep the ext_len the same.
...
@@ -434,33 +413,15 @@ def evaluate(eval_iter):
...
@@ -434,33 +413,15 @@ def evaluate(eval_iter):
break
break
ret
=
model
(
data
,
target
,
*
mems
)
ret
=
model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss
=
loss
.
mean
()
loss
=
loss
.
mean
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_len
+=
seq_len
total_len
+=
seq_len
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
# Switch back to the training mode
# Switch back to the training mode
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
train
()
model
.
train
()
return
total_loss
/
total_len
return
total_loss
/
total_len
# return total_loss / total_len, avg_nnzs, act_hist, co_act_hist
def
train
():
def
train
():
...
@@ -471,11 +432,6 @@ def train():
...
@@ -471,11 +432,6 @@ def train():
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
else
:
else
:
mems
=
tuple
()
mems
=
tuple
()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
model
.
zero_grad
()
model
.
zero_grad
()
...
@@ -487,7 +443,6 @@ def train():
...
@@ -487,7 +443,6 @@ def train():
target_i
=
target_chunks
[
i
].
contiguous
()
target_i
=
target_chunks
[
i
].
contiguous
()
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems[i] = ret[0], ret[1], ret[2:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
...
@@ -497,28 +452,12 @@ def train():
...
@@ -497,28 +452,12 @@ def train():
else
:
else
:
ret
=
para_model
(
data
,
target
,
*
mems
)
ret
=
para_model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
else
:
else
:
loss
.
backward
()
loss
.
backward
()
train_loss
+=
loss
.
float
().
item
()
train_loss
+=
loss
.
float
().
item
()
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# # nnzs = [act.sum().item() / act.numel() for act in acts]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
clip_master_grads
(
args
.
clip
)
optimizer
.
clip_master_grads
(
args
.
clip
)
...
@@ -557,39 +496,12 @@ def train():
...
@@ -557,39 +496,12 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
else
:
else
:
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
# final_avg_nnzs = [avg_nnzs[i].avg for i in range(len(avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_avg_nnzs)/len(final_avg_nnzs)*100,
# max(final_avg_nnzs)*100,
# )
logging
(
log_str
)
logging
(
log_str
)
# co_act_dir = pathlib.Path(logging.keywords['log_path']).parent.joinpath("co_act")
# co_act_dir.mkdir(parents=True, exist_ok=True)
# co_act_path = co_act_dir.joinpath('epoch_%d_train_step_%d.pt' % (epoch, train_step))
# torch.save(co_act_hist, co_act_path)
# for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum()
# prob, index = torch.topk(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1,
# prob[:64].sum().item(),
# prob[:128].sum().item(),
# prob[:256].sum().item(),
# prob[:512].sum().item(),
# prob[:1024].sum().item()
# )
# logging(log_str)
# act_hist[i] = 0.
# co_act_hist[i] = 0.
train_loss
=
0
train_loss
=
0
log_start_time
=
time
.
time
()
log_start_time
=
time
.
time
()
if
train_step
%
args
.
eval_interval
==
0
:
if
train_step
%
args
.
eval_interval
==
0
:
val_loss
=
evaluate
(
va_iter
)
val_loss
=
evaluate
(
va_iter
)
# val_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(va_iter)
logging
(
'-'
*
100
)
logging
(
'-'
*
100
)
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
'| valid loss {:5.2f}'
.
format
(
'| valid loss {:5.2f}'
.
format
(
...
@@ -599,11 +511,6 @@ def train():
...
@@ -599,11 +511,6 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
else
:
else
:
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
logging
(
log_str
)
logging
(
log_str
)
logging
(
'-'
*
100
)
logging
(
'-'
*
100
)
# Save the model if the validation loss is the best we've seen so far.
# Save the model if the validation loss is the best we've seen so far.
...
@@ -653,7 +560,6 @@ para_model = model.to(device)
...
@@ -653,7 +560,6 @@ para_model = model.to(device)
# Run on test data.
# Run on test data.
test_loss
=
evaluate
(
te_iter
)
test_loss
=
evaluate
(
te_iter
)
# test_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(te_iter)
logging
(
'='
*
100
)
logging
(
'='
*
100
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
...
@@ -661,11 +567,4 @@ if args.dataset in ['enwik8', 'text8']:
...
@@ -661,11 +567,4 @@ if args.dataset in ['enwik8', 'text8']:
else
:
else
:
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
test_loss
,
math
.
exp
(
test_loss
)))
test_loss
,
math
.
exp
(
test_loss
)))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str = ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
# logging(log_str)
logging
(
'='
*
100
)
logging
(
'='
*
100
)
examples/transformer-xl/utils/proj_adaptive_softmax.py
View file @
03b2a725
...
@@ -9,6 +9,10 @@ import torch.nn.functional as F
...
@@ -9,6 +9,10 @@ import torch.nn.functional as F
CUDA_MAJOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
0
])
CUDA_MAJOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
0
])
CUDA_MINOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
1
])
CUDA_MINOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
1
])
class
Projection
(
nn
.
Module
):
def
__init__
(
self
,
out_feat
,
in_feat
):
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
out_feat
,
in_feat
))
class
ProjectedAdaptiveLogSoftmax
(
nn
.
Module
):
class
ProjectedAdaptiveLogSoftmax
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
keep_order
=
False
):
keep_order
=
False
):
...
@@ -31,13 +35,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
...
@@ -31,13 +35,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
self
.
cluster_bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
n_clusters
))
self
.
cluster_bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
n_clusters
))
self
.
out_layers
=
nn
.
ModuleList
()
self
.
out_layers
=
nn
.
ModuleList
()
self
.
out_projs
=
nn
.
Parameter
List
()
self
.
out_projs
=
nn
.
Module
List
()
if
div_val
==
1
:
if
div_val
==
1
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
for
i
in
range
(
len
(
self
.
cutoffs
)):
if
d_proj
!=
d_embed
:
if
d_proj
!=
d_embed
:
self
.
out_projs
.
append
(
self
.
out_projs
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
d_proj
,
d_embed
)
)
Projection
(
d_proj
,
d_embed
)
)
)
else
:
else
:
self
.
out_projs
.
append
(
None
)
self
.
out_projs
.
append
(
None
)
...
@@ -49,7 +53,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
...
@@ -49,7 +53,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
d_emb_i
=
d_embed
//
(
div_val
**
i
)
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
out_projs
.
append
(
self
.
out_projs
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
d_proj
,
d_emb_i
)
)
Projection
(
d_proj
,
d_emb_i
)
)
)
self
.
out_layers
.
append
(
nn
.
Linear
(
d_emb_i
,
r_idx
-
l_idx
))
self
.
out_layers
.
append
(
nn
.
Linear
(
d_emb_i
,
r_idx
-
l_idx
))
...
@@ -82,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
...
@@ -82,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if
self
.
n_clusters
==
0
:
if
self
.
n_clusters
==
0
:
logit
=
self
.
_compute_logit
(
hidden
,
self
.
out_layers
[
0
].
weight
,
logit
=
self
.
_compute_logit
(
hidden
,
self
.
out_layers
[
0
].
weight
,
self
.
out_layers
[
0
].
bias
,
self
.
out_projs
[
0
])
self
.
out_layers
[
0
].
bias
,
self
.
out_projs
[
0
]
.
weight
if
self
.
out_projs
[
0
]
is
not
None
else
None
)
nll
=
-
F
.
log_softmax
(
logit
,
dim
=-
1
)
\
nll
=
-
F
.
log_softmax
(
logit
,
dim
=-
1
)
\
.
gather
(
1
,
target
.
unsqueeze
(
1
)).
squeeze
(
1
)
.
gather
(
1
,
target
.
unsqueeze
(
1
)).
squeeze
(
1
)
else
:
else
:
...
@@ -106,7 +110,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
...
@@ -106,7 +110,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
weights
.
append
(
weight_i
)
weights
.
append
(
weight_i
)
biases
.
append
(
bias_i
)
biases
.
append
(
bias_i
)
head_weight
,
head_bias
,
head_proj
=
weights
[
0
],
biases
[
0
],
self
.
out_projs
[
0
]
head_weight
,
head_bias
,
head_proj
=
weights
[
0
],
biases
[
0
],
self
.
out_projs
[
0
]
.
weight
head_logit
=
self
.
_compute_logit
(
hidden
,
head_weight
,
head_bias
,
head_proj
)
head_logit
=
self
.
_compute_logit
(
hidden
,
head_weight
,
head_bias
,
head_proj
)
head_logprob
=
F
.
log_softmax
(
head_logit
,
dim
=
1
)
head_logprob
=
F
.
log_softmax
(
head_logit
,
dim
=
1
)
...
@@ -131,7 +135,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
...
@@ -131,7 +135,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if
i
==
0
:
if
i
==
0
:
logprob_i
=
head_logprob_i
.
gather
(
1
,
target_i
[:,
None
]).
squeeze
(
1
)
logprob_i
=
head_logprob_i
.
gather
(
1
,
target_i
[:,
None
]).
squeeze
(
1
)
else
:
else
:
weight_i
,
bias_i
,
proj_i
=
weights
[
i
],
biases
[
i
],
self
.
out_projs
[
i
]
weight_i
,
bias_i
,
proj_i
=
weights
[
i
],
biases
[
i
],
self
.
out_projs
[
i
]
.
weight
hidden_i
=
hidden
.
index_select
(
0
,
indices_i
)
hidden_i
=
hidden
.
index_select
(
0
,
indices_i
)
...
...
fmoe/transformer.py
View file @
03b2a725
...
@@ -49,10 +49,12 @@ class FMoETransformerMLP(FMoE):
...
@@ -49,10 +49,12 @@ class FMoETransformerMLP(FMoE):
top_k
=
2
,
top_k
=
2
,
do_lnorm
=
False
,
do_lnorm
=
False
,
pre_lnorm
=
False
,
pre_lnorm
=
False
,
expert_dp_comm
=
'none'
expert_dp_comm
=
'none'
,
dropout
=
0.1
):
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
)
rank
=
self
.
mp_rank
)
self
.
pre_lnorm
=
pre_lnorm
self
.
pre_lnorm
=
pre_lnorm
...
@@ -72,7 +74,9 @@ class FMoETransformerMLP(FMoE):
...
@@ -72,7 +74,9 @@ class FMoETransformerMLP(FMoE):
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
if
self
.
pre_lnorm
is
not
None
and
self
.
pre_lnorm
:
if
self
.
pre_lnorm
is
not
None
and
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
inp
=
self
.
layer_norm
(
inp
)
output
=
super
().
forward
(
inp
)
+
inp
output
=
super
().
forward
(
inp
)
output
=
self
.
dropout
(
output
)
output
+=
inp
if
self
.
pre_lnorm
is
not
None
and
not
self
.
pre_lnorm
:
if
self
.
pre_lnorm
is
not
None
and
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
output
=
self
.
layer_norm
(
output
)
return
output
.
reshape
(
original_shape
)
return
output
.
reshape
(
original_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment