Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
bdb64914
Commit
bdb64914
authored
Feb 25, 2021
by
Jiezhong Qiu
Browse files
remove unused code
parent
5dc62b41
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
580 deletions
+5
-580
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+5
-471
examples/transformer-xl/train.py
examples/transformer-xl/train.py
+0
-109
No files found.
examples/transformer-xl/mem_transformer.py
View file @
bdb64914
...
@@ -7,7 +7,6 @@ import numpy as np
...
@@ -7,7 +7,6 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
# import torch_sparse
sys
.
path
.
append
(
'utils'
)
sys
.
path
.
append
(
'utils'
)
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
,
Projection
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
,
Projection
...
@@ -32,361 +31,16 @@ class PositionalEmbedding(nn.Module):
...
@@ -32,361 +31,16 @@ class PositionalEmbedding(nn.Module):
return
pos_emb
[:,
None
,:]
return
pos_emb
[:,
None
,:]
# A baseline naive slow implementation
class
MoEPositionwiseFFRaw
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
top_k
=
64
):
super
(
MoEPositionwiseFFRaw
,
self
).
__init__
()
print
(
"MoEPositionwiseFF"
)
self
.
top_k
=
top_k
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
gate
=
nn
.
Linear
(
d_model
,
d_inner
)
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_inner
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
ratio
=
top_k
/
d_inner
self
.
dropout_middle
=
nn
.
Dropout
(
dropout
*
ratio
)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp_Linear
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp_Linear
.
weight
.
data
.
transpose
(
0
,
1
)
self
.
b2
.
data
=
temp_Linear
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
relu_out
=
F
.
relu
(
gate_top_k_val
)
x
=
self
.
dropout_middle
(
relu_out
)
W2_select
=
self
.
W2
[
gate_top_k_idx
]
# [.. x top_k x d_model]
core_out
=
torch
.
einsum
(
'ijk,ijkd->ijd'
,
(
x
,
W2_select
))
+
self
.
b2
# [.. x d_model]
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
def
my_topk
(
x
,
k
,
inplace
=
True
):
y
=
x
if
inplace
else
x
.
clone
()
top1_val
,
top1_idx
=
torch
.
max
(
y
,
dim
=-
1
)
top1_val
=
top1_val
.
unsqueeze
(
-
1
)
top1_idx
=
top1_idx
.
unsqueeze
(
-
1
)
if
k
==
1
:
return
top1_val
,
top1_idx
y
.
scatter_
(
-
1
,
top1_idx
,
value
=
float
(
'-inf'
))
top2_val
,
top2_idx
=
torch
.
max
(
y
,
dim
=-
1
)
top2_val
=
top2_val
.
unsqueeze
(
-
1
)
top2_idx
=
top2_idx
.
unsqueeze
(
-
1
)
top_val
=
torch
.
cat
((
top1_val
,
top2_val
),
dim
=-
1
)
top_idx
=
torch
.
cat
((
top1_idx
,
top2_idx
),
dim
=-
1
)
return
top_val
,
top_idx
class
MultiHeadHierarchicalMoEPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_block
=
16
,
top_block
=
2
):
super
(
MultiHeadHierarchicalMoEPositionwiseFF
,
self
).
__init__
()
print
(
"MultiHeadHierarchicalMoEPositionwiseFF"
)
assert
d_inner
%
n_block
==
0
assert
top_block
in
[
1
,
2
]
self
.
top_block
=
top_block
self
.
n_block
=
n_block
d_block
=
d_inner
//
n_block
self
.
d_block
=
d_block
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
block_net_W
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
,
top_block
,
n_block
))
self
.
block_net_b
=
nn
.
Parameter
(
torch
.
Tensor
(
top_block
,
n_block
))
self
.
W1
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
,
d_model
))
self
.
b1
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
))
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
ratio
=
top_block
/
n_block
self
.
dropout_middle
=
nn
.
Dropout
(
dropout
*
ratio
)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
# self.scale = 1 / (d_model ** 0.5)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp
=
nn
.
Linear
(
self
.
d_model
,
self
.
d_inner
)
self
.
W1
.
data
=
temp
.
weight
.
data
.
view
(
self
.
n_block
,
self
.
d_block
,
self
.
d_model
)
self
.
b1
.
data
=
temp
.
bias
.
data
.
view
(
self
.
n_block
,
self
.
d_block
)
temp
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
().
view
(
self
.
n_block
,
self
.
d_block
,
self
.
d_model
)
self
.
b2
.
data
=
temp
.
bias
.
data
for
i
in
range
(
self
.
top_block
):
temp
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_block
)
self
.
block_net_W
.
data
[:,
i
]
=
temp
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
()
self
.
block_net_b
.
data
[
i
]
=
temp
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
block
=
torch
.
einsum
(
"ibd,dan->iban"
,
(
inp
,
self
.
block_net_W
))
+
self
.
block_net_b
# [.. x top_block x n_block ]
block_val
,
block_idx
=
my_topk
(
block
,
k
=
1
,
inplace
=
True
)
# block_val, block_idx = torch.topk(block, k=1, dim=-1, largest=True, sorted=False) # [.. x top_k x 1]
block_val
=
block_val
.
squeeze
(
-
1
)
block_idx
=
block_idx
.
squeeze
(
-
1
)
gate
=
F
.
softmax
(
block_val
,
dim
=-
1
)
W1_block
=
self
.
W1
[
block_idx
]
# [.. x top_k x d_block x d_model]
b1_block
=
self
.
b1
[
block_idx
]
# [.. x top_k x d_block]
x
=
torch
.
einsum
(
'ibd,ibnhd->ibnh'
,
(
inp
,
W1_block
))
+
b1_block
# [.. x top_k x d_block]
# x = x + block_val.unsqueeze(-1) # somehow like residual
x
=
x
*
gate
.
unsqueeze
(
-
1
)
relu_out
=
F
.
relu
(
x
)
relu_out
=
self
.
dropout_middle
(
relu_out
)
W2_block
=
self
.
W2
[
block_idx
]
# [.. x top_k x d_model]
core_out
=
torch
.
einsum
(
'ibnh,ibnhd->ibd'
,
(
x
,
W2_block
))
+
self
.
b2
# [.. x d_model]
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
HierarchicalMoEPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_block
=
16
,
top_block
=
2
):
super
(
HierarchicalMoEPositionwiseFF
,
self
).
__init__
()
print
(
"HierarchicalMoEPositionwiseFF"
)
assert
d_inner
%
n_block
==
0
assert
top_block
in
[
1
,
2
]
self
.
top_block
=
top_block
self
.
n_block
=
n_block
d_block
=
d_inner
//
n_block
self
.
d_block
=
d_block
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
block_net
=
nn
.
Linear
(
d_model
,
n_block
,
bias
=
True
)
self
.
W1
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
,
d_model
))
self
.
b1
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
))
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
n_block
,
d_block
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
ratio
=
top_block
/
n_block
self
.
dropout_middle
=
nn
.
Dropout
(
dropout
*
ratio
)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
# self.scale = 1 / (d_model ** 0.5)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp
=
nn
.
Linear
(
self
.
d_model
,
self
.
d_inner
)
self
.
W1
.
data
=
temp
.
weight
.
data
.
view
(
self
.
n_block
,
self
.
d_block
,
self
.
d_model
)
self
.
b1
.
data
=
temp
.
bias
.
data
.
view
(
self
.
n_block
,
self
.
d_block
)
temp
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
().
view
(
self
.
n_block
,
self
.
d_block
,
self
.
d_model
)
self
.
b2
.
data
=
temp
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
block
=
self
.
block_net
(
inp
)
# block_val, block_idx = my_topk(block, k=self.top_block)
block_val
,
block_idx
=
torch
.
topk
(
block
,
k
=
self
.
top_block
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate
=
F
.
softmax
(
block_val
,
dim
=-
1
)
W1_block
=
self
.
W1
[
block_idx
]
# [.. x top_k x d_block x d_model]
b1_block
=
self
.
b1
[
block_idx
]
# [.. x top_k x d_block]
x
=
torch
.
einsum
(
'ibd,ibnhd->ibnh'
,
(
inp
,
W1_block
))
+
b1_block
# [.. x top_k x d_block]
# x = x + block_val.unsqueeze(-1) # somehow like residual
x
=
x
*
gate
.
unsqueeze
(
-
1
)
relu_out
=
F
.
relu
(
x
)
relu_out
=
self
.
dropout_middle
(
relu_out
)
W2_block
=
self
.
W2
[
block_idx
]
# [.. x top_k x d_model]
core_out
=
torch
.
einsum
(
'ibnh,ibnhd->ibd'
,
(
x
,
W2_block
))
+
self
.
b2
# [.. x d_model]
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
SparsePositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
SparsePositionwiseFF
,
self
).
__init__
()
print
(
"SparsePositionwiseFF"
)
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
CoreNet_1
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Dropout
(
dropout
)
)
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_inner
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp_Linear
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp_Linear
.
weight
.
data
.
transpose
(
0
,
1
)
self
.
b2
.
data
=
temp_Linear
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
relu_out
=
self
.
CoreNet_1
(
inp
).
view
(
-
1
,
self
.
d_inner
)
sparse_relu_out
=
torch_sparse
.
SparseTensor
.
from_dense
(
relu_out
)
core_out
=
torch_sparse
.
matmul
(
sparse_relu_out
,
self
.
W2
)
+
self
.
b2
core_out
=
core_out
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
d_model
)
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
MultiHeadPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_head
=
2
):
super
(
MultiHeadPositionwiseFF
,
self
).
__init__
()
print
(
"MultiHeadPositionwiseFF"
)
assert
d_model
%
n_head
==
0
self
.
n_head
=
n_head
d_head
=
d_model
//
n_head
self
.
d_head
=
d_head
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
k_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_inner
,
d_head
))
self
.
k_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_inner
))
self
.
v_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_head
,
d_inner
))
self
.
v_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_head
))
#self.o_net = nn.Linear(d_model, d_model)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
for
i
in
range
(
self
.
n_head
):
tmp
=
nn
.
Linear
(
self
.
d_head
,
self
.
d_inner
)
self
.
k_weight
.
data
[
i
]
=
tmp
.
weight
.
data
self
.
k_bias
.
data
[
i
]
=
tmp
.
bias
.
data
tmp
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_head
)
self
.
v_weight
.
data
[
i
]
=
tmp
.
weight
.
data
self
.
v_bias
.
data
[
i
]
=
tmp
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
head_q
=
self
.
q_net
(
inp
)
head_q
=
head_q
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [.. x n_head x d_head]
attn_score
=
torch
.
einsum
(
'ibnd,nhd->ibnh'
,
(
head_q
,
self
.
k_weight
))
+
self
.
k_bias
# [.. x n_head x d_inner]
attn_score
=
F
.
relu
(
attn_score
)
attn_score
=
self
.
dropout
(
attn_score
)
attn_vec
=
torch
.
einsum
(
'ibnh,ndh->ibnd'
,
(
attn_score
,
self
.
v_weight
))
+
self
.
v_bias
attn_vec
=
attn_vec
.
contiguous
().
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
d_model
)
# core_out = self.o_net(attn_vec)
core_out
=
self
.
dropout
(
attn_vec
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
PositionwiseFF
(
nn
.
Module
):
class
PositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
use_softmax
=
True
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
PositionwiseFF
,
self
).
__init__
()
super
(
PositionwiseFF
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
CoreNet_1
=
nn
.
Sequential
(
self
.
CoreNet
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Softmax
(
dim
=-
1
)
if
use_softmax
else
nn
.
ReLU
(
inplace
=
True
)
)
self
.
CoreNet_2
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
...
@@ -399,113 +53,18 @@ class PositionwiseFF(nn.Module):
...
@@ -399,113 +53,18 @@ class PositionwiseFF(nn.Module):
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
##### layer normalization + positionwise feed-forward
relu_out
=
self
.
CoreNet_1
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
CoreNet
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
CoreNet_2
(
relu_out
)
##### residual connection
##### residual connection
output
=
core_out
+
inp
output
=
core_out
+
inp
else
:
else
:
##### positionwise feed-forward
##### positionwise feed-forward
relu_out
=
self
.
CoreNet_1
(
inp
)
core_out
=
self
.
CoreNet
(
inp
)
core_out
=
self
.
CoreNet_2
(
relu_out
)
##### residual connection + layer normalization
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
return
output
# return output, relu_out.detach()
class
ExtendedMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
):
super
(
ExtendedMultiHeadAttn
,
self
).
__init__
()
print
(
"ExtendedMultiHeadAttn"
)
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
n_head
*
d_head
,
bias
=
False
)
self
.
kv_net
=
nn
.
Linear
(
d_model
,
2
*
n_head
*
d_head
,
bias
=
False
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
*
2
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
# self.coeff = nn.Parameter(torch.Tensor(n_head, 2))
# nn.init.uniform_(self.coeff, a=-1, b=1)
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if
mems
is
not
None
:
c
=
torch
.
cat
([
mems
,
h
],
0
)
mem_len
=
mems
.
size
(
0
)
else
:
c
=
h
mem_len
=
0
if
self
.
pre_lnorm
:
##### layer normalization
c
=
self
.
layer_norm
(
c
)
head_q
=
self
.
q_net
(
c
)
head_k
,
head_v
=
torch
.
chunk
(
self
.
kv_net
(
c
),
2
,
-
1
)
head_q
=
head_q
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_k
=
head_k
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_v
=
head_v
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [qlen x klen x bsz x n_head]
attn_score
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
head_q
,
head_k
))
attn_score
.
mul_
(
self
.
scale
)
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
.
dim
()
==
2
:
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[
None
,:,:,
None
].
bool
(),
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
[
mem_len
:].
masked_fill_
(
attn_mask
[:,:,:,
None
].
bool
(),
-
float
(
'inf'
))
mem2other_attn
=
attn_mask
.
new_ones
(
mem_len
,
c
.
size
(
0
))
mem2other_attn
[:,
:
mem_len
]
=
0
attn_score
[:
mem_len
].
masked_fill_
(
mem2other_attn
[:,
:,
None
,
None
].
bool
(),
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec_quad
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
attn_vec
))
# [qlen x bsz x n_head x d_head x 2]
attn_vecs
=
torch
.
cat
([
attn_vec
.
unsqueeze
(
-
1
),
attn_vec_quad
.
unsqueeze
(
-
1
)],
dim
=-
1
)
# attn_vec = torch.einsum('ibndt,nt->ibnd', (attn_vecs, self.coeff))
attn_vec
=
attn_vecs
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
*
2
)
attn_vec
=
attn_vec
[
mem_len
:]
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
output
=
h
+
attn_out
else
:
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
h
+
attn_out
)
return
output
class
MultiHeadAttn
(
nn
.
Module
):
class
MultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
...
@@ -817,7 +376,6 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -817,7 +376,6 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
return
output
return
output
from
fmoe
import
FMoETransformerMLP
from
fmoe
import
FMoETransformerMLP
class
CustomizedMoEPositionwiseFF
(
FMoETransformerMLP
):
class
CustomizedMoEPositionwiseFF
(
FMoETransformerMLP
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
...
@@ -827,7 +385,6 @@ class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
...
@@ -827,7 +385,6 @@ class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
)
)
super
().
__init__
(
num_expert
=
moe_num_expert
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
top_k
=
moe_top_k
,
super
().
__init__
(
num_expert
=
moe_num_expert
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
top_k
=
moe_top_k
,
do_lnorm
=
True
,
pre_lnorm
=
pre_lnorm
,
activation
=
activation
,
dropout
=
dropout
)
do_lnorm
=
True
,
pre_lnorm
=
pre_lnorm
,
activation
=
activation
,
dropout
=
dropout
)
#self.dropout = nn.Dropout(dropout)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
super
().
forward
(
x
)
x
=
super
().
forward
(
x
)
...
@@ -838,7 +395,6 @@ class DecoderLayer(nn.Module):
...
@@ -838,7 +395,6 @@ class DecoderLayer(nn.Module):
super
(
DecoderLayer
,
self
).
__init__
()
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
...
@@ -849,10 +405,8 @@ class DecoderLayer(nn.Module):
...
@@ -849,10 +405,8 @@ class DecoderLayer(nn.Module):
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
=
self
.
pos_ff
(
output
)
# output, relu_out = self.pos_ff(output)
return
output
return
output
# return output, relu_out
class
RelLearnableDecoderLayer
(
nn
.
Module
):
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -872,10 +426,8 @@ class RelLearnableDecoderLayer(nn.Module):
...
@@ -872,10 +426,8 @@ class RelLearnableDecoderLayer(nn.Module):
attn_mask
=
dec_attn_mask
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
=
self
.
pos_ff
(
output
)
# output, relu_out = self.pos_ff(output)
return
output
return
output
# return output, relu_out
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -895,11 +447,8 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -895,11 +447,8 @@ class RelPartialLearnableDecoderLayer(nn.Module):
attn_mask
=
dec_attn_mask
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
output
=
self
.
pos_ff
(
output
)
# output, relu_out = self.pos_ff(output)
return
output
return
output
# return output, relu_out
class
AdaptiveEmbedding
(
nn
.
Module
):
class
AdaptiveEmbedding
(
nn
.
Module
):
...
@@ -1135,7 +684,6 @@ class MemTransformerLM(nn.Module):
...
@@ -1135,7 +684,6 @@ class MemTransformerLM(nn.Module):
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
hids
=
[]
hids
=
[]
# relu_outs = []
if
self
.
attn_type
==
0
:
# default
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
dtype
=
word_emb
.
dtype
)
...
@@ -1149,11 +697,9 @@ class MemTransformerLM(nn.Module):
...
@@ -1149,11 +697,9 @@ class MemTransformerLM(nn.Module):
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
# core_out, relu_out = layer(core_out, pos_emb, self.r_w_bias,
core_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
core_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
# relu_outs.append(relu_out)
elif
self
.
attn_type
==
1
:
# learnable
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
...
@@ -1165,11 +711,9 @@ class MemTransformerLM(nn.Module):
...
@@ -1165,11 +711,9 @@ class MemTransformerLM(nn.Module):
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
# core_out, relu_out = layer(core_out, r_emb, self.r_w_bias[i],
core_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
core_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
# relu_outs.append(relu_out)
elif
self
.
attn_type
==
2
:
# absolute
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
dtype
=
word_emb
.
dtype
)
...
@@ -1184,11 +728,9 @@ class MemTransformerLM(nn.Module):
...
@@ -1184,11 +728,9 @@ class MemTransformerLM(nn.Module):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
mems_i
+=
pos_emb
[:
mlen
]
# core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
# relu_outs.append(relu_out)
elif
self
.
attn_type
==
3
:
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
...
@@ -1206,18 +748,15 @@ class MemTransformerLM(nn.Module):
...
@@ -1206,18 +748,15 @@ class MemTransformerLM(nn.Module):
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
# core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
# relu_outs.append(relu_out)
core_out
=
self
.
drop
(
core_out
)
core_out
=
self
.
drop
(
core_out
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
return
core_out
,
new_mems
return
core_out
,
new_mems
# return core_out, new_mems, relu_outs
def
forward
(
self
,
data
,
target
,
*
mems
):
def
forward
(
self
,
data
,
target
,
*
mems
):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
...
@@ -1228,9 +767,6 @@ class MemTransformerLM(nn.Module):
...
@@ -1228,9 +767,6 @@ class MemTransformerLM(nn.Module):
tgt_len
=
target
.
size
(
0
)
tgt_len
=
target
.
size
(
0
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
# hidden, new_mems, relu_outs = self._forward(data, mems=mems)
# relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1)
pred_hid
=
hidden
[
-
tgt_len
:]
pred_hid
=
hidden
[
-
tgt_len
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
if
self
.
sample_softmax
>
0
and
self
.
training
:
...
@@ -1244,10 +780,8 @@ class MemTransformerLM(nn.Module):
...
@@ -1244,10 +780,8 @@ class MemTransformerLM(nn.Module):
if
new_mems
is
None
:
if
new_mems
is
None
:
return
[
loss
]
return
[
loss
]
# return [relu_outs, loss]
else
:
else
:
return
[
loss
]
+
new_mems
return
[
loss
]
+
new_mems
# return [relu_outs, loss] + new_mems
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
argparse
import
argparse
...
...
examples/transformer-xl/train.py
View file @
bdb64914
...
@@ -4,7 +4,6 @@ import time
...
@@ -4,7 +4,6 @@ import time
import
math
import
math
import
os
,
sys
import
os
,
sys
import
itertools
import
itertools
import
pathlib
import
numpy
as
np
import
numpy
as
np
...
@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM
...
@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM
from
utils.exp_utils
import
create_exp_dir
from
utils.exp_utils
import
create_exp_dir
from
utils.data_parallel
import
BalancedDataParallel
from
utils.data_parallel
import
BalancedDataParallel
class
AverageMeter
(
object
):
"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
help
=
'location of the data corpus'
)
help
=
'location of the data corpus'
)
...
@@ -418,9 +392,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
...
@@ -418,9 +392,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def
evaluate
(
eval_iter
):
def
evaluate
(
eval_iter
):
# Turn on evaluation mode which disables dropout.
# Turn on evaluation mode which disables dropout.
model
.
eval
()
model
.
eval
()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
# If the model does not use memory at all, make the ext_len longer.
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
# Otherwise, make the mem_len longer and keep the ext_len the same.
...
@@ -440,33 +411,15 @@ def evaluate(eval_iter):
...
@@ -440,33 +411,15 @@ def evaluate(eval_iter):
break
break
ret
=
model
(
data
,
target
,
*
mems
)
ret
=
model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss
=
loss
.
mean
()
loss
=
loss
.
mean
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_len
+=
seq_len
total_len
+=
seq_len
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
# Switch back to the training mode
# Switch back to the training mode
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
train
()
model
.
train
()
return
total_loss
/
total_len
return
total_loss
/
total_len
# return total_loss / total_len, avg_nnzs, act_hist, co_act_hist
def
train
():
def
train
():
...
@@ -477,11 +430,6 @@ def train():
...
@@ -477,11 +430,6 @@ def train():
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
else
:
else
:
mems
=
tuple
()
mems
=
tuple
()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
model
.
zero_grad
()
model
.
zero_grad
()
...
@@ -493,7 +441,6 @@ def train():
...
@@ -493,7 +441,6 @@ def train():
target_i
=
target_chunks
[
i
].
contiguous
()
target_i
=
target_chunks
[
i
].
contiguous
()
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems[i] = ret[0], ret[1], ret[2:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
...
@@ -503,28 +450,12 @@ def train():
...
@@ -503,28 +450,12 @@ def train():
else
:
else
:
ret
=
para_model
(
data
,
target
,
*
mems
)
ret
=
para_model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
else
:
else
:
loss
.
backward
()
loss
.
backward
()
train_loss
+=
loss
.
float
().
item
()
train_loss
+=
loss
.
float
().
item
()
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# # nnzs = [act.sum().item() / act.numel() for act in acts]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
clip_master_grads
(
args
.
clip
)
optimizer
.
clip_master_grads
(
args
.
clip
)
...
@@ -563,39 +494,12 @@ def train():
...
@@ -563,39 +494,12 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
else
:
else
:
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
# final_avg_nnzs = [avg_nnzs[i].avg for i in range(len(avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_avg_nnzs)/len(final_avg_nnzs)*100,
# max(final_avg_nnzs)*100,
# )
logging
(
log_str
)
logging
(
log_str
)
# co_act_dir = pathlib.Path(logging.keywords['log_path']).parent.joinpath("co_act")
# co_act_dir.mkdir(parents=True, exist_ok=True)
# co_act_path = co_act_dir.joinpath('epoch_%d_train_step_%d.pt' % (epoch, train_step))
# torch.save(co_act_hist, co_act_path)
# for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum()
# prob, index = torch.top_k(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1,
# prob[:64].sum().item(),
# prob[:128].sum().item(),
# prob[:256].sum().item(),
# prob[:512].sum().item(),
# prob[:1024].sum().item()
# )
# logging(log_str)
# act_hist[i] = 0.
# co_act_hist[i] = 0.
train_loss
=
0
train_loss
=
0
log_start_time
=
time
.
time
()
log_start_time
=
time
.
time
()
if
train_step
%
args
.
eval_interval
==
0
:
if
train_step
%
args
.
eval_interval
==
0
:
val_loss
=
evaluate
(
va_iter
)
val_loss
=
evaluate
(
va_iter
)
# val_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(va_iter)
logging
(
'-'
*
100
)
logging
(
'-'
*
100
)
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
'| valid loss {:5.2f}'
.
format
(
'| valid loss {:5.2f}'
.
format
(
...
@@ -605,11 +509,6 @@ def train():
...
@@ -605,11 +509,6 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
else
:
else
:
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
logging
(
log_str
)
logging
(
log_str
)
logging
(
'-'
*
100
)
logging
(
'-'
*
100
)
# Save the model if the validation loss is the best we've seen so far.
# Save the model if the validation loss is the best we've seen so far.
...
@@ -659,7 +558,6 @@ para_model = model.to(device)
...
@@ -659,7 +558,6 @@ para_model = model.to(device)
# Run on test data.
# Run on test data.
test_loss
=
evaluate
(
te_iter
)
test_loss
=
evaluate
(
te_iter
)
# test_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(te_iter)
logging
(
'='
*
100
)
logging
(
'='
*
100
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
...
@@ -667,11 +565,4 @@ if args.dataset in ['enwik8', 'text8']:
...
@@ -667,11 +565,4 @@ if args.dataset in ['enwik8', 'text8']:
else
:
else
:
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
test_loss
,
math
.
exp
(
test_loss
)))
test_loss
,
math
.
exp
(
test_loss
)))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str = ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
# logging(log_str)
logging
(
'='
*
100
)
logging
(
'='
*
100
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment