Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
03b2a725
Unverified
Commit
03b2a725
authored
Feb 25, 2021
by
Rick Ho
Committed by
GitHub
Feb 25, 2021
Browse files
Merge pull request #6 from xfmr-xl
Test Transformer-XL
parents
e86dea53
0a942e3f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
88 additions
and
620 deletions
+88
-620
.gitignore
.gitignore
+3
-0
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+59
-502
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
+1
-0
examples/transformer-xl/train.py
examples/transformer-xl/train.py
+9
-110
examples/transformer-xl/utils/proj_adaptive_softmax.py
examples/transformer-xl/utils/proj_adaptive_softmax.py
+10
-6
fmoe/transformer.py
fmoe/transformer.py
+6
-2
No files found.
.gitignore
View file @
03b2a725
...
...
@@ -10,3 +10,6 @@ a.out
build
*swp
logs
examples/transformer-xl/data
examples/data
examples/transformer-xl/LM-TFM-enwik8
examples/transformer-xl/mem_transformer.py
View file @
03b2a725
...
...
@@ -7,10 +7,9 @@ import numpy as np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
# import torch_sparse
sys
.
path
.
append
(
'utils'
)
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
,
Projection
from
log_uniform_sampler
import
LogUniformSampler
,
sample_logits
class
PositionalEmbedding
(
nn
.
Module
):
...
...
@@ -32,361 +31,16 @@ class PositionalEmbedding(nn.Module):
return
pos_emb
[:,
None
,:]
# A baseline naive slow implementation
class
MoEPositionwiseFFRaw
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
top_k
=
64
):
super
(
MoEPositionwiseFFRaw
,
self
).
__init__
()
print
(
"MoEPositionwiseFF"
)
self
.
top_k
=
top_k
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
gate
=
nn
.
Linear
(
d_model
,
d_inner
)
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_inner
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
ratio
=
top_k
/
d_inner
self
.
dropout_middle
=
nn
.
Dropout
(
dropout
*
ratio
)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp_Linear
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp_Linear
.
weight
.
data
.
transpose
(
0
,
1
)
self
.
b2
.
data
=
temp_Linear
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
relu_out
=
F
.
relu
(
gate_top_k_val
)
x
=
self
.
dropout_middle
(
relu_out
)
W2_select
=
self
.
W2
[
gate_top_k_idx
]
# [.. x top_k x d_model]
core_out
=
torch
.
einsum
(
'ijk,ijkd->ijd'
,
(
x
,
W2_select
))
+
self
.
b2
# [.. x d_model]
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
def
my_topk
(
x
,
k
,
inplace
=
True
):
y
=
x
if
inplace
else
x
.
clone
()
top1_val
,
top1_idx
=
torch
.
max
(
y
,
dim
=-
1
)
top1_val
=
top1_val
.
unsqueeze
(
-
1
)
top1_idx
=
top1_idx
.
unsqueeze
(
-
1
)
if
k
==
1
:
return
top1_val
,
top1_idx
y
.
scatter_
(
-
1
,
top1_idx
,
value
=
float
(
'-inf'
))
top2_val
,
top2_idx
=
torch
.
max
(
y
,
dim
=-
1
)
top2_val
=
top2_val
.
unsqueeze
(
-
1
)
top2_idx
=
top2_idx
.
unsqueeze
(
-
1
)
top_val
=
torch
.
cat
((
top1_val
,
top2_val
),
dim
=-
1
)
top_idx
=
torch
.
cat
((
top1_idx
,
top2_idx
),
dim
=-
1
)
return
top_val
,
top_idx
class
MultiHeadHierarchicalMoEPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_block
=
16
,
top_block
=
2
):
super
(
MultiHeadHierarchicalMoEPositionwiseFF
,
self
).
__init__
()
print
(
"MultiHeadHierarchicalMoEPositionwiseFF"
)
assert
d_inner
%
n_block
==
0
assert
top_block
in
[
1
,
2
]
self
.
top_block
=
top_block
self
.
n_block
=
n_block
d_block
=
d_inner
//
n_block
self
.
d_block
=
d_block
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
block_net_W
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
,
top_block
,
n_block
))
self
.
block_net_b
=
nn
.
Parameter
(
torch
.
Tensor
(
top_block
,
n_block
))
self
.
W1
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
,
d_model
))
self
.
b1
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
))
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
ratio
=
top_block
/
n_block
self
.
dropout_middle
=
nn
.
Dropout
(
dropout
*
ratio
)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
# self.scale = 1 / (d_model ** 0.5)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp
=
nn
.
Linear
(
self
.
d_model
,
self
.
d_inner
)
self
.
W1
.
data
=
temp
.
weight
.
data
.
view
(
self
.
n_block
,
self
.
d_block
,
self
.
d_model
)
self
.
b1
.
data
=
temp
.
bias
.
data
.
view
(
self
.
n_block
,
self
.
d_block
)
temp
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
().
view
(
self
.
n_block
,
self
.
d_block
,
self
.
d_model
)
self
.
b2
.
data
=
temp
.
bias
.
data
for
i
in
range
(
self
.
top_block
):
temp
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_block
)
self
.
block_net_W
.
data
[:,
i
]
=
temp
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
()
self
.
block_net_b
.
data
[
i
]
=
temp
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
block
=
torch
.
einsum
(
"ibd,dan->iban"
,
(
inp
,
self
.
block_net_W
))
+
self
.
block_net_b
# [.. x top_block x n_block ]
block_val
,
block_idx
=
my_topk
(
block
,
k
=
1
,
inplace
=
True
)
# block_val, block_idx = torch.topk(block, k=1, dim=-1, largest=True, sorted=False) # [.. x top_k x 1]
block_val
=
block_val
.
squeeze
(
-
1
)
block_idx
=
block_idx
.
squeeze
(
-
1
)
gate
=
F
.
softmax
(
block_val
,
dim
=-
1
)
W1_block
=
self
.
W1
[
block_idx
]
# [.. x top_k x d_block x d_model]
b1_block
=
self
.
b1
[
block_idx
]
# [.. x top_k x d_block]
x
=
torch
.
einsum
(
'ibd,ibnhd->ibnh'
,
(
inp
,
W1_block
))
+
b1_block
# [.. x top_k x d_block]
# x = x + block_val.unsqueeze(-1) # somehow like residual
x
=
x
*
gate
.
unsqueeze
(
-
1
)
relu_out
=
F
.
relu
(
x
)
relu_out
=
self
.
dropout_middle
(
relu_out
)
W2_block
=
self
.
W2
[
block_idx
]
# [.. x top_k x d_model]
core_out
=
torch
.
einsum
(
'ibnh,ibnhd->ibd'
,
(
x
,
W2_block
))
+
self
.
b2
# [.. x d_model]
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
HierarchicalMoEPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_block
=
16
,
top_block
=
2
):
super
(
HierarchicalMoEPositionwiseFF
,
self
).
__init__
()
print
(
"HierarchicalMoEPositionwiseFF"
)
assert
d_inner
%
n_block
==
0
assert
top_block
in
[
1
,
2
]
self
.
top_block
=
top_block
self
.
n_block
=
n_block
d_block
=
d_inner
//
n_block
self
.
d_block
=
d_block
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
block_net
=
nn
.
Linear
(
d_model
,
n_block
,
bias
=
True
)
self
.
W1
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
,
d_model
))
self
.
b1
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
))
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
ratio
=
top_block
/
n_block
self
.
dropout_middle
=
nn
.
Dropout
(
dropout
*
ratio
)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
# self.scale = 1 / (d_model ** 0.5)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp
=
nn
.
Linear
(
self
.
d_model
,
self
.
d_inner
)
self
.
W1
.
data
=
temp
.
weight
.
data
.
view
(
self
.
n_block
,
self
.
d_block
,
self
.
d_model
)
self
.
b1
.
data
=
temp
.
bias
.
data
.
view
(
self
.
n_block
,
self
.
d_block
)
temp
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
().
view
(
self
.
n_block
,
self
.
d_block
,
self
.
d_model
)
self
.
b2
.
data
=
temp
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
block
=
self
.
block_net
(
inp
)
# block_val, block_idx = my_topk(block, k=self.top_block)
block_val
,
block_idx
=
torch
.
topk
(
block
,
k
=
self
.
top_block
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate
=
F
.
softmax
(
block_val
,
dim
=-
1
)
W1_block
=
self
.
W1
[
block_idx
]
# [.. x top_k x d_block x d_model]
b1_block
=
self
.
b1
[
block_idx
]
# [.. x top_k x d_block]
x
=
torch
.
einsum
(
'ibd,ibnhd->ibnh'
,
(
inp
,
W1_block
))
+
b1_block
# [.. x top_k x d_block]
# x = x + block_val.unsqueeze(-1) # somehow like residual
x
=
x
*
gate
.
unsqueeze
(
-
1
)
relu_out
=
F
.
relu
(
x
)
relu_out
=
self
.
dropout_middle
(
relu_out
)
W2_block
=
self
.
W2
[
block_idx
]
# [.. x top_k x d_model]
core_out
=
torch
.
einsum
(
'ibnh,ibnhd->ibd'
,
(
x
,
W2_block
))
+
self
.
b2
# [.. x d_model]
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
SparsePositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
SparsePositionwiseFF
,
self
).
__init__
()
print
(
"SparsePositionwiseFF"
)
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
CoreNet_1
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Dropout
(
dropout
)
)
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_inner
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp_Linear
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp_Linear
.
weight
.
data
.
transpose
(
0
,
1
)
self
.
b2
.
data
=
temp_Linear
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
relu_out
=
self
.
CoreNet_1
(
inp
).
view
(
-
1
,
self
.
d_inner
)
sparse_relu_out
=
torch_sparse
.
SparseTensor
.
from_dense
(
relu_out
)
core_out
=
torch_sparse
.
matmul
(
sparse_relu_out
,
self
.
W2
)
+
self
.
b2
core_out
=
core_out
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
d_model
)
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
MultiHeadPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_head
=
2
):
super
(
MultiHeadPositionwiseFF
,
self
).
__init__
()
print
(
"MultiHeadPositionwiseFF"
)
assert
d_model
%
n_head
==
0
self
.
n_head
=
n_head
d_head
=
d_model
//
n_head
self
.
d_head
=
d_head
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
k_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_inner
,
d_head
))
self
.
k_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_inner
))
self
.
v_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_head
,
d_inner
))
self
.
v_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_head
))
#self.o_net = nn.Linear(d_model, d_model)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
for
i
in
range
(
self
.
n_head
):
tmp
=
nn
.
Linear
(
self
.
d_head
,
self
.
d_inner
)
self
.
k_weight
.
data
[
i
]
=
tmp
.
weight
.
data
self
.
k_bias
.
data
[
i
]
=
tmp
.
bias
.
data
tmp
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_head
)
self
.
v_weight
.
data
[
i
]
=
tmp
.
weight
.
data
self
.
v_bias
.
data
[
i
]
=
tmp
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
head_q
=
self
.
q_net
(
inp
)
head_q
=
head_q
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [.. x n_head x d_head]
attn_score
=
torch
.
einsum
(
'ibnd,nhd->ibnh'
,
(
head_q
,
self
.
k_weight
))
+
self
.
k_bias
# [.. x n_head x d_inner]
attn_score
=
F
.
relu
(
attn_score
)
attn_score
=
self
.
dropout
(
attn_score
)
attn_vec
=
torch
.
einsum
(
'ibnh,ndh->ibnd'
,
(
attn_score
,
self
.
v_weight
))
+
self
.
v_bias
attn_vec
=
attn_vec
.
contiguous
().
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
d_model
)
# core_out = self.o_net(attn_vec)
core_out
=
self
.
dropout
(
attn_vec
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
PositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
use_softmax
=
True
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
PositionwiseFF
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
CoreNet_1
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
Softmax
(
dim
=-
1
)
if
use_softmax
else
nn
.
ReLU
(
inplace
=
True
)
)
self
.
CoreNet_2
=
nn
.
Sequential
(
self
.
CoreNet
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Dropout
(
dropout
),
...
...
@@ -399,113 +53,18 @@ class PositionwiseFF(nn.Module):
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
relu_out
=
self
.
CoreNet_1
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
CoreNet_2
(
relu_out
)
core_out
=
self
.
CoreNet
(
self
.
layer_norm
(
inp
))
##### residual connection
output
=
core_out
+
inp
else
:
##### positionwise feed-forward
relu_out
=
self
.
CoreNet_1
(
inp
)
core_out
=
self
.
CoreNet_2
(
relu_out
)
core_out
=
self
.
CoreNet
(
inp
)
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
# return output, relu_out.detach()
class
ExtendedMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
):
super
(
ExtendedMultiHeadAttn
,
self
).
__init__
()
print
(
"ExtendedMultiHeadAttn"
)
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
n_head
*
d_head
,
bias
=
False
)
self
.
kv_net
=
nn
.
Linear
(
d_model
,
2
*
n_head
*
d_head
,
bias
=
False
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
*
2
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
# self.coeff = nn.Parameter(torch.Tensor(n_head, 2))
# nn.init.uniform_(self.coeff, a=-1, b=1)
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if
mems
is
not
None
:
c
=
torch
.
cat
([
mems
,
h
],
0
)
mem_len
=
mems
.
size
(
0
)
else
:
c
=
h
mem_len
=
0
if
self
.
pre_lnorm
:
##### layer normalization
c
=
self
.
layer_norm
(
c
)
head_q
=
self
.
q_net
(
c
)
head_k
,
head_v
=
torch
.
chunk
(
self
.
kv_net
(
c
),
2
,
-
1
)
head_q
=
head_q
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_k
=
head_k
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_v
=
head_v
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [qlen x klen x bsz x n_head]
attn_score
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
head_q
,
head_k
))
attn_score
.
mul_
(
self
.
scale
)
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
.
dim
()
==
2
:
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[
None
,:,:,
None
].
bool
(),
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[:,:,:,
None
].
bool
(),
-
float
(
'inf'
))
mem2other_attn
=
attn_mask
.
new_ones
(
mem_len
,
c
.
size
(
0
))
mem2other_attn
[:,
:
mem_len
]
=
0
attn_score
[:
mem_len
].
masked_fill_
(
mem2other_attn
[:,
:,
None
,
None
].
bool
(),
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec_quad
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
attn_vec
))
# [qlen x bsz x n_head x d_head x 2]
attn_vecs
=
torch
.
cat
([
attn_vec
.
unsqueeze
(
-
1
),
attn_vec_quad
.
unsqueeze
(
-
1
)],
dim
=-
1
)
# attn_vec = torch.einsum('ibndt,nt->ibnd', (attn_vecs, self.coeff))
attn_vec
=
attn_vecs
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
*
2
)
attn_vec
=
attn_vec
[
mem_len
:]
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
output
=
h
+
attn_out
else
:
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
h
+
attn_out
)
return
output
class
MultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
...
...
@@ -583,7 +142,8 @@ class MultiHeadAttn(nn.Module):
class
RelMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
):
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
moe
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
self
.
n_head
=
n_head
...
...
@@ -816,42 +376,41 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
return
output
from
fmoe
import
FMoETransformerMLP
class
CustomizedMoEPositionwiseFF
(
FMoETransformerMLP
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
def
activation
(
x
):
return
self
.
dropout
(
F
.
relu
(
x
))
super
().
__init__
(
num_expert
=
8
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
pre_lnorm
=
pre_lnorm
,
activation
=
activation
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
d_model
,
dtype
=
torch
.
float32
)
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
activation
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
ReLU
()
)
super
().
__init__
(
num_expert
=
moe_num_expert
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
top_k
=
moe_top_k
,
do_lnorm
=
True
,
pre_lnorm
=
pre_lnorm
,
activation
=
activation
,
dropout
=
dropout
)
def
forward
(
self
,
x
):
x
=
super
().
forward
(
x
)
return
x
+
self
.
bias
return
x
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
if
kwargs
.
get
(
'moe'
)
is
False
:
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
# output, relu_out = self.pos_ff(output)
return
output
# return output, relu_out
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
...
@@ -860,8 +419,15 @@ class RelLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
if
kwargs
.
get
(
'moe'
)
is
False
:
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
...
@@ -869,10 +435,8 @@ class RelLearnableDecoderLayer(nn.Module):
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
# output, relu_out = self.pos_ff(output)
return
output
# return output, relu_out
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
...
@@ -881,8 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
if
kwargs
.
get
(
'moe'
)
is
False
:
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
...
@@ -890,10 +461,8 @@ class RelPartialLearnableDecoderLayer(nn.Module):
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
# output, relu_out = self.pos_ff(output)
return
output
# return output, relu_out
class
AdaptiveEmbedding
(
nn
.
Module
):
...
...
@@ -913,25 +482,26 @@ class AdaptiveEmbedding(nn.Module):
self
.
cutoff_ends
=
[
0
]
+
self
.
cutoffs
self
.
emb_layers
=
nn
.
ModuleList
()
self
.
emb_projs
=
nn
.
ParameterList
()
self
.
emb_projs
=
nn
.
ModuleList
()
if
div_val
==
1
:
self
.
emb_layers
.
append
(
nn
.
Embedding
(
n_token
,
d_embed
,
sparse
=
sample_softmax
>
0
)
)
if
d_proj
!=
d_embed
:
self
.
emb_projs
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
d_proj
,
d_embed
))
)
self
.
emb_projs
.
append
(
Projection
(
d_proj
,
d_embed
))
else
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
emb_layers
.
append
(
nn
.
Embedding
(
r_idx
-
l_idx
,
d_emb_i
))
self
.
emb_projs
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
d_proj
,
d_emb_i
))
)
self
.
emb_projs
.
append
(
Projectio
(
d_proj
,
d_emb_i
))
def
forward
(
self
,
inp
):
if
self
.
div_val
==
1
:
embed
=
self
.
emb_layers
[
0
](
inp
)
if
self
.
d_proj
!=
self
.
d_embed
:
embed
=
F
.
linear
(
embed
,
self
.
emb_projs
[
0
])
embed
=
F
.
linear
(
embed
,
self
.
emb_projs
[
0
]
.
weight
)
else
:
param
=
next
(
self
.
parameters
())
inp_flat
=
inp
.
view
(
-
1
)
...
...
@@ -948,7 +518,7 @@ class AdaptiveEmbedding(nn.Module):
inp_i
=
inp_flat
.
index_select
(
0
,
indices_i
)
-
l_idx
emb_i
=
self
.
emb_layers
[
i
](
inp_i
)
emb_i
=
F
.
linear
(
emb_i
,
self
.
emb_projs
[
i
])
emb_i
=
F
.
linear
(
emb_i
,
self
.
emb_projs
[
i
]
.
weight
)
emb_flat
.
index_copy_
(
0
,
indices_i
,
emb_i
)
...
...
@@ -965,7 +535,7 @@ class MemTransformerLM(nn.Module):
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
cutoffs
=
[],
adapt_inp
=
False
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
sample_softmax
=-
1
):
sample_softmax
=-
1
,
moe
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
MemTransformerLM
,
self
).
__init__
()
self
.
n_token
=
n_token
...
...
@@ -996,7 +566,8 @@ class MemTransformerLM(nn.Module):
RelPartialLearnableDecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
)
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
elif
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
n_layer
):
...
...
@@ -1004,14 +575,16 @@ class MemTransformerLM(nn.Module):
RelLearnableDecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
)
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
elif
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
n_layer
):
self
.
layers
.
append
(
DecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
)
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
self
.
sample_softmax
=
sample_softmax
...
...
@@ -1035,9 +608,9 @@ class MemTransformerLM(nn.Module):
if
tie_projs
:
for
i
,
tie_proj
in
enumerate
(
tie_projs
):
if
tie_proj
and
div_val
==
1
and
d_model
!=
d_embed
:
self
.
crit
.
out_projs
[
i
]
=
self
.
word_emb
.
emb_projs
[
0
]
self
.
crit
.
out_projs
[
i
]
.
weight
=
self
.
word_emb
.
emb_projs
[
0
]
.
weight
elif
tie_proj
and
div_val
!=
1
:
self
.
crit
.
out_projs
[
i
]
=
self
.
word_emb
.
emb_projs
[
i
]
self
.
crit
.
out_projs
[
i
]
.
weight
=
self
.
word_emb
.
emb_projs
[
i
]
.
weight
self
.
same_length
=
same_length
self
.
clamp_len
=
clamp_len
...
...
@@ -1070,12 +643,11 @@ class MemTransformerLM(nn.Module):
self
.
mem_len
=
mem_len
self
.
ext_len
=
ext_len
def
init_mems
(
self
):
def
init_mems
(
self
,
x
):
if
self
.
mem_len
>
0
:
mems
=
[]
param
=
next
(
self
.
parameters
())
for
i
in
range
(
self
.
n_layer
+
1
):
empty
=
torch
.
empty
(
0
,
dtype
=
param
.
dtype
,
device
=
param
.
device
)
empty
=
torch
.
empty
(
0
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
mems
.
append
(
empty
)
return
mems
...
...
@@ -1126,7 +698,6 @@ class MemTransformerLM(nn.Module):
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
hids
=
[]
# relu_outs = []
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
...
...
@@ -1140,11 +711,9 @@ class MemTransformerLM(nn.Module):
hids
.
append
(
core_out
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
# core_out, relu_out = layer(core_out, pos_emb, self.r_w_bias,
core_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
# relu_outs.append(relu_out)
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
hids
.
append
(
core_out
)
...
...
@@ -1156,11 +725,9 @@ class MemTransformerLM(nn.Module):
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
# core_out, relu_out = layer(core_out, r_emb, self.r_w_bias[i],
core_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
# relu_outs.append(relu_out)
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
...
...
@@ -1175,11 +742,9 @@ class MemTransformerLM(nn.Module):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
# core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
# relu_outs.append(relu_out)
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
...
...
@@ -1197,31 +762,25 @@ class MemTransformerLM(nn.Module):
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
# core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
# relu_outs.append(relu_out)
core_out
=
self
.
drop
(
core_out
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
return
core_out
,
new_mems
# return core_out, new_mems, relu_outs
def
forward
(
self
,
data
,
target
,
*
mems
):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
if
not
mems
:
mems
=
self
.
init_mems
()
if
not
mems
:
mems
=
self
.
init_mems
(
data
)
tgt_len
=
target
.
size
(
0
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
# hidden, new_mems, relu_outs = self._forward(data, mems=mems)
# relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1)
pred_hid
=
hidden
[
-
tgt_len
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
...
...
@@ -1235,10 +794,8 @@ class MemTransformerLM(nn.Module):
if
new_mems
is
None
:
return
[
loss
]
# return [relu_outs, loss]
else
:
return
[
loss
]
+
new_mems
# return [relu_outs, loss] + new_mems
if
__name__
==
'__main__'
:
import
argparse
...
...
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
View file @
03b2a725
...
...
@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
--batch_size
22
\
--multi_gpu
\
--gpu0_bsz
4
\
--moe
--moe-num-expert
64
--moe-top-k
2
\
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
...
...
examples/transformer-xl/train.py
View file @
03b2a725
...
...
@@ -4,7 +4,6 @@ import time
import
math
import
os
,
sys
import
itertools
import
pathlib
import
numpy
as
np
...
...
@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM
from
utils.exp_utils
import
create_exp_dir
from
utils.data_parallel
import
BalancedDataParallel
class
AverageMeter
(
object
):
"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
help
=
'location of the data corpus'
)
...
...
@@ -167,8 +141,15 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
help
=
'Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.'
)
parser
.
add_argument
(
'--moe'
,
action
=
'store_true'
,
help
=
'replace position-wise ffn with moe position-wise ffn'
)
parser
.
add_argument
(
'--moe-num-expert'
,
type
=
int
,
default
=
64
,
help
=
'number of experts in MoE'
)
parser
.
add_argument
(
'--moe-top-k'
,
type
=
int
,
default
=
2
,
help
=
'top_k experts in hard gate of moe'
)
args
=
parser
.
parse_args
()
args
.
tied
=
not
args
.
not_tied
assert
args
.
moe_num_expert
>=
args
.
moe_top_k
,
"must have moe-num-expert >= moe-top_k"
if
args
.
d_embed
<
0
:
args
.
d_embed
=
args
.
d_model
...
...
@@ -305,7 +286,8 @@ else:
tie_projs
=
tie_projs
,
pre_lnorm
=
args
.
pre_lnorm
,
tgt_len
=
args
.
tgt_len
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
)
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
,
moe
=
args
.
moe
,
moe_num_expert
=
args
.
moe_num_expert
,
moe_top_k
=
args
.
moe_top_k
)
model
.
apply
(
weights_init
)
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
...
...
@@ -412,9 +394,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def
evaluate
(
eval_iter
):
# Turn on evaluation mode which disables dropout.
model
.
eval
()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
...
...
@@ -434,33 +413,15 @@ def evaluate(eval_iter):
break
ret
=
model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss
=
loss
.
mean
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_len
+=
seq_len
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
# Switch back to the training mode
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
train
()
return
total_loss
/
total_len
# return total_loss / total_len, avg_nnzs, act_hist, co_act_hist
def
train
():
...
...
@@ -471,11 +432,6 @@ def train():
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
else
:
mems
=
tuple
()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
model
.
zero_grad
()
...
...
@@ -487,7 +443,6 @@ def train():
target_i
=
target_chunks
[
i
].
contiguous
()
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems[i] = ret[0], ret[1], ret[2:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
...
...
@@ -497,28 +452,12 @@ def train():
else
:
ret
=
para_model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
train_loss
+=
loss
.
float
().
item
()
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# # nnzs = [act.sum().item() / act.numel() for act in acts]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
if
args
.
fp16
:
optimizer
.
clip_master_grads
(
args
.
clip
)
...
...
@@ -557,39 +496,12 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
else
:
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
# final_avg_nnzs = [avg_nnzs[i].avg for i in range(len(avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_avg_nnzs)/len(final_avg_nnzs)*100,
# max(final_avg_nnzs)*100,
# )
logging
(
log_str
)
# co_act_dir = pathlib.Path(logging.keywords['log_path']).parent.joinpath("co_act")
# co_act_dir.mkdir(parents=True, exist_ok=True)
# co_act_path = co_act_dir.joinpath('epoch_%d_train_step_%d.pt' % (epoch, train_step))
# torch.save(co_act_hist, co_act_path)
# for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum()
# prob, index = torch.topk(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1,
# prob[:64].sum().item(),
# prob[:128].sum().item(),
# prob[:256].sum().item(),
# prob[:512].sum().item(),
# prob[:1024].sum().item()
# )
# logging(log_str)
# act_hist[i] = 0.
# co_act_hist[i] = 0.
train_loss
=
0
log_start_time
=
time
.
time
()
if
train_step
%
args
.
eval_interval
==
0
:
val_loss
=
evaluate
(
va_iter
)
# val_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(va_iter)
logging
(
'-'
*
100
)
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
'| valid loss {:5.2f}'
.
format
(
...
...
@@ -599,11 +511,6 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
else
:
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
logging
(
log_str
)
logging
(
'-'
*
100
)
# Save the model if the validation loss is the best we've seen so far.
...
...
@@ -653,7 +560,6 @@ para_model = model.to(device)
# Run on test data.
test_loss
=
evaluate
(
te_iter
)
# test_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(te_iter)
logging
(
'='
*
100
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
...
...
@@ -661,11 +567,4 @@ if args.dataset in ['enwik8', 'text8']:
else
:
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
test_loss
,
math
.
exp
(
test_loss
)))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str = ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
# logging(log_str)
logging
(
'='
*
100
)
examples/transformer-xl/utils/proj_adaptive_softmax.py
View file @
03b2a725
...
...
@@ -9,6 +9,10 @@ import torch.nn.functional as F
CUDA_MAJOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
0
])
CUDA_MINOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
1
])
class
Projection
(
nn
.
Module
):
def
__init__
(
self
,
out_feat
,
in_feat
):
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
out_feat
,
in_feat
))
class
ProjectedAdaptiveLogSoftmax
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
keep_order
=
False
):
...
...
@@ -31,13 +35,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
self
.
cluster_bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
n_clusters
))
self
.
out_layers
=
nn
.
ModuleList
()
self
.
out_projs
=
nn
.
Parameter
List
()
self
.
out_projs
=
nn
.
Module
List
()
if
div_val
==
1
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
if
d_proj
!=
d_embed
:
self
.
out_projs
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
d_proj
,
d_embed
)
)
Projection
(
d_proj
,
d_embed
)
)
else
:
self
.
out_projs
.
append
(
None
)
...
...
@@ -49,7 +53,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
out_projs
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
d_proj
,
d_emb_i
)
)
Projection
(
d_proj
,
d_emb_i
)
)
self
.
out_layers
.
append
(
nn
.
Linear
(
d_emb_i
,
r_idx
-
l_idx
))
...
...
@@ -82,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if
self
.
n_clusters
==
0
:
logit
=
self
.
_compute_logit
(
hidden
,
self
.
out_layers
[
0
].
weight
,
self
.
out_layers
[
0
].
bias
,
self
.
out_projs
[
0
])
self
.
out_layers
[
0
].
bias
,
self
.
out_projs
[
0
]
.
weight
if
self
.
out_projs
[
0
]
is
not
None
else
None
)
nll
=
-
F
.
log_softmax
(
logit
,
dim
=-
1
)
\
.
gather
(
1
,
target
.
unsqueeze
(
1
)).
squeeze
(
1
)
else
:
...
...
@@ -106,7 +110,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
weights
.
append
(
weight_i
)
biases
.
append
(
bias_i
)
head_weight
,
head_bias
,
head_proj
=
weights
[
0
],
biases
[
0
],
self
.
out_projs
[
0
]
head_weight
,
head_bias
,
head_proj
=
weights
[
0
],
biases
[
0
],
self
.
out_projs
[
0
]
.
weight
head_logit
=
self
.
_compute_logit
(
hidden
,
head_weight
,
head_bias
,
head_proj
)
head_logprob
=
F
.
log_softmax
(
head_logit
,
dim
=
1
)
...
...
@@ -131,7 +135,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if
i
==
0
:
logprob_i
=
head_logprob_i
.
gather
(
1
,
target_i
[:,
None
]).
squeeze
(
1
)
else
:
weight_i
,
bias_i
,
proj_i
=
weights
[
i
],
biases
[
i
],
self
.
out_projs
[
i
]
weight_i
,
bias_i
,
proj_i
=
weights
[
i
],
biases
[
i
],
self
.
out_projs
[
i
]
.
weight
hidden_i
=
hidden
.
index_select
(
0
,
indices_i
)
...
...
fmoe/transformer.py
View file @
03b2a725
...
...
@@ -49,10 +49,12 @@ class FMoETransformerMLP(FMoE):
top_k
=
2
,
do_lnorm
=
False
,
pre_lnorm
=
False
,
expert_dp_comm
=
'none'
expert_dp_comm
=
'none'
,
dropout
=
0.1
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
)
self
.
pre_lnorm
=
pre_lnorm
...
...
@@ -72,7 +74,9 @@ class FMoETransformerMLP(FMoE):
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
if
self
.
pre_lnorm
is
not
None
and
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
output
=
super
().
forward
(
inp
)
+
inp
output
=
super
().
forward
(
inp
)
output
=
self
.
dropout
(
output
)
output
+=
inp
if
self
.
pre_lnorm
is
not
None
and
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
.
reshape
(
original_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment