Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
0f091a1d
Commit
0f091a1d
authored
May 17, 2023
by
Sugon_ldc
Browse files
add fastmoe project
parents
Pipeline
#263
failed with stages
in 0 seconds
Changes
95
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2625 additions
and
0 deletions
+2625
-0
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+868
-0
examples/transformer-xl/scripts/getdata.sh
examples/transformer-xl/scripts/getdata.sh
+90
-0
examples/transformer-xl/scripts/run_enwik8_base.sh
examples/transformer-xl/scripts/run_enwik8_base.sh
+41
-0
examples/transformer-xl/scripts/run_enwik8_base_moe.sh
examples/transformer-xl/scripts/run_enwik8_base_moe.sh
+42
-0
examples/transformer-xl/scripts/run_enwik8_large.sh
examples/transformer-xl/scripts/run_enwik8_large.sh
+41
-0
examples/transformer-xl/scripts/run_lm1b_base.sh
examples/transformer-xl/scripts/run_lm1b_base.sh
+43
-0
examples/transformer-xl/scripts/run_lm1b_large.sh
examples/transformer-xl/scripts/run_lm1b_large.sh
+43
-0
examples/transformer-xl/scripts/run_text8_base.sh
examples/transformer-xl/scripts/run_text8_base.sh
+41
-0
examples/transformer-xl/scripts/run_text8_large.sh
examples/transformer-xl/scripts/run_text8_large.sh
+38
-0
examples/transformer-xl/scripts/run_wt103_base.sh
examples/transformer-xl/scripts/run_wt103_base.sh
+42
-0
examples/transformer-xl/scripts/run_wt103_large.sh
examples/transformer-xl/scripts/run_wt103_large.sh
+44
-0
examples/transformer-xl/train.py
examples/transformer-xl/train.py
+570
-0
examples/transformer-xl/utils/adaptive_softmax.py
examples/transformer-xl/utils/adaptive_softmax.py
+90
-0
examples/transformer-xl/utils/data_parallel.py
examples/transformer-xl/utils/data_parallel.py
+91
-0
examples/transformer-xl/utils/exp_utils.py
examples/transformer-xl/utils/exp_utils.py
+40
-0
examples/transformer-xl/utils/log_uniform_sampler.py
examples/transformer-xl/utils/log_uniform_sampler.py
+147
-0
examples/transformer-xl/utils/proj_adaptive_softmax.py
examples/transformer-xl/utils/proj_adaptive_softmax.py
+155
-0
examples/transformer-xl/utils/vocabulary.py
examples/transformer-xl/utils/vocabulary.py
+163
-0
fmoe/__init__.py
fmoe/__init__.py
+8
-0
fmoe/balance.py
fmoe/balance.py
+28
-0
No files found.
examples/transformer-xl/mem_transformer.py
0 → 100644
View file @
0f091a1d
import
sys
import
math
import
functools
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
sys
.
path
.
append
(
'utils'
)
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
,
Projection
from
log_uniform_sampler
import
LogUniformSampler
,
sample_logits
class
PositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
demb
):
super
(
PositionalEmbedding
,
self
).
__init__
()
self
.
demb
=
demb
inv_freq
=
1
/
(
10000
**
(
torch
.
arange
(
0.0
,
demb
,
2.0
)
/
demb
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
def
forward
(
self
,
pos_seq
,
bsz
=
None
):
sinusoid_inp
=
torch
.
ger
(
pos_seq
,
self
.
inv_freq
)
pos_emb
=
torch
.
cat
([
sinusoid_inp
.
sin
(),
sinusoid_inp
.
cos
()],
dim
=-
1
)
if
bsz
is
not
None
:
return
pos_emb
[:,
None
,:].
expand
(
-
1
,
bsz
,
-
1
)
else
:
return
pos_emb
[:,
None
,:]
class
PositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
PositionwiseFF
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
CoreNet
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Dropout
(
dropout
),
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
core_out
=
self
.
CoreNet
(
self
.
layer_norm
(
inp
))
##### residual connection
output
=
core_out
+
inp
else
:
##### positionwise feed-forward
core_out
=
self
.
CoreNet
(
inp
)
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
class
MultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
n_head
*
d_head
,
bias
=
False
)
self
.
kv_net
=
nn
.
Linear
(
d_model
,
2
*
n_head
*
d_head
,
bias
=
False
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if
mems
is
not
None
:
c
=
torch
.
cat
([
mems
,
h
],
0
)
else
:
c
=
h
if
self
.
pre_lnorm
:
##### layer normalization
c
=
self
.
layer_norm
(
c
)
head_q
=
self
.
q_net
(
h
)
head_k
,
head_v
=
torch
.
chunk
(
self
.
kv_net
(
c
),
2
,
-
1
)
head_q
=
head_q
.
view
(
h
.
size
(
0
),
h
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_k
=
head_k
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_v
=
head_v
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [qlen x klen x bsz x n_head]
attn_score
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
head_q
,
head_k
))
attn_score
.
mul_
(
self
.
scale
)
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
].
bool
(),
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
].
bool
(),
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
output
=
h
+
attn_out
else
:
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
h
+
attn_out
)
return
output
class
RelMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
moe
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
qkv_net
=
nn
.
Linear
(
d_model
,
3
*
n_head
*
d_head
,
bias
=
False
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
def
_parallelogram_mask
(
self
,
h
,
w
,
left
=
False
):
mask
=
torch
.
ones
((
h
,
w
)).
byte
()
m
=
min
(
h
,
w
)
mask
[:
m
,:
m
]
=
torch
.
triu
(
mask
[:
m
,:
m
])
mask
[
-
m
:,
-
m
:]
=
torch
.
tril
(
mask
[
-
m
:,
-
m
:])
if
left
:
return
mask
else
:
return
mask
.
flip
(
0
)
def
_shift
(
self
,
x
,
qlen
,
klen
,
mask
,
left
=
False
):
if
qlen
>
1
:
zero_pad
=
torch
.
zeros
((
x
.
size
(
0
),
qlen
-
1
,
x
.
size
(
2
),
x
.
size
(
3
)),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
zero_pad
=
torch
.
zeros
(
0
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
if
left
:
mask
=
mask
.
flip
(
1
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
).
expand
(
qlen
,
-
1
,
-
1
,
-
1
)
else
:
x_padded
=
torch
.
cat
([
x
,
zero_pad
],
dim
=
1
).
expand
(
qlen
,
-
1
,
-
1
,
-
1
)
x
=
x_padded
.
masked_select
(
mask
[:,:,
None
,
None
])
\
.
view
(
qlen
,
klen
,
x
.
size
(
2
),
x
.
size
(
3
))
return
x
def
_rel_shift
(
self
,
x
,
zero_triu
=
False
):
zero_pad
=
torch
.
zeros
((
x
.
size
(
0
),
1
,
*
x
.
size
()[
2
:]),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
)
x_padded
=
x_padded
.
view
(
x
.
size
(
1
)
+
1
,
x
.
size
(
0
),
*
x
.
size
()[
2
:])
x
=
x_padded
[
1
:].
view_as
(
x
)
if
zero_triu
:
ones
=
torch
.
ones
((
x
.
size
(
0
),
x
.
size
(
1
)))
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
1
)
-
x
.
size
(
0
))[:,:,
None
,
None
]
return
x
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
):
raise
NotImplementedError
class
RelPartialLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelPartialLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
r_net
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_head
*
self
.
d_head
,
bias
=
False
)
def
forward
(
self
,
w
,
r
,
r_w_bias
,
r_r_bias
,
attn_mask
=
None
,
mems
=
None
):
qlen
,
rlen
,
bsz
=
w
.
size
(
0
),
r
.
size
(
0
),
w
.
size
(
1
)
if
mems
is
not
None
:
cat
=
torch
.
cat
([
mems
,
w
],
0
)
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
else
:
w_heads
=
self
.
qkv_net
(
cat
)
r_head_k
=
self
.
r_net
(
r
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
w_head_q
=
w_head_q
[
-
qlen
:]
else
:
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
w
))
else
:
w_heads
=
self
.
qkv_net
(
w
)
r_head_k
=
self
.
r_net
(
r
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
klen
=
w_head_k
.
size
(
0
)
w_head_q
=
w_head_q
.
view
(
qlen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
# qlen x bsz x n_head x d_head
w_head_k
=
w_head_k
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
# qlen x bsz x n_head x d_head
w_head_v
=
w_head_v
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
# qlen x bsz x n_head x d_head
r_head_k
=
r_head_k
.
view
(
rlen
,
self
.
n_head
,
self
.
d_head
)
# qlen x n_head x d_head
#### compute attention score
rw_head_q
=
w_head_q
+
r_w_bias
# qlen x bsz x n_head x d_head
AC
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
rw_head_q
,
w_head_k
))
# qlen x klen x bsz x n_head
rr_head_q
=
w_head_q
+
r_r_bias
BD
=
torch
.
einsum
(
'ibnd,jnd->ijbn'
,
(
rr_head_q
,
r_head_k
))
# qlen x klen x bsz x n_head
BD
=
self
.
_rel_shift
(
BD
)
# [qlen x klen x bsz x n_head]
attn_score
=
AC
+
BD
attn_score
.
mul_
(
self
.
scale
)
#### compute attention probability
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
.
dim
()
==
2
:
attn_score
=
attn_score
.
float
().
masked_fill
(
attn_mask
[
None
,:,:,
None
].
bool
(),
-
float
(
'inf'
)).
type_as
(
attn_score
)
elif
attn_mask
.
dim
()
==
3
:
attn_score
=
attn_score
.
float
().
masked_fill
(
attn_mask
[:,:,:,
None
].
bool
(),
-
float
(
'inf'
)).
type_as
(
attn_score
)
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
#### compute attention vector
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
# [qlen x bsz x n_head x d_head]
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
output
=
w
+
attn_out
else
:
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
w
+
attn_out
)
return
output
class
RelLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
w
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
None
,
mems
=
None
):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen
,
bsz
=
w
.
size
(
0
),
w
.
size
(
1
)
if
mems
is
not
None
:
cat
=
torch
.
cat
([
mems
,
w
],
0
)
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
else
:
w_heads
=
self
.
qkv_net
(
cat
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
w_head_q
=
w_head_q
[
-
qlen
:]
else
:
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
w
))
else
:
w_heads
=
self
.
qkv_net
(
w
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
klen
=
w_head_k
.
size
(
0
)
w_head_q
=
w_head_q
.
view
(
qlen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
w_head_k
=
w_head_k
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
w_head_v
=
w_head_v
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
if
klen
>
r_emb
.
size
(
0
):
r_emb_pad
=
r_emb
[
0
:
1
].
expand
(
klen
-
r_emb
.
size
(
0
),
-
1
,
-
1
)
r_emb
=
torch
.
cat
([
r_emb_pad
,
r_emb
],
0
)
r_bias_pad
=
r_bias
[
0
:
1
].
expand
(
klen
-
r_bias
.
size
(
0
),
-
1
)
r_bias
=
torch
.
cat
([
r_bias_pad
,
r_bias
],
0
)
else
:
r_emb
=
r_emb
[
-
klen
:]
r_bias
=
r_bias
[
-
klen
:]
#### compute attention score
rw_head_q
=
w_head_q
+
r_w_bias
[
None
]
# qlen x bsz x n_head x d_head
AC
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
rw_head_q
,
w_head_k
))
# qlen x klen x bsz x n_head
B_
=
torch
.
einsum
(
'ibnd,jnd->ijbn'
,
(
w_head_q
,
r_emb
))
# qlen x klen x bsz x n_head
D_
=
r_bias
[
None
,
:,
None
]
# 1 x klen x 1 x n_head
BD
=
self
.
_rel_shift
(
B_
+
D_
)
# [qlen x klen x bsz x n_head]
attn_score
=
AC
+
BD
attn_score
.
mul_
(
self
.
scale
)
#### compute attention probability
if
attn_mask
is
not
None
and
attn_mask
.
any
().
item
():
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
].
bool
(),
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
].
bool
(),
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
#### compute attention vector
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
# [qlen x bsz x n_head x d_head]
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
output
=
w
+
attn_out
else
:
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
w
+
attn_out
)
return
output
from
fmoe
import
FMoETransformerMLP
class
CustomizedMoEPositionwiseFF
(
FMoETransformerMLP
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
activation
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Dropout
(
dropout
)
)
super
().
__init__
(
num_expert
=
moe_num_expert
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
top_k
=
moe_top_k
,
activation
=
activation
)
self
.
pre_lnorm
=
pre_lnorm
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
core_out
=
super
().
forward
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
dropout
(
core_out
)
##### residual connection
output
=
core_out
+
inp
else
:
##### positionwise feed-forward
core_out
=
super
().
forward
(
inp
)
core_out
=
self
.
dropout
(
core_out
)
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
if
kwargs
.
get
(
'moe'
)
is
False
:
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
output
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
return
output
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
RelLearnableDecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
if
kwargs
.
get
(
'moe'
)
is
False
:
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
output
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
return
output
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
RelPartialLearnableDecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
if
kwargs
.
get
(
'moe'
)
is
False
:
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
else
:
self
.
pos_ff
=
CustomizedMoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
moe_num_expert
=
kwargs
.
get
(
'moe_num_expert'
),
moe_top_k
=
kwargs
.
get
(
'moe_top_k'
))
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
output
=
self
.
dec_attn
(
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
)
output
=
self
.
pos_ff
(
output
)
return
output
class
AdaptiveEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
sample_softmax
=
False
):
super
(
AdaptiveEmbedding
,
self
).
__init__
()
self
.
n_token
=
n_token
self
.
d_embed
=
d_embed
self
.
cutoffs
=
cutoffs
+
[
n_token
]
self
.
div_val
=
div_val
self
.
d_proj
=
d_proj
self
.
emb_scale
=
d_proj
**
0.5
self
.
cutoff_ends
=
[
0
]
+
self
.
cutoffs
self
.
emb_layers
=
nn
.
ModuleList
()
self
.
emb_projs
=
nn
.
ModuleList
()
if
div_val
==
1
:
self
.
emb_layers
.
append
(
nn
.
Embedding
(
n_token
,
d_embed
,
sparse
=
sample_softmax
>
0
)
)
if
d_proj
!=
d_embed
:
self
.
emb_projs
.
append
(
Projection
(
d_proj
,
d_embed
))
else
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
emb_layers
.
append
(
nn
.
Embedding
(
r_idx
-
l_idx
,
d_emb_i
))
self
.
emb_projs
.
append
(
Projection
(
d_proj
,
d_emb_i
))
def
forward
(
self
,
inp
):
if
self
.
div_val
==
1
:
embed
=
self
.
emb_layers
[
0
](
inp
)
if
self
.
d_proj
!=
self
.
d_embed
:
embed
=
F
.
linear
(
embed
,
self
.
emb_projs
[
0
].
weight
)
else
:
param
=
next
(
self
.
parameters
())
inp_flat
=
inp
.
view
(
-
1
)
emb_flat
=
torch
.
zeros
([
inp_flat
.
size
(
0
),
self
.
d_proj
],
dtype
=
param
.
dtype
,
device
=
param
.
device
)
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
mask_i
=
(
inp_flat
>=
l_idx
)
&
(
inp_flat
<
r_idx
)
indices_i
=
mask_i
.
nonzero
().
squeeze
()
if
indices_i
.
numel
()
==
0
:
continue
inp_i
=
inp_flat
.
index_select
(
0
,
indices_i
)
-
l_idx
emb_i
=
self
.
emb_layers
[
i
](
inp_i
)
emb_i
=
F
.
linear
(
emb_i
,
self
.
emb_projs
[
i
].
weight
)
emb_flat
.
index_copy_
(
0
,
indices_i
,
emb_i
)
embed
=
emb_flat
.
view
(
*
inp
.
size
(),
self
.
d_proj
)
embed
.
mul_
(
self
.
emb_scale
)
return
embed
class
MemTransformerLM
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
n_layer
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
dropatt
,
tie_weight
=
True
,
d_embed
=
None
,
div_val
=
1
,
tie_projs
=
[
False
],
pre_lnorm
=
False
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
cutoffs
=
[],
adapt_inp
=
False
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
sample_softmax
=-
1
,
moe
=
False
,
moe_num_expert
=
64
,
moe_top_k
=
2
):
super
(
MemTransformerLM
,
self
).
__init__
()
self
.
n_token
=
n_token
d_embed
=
d_model
if
d_embed
is
None
else
d_embed
self
.
d_embed
=
d_embed
self
.
d_model
=
d_model
self
.
n_head
=
n_head
self
.
d_head
=
d_head
self
.
word_emb
=
AdaptiveEmbedding
(
n_token
,
d_embed
,
d_model
,
cutoffs
,
div_val
=
div_val
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
n_layer
=
n_layer
self
.
tgt_len
=
tgt_len
self
.
mem_len
=
mem_len
self
.
ext_len
=
ext_len
self
.
max_klen
=
tgt_len
+
ext_len
+
mem_len
self
.
attn_type
=
attn_type
self
.
layers
=
nn
.
ModuleList
()
if
attn_type
==
0
:
# the default attention
for
i
in
range
(
n_layer
):
self
.
layers
.
append
(
RelPartialLearnableDecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
elif
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
n_layer
):
self
.
layers
.
append
(
RelLearnableDecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
elif
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
n_layer
):
self
.
layers
.
append
(
DecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
moe
=
moe
,
moe_num_expert
=
moe_num_expert
,
moe_top_k
=
moe_top_k
)
)
self
.
sample_softmax
=
sample_softmax
# use sampled softmax
if
sample_softmax
>
0
:
self
.
out_layer
=
nn
.
Linear
(
d_model
,
n_token
)
if
tie_weight
:
self
.
out_layer
.
weight
=
self
.
word_emb
.
weight
self
.
tie_weight
=
tie_weight
self
.
sampler
=
LogUniformSampler
(
n_token
,
sample_softmax
)
# use adaptive softmax (including standard softmax)
else
:
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
n_token
,
d_embed
,
d_model
,
cutoffs
,
div_val
=
div_val
)
if
tie_weight
:
for
i
in
range
(
len
(
self
.
crit
.
out_layers
)):
self
.
crit
.
out_layers
[
i
].
weight
=
self
.
word_emb
.
emb_layers
[
i
].
weight
if
tie_projs
:
for
i
,
tie_proj
in
enumerate
(
tie_projs
):
if
tie_proj
and
div_val
==
1
and
d_model
!=
d_embed
:
self
.
crit
.
out_projs
[
i
].
weight
=
self
.
word_emb
.
emb_projs
[
0
].
weight
elif
tie_proj
and
div_val
!=
1
:
self
.
crit
.
out_projs
[
i
].
weight
=
self
.
word_emb
.
emb_projs
[
i
].
weight
self
.
same_length
=
same_length
self
.
clamp_len
=
clamp_len
self
.
_create_params
()
def
backward_compatible
(
self
):
self
.
sample_softmax
=
-
1
def
_create_params
(
self
):
if
self
.
attn_type
==
0
:
# default attention
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
elif
self
.
attn_type
==
1
:
# learnable
self
.
r_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_layer
,
self
.
n_head
,
self
.
d_head
))
self
.
r_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
))
elif
self
.
attn_type
==
2
:
# absolute standard
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
elif
self
.
attn_type
==
3
:
# absolute deeper SA
self
.
r_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
self
.
tgt_len
=
tgt_len
self
.
mem_len
=
mem_len
self
.
ext_len
=
ext_len
def
init_mems
(
self
,
x
):
if
self
.
mem_len
>
0
:
mems
=
[]
for
i
in
range
(
self
.
n_layer
+
1
):
empty
=
torch
.
empty
(
0
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
mems
.
append
(
empty
)
return
mems
else
:
return
None
def
_update_mems
(
self
,
hids
,
mems
,
qlen
,
mlen
):
# does not deal with None
if
mems
is
None
:
return
None
# mems is not None
assert
len
(
hids
)
==
len
(
mems
),
'len(hids) != len(mems)'
# There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
with
torch
.
no_grad
():
new_mems
=
[]
end_idx
=
mlen
+
max
(
0
,
qlen
-
0
-
self
.
ext_len
)
beg_idx
=
max
(
0
,
end_idx
-
self
.
mem_len
)
for
i
in
range
(
len
(
hids
)):
cat
=
torch
.
cat
([
mems
[
i
],
hids
[
i
]],
dim
=
0
)
new_mems
.
append
(
cat
[
beg_idx
:
end_idx
].
detach
())
return
new_mems
def
_forward
(
self
,
dec_inp
,
mems
=
None
):
qlen
,
bsz
=
dec_inp
.
size
()
word_emb
=
self
.
word_emb
(
dec_inp
)
mlen
=
mems
[
0
].
size
(
0
)
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
if
self
.
same_length
:
all_ones
=
word_emb
.
new_ones
(
qlen
,
klen
)
mask_len
=
klen
-
self
.
mem_len
if
mask_len
>
0
:
mask_shift_len
=
qlen
-
mask_len
else
:
mask_shift_len
=
qlen
dec_attn_mask
=
(
torch
.
triu
(
all_ones
,
1
+
mlen
)
+
torch
.
tril
(
all_ones
,
-
mask_shift_len
)).
byte
()[:,
:,
None
]
# -1
else
:
dec_attn_mask
=
torch
.
triu
(
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
hids
=
[]
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
pos_emb
=
self
.
pos_emb
(
pos_seq
)
core_out
=
self
.
drop
(
word_emb
)
pos_emb
=
self
.
drop
(
pos_emb
)
hids
.
append
(
core_out
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
hids
.
append
(
core_out
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
self
.
clamp_len
>
0
:
r_emb
=
self
.
r_emb
[
i
][
-
self
.
clamp_len
:]
r_bias
=
self
.
r_bias
[
i
][
-
self
.
clamp_len
:]
else
:
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
pos_emb
=
self
.
pos_emb
(
pos_seq
)
core_out
=
self
.
drop
(
word_emb
+
pos_emb
[
-
qlen
:])
hids
.
append
(
core_out
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
hids
.
append
(
core_out
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
mlen
>
0
:
cur_emb
=
self
.
r_emb
[
i
][:
-
qlen
]
cur_size
=
cur_emb
.
size
(
0
)
if
cur_size
<
mlen
:
cur_emb_pad
=
cur_emb
[
0
:
1
].
expand
(
mlen
-
cur_size
,
-
1
,
-
1
)
cur_emb
=
torch
.
cat
([
cur_emb_pad
,
cur_emb
],
0
)
else
:
cur_emb
=
cur_emb
[
-
mlen
:]
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
core_out
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
core_out
=
self
.
drop
(
core_out
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
return
core_out
,
new_mems
def
forward
(
self
,
data
,
target
,
*
mems
):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
if
not
mems
:
mems
=
self
.
init_mems
(
data
)
tgt_len
=
target
.
size
(
0
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
pred_hid
=
hidden
[
-
tgt_len
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
assert
self
.
tie_weight
logit
=
sample_logits
(
self
.
word_emb
,
self
.
out_layer
.
bias
,
target
,
pred_hid
,
self
.
sampler
)
loss
=
-
F
.
log_softmax
(
logit
,
-
1
)[:,
:,
0
]
else
:
loss
=
self
.
crit
(
pred_hid
.
view
(
-
1
,
pred_hid
.
size
(
-
1
)),
target
.
contiguous
().
view
(
-
1
))
loss
=
loss
.
view
(
tgt_len
,
-
1
)
if
new_mems
is
None
:
return
[
loss
]
else
:
return
[
loss
]
+
new_mems
if
__name__
==
'__main__'
:
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
'unit test'
)
parser
.
add_argument
(
'--n_layer'
,
type
=
int
,
default
=
4
,
help
=
''
)
parser
.
add_argument
(
'--n_rel_layer'
,
type
=
int
,
default
=
4
,
help
=
''
)
parser
.
add_argument
(
'--n_head'
,
type
=
int
,
default
=
2
,
help
=
''
)
parser
.
add_argument
(
'--d_head'
,
type
=
int
,
default
=
2
,
help
=
''
)
parser
.
add_argument
(
'--d_model'
,
type
=
int
,
default
=
200
,
help
=
''
)
parser
.
add_argument
(
'--d_embed'
,
type
=
int
,
default
=
200
,
help
=
''
)
parser
.
add_argument
(
'--d_inner'
,
type
=
int
,
default
=
200
,
help
=
''
)
parser
.
add_argument
(
'--dropout'
,
type
=
float
,
default
=
0.0
,
help
=
''
)
parser
.
add_argument
(
'--cuda'
,
action
=
'store_true'
,
help
=
''
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1111
,
help
=
''
)
parser
.
add_argument
(
'--multi_gpu'
,
action
=
'store_true'
,
help
=
''
)
args
=
parser
.
parse_args
()
device
=
torch
.
device
(
"cuda"
if
args
.
cuda
else
"cpu"
)
B
=
4
tgt_len
,
mem_len
,
ext_len
=
36
,
36
,
0
data_len
=
tgt_len
*
20
args
.
n_token
=
10000
import
data_utils
data
=
torch
.
LongTensor
(
data_len
*
B
).
random_
(
0
,
args
.
n_token
).
to
(
device
)
diter
=
data_utils
.
LMOrderedIterator
(
data
,
B
,
tgt_len
,
device
=
device
,
ext_len
=
ext_len
)
cutoffs
=
[
args
.
n_token
//
2
]
tie_projs
=
[
False
]
+
[
True
]
*
len
(
cutoffs
)
for
div_val
in
[
1
,
2
]:
for
d_embed
in
[
200
,
100
]:
model
=
MemTransformerLM
(
args
.
n_token
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_model
,
args
.
d_head
,
args
.
d_inner
,
args
.
dropout
,
dropatt
=
args
.
dropout
,
tie_weight
=
True
,
d_embed
=
d_embed
,
div_val
=
div_val
,
tie_projs
=
tie_projs
,
pre_lnorm
=
True
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
cutoffs
=
cutoffs
,
attn_type
=
0
).
to
(
device
)
print
(
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()))
mems
=
tuple
()
for
idx
,
(
inp
,
tgt
,
seqlen
)
in
enumerate
(
diter
):
print
(
'batch {}'
.
format
(
idx
))
out
=
model
(
inp
,
tgt
,
*
mems
)
mems
=
out
[
1
:]
examples/transformer-xl/scripts/getdata.sh
0 → 100755
View file @
0f091a1d
echo
"=== Acquiring datasets ==="
echo
"---"
mkdir
-p
../data
cd
../data
if
[[
!
-d
'wikitext-2'
]]
;
then
echo
"- Downloading WikiText-2 (WT2)"
wget
--quiet
--continue
https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
unzip
-q
wikitext-2-v1.zip
cd
wikitext-2
mv
wiki.train.tokens train.txt
mv
wiki.valid.tokens valid.txt
mv
wiki.test.tokens test.txt
cd
..
fi
echo
"- Downloading WikiText-103 (WT2)"
if
[[
!
-d
'wikitext-103'
]]
;
then
wget
--continue
https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip
unzip
-q
wikitext-103-v1.zip
cd
wikitext-103
mv
wiki.train.tokens train.txt
mv
wiki.valid.tokens valid.txt
mv
wiki.test.tokens test.txt
cd
..
fi
echo
"- Downloading enwik8 (Character)"
if
[[
!
-d
'enwik8'
]]
;
then
mkdir
-p
enwik8
cd
enwik8
wget
--continue
http://mattmahoney.net/dc/enwik8.zip
wget https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py
python3 prep_enwik8.py
cd
..
fi
echo
"- Downloading text8 (Character)"
if
[[
!
-d
'text8'
]]
;
then
mkdir
-p
text8
cd
text8
wget
--continue
http://mattmahoney.net/dc/text8.zip
python ../../prep_text8.py
cd
..
fi
echo
"- Downloading Penn Treebank (PTB)"
if
[[
!
-d
'penn'
]]
;
then
wget
--quiet
--continue
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
tar
-xzf
simple-examples.tgz
mkdir
-p
penn
cd
penn
mv
../simple-examples/data/ptb.train.txt train.txt
mv
../simple-examples/data/ptb.test.txt test.txt
mv
../simple-examples/data/ptb.valid.txt valid.txt
cd
..
echo
"- Downloading Penn Treebank (Character)"
mkdir
-p
pennchar
cd
pennchar
mv
../simple-examples/data/ptb.char.train.txt train.txt
mv
../simple-examples/data/ptb.char.test.txt test.txt
mv
../simple-examples/data/ptb.char.valid.txt valid.txt
cd
..
rm
-rf
simple-examples/
fi
echo
"- Downloading 1B words"
if
[[
!
-d
'one-billion-words'
]]
;
then
mkdir
-p
one-billion-words
cd
one-billion-words
wget
--no-proxy
http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
tar
xzvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
path
=
"1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/"
cat
${
path
}
/news.en.heldout-00000-of-00050
>
valid.txt
cat
${
path
}
/news.en.heldout-00000-of-00050
>
test.txt
wget https://github.com/rafaljozefowicz/lm/raw/master/1b_word_vocab.txt
cd
..
fi
echo
"---"
echo
"Happy language modeling :)"
examples/transformer-xl/scripts/run_enwik8_base.sh
0 → 100755
View file @
0f091a1d
#!/bin/bash
if
[[
$1
==
'train'
]]
;
then
echo
'Run training...'
python train.py
\
--cuda
\
--data
../data/enwik8/
\
--dataset
enwik8
\
--n_layer
12
\
--d_model
512
\
--n_head
8
\
--d_head
64
\
--d_inner
2048
\
--dropout
0.1
\
--dropatt
0.0
\
--optim
adam
\
--lr
0.00025
\
--warmup_step
0
\
--max_step
400000
\
--tgt_len
512
\
--mem_len
512
\
--eval_tgt_len
128
\
--batch_size
22
\
--multi_gpu
\
--gpu0_bsz
4
\
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
python eval.py
\
--cuda
\
--data
../data/enwik8/
\
--dataset
enwik8
\
--tgt_len
80
\
--mem_len
2100
\
--clamp_len
820
\
--same_length
\
--split
test
\
${
@
:2
}
else
echo
'unknown argment 1'
fi
examples/transformer-xl/scripts/run_enwik8_base_moe.sh
0 → 100755
View file @
0f091a1d
#!/bin/bash
if
[[
$1
==
'train'
]]
;
then
echo
'Run training...'
python train.py
\
--cuda
\
--data
../data/enwik8/
\
--dataset
enwik8
\
--n_layer
12
\
--d_model
512
\
--n_head
8
\
--d_head
64
\
--d_inner
1024
\
--dropout
0.1
\
--dropatt
0.0
\
--optim
adam
\
--lr
0.00025
\
--warmup_step
0
\
--max_step
400000
\
--tgt_len
512
\
--mem_len
512
\
--eval_tgt_len
128
\
--batch_size
22
\
--multi_gpu
\
--gpu0_bsz
4
\
--moe
--moe-num-expert
64
--moe-top-k
2
\
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
python eval.py
\
--cuda
\
--data
../data/enwik8/
\
--dataset
enwik8
\
--tgt_len
80
\
--mem_len
2100
\
--clamp_len
820
\
--same_length
\
--split
test
\
${
@
:2
}
else
echo
'unknown argment 1'
fi
examples/transformer-xl/scripts/run_enwik8_large.sh
0 → 100755
View file @
0f091a1d
#!/bin/bash
if
[[
$1
==
'train'
]]
;
then
echo
'Run training...'
python train.py
\
--cuda
\
--data
../data/enwik8/
\
--dataset
enwik8
\
--n_layer
24
\
--d_model
1024
\
--n_head
8
\
--d_head
128
\
--d_inner
3072
\
--dropout
0.15
\
--dropatt
0.15
\
--optim
adam
\
--lr
0.00025
\
--warmup_step
4000
\
--max_step
400000
\
--tgt_len
768
\
--mem_len
768
\
--eval_tgt_len
128
\
--batch_size
64
\
--multi_gpu
\
--gpu0_bsz
0
\
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
python eval.py
\
--cuda
\
--data
../data/enwik8/
\
--dataset
enwik8
\
--tgt_len
128
\
--mem_len
3800
\
--clamp_len
1000
\
--same_length
\
--split
test
\
${
@
:2
}
else
echo
'unknown argment 1'
fi
examples/transformer-xl/scripts/run_lm1b_base.sh
0 → 100755
View file @
0f091a1d
#!/bin/bash
if
[[
$1
==
'train'
]]
;
then
echo
'Run training...'
python train.py
\
--cuda
\
--data
../data/one-billion-words/
\
--dataset
lm1b
\
--adaptive
\
--n_layer
18
\
--d_model
1024
\
--div_val
4
\
--n_head
8
\
--d_head
128
\
--d_inner
4096
\
--dropout
0.0
\
--dropatt
0.0
\
--optim
adam
\
--warmup_step
20000
\
--max_step
500000
\
--lr
0.00025
\
--tgt_len
32
\
--mem_len
32
\
--eval_tgt_len
32
\
--batch_size
224
\
--multi_gpu
\
--gpu0_bsz
32
\
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
python eval.py
\
--cuda
\
--data
../data/one-billion-words/
\
--dataset
lm1b
\
--batch_size
64
\
--tgt_len
32
\
--mem_len
128
\
--split
test
\
--same_length
\
${
@
:2
}
else
echo
'unknown argment 1'
fi
examples/transformer-xl/scripts/run_lm1b_large.sh
0 → 100755
View file @
0f091a1d
#!/bin/bash
if
[[
$1
==
'train'
]]
;
then
echo
'Run training...'
python train.py
\
--cuda
\
--data
../data/one-billion-words/
\
--dataset
lm1b
\
--adaptive
\
--div_val
4
\
--n_layer
24
\
--d_model
1280
\
--n_head
16
\
--d_head
80
\
--d_inner
8192
\
--dropout
0.05
\
--dropatt
0.05
\
--optim
adam
\
--warmup_step
30000
\
--max_step
1200000
\
--lr
0.00025
\
--tgt_len
32
\
--mem_len
32
\
--eval_tgt_len
32
\
--batch_size
512
\
--multi_gpu
\
--gpu0_bsz
0
\
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
python eval.py
\
--cuda
\
--data
../data/one-billion-words/
\
--dataset
lm1b
\
--batch_size
8
\
--tgt_len
32
\
--mem_len
128
\
--split
test
\
--same_length
\
${
@
:2
}
else
echo
'unknown argment 1'
fi
examples/transformer-xl/scripts/run_text8_base.sh
0 → 100755
View file @
0f091a1d
#!/bin/bash
if
[[
$1
==
'train'
]]
;
then
echo
'Run training...'
python train.py
\
--cuda
\
--data
../data/text8/
\
--dataset
text8
\
--n_layer
12
\
--d_model
512
\
--n_head
8
\
--d_head
64
\
--d_inner
2048
\
--dropout
0.1
\
--dropatt
0.0
\
--optim
adam
\
--lr
0.00025
\
--warmup_step
0
\
--max_step
400000
\
--tgt_len
512
\
--mem_len
512
\
--eval_tgt_len
128
\
--batch_size
22
\
--multi_gpu
\
--gpu0_bsz
4
\
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
python eval.py
\
--cuda
\
--data
../data/text8/
\
--dataset
text8
\
--tgt_len
80
\
--mem_len
2100
\
--clamp_len
820
\
--same_length
\
--split
test
\
${
@
:2
}
else
echo
'unknown argment 1'
fi
examples/transformer-xl/scripts/run_text8_large.sh
0 → 100755
View file @
0f091a1d
#!/bin/bash
if
[[
$1
==
'train'
]]
;
then
echo
'Run training...'
python train.py
\
--cuda
\
--data
../data/text8/
\
--dataset
text8
\
--n_layer
24
\
--d_model
1024
\
--n_head
8
\
--d_head
128
\
--d_inner
3072
\
--dropout
0.15
\
--dropatt
0.15
\
--optim
adam
\
--lr
0.00025
\
--tgt_len
768
\
--mem_len
768
\
--eval_tgt_len
128
\
--batch_size
64
\
--max_step
400000
\
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
python eval.py
\
--cuda
\
--data
../data/text8/
\
--dataset
text8
\
--tgt_len
128
\
--mem_len
3800
\
--clamp_len
1000
\
--same_length
\
--split
test
\
${
@
:2
}
else
echo
'unknown argment 1'
fi
examples/transformer-xl/scripts/run_wt103_base.sh
0 → 100755
View file @
0f091a1d
#!/bin/bash
if
[[
$1
==
'train'
]]
;
then
echo
'Run training...'
python train.py
\
--cuda
\
--data
../data/wikitext-103/
\
--dataset
wt103
\
--adaptive
\
--n_layer
16
\
--d_model
410
\
--n_head
10
\
--d_head
41
\
--d_inner
2100
\
--dropout
0.1
\
--dropatt
0.0
\
--optim
adam
\
--lr
0.00025
\
--warmup_step
0
\
--max_step
200000
\
--tgt_len
150
\
--mem_len
150
\
--eval_tgt_len
150
\
--batch_size
60
\
--multi_gpu
\
--gpu0_bsz
4
\
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
python eval.py
\
--cuda
\
--data
../data/wikitext-103/
\
--dataset
wt103
\
--tgt_len
64
\
--mem_len
640
\
--clamp_len
400
\
--same_length
\
--split
test
\
${
@
:2
}
else
echo
'unknown argment 1'
fi
examples/transformer-xl/scripts/run_wt103_large.sh
0 → 100755
View file @
0f091a1d
#!/bin/bash
export
PYTHONPATH
=
$PWD
/cuda/build/lib.linux-x86_64-3.7
if
[[
$1
==
'train'
]]
;
then
echo
'Run training...'
python3 train.py
\
--cuda
\
--data
../data/wikitext-103/
\
--dataset
wt103
\
--adaptive
\
--div_val
4
\
--n_layer
18
\
--d_model
1024
\
--n_head
16
\
--d_head
64
\
--d_inner
4096
\
--dropout
0.2
\
--dropatt
0.2
\
--optim
adam
\
--lr
0.00025
\
--warmup_step
16000
\
--max_step
4000000
\
--tgt_len
384
\
--mem_len
384
\
--eval_tgt_len
128
\
--batch_size
128
\
--multi_gpu
\
--gpu0_bsz
0
\
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
python eval.py
\
--cuda
\
--data
../data/wikitext-103/
\
--dataset
wt103
\
--tgt_len
128
\
--mem_len
1600
\
--clamp_len
1000
\
--same_length
\
--split
test
\
${
@
:2
}
else
echo
'unknown argment 1'
fi
examples/transformer-xl/train.py
0 → 100644
View file @
0f091a1d
# coding: utf-8
import
argparse
import
time
import
math
import
os
,
sys
import
itertools
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
data_utils
import
get_lm_corpus
from
mem_transformer
import
MemTransformerLM
from
utils.exp_utils
import
create_exp_dir
from
utils.data_parallel
import
BalancedDataParallel
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
help
=
'location of the data corpus'
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
'wt103'
,
choices
=
[
'wt103'
,
'lm1b'
,
'enwik8'
,
'text8'
],
help
=
'dataset name'
)
parser
.
add_argument
(
'--n_layer'
,
type
=
int
,
default
=
12
,
help
=
'number of total layers'
)
parser
.
add_argument
(
'--n_head'
,
type
=
int
,
default
=
10
,
help
=
'number of heads'
)
parser
.
add_argument
(
'--d_head'
,
type
=
int
,
default
=
50
,
help
=
'head dimension'
)
parser
.
add_argument
(
'--d_embed'
,
type
=
int
,
default
=-
1
,
help
=
'embedding dimension'
)
parser
.
add_argument
(
'--d_model'
,
type
=
int
,
default
=
500
,
help
=
'model dimension'
)
parser
.
add_argument
(
'--d_inner'
,
type
=
int
,
default
=
1000
,
help
=
'inner dimension in FF'
)
parser
.
add_argument
(
'--dropout'
,
type
=
float
,
default
=
0.0
,
help
=
'global dropout rate'
)
parser
.
add_argument
(
'--dropatt'
,
type
=
float
,
default
=
0.0
,
help
=
'attention probability dropout rate'
)
parser
.
add_argument
(
'--init'
,
default
=
'normal'
,
type
=
str
,
help
=
'parameter initializer to use.'
)
parser
.
add_argument
(
'--emb_init'
,
default
=
'normal'
,
type
=
str
,
help
=
'parameter initializer to use.'
)
parser
.
add_argument
(
'--init_range'
,
type
=
float
,
default
=
0.1
,
help
=
'parameters initialized by U(-init_range, init_range)'
)
parser
.
add_argument
(
'--emb_init_range'
,
type
=
float
,
default
=
0.01
,
help
=
'parameters initialized by U(-init_range, init_range)'
)
parser
.
add_argument
(
'--init_std'
,
type
=
float
,
default
=
0.02
,
help
=
'parameters initialized by N(0, init_std)'
)
parser
.
add_argument
(
'--proj_init_std'
,
type
=
float
,
default
=
0.01
,
help
=
'parameters initialized by N(0, init_std)'
)
parser
.
add_argument
(
'--optim'
,
default
=
'adam'
,
type
=
str
,
choices
=
[
'adam'
,
'sgd'
,
'adagrad'
],
help
=
'optimizer to use.'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.00025
,
help
=
'initial learning rate (0.00025|5 for adam|sgd)'
)
parser
.
add_argument
(
'--mom'
,
type
=
float
,
default
=
0.0
,
help
=
'momentum for sgd'
)
parser
.
add_argument
(
'--scheduler'
,
default
=
'cosine'
,
type
=
str
,
choices
=
[
'cosine'
,
'inv_sqrt'
,
'dev_perf'
,
'constant'
],
help
=
'lr scheduler to use.'
)
parser
.
add_argument
(
'--warmup_step'
,
type
=
int
,
default
=
0
,
help
=
'upper epoch limit'
)
parser
.
add_argument
(
'--decay_rate'
,
type
=
float
,
default
=
0.5
,
help
=
'decay factor when ReduceLROnPlateau is used'
)
parser
.
add_argument
(
'--lr_min'
,
type
=
float
,
default
=
0.0
,
help
=
'minimum learning rate during annealing'
)
parser
.
add_argument
(
'--clip'
,
type
=
float
,
default
=
0.25
,
help
=
'gradient clipping'
)
parser
.
add_argument
(
'--clip_nonemb'
,
action
=
'store_true'
,
help
=
'only clip the gradient of non-embedding params'
)
parser
.
add_argument
(
'--max_step'
,
type
=
int
,
default
=
100000
,
help
=
'upper epoch limit'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
60
,
help
=
'batch size'
)
parser
.
add_argument
(
'--batch_chunk'
,
type
=
int
,
default
=
1
,
help
=
'split batch into chunks to save memory'
)
parser
.
add_argument
(
'--tgt_len'
,
type
=
int
,
default
=
70
,
help
=
'number of tokens to predict'
)
parser
.
add_argument
(
'--eval_tgt_len'
,
type
=
int
,
default
=
50
,
help
=
'number of tokens to predict for evaluation'
)
parser
.
add_argument
(
'--ext_len'
,
type
=
int
,
default
=
0
,
help
=
'length of the extended context'
)
parser
.
add_argument
(
'--mem_len'
,
type
=
int
,
default
=
0
,
help
=
'length of the retained previous heads'
)
parser
.
add_argument
(
'--not_tied'
,
action
=
'store_true'
,
help
=
'do not tie the word embedding and softmax weights'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1111
,
help
=
'random seed'
)
parser
.
add_argument
(
'--cuda'
,
action
=
'store_true'
,
help
=
'use CUDA'
)
parser
.
add_argument
(
'--adaptive'
,
action
=
'store_true'
,
help
=
'use adaptive softmax'
)
parser
.
add_argument
(
'--div_val'
,
type
=
int
,
default
=
1
,
help
=
'divident value for adapative input and softmax'
)
parser
.
add_argument
(
'--pre_lnorm'
,
action
=
'store_true'
,
help
=
'apply LayerNorm to the input instead of the output'
)
parser
.
add_argument
(
'--varlen'
,
action
=
'store_true'
,
help
=
'use variable length'
)
parser
.
add_argument
(
'--multi_gpu'
,
action
=
'store_true'
,
help
=
'use multiple GPU'
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
200
,
help
=
'report interval'
)
parser
.
add_argument
(
'--eval-interval'
,
type
=
int
,
default
=
4000
,
help
=
'evaluation interval'
)
parser
.
add_argument
(
'--work_dir'
,
default
=
'LM-TFM'
,
type
=
str
,
help
=
'experiment directory.'
)
parser
.
add_argument
(
'--restart'
,
action
=
'store_true'
,
help
=
'restart training from the saved checkpoint'
)
parser
.
add_argument
(
'--restart_dir'
,
type
=
str
,
default
=
''
,
help
=
'restart dir'
)
parser
.
add_argument
(
'--debug'
,
action
=
'store_true'
,
help
=
'run in debug mode (do not create exp dir)'
)
parser
.
add_argument
(
'--same_length'
,
action
=
'store_true'
,
help
=
'use the same attn length for all tokens'
)
parser
.
add_argument
(
'--attn_type'
,
type
=
int
,
default
=
0
,
help
=
'attention type. 0 for ours, 1 for Shaw et al,'
'2 for Vaswani et al, 3 for Al Rfou et al.'
)
parser
.
add_argument
(
'--clamp_len'
,
type
=
int
,
default
=-
1
,
help
=
'use the same pos embeddings after clamp_len'
)
parser
.
add_argument
(
'--eta_min'
,
type
=
float
,
default
=
0.0
,
help
=
'min learning rate for cosine scheduler'
)
parser
.
add_argument
(
'--gpu0_bsz'
,
type
=
int
,
default
=-
1
,
help
=
'batch size on gpu 0'
)
parser
.
add_argument
(
'--max_eval_steps'
,
type
=
int
,
default
=-
1
,
help
=
'max eval steps'
)
parser
.
add_argument
(
'--sample_softmax'
,
type
=
int
,
default
=-
1
,
help
=
'number of samples in sampled softmax'
)
parser
.
add_argument
(
'--patience'
,
type
=
int
,
default
=
0
,
help
=
'patience'
)
parser
.
add_argument
(
'--finetune_v2'
,
action
=
'store_true'
,
help
=
'finetune v2'
)
parser
.
add_argument
(
'--finetune_v3'
,
action
=
'store_true'
,
help
=
'finetune v3'
)
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'Run in pseudo-fp16 mode (fp16 storage fp32 math).'
)
parser
.
add_argument
(
'--static-loss-scale'
,
type
=
float
,
default
=
1
,
help
=
'Static loss scale, positive power of 2 values can '
'improve fp16 convergence.'
)
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
help
=
'Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.'
)
parser
.
add_argument
(
'--moe'
,
action
=
'store_true'
,
help
=
'replace position-wise ffn with moe position-wise ffn'
)
parser
.
add_argument
(
'--moe-num-expert'
,
type
=
int
,
default
=
64
,
help
=
'number of experts in MoE'
)
parser
.
add_argument
(
'--moe-top-k'
,
type
=
int
,
default
=
2
,
help
=
'top_k experts in hard gate of moe'
)
args
=
parser
.
parse_args
()
args
.
tied
=
not
args
.
not_tied
assert
args
.
moe_num_expert
>=
args
.
moe_top_k
,
"must have moe-num-expert >= moe-top_k"
if
args
.
d_embed
<
0
:
args
.
d_embed
=
args
.
d_model
assert
args
.
ext_len
>=
0
,
'extended context length must be non-negative'
assert
args
.
batch_size
%
args
.
batch_chunk
==
0
args
.
work_dir
=
'{}-{}'
.
format
(
args
.
work_dir
,
args
.
dataset
)
args
.
work_dir
=
os
.
path
.
join
(
args
.
work_dir
,
time
.
strftime
(
'%Y%m%d-%H%M%S'
))
logging
=
create_exp_dir
(
args
.
work_dir
,
scripts_to_save
=
[
'train.py'
,
'mem_transformer.py'
],
debug
=
args
.
debug
)
# Set the random seed manually for reproducibility.
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
if
torch
.
cuda
.
is_available
():
if
not
args
.
cuda
:
print
(
'WARNING: You have a CUDA device, so you should probably run with --cuda'
)
else
:
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
# Validate `--fp16` option
if
args
.
fp16
:
if
not
args
.
cuda
:
print
(
'WARNING: --fp16 requires --cuda, ignoring --fp16 option'
)
args
.
fp16
=
False
else
:
try
:
from
apex.fp16_utils
import
FP16_Optimizer
except
:
print
(
'WARNING: apex not installed, ignoring --fp16 option'
)
args
.
fp16
=
False
device
=
torch
.
device
(
'cuda'
if
args
.
cuda
else
'cpu'
)
###############################################################################
# Load data
###############################################################################
corpus
=
get_lm_corpus
(
args
.
data
,
args
.
dataset
)
ntokens
=
len
(
corpus
.
vocab
)
args
.
n_token
=
ntokens
eval_batch_size
=
10
tr_iter
=
corpus
.
get_iterator
(
'train'
,
args
.
batch_size
,
args
.
tgt_len
,
device
=
device
,
ext_len
=
args
.
ext_len
)
va_iter
=
corpus
.
get_iterator
(
'valid'
,
eval_batch_size
,
args
.
eval_tgt_len
,
device
=
device
,
ext_len
=
args
.
ext_len
)
te_iter
=
corpus
.
get_iterator
(
'test'
,
eval_batch_size
,
args
.
eval_tgt_len
,
device
=
device
,
ext_len
=
args
.
ext_len
)
# adaptive softmax / embedding
cutoffs
,
tie_projs
=
[],
[
False
]
if
args
.
adaptive
:
assert
args
.
dataset
in
[
'wt103'
,
'lm1b'
]
if
args
.
dataset
==
'wt103'
:
cutoffs
=
[
20000
,
40000
,
200000
]
tie_projs
+=
[
True
]
*
len
(
cutoffs
)
elif
args
.
dataset
==
'lm1b'
:
cutoffs
=
[
60000
,
100000
,
640000
]
tie_projs
+=
[
False
]
*
len
(
cutoffs
)
###############################################################################
# Build the model
###############################################################################
def
init_weight
(
weight
):
if
args
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
args
.
init_range
,
args
.
init_range
)
elif
args
.
init
==
'normal'
:
nn
.
init
.
normal_
(
weight
,
0.0
,
args
.
init_std
)
def
init_bias
(
bias
):
nn
.
init
.
constant_
(
bias
,
0.0
)
def
weights_init
(
m
):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Linear'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
)
and
m
.
weight
is
not
None
:
init_weight
(
m
.
weight
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
init_bias
(
m
.
bias
)
elif
classname
.
find
(
'AdaptiveEmbedding'
)
!=
-
1
:
if
hasattr
(
m
,
'emb_projs'
):
for
i
in
range
(
len
(
m
.
emb_projs
)):
if
m
.
emb_projs
[
i
]
is
not
None
:
nn
.
init
.
normal_
(
m
.
emb_projs
[
i
],
0.0
,
args
.
proj_init_std
)
elif
classname
.
find
(
'Embedding'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
init_weight
(
m
.
weight
)
elif
classname
.
find
(
'ProjectedAdaptiveLogSoftmax'
)
!=
-
1
:
if
hasattr
(
m
,
'cluster_weight'
)
and
m
.
cluster_weight
is
not
None
:
init_weight
(
m
.
cluster_weight
)
if
hasattr
(
m
,
'cluster_bias'
)
and
m
.
cluster_bias
is
not
None
:
init_bias
(
m
.
cluster_bias
)
if
hasattr
(
m
,
'out_projs'
):
for
i
in
range
(
len
(
m
.
out_projs
)):
if
m
.
out_projs
[
i
]
is
not
None
:
nn
.
init
.
normal_
(
m
.
out_projs
[
i
],
0.0
,
args
.
proj_init_std
)
elif
classname
.
find
(
'LayerNorm'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
nn
.
init
.
normal_
(
m
.
weight
,
1.0
,
args
.
init_std
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
init_bias
(
m
.
bias
)
elif
classname
.
find
(
'TransformerLM'
)
!=
-
1
:
if
hasattr
(
m
,
'r_emb'
):
init_weight
(
m
.
r_emb
)
if
hasattr
(
m
,
'r_w_bias'
):
init_weight
(
m
.
r_w_bias
)
if
hasattr
(
m
,
'r_r_bias'
):
init_weight
(
m
.
r_r_bias
)
if
hasattr
(
m
,
'r_bias'
):
init_bias
(
m
.
r_bias
)
def
update_dropout
(
m
):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Dropout'
)
!=
-
1
:
if
hasattr
(
m
,
'p'
):
m
.
p
=
args
.
dropout
def
update_dropatt
(
m
):
if
hasattr
(
m
,
'dropatt'
):
m
.
dropatt
.
p
=
args
.
dropatt
if
args
.
restart
:
with
open
(
os
.
path
.
join
(
args
.
restart_dir
,
'model.pt'
),
'rb'
)
as
f
:
model
=
torch
.
load
(
f
)
if
not
args
.
fp16
:
model
=
model
.
float
()
model
.
apply
(
update_dropout
)
model
.
apply
(
update_dropatt
)
else
:
model
=
MemTransformerLM
(
ntokens
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_model
,
args
.
d_head
,
args
.
d_inner
,
args
.
dropout
,
args
.
dropatt
,
tie_weight
=
args
.
tied
,
d_embed
=
args
.
d_embed
,
div_val
=
args
.
div_val
,
tie_projs
=
tie_projs
,
pre_lnorm
=
args
.
pre_lnorm
,
tgt_len
=
args
.
tgt_len
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
,
moe
=
args
.
moe
,
moe_num_expert
=
args
.
moe_num_expert
,
moe_top_k
=
args
.
moe_top_k
)
model
.
apply
(
weights_init
)
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
args
.
n_nonemb_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
layers
.
parameters
()])
if
args
.
fp16
:
model
=
model
.
half
()
if
args
.
multi_gpu
:
model
=
model
.
to
(
device
)
if
args
.
gpu0_bsz
>=
0
:
para_model
=
BalancedDataParallel
(
args
.
gpu0_bsz
//
args
.
batch_chunk
,
model
,
dim
=
1
).
to
(
device
)
else
:
para_model
=
nn
.
DataParallel
(
model
,
dim
=
1
).
to
(
device
)
else
:
para_model
=
model
.
to
(
device
)
#### optimizer
if
args
.
optim
.
lower
()
==
'sgd'
:
if
args
.
sample_softmax
>
0
:
dense_params
,
sparse_params
=
[],
[]
for
param
in
model
.
parameters
():
if
param
.
size
()
==
model
.
word_emb
.
weight
.
size
():
sparse_params
.
append
(
param
)
else
:
dense_params
.
append
(
param
)
optimizer_sparse
=
optim
.
SGD
(
sparse_params
,
lr
=
args
.
lr
*
2
)
optimizer
=
optim
.
SGD
(
dense_params
,
lr
=
args
.
lr
,
momentum
=
args
.
mom
)
else
:
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
args
.
lr
,
momentum
=
args
.
mom
)
elif
args
.
optim
.
lower
()
==
'adam'
:
if
args
.
sample_softmax
>
0
:
dense_params
,
sparse_params
=
[],
[]
for
param
in
model
.
parameters
():
if
param
.
size
()
==
model
.
word_emb
.
weight
.
size
():
sparse_params
.
append
(
param
)
else
:
dense_params
.
append
(
param
)
optimizer_sparse
=
optim
.
SparseAdam
(
sparse_params
,
lr
=
args
.
lr
)
optimizer
=
optim
.
Adam
(
dense_params
,
lr
=
args
.
lr
)
else
:
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
elif
args
.
optim
.
lower
()
==
'adagrad'
:
optimizer
=
optim
.
Adagrad
(
model
.
parameters
(),
lr
=
args
.
lr
)
#### scheduler
if
args
.
scheduler
==
'cosine'
:
# here we do not set eta_min to lr_min to be backward compatible
# because in previous versions eta_min is default to 0
# rather than the default value of lr_min 1e-6
scheduler
=
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
args
.
max_step
,
eta_min
=
args
.
eta_min
)
# should use eta_min arg
if
args
.
sample_softmax
>
0
:
scheduler_sparse
=
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer_sparse
,
args
.
max_step
,
eta_min
=
args
.
eta_min
)
# should use eta_min arg
elif
args
.
scheduler
==
'inv_sqrt'
:
# originally used for Transformer (in Attention is all you need)
def
lr_lambda
(
step
):
# return a multiplier instead of a learning rate
if
step
==
0
and
args
.
warmup_step
==
0
:
return
1.
else
:
return
1.
/
(
step
**
0.5
)
if
step
>
args
.
warmup_step
\
else
step
/
(
args
.
warmup_step
**
1.5
)
scheduler
=
optim
.
lr_scheduler
.
LambdaLR
(
optimizer
,
lr_lambda
=
lr_lambda
)
elif
args
.
scheduler
==
'dev_perf'
:
scheduler
=
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
factor
=
args
.
decay_rate
,
patience
=
args
.
patience
,
min_lr
=
args
.
lr_min
)
if
args
.
sample_softmax
>
0
:
scheduler_sparse
=
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer_sparse
,
factor
=
args
.
decay_rate
,
patience
=
args
.
patience
,
min_lr
=
args
.
lr_min
)
elif
args
.
scheduler
==
'constant'
:
pass
if
args
.
cuda
and
args
.
fp16
:
# If args.dynamic_loss_scale is False, static_loss_scale will be used.
# If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale.
optimizer
=
FP16_Optimizer
(
optimizer
,
static_loss_scale
=
args
.
static_loss_scale
,
dynamic_loss_scale
=
args
.
dynamic_loss_scale
,
dynamic_loss_args
=
{
'init_scale'
:
2
**
16
})
if
args
.
restart
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
restart_dir
,
'optimizer.pt'
)):
with
open
(
os
.
path
.
join
(
args
.
restart_dir
,
'optimizer.pt'
),
'rb'
)
as
f
:
opt_state_dict
=
torch
.
load
(
f
)
optimizer
.
load_state_dict
(
opt_state_dict
)
else
:
print
(
'Optimizer was not saved. Start from scratch.'
)
logging
(
'='
*
100
)
for
k
,
v
in
args
.
__dict__
.
items
():
logging
(
' - {} : {}'
.
format
(
k
,
v
))
logging
(
'='
*
100
)
logging
(
'#params = {}'
.
format
(
args
.
n_all_param
))
logging
(
'#non emb params = {}'
.
format
(
args
.
n_nonemb_param
))
###############################################################################
# Training code
###############################################################################
def
evaluate
(
eval_iter
):
# Turn on evaluation mode which disables dropout.
model
.
eval
()
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
if
args
.
mem_len
==
0
:
model
.
reset_length
(
args
.
eval_tgt_len
,
args
.
ext_len
+
args
.
tgt_len
-
args
.
eval_tgt_len
,
args
.
mem_len
)
else
:
model
.
reset_length
(
args
.
eval_tgt_len
,
args
.
ext_len
,
args
.
mem_len
+
args
.
tgt_len
-
args
.
eval_tgt_len
)
# Evaluation
total_len
,
total_loss
=
0
,
0.
with
torch
.
no_grad
():
mems
=
tuple
()
for
i
,
(
data
,
target
,
seq_len
)
in
enumerate
(
eval_iter
):
if
args
.
max_eval_steps
>
0
and
i
>=
args
.
max_eval_steps
:
break
ret
=
model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
loss
=
loss
.
mean
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_len
+=
seq_len
# Switch back to the training mode
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
train
()
return
total_loss
/
total_len
def
train
():
# Turn on training mode which enables dropout.
global
train_step
,
train_loss
,
best_val_loss
,
eval_start_time
,
log_start_time
model
.
train
()
if
args
.
batch_chunk
>
1
:
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
else
:
mems
=
tuple
()
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
model
.
zero_grad
()
if
args
.
batch_chunk
>
1
:
data_chunks
=
torch
.
chunk
(
data
,
args
.
batch_chunk
,
1
)
target_chunks
=
torch
.
chunk
(
target
,
args
.
batch_chunk
,
1
)
for
i
in
range
(
args
.
batch_chunk
):
data_i
=
data_chunks
[
i
].
contiguous
()
target_i
=
target_chunks
[
i
].
contiguous
()
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
train_loss
+=
loss
.
float
().
item
()
else
:
ret
=
para_model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
train_loss
+=
loss
.
float
().
item
()
if
args
.
fp16
:
optimizer
.
clip_master_grads
(
args
.
clip
)
else
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
clip
)
optimizer
.
step
()
if
args
.
sample_softmax
>
0
:
optimizer_sparse
.
step
()
# step-wise learning rate annealing
train_step
+=
1
if
args
.
scheduler
in
[
'cosine'
,
'constant'
,
'dev_perf'
]:
# linear warmup stage
if
train_step
<
args
.
warmup_step
:
curr_lr
=
args
.
lr
*
train_step
/
args
.
warmup_step
optimizer
.
param_groups
[
0
][
'lr'
]
=
curr_lr
if
args
.
sample_softmax
>
0
:
optimizer_sparse
.
param_groups
[
0
][
'lr'
]
=
curr_lr
*
2
else
:
if
args
.
scheduler
==
'cosine'
:
scheduler
.
step
(
train_step
)
if
args
.
sample_softmax
>
0
:
scheduler_sparse
.
step
(
train_step
)
elif
args
.
scheduler
==
'inv_sqrt'
:
scheduler
.
step
(
train_step
)
if
train_step
%
args
.
log_interval
==
0
:
cur_loss
=
train_loss
/
args
.
log_interval
elapsed
=
time
.
time
()
-
log_start_time
log_str
=
'| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} '
\
'| ms/batch {:5.2f} | loss {:5.2f}'
.
format
(
epoch
,
train_step
,
batch
+
1
,
optimizer
.
param_groups
[
0
][
'lr'
],
elapsed
*
1000
/
args
.
log_interval
,
cur_loss
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
else
:
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
logging
(
log_str
)
train_loss
=
0
log_start_time
=
time
.
time
()
if
train_step
%
args
.
eval_interval
==
0
:
val_loss
=
evaluate
(
va_iter
)
logging
(
'-'
*
100
)
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
'| valid loss {:5.2f}'
.
format
(
train_step
//
args
.
eval_interval
,
train_step
,
(
time
.
time
()
-
eval_start_time
),
val_loss
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
else
:
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
logging
(
log_str
)
logging
(
'-'
*
100
)
# Save the model if the validation loss is the best we've seen so far.
if
not
best_val_loss
or
val_loss
<
best_val_loss
:
if
not
args
.
debug
:
with
open
(
os
.
path
.
join
(
args
.
work_dir
,
'model.pt'
),
'wb'
)
as
f
:
torch
.
save
(
model
,
f
)
with
open
(
os
.
path
.
join
(
args
.
work_dir
,
'optimizer.pt'
),
'wb'
)
as
f
:
torch
.
save
(
optimizer
.
state_dict
(),
f
)
best_val_loss
=
val_loss
# dev-performance based learning rate annealing
if
args
.
scheduler
==
'dev_perf'
:
scheduler
.
step
(
val_loss
)
if
args
.
sample_softmax
>
0
:
scheduler_sparse
.
step
(
val_loss
)
eval_start_time
=
time
.
time
()
if
train_step
==
args
.
max_step
:
break
# Loop over epochs.
train_step
=
0
train_loss
=
0
best_val_loss
=
None
log_start_time
=
time
.
time
()
eval_start_time
=
time
.
time
()
# At any point you can hit Ctrl + C to break out of training early.
try
:
for
epoch
in
itertools
.
count
(
start
=
1
):
train
()
if
train_step
==
args
.
max_step
:
logging
(
'-'
*
100
)
logging
(
'End of training'
)
break
except
KeyboardInterrupt
:
logging
(
'-'
*
100
)
logging
(
'Exiting from training early'
)
# Load the best saved model.
with
open
(
os
.
path
.
join
(
args
.
work_dir
,
'model.pt'
),
'rb'
)
as
f
:
model
=
torch
.
load
(
f
)
para_model
=
model
.
to
(
device
)
# Run on test data.
test_loss
=
evaluate
(
te_iter
)
logging
(
'='
*
100
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
test_loss
,
test_loss
/
math
.
log
(
2
)))
else
:
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
test_loss
,
math
.
exp
(
test_loss
)))
logging
(
'='
*
100
)
examples/transformer-xl/utils/adaptive_softmax.py
0 → 100644
View file @
0f091a1d
from
collections
import
defaultdict
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
AdaptiveLogSoftmax
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
n_classes
,
cutoffs
,
keep_order
=
False
):
super
(
AdaptiveLogSoftmax
,
self
).
__init__
()
cutoffs
=
list
(
cutoffs
)
if
(
cutoffs
!=
sorted
(
cutoffs
))
\
or
(
min
(
cutoffs
)
<=
0
)
\
or
(
max
(
cutoffs
)
>=
(
n_classes
-
1
))
\
or
(
len
(
set
(
cutoffs
))
!=
len
(
cutoffs
))
\
or
any
([
int
(
c
)
!=
c
for
c
in
cutoffs
]):
raise
ValueError
(
"cutoffs should be a sequence of unique, positive "
"integers sorted in an increasing order, where "
"each value is between 1 and n_classes-1"
)
self
.
in_features
=
in_features
self
.
n_classes
=
n_classes
self
.
cutoffs
=
cutoffs
+
[
n_classes
]
self
.
shortlist_size
=
self
.
cutoffs
[
0
]
self
.
n_clusters
=
len
(
self
.
cutoffs
)
-
1
self
.
head_size
=
self
.
shortlist_size
+
self
.
n_clusters
self
.
cluster_weight
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
n_clusters
,
self
.
in_features
))
self
.
cluster_bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
n_clusters
))
self
.
keep_order
=
keep_order
def
forward
(
self
,
hidden
,
target
,
weight
,
bias
,
keep_order
=
False
):
if
hidden
.
size
(
0
)
!=
target
.
size
(
0
):
raise
RuntimeError
(
'Input and target should have the same size '
'in the batch dimension.'
)
head_weight
=
torch
.
cat
(
[
weight
[:
self
.
shortlist_size
],
self
.
cluster_weight
],
dim
=
0
)
head_bias
=
torch
.
cat
(
[
bias
[:
self
.
shortlist_size
],
self
.
cluster_bias
],
dim
=
0
)
head_logit
=
F
.
linear
(
hidden
,
head_weight
,
bias
=
head_bias
)
head_logprob
=
F
.
log_softmax
(
head_logit
,
dim
=
1
)
nll
=
torch
.
zeros_like
(
target
,
dtype
=
hidden
.
dtype
,
device
=
hidden
.
device
)
offset
=
0
cutoff_values
=
[
0
]
+
self
.
cutoffs
for
i
in
range
(
len
(
cutoff_values
)
-
1
):
l_idx
,
h_idx
=
cutoff_values
[
i
],
cutoff_values
[
i
+
1
]
mask_i
=
(
target
>=
l_idx
)
&
(
target
<
h_idx
)
indices_i
=
mask_i
.
nonzero
().
squeeze
()
if
indices_i
.
numel
()
==
0
:
continue
target_i
=
target
.
index_select
(
0
,
indices_i
)
-
l_idx
head_logprob_i
=
head_logprob
.
index_select
(
0
,
indices_i
)
if
i
==
0
:
logprob_i
=
head_logprob_i
.
gather
(
1
,
target_i
[:,
None
]).
squeeze
(
1
)
else
:
weight_i
=
weight
[
l_idx
:
h_idx
]
bias_i
=
bias
[
l_idx
:
h_idx
]
hidden_i
=
hidden
.
index_select
(
0
,
indices_i
)
tail_logit_i
=
F
.
linear
(
hidden_i
,
weight_i
,
bias
=
bias_i
)
tail_logprob_i
=
F
.
log_softmax
(
tail_logit_i
,
dim
=
1
)
logprob_i
=
head_logprob_i
[:,
-
i
]
\
+
tail_logprob_i
.
gather
(
1
,
target_i
[:,
None
]).
squeeze
(
1
)
if
(
hasattr
(
self
,
'keep_order'
)
and
self
.
keep_order
)
or
keep_order
:
nll
.
index_copy_
(
0
,
indices_i
,
-
logprob_i
)
else
:
nll
[
offset
:
offset
+
logprob_i
.
size
(
0
)].
copy_
(
-
logprob_i
)
offset
+=
logprob_i
.
size
(
0
)
return
nll
examples/transformer-xl/utils/data_parallel.py
0 → 100644
View file @
0f091a1d
from
torch.nn.parallel
import
DataParallel
import
torch
from
torch.nn.parallel._functions
import
Scatter
from
torch.nn.parallel.parallel_apply
import
parallel_apply
def
scatter
(
inputs
,
target_gpus
,
chunk_sizes
,
dim
=
0
):
r
"""
Slices tensors into approximately equal chunks and
distributes them across given GPUs. Duplicates
references to objects that are not tensors.
"""
def
scatter_map
(
obj
):
if
isinstance
(
obj
,
torch
.
Tensor
):
try
:
return
Scatter
.
apply
(
target_gpus
,
chunk_sizes
,
dim
,
obj
)
except
:
print
(
'obj'
,
obj
.
size
())
print
(
'dim'
,
dim
)
print
(
'chunk_sizes'
,
chunk_sizes
)
quit
()
if
isinstance
(
obj
,
tuple
)
and
len
(
obj
)
>
0
:
return
list
(
zip
(
*
map
(
scatter_map
,
obj
)))
if
isinstance
(
obj
,
list
)
and
len
(
obj
)
>
0
:
return
list
(
map
(
list
,
zip
(
*
map
(
scatter_map
,
obj
))))
if
isinstance
(
obj
,
dict
)
and
len
(
obj
)
>
0
:
return
list
(
map
(
type
(
obj
),
zip
(
*
map
(
scatter_map
,
obj
.
items
()))))
return
[
obj
for
targets
in
target_gpus
]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try
:
return
scatter_map
(
inputs
)
finally
:
scatter_map
=
None
def
scatter_kwargs
(
inputs
,
kwargs
,
target_gpus
,
chunk_sizes
,
dim
=
0
):
r
"""Scatter with support for kwargs dictionary"""
inputs
=
scatter
(
inputs
,
target_gpus
,
chunk_sizes
,
dim
)
if
inputs
else
[]
kwargs
=
scatter
(
kwargs
,
target_gpus
,
chunk_sizes
,
dim
)
if
kwargs
else
[]
if
len
(
inputs
)
<
len
(
kwargs
):
inputs
.
extend
([()
for
_
in
range
(
len
(
kwargs
)
-
len
(
inputs
))])
elif
len
(
kwargs
)
<
len
(
inputs
):
kwargs
.
extend
([{}
for
_
in
range
(
len
(
inputs
)
-
len
(
kwargs
))])
inputs
=
tuple
(
inputs
)
kwargs
=
tuple
(
kwargs
)
return
inputs
,
kwargs
class
BalancedDataParallel
(
DataParallel
):
def
__init__
(
self
,
gpu0_bsz
,
*
args
,
**
kwargs
):
self
.
gpu0_bsz
=
gpu0_bsz
super
().
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
not
self
.
device_ids
:
return
self
.
module
(
*
inputs
,
**
kwargs
)
if
self
.
gpu0_bsz
==
0
:
device_ids
=
self
.
device_ids
[
1
:]
else
:
device_ids
=
self
.
device_ids
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
device_ids
)
if
len
(
self
.
device_ids
)
==
1
:
return
self
.
module
(
*
inputs
[
0
],
**
kwargs
[
0
])
replicas
=
self
.
replicate
(
self
.
module
,
self
.
device_ids
)
if
self
.
gpu0_bsz
==
0
:
replicas
=
replicas
[
1
:]
outputs
=
self
.
parallel_apply
(
replicas
,
device_ids
,
inputs
,
kwargs
)
return
self
.
gather
(
outputs
,
self
.
output_device
)
def
parallel_apply
(
self
,
replicas
,
device_ids
,
inputs
,
kwargs
):
return
parallel_apply
(
replicas
,
inputs
,
kwargs
,
device_ids
)
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
bsz
=
inputs
[
0
].
size
(
self
.
dim
)
num_dev
=
len
(
self
.
device_ids
)
gpu0_bsz
=
self
.
gpu0_bsz
bsz_unit
=
(
bsz
-
gpu0_bsz
)
//
(
num_dev
-
1
)
if
gpu0_bsz
<
bsz_unit
:
chunk_sizes
=
[
gpu0_bsz
]
+
[
bsz_unit
]
*
(
num_dev
-
1
)
delta
=
bsz
-
sum
(
chunk_sizes
)
for
i
in
range
(
delta
):
chunk_sizes
[
i
+
1
]
+=
1
if
gpu0_bsz
==
0
:
chunk_sizes
=
chunk_sizes
[
1
:]
else
:
return
super
().
scatter
(
inputs
,
kwargs
,
device_ids
)
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
chunk_sizes
,
dim
=
self
.
dim
)
examples/transformer-xl/utils/exp_utils.py
0 → 100644
View file @
0f091a1d
import
functools
import
os
,
shutil
import
numpy
as
np
import
torch
def
logging
(
s
,
log_path
,
print_
=
True
,
log_
=
True
):
if
print_
:
print
(
s
)
if
log_
:
with
open
(
log_path
,
'a+'
)
as
f_log
:
f_log
.
write
(
s
+
'
\n
'
)
def
get_logger
(
log_path
,
**
kwargs
):
return
functools
.
partial
(
logging
,
log_path
=
log_path
,
**
kwargs
)
def
create_exp_dir
(
dir_path
,
scripts_to_save
=
None
,
debug
=
False
):
if
debug
:
print
(
'Debug Mode : no experiment dir created'
)
return
functools
.
partial
(
logging
,
log_path
=
None
,
log_
=
False
)
if
not
os
.
path
.
exists
(
dir_path
):
os
.
makedirs
(
dir_path
)
print
(
'Experiment dir : {}'
.
format
(
dir_path
))
if
scripts_to_save
is
not
None
:
script_path
=
os
.
path
.
join
(
dir_path
,
'scripts'
)
if
not
os
.
path
.
exists
(
script_path
):
os
.
makedirs
(
script_path
)
for
script
in
scripts_to_save
:
dst_file
=
os
.
path
.
join
(
dir_path
,
'scripts'
,
os
.
path
.
basename
(
script
))
shutil
.
copyfile
(
script
,
dst_file
)
return
get_logger
(
log_path
=
os
.
path
.
join
(
dir_path
,
'log.txt'
))
def
save_checkpoint
(
model
,
optimizer
,
path
,
epoch
):
torch
.
save
(
model
,
os
.
path
.
join
(
path
,
'model_{}.pt'
.
format
(
epoch
)))
torch
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
path
,
'optimizer_{}.pt'
.
format
(
epoch
)))
examples/transformer-xl/utils/log_uniform_sampler.py
0 → 100644
View file @
0f091a1d
import
torch
from
torch
import
nn
import
numpy
as
np
class
LogUniformSampler
(
object
):
def
__init__
(
self
,
range_max
,
n_sample
):
"""
Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
`P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
expected count can be approximated by 1 - (1 - p)^n
and we use a numerically stable version -expm1(num_tries * log1p(-p))
Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
"""
with
torch
.
no_grad
():
self
.
range_max
=
range_max
log_indices
=
torch
.
arange
(
1.
,
range_max
+
2.
,
1.
).
log_
()
self
.
dist
=
(
log_indices
[
1
:]
-
log_indices
[:
-
1
])
/
log_indices
[
-
1
]
# print('P', self.dist.numpy().tolist()[-30:])
self
.
log_q
=
(
-
(
-
self
.
dist
.
double
().
log1p_
()
*
2
*
n_sample
).
expm1_
()).
log_
().
float
()
self
.
n_sample
=
n_sample
def
sample
(
self
,
labels
):
"""
labels: [b1, b2]
Return
true_log_probs: [b1, b2]
samp_log_probs: [n_sample]
neg_samples: [n_sample]
"""
# neg_samples = torch.empty(0).long()
n_sample
=
self
.
n_sample
n_tries
=
2
*
n_sample
with
torch
.
no_grad
():
neg_samples
=
torch
.
multinomial
(
self
.
dist
,
n_tries
,
replacement
=
True
).
unique
()
device
=
labels
.
device
neg_samples
=
neg_samples
.
to
(
device
)
true_log_probs
=
self
.
log_q
[
labels
].
to
(
device
)
samp_log_probs
=
self
.
log_q
[
neg_samples
].
to
(
device
)
return
true_log_probs
,
samp_log_probs
,
neg_samples
def
sample_logits
(
embedding
,
bias
,
labels
,
inputs
,
sampler
):
"""
embedding: an nn.Embedding layer
bias: [n_vocab]
labels: [b1, b2]
inputs: [b1, b2, n_emb]
sampler: you may use a LogUniformSampler
Return
logits: [b1, b2, 1 + n_sample]
"""
true_log_probs
,
samp_log_probs
,
neg_samples
=
sampler
.
sample
(
labels
)
n_sample
=
neg_samples
.
size
(
0
)
b1
,
b2
=
labels
.
size
(
0
),
labels
.
size
(
1
)
all_ids
=
torch
.
cat
([
labels
.
view
(
-
1
),
neg_samples
])
all_w
=
embedding
(
all_ids
)
true_w
=
all_w
[:
-
n_sample
].
view
(
b1
,
b2
,
-
1
)
sample_w
=
all_w
[
-
n_sample
:].
view
(
n_sample
,
-
1
)
all_b
=
bias
[
all_ids
]
true_b
=
all_b
[:
-
n_sample
].
view
(
b1
,
b2
)
sample_b
=
all_b
[
-
n_sample
:]
hit
=
(
labels
[:,
:,
None
]
==
neg_samples
).
detach
()
true_logits
=
torch
.
einsum
(
'ijk,ijk->ij'
,
[
true_w
,
inputs
])
+
true_b
-
true_log_probs
sample_logits
=
torch
.
einsum
(
'lk,ijk->ijl'
,
[
sample_w
,
inputs
])
+
sample_b
-
samp_log_probs
sample_logits
.
masked_fill_
(
hit
,
-
1e30
)
logits
=
torch
.
cat
([
true_logits
[:,
:,
None
],
sample_logits
],
-
1
)
return
logits
# class LogUniformSampler(object):
# def __init__(self, range_max, unique=False):
# """
# Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
# `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
# """
# self.range_max = range_max
# log_indices = torch.arange(1., range_max+2., 1.).log_()
# self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
# self.unique = unique
# if self.unique:
# self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
# def sample(self, n_sample, labels):
# pos_sample, new_labels = labels.unique(return_inverse=True)
# n_pos_sample = pos_sample.size(0)
# n_neg_sample = n_sample - n_pos_sample
# if self.unique:
# self.exclude_mask.index_fill_(0, pos_sample, 1)
# sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
# self.exclude_mask.index_fill_(0, pos_sample, 0)
# else:
# sample_dist = self.dist
# neg_sample = torch.multinomial(sample_dist, n_neg_sample)
# sample = torch.cat([pos_sample, neg_sample])
# sample_prob = self.dist[sample]
# return new_labels, sample, sample_prob
if
__name__
==
'__main__'
:
S
,
B
=
3
,
4
n_vocab
=
10000
n_sample
=
5
H
=
32
labels
=
torch
.
LongTensor
(
S
,
B
).
random_
(
0
,
n_vocab
)
# sampler = LogUniformSampler(n_vocab, unique=False)
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
sampler
=
LogUniformSampler
(
n_vocab
,
unique
=
True
)
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
# print('true_probs', true_probs.numpy().tolist())
# print('samp_probs', samp_probs.numpy().tolist())
# print('neg_samples', neg_samples.numpy().tolist())
# print('sum', torch.sum(sampler.dist).item())
# assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
embedding
=
nn
.
Embedding
(
n_vocab
,
H
)
bias
=
torch
.
zeros
(
n_vocab
)
inputs
=
torch
.
Tensor
(
S
,
B
,
H
).
normal_
()
logits
,
out_labels
=
sample_logits
(
embedding
,
bias
,
labels
,
inputs
,
sampler
,
n_sample
)
print
(
'logits'
,
logits
.
detach
().
numpy
().
tolist
())
print
(
'logits shape'
,
logits
.
size
())
print
(
'out_labels'
,
out_labels
.
detach
().
numpy
().
tolist
())
print
(
'out_labels shape'
,
out_labels
.
size
())
examples/transformer-xl/utils/proj_adaptive_softmax.py
0 → 100644
View file @
0f091a1d
from
collections
import
defaultdict
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
CUDA_MAJOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
0
])
CUDA_MINOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
1
])
class
Projection
(
nn
.
Module
):
def
__init__
(
self
,
out_feat
,
in_feat
):
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
out_feat
,
in_feat
))
class
ProjectedAdaptiveLogSoftmax
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
keep_order
=
False
):
super
(
ProjectedAdaptiveLogSoftmax
,
self
).
__init__
()
self
.
n_token
=
n_token
self
.
d_embed
=
d_embed
self
.
d_proj
=
d_proj
self
.
cutoffs
=
cutoffs
+
[
n_token
]
self
.
cutoff_ends
=
[
0
]
+
self
.
cutoffs
self
.
div_val
=
div_val
self
.
shortlist_size
=
self
.
cutoffs
[
0
]
self
.
n_clusters
=
len
(
self
.
cutoffs
)
-
1
self
.
head_size
=
self
.
shortlist_size
+
self
.
n_clusters
if
self
.
n_clusters
>
0
:
self
.
cluster_weight
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
n_clusters
,
self
.
d_embed
))
self
.
cluster_bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
n_clusters
))
self
.
out_layers
=
nn
.
ModuleList
()
self
.
out_projs
=
nn
.
ModuleList
()
if
div_val
==
1
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
if
d_proj
!=
d_embed
:
self
.
out_projs
.
append
(
Projection
(
d_proj
,
d_embed
)
)
else
:
self
.
out_projs
.
append
(
None
)
self
.
out_layers
.
append
(
nn
.
Linear
(
d_embed
,
n_token
))
else
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
out_projs
.
append
(
Projection
(
d_proj
,
d_emb_i
)
)
self
.
out_layers
.
append
(
nn
.
Linear
(
d_emb_i
,
r_idx
-
l_idx
))
self
.
keep_order
=
keep_order
def
_compute_logit
(
self
,
hidden
,
weight
,
bias
,
proj
):
if
proj
is
None
:
logit
=
F
.
linear
(
hidden
,
weight
,
bias
=
bias
)
else
:
# if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
proj_hid
=
F
.
linear
(
hidden
,
proj
.
t
().
contiguous
())
logit
=
F
.
linear
(
proj_hid
,
weight
,
bias
=
bias
)
# else:
# logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
# if bias is not None:
# logit = logit + bias
return
logit
def
forward
(
self
,
hidden
,
target
,
keep_order
=
False
):
'''
hidden :: [len*bsz x d_proj]
target :: [len*bsz]
'''
if
hidden
.
size
(
0
)
!=
target
.
size
(
0
):
raise
RuntimeError
(
'Input and target should have the same size '
'in the batch dimension.'
)
if
self
.
n_clusters
==
0
:
logit
=
self
.
_compute_logit
(
hidden
,
self
.
out_layers
[
0
].
weight
,
self
.
out_layers
[
0
].
bias
,
self
.
out_projs
[
0
].
weight
if
self
.
out_projs
[
0
]
is
not
None
else
None
)
nll
=
-
F
.
log_softmax
(
logit
,
dim
=-
1
)
\
.
gather
(
1
,
target
.
unsqueeze
(
1
)).
squeeze
(
1
)
else
:
# construct weights and biases
weights
,
biases
=
[],
[]
for
i
in
range
(
len
(
self
.
cutoffs
)):
if
self
.
div_val
==
1
:
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
weight_i
=
self
.
out_layers
[
0
].
weight
[
l_idx
:
r_idx
]
bias_i
=
self
.
out_layers
[
0
].
bias
[
l_idx
:
r_idx
]
else
:
weight_i
=
self
.
out_layers
[
i
].
weight
bias_i
=
self
.
out_layers
[
i
].
bias
if
i
==
0
:
weight_i
=
torch
.
cat
(
[
weight_i
,
self
.
cluster_weight
],
dim
=
0
)
bias_i
=
torch
.
cat
(
[
bias_i
,
self
.
cluster_bias
],
dim
=
0
)
weights
.
append
(
weight_i
)
biases
.
append
(
bias_i
)
head_weight
,
head_bias
,
head_proj
=
weights
[
0
],
biases
[
0
],
self
.
out_projs
[
0
].
weight
if
self
.
out_projs
[
0
]
is
not
None
else
None
head_logit
=
self
.
_compute_logit
(
hidden
,
head_weight
,
head_bias
,
head_proj
)
head_logprob
=
F
.
log_softmax
(
head_logit
,
dim
=
1
)
nll
=
torch
.
zeros_like
(
target
,
dtype
=
hidden
.
dtype
,
device
=
hidden
.
device
)
offset
=
0
cutoff_values
=
[
0
]
+
self
.
cutoffs
for
i
in
range
(
len
(
cutoff_values
)
-
1
):
l_idx
,
r_idx
=
cutoff_values
[
i
],
cutoff_values
[
i
+
1
]
mask_i
=
(
target
>=
l_idx
)
&
(
target
<
r_idx
)
indices_i
=
mask_i
.
nonzero
().
squeeze
()
if
indices_i
.
numel
()
==
0
:
continue
target_i
=
target
.
index_select
(
0
,
indices_i
)
-
l_idx
head_logprob_i
=
head_logprob
.
index_select
(
0
,
indices_i
)
if
i
==
0
:
logprob_i
=
head_logprob_i
.
gather
(
1
,
target_i
[:,
None
]).
squeeze
(
1
)
else
:
weight_i
,
bias_i
,
proj_i
=
weights
[
i
],
biases
[
i
],
self
.
out_projs
[
i
].
weight
if
self
.
out_projs
[
i
]
is
not
None
else
None
hidden_i
=
hidden
.
index_select
(
0
,
indices_i
)
tail_logit_i
=
self
.
_compute_logit
(
hidden_i
,
weight_i
,
bias_i
,
proj_i
)
tail_logprob_i
=
F
.
log_softmax
(
tail_logit_i
,
dim
=
1
)
logprob_i
=
head_logprob_i
[:,
-
i
]
\
+
tail_logprob_i
.
gather
(
1
,
target_i
[:,
None
]).
squeeze
(
1
)
if
(
hasattr
(
self
,
'keep_order'
)
and
self
.
keep_order
)
or
keep_order
:
nll
.
index_copy_
(
0
,
indices_i
,
-
logprob_i
)
else
:
nll
[
offset
:
offset
+
logprob_i
.
size
(
0
)].
copy_
(
-
logprob_i
)
offset
+=
logprob_i
.
size
(
0
)
return
nll
examples/transformer-xl/utils/vocabulary.py
0 → 100644
View file @
0f091a1d
import
os
from
collections
import
Counter
,
OrderedDict
import
torch
class
Vocab
(
object
):
def
__init__
(
self
,
special
=
[],
min_freq
=
0
,
max_size
=
None
,
lower_case
=
True
,
delimiter
=
None
,
vocab_file
=
None
):
self
.
counter
=
Counter
()
self
.
special
=
special
self
.
min_freq
=
min_freq
self
.
max_size
=
max_size
self
.
lower_case
=
lower_case
self
.
delimiter
=
delimiter
self
.
vocab_file
=
vocab_file
def
tokenize
(
self
,
line
,
add_eos
=
False
,
add_double_eos
=
False
):
line
=
line
.
strip
()
# convert to lower case
if
self
.
lower_case
:
line
=
line
.
lower
()
# empty delimiter '' will evaluate False
if
self
.
delimiter
==
''
:
symbols
=
line
else
:
symbols
=
line
.
split
(
self
.
delimiter
)
if
add_double_eos
:
# lm1b
return
[
'<S>'
]
+
symbols
+
[
'<S>'
]
elif
add_eos
:
return
symbols
+
[
'<eos>'
]
else
:
return
symbols
def
count_file
(
self
,
path
,
verbose
=
False
,
add_eos
=
False
):
if
verbose
:
print
(
'counting file {} ...'
.
format
(
path
))
assert
os
.
path
.
exists
(
path
)
sents
=
[]
with
open
(
path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
symbols
=
self
.
tokenize
(
line
,
add_eos
=
add_eos
)
self
.
counter
.
update
(
symbols
)
sents
.
append
(
symbols
)
return
sents
def
count_sents
(
self
,
sents
,
verbose
=
False
):
"""
sents : a list of sentences, each a list of tokenized symbols
"""
if
verbose
:
print
(
'counting {} sents ...'
.
format
(
len
(
sents
)))
for
idx
,
symbols
in
enumerate
(
sents
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
self
.
counter
.
update
(
symbols
)
def
_build_from_file
(
self
,
vocab_file
):
self
.
idx2sym
=
[]
self
.
sym2idx
=
OrderedDict
()
with
open
(
vocab_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
symb
=
line
.
strip
().
split
()[
0
]
self
.
add_symbol
(
symb
)
self
.
unk_idx
=
self
.
sym2idx
[
'<UNK>'
]
def
build_vocab
(
self
):
if
self
.
vocab_file
:
print
(
'building vocab from {}'
.
format
(
self
.
vocab_file
))
self
.
_build_from_file
(
self
.
vocab_file
)
print
(
'final vocab size {}'
.
format
(
len
(
self
)))
else
:
print
(
'building vocab with min_freq={}, max_size={}'
.
format
(
self
.
min_freq
,
self
.
max_size
))
self
.
idx2sym
=
[]
self
.
sym2idx
=
OrderedDict
()
for
sym
in
self
.
special
:
self
.
add_special
(
sym
)
for
sym
,
cnt
in
self
.
counter
.
most_common
(
self
.
max_size
):
if
cnt
<
self
.
min_freq
:
break
self
.
add_symbol
(
sym
)
print
(
'final vocab size {} from {} unique tokens'
.
format
(
len
(
self
),
len
(
self
.
counter
)))
def
encode_file
(
self
,
path
,
ordered
=
False
,
verbose
=
False
,
add_eos
=
True
,
add_double_eos
=
False
):
if
verbose
:
print
(
'encoding file {} ...'
.
format
(
path
))
assert
os
.
path
.
exists
(
path
)
encoded
=
[]
with
open
(
path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
symbols
=
self
.
tokenize
(
line
,
add_eos
=
add_eos
,
add_double_eos
=
add_double_eos
)
encoded
.
append
(
self
.
convert_to_tensor
(
symbols
))
if
ordered
:
encoded
=
torch
.
cat
(
encoded
)
return
encoded
def
encode_sents
(
self
,
sents
,
ordered
=
False
,
verbose
=
False
):
if
verbose
:
print
(
'encoding {} sents ...'
.
format
(
len
(
sents
)))
encoded
=
[]
for
idx
,
symbols
in
enumerate
(
sents
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
encoded
.
append
(
self
.
convert_to_tensor
(
symbols
))
if
ordered
:
encoded
=
torch
.
cat
(
encoded
)
return
encoded
def
add_special
(
self
,
sym
):
if
sym
not
in
self
.
sym2idx
:
self
.
idx2sym
.
append
(
sym
)
self
.
sym2idx
[
sym
]
=
len
(
self
.
idx2sym
)
-
1
setattr
(
self
,
'{}_idx'
.
format
(
sym
.
strip
(
'<>'
)),
self
.
sym2idx
[
sym
])
def
add_symbol
(
self
,
sym
):
if
sym
not
in
self
.
sym2idx
:
self
.
idx2sym
.
append
(
sym
)
self
.
sym2idx
[
sym
]
=
len
(
self
.
idx2sym
)
-
1
def
get_sym
(
self
,
idx
):
assert
0
<=
idx
<
len
(
self
),
'Index {} out of range'
.
format
(
idx
)
return
self
.
idx2sym
[
idx
]
def
get_idx
(
self
,
sym
):
if
sym
in
self
.
sym2idx
:
return
self
.
sym2idx
[
sym
]
else
:
# print('encounter unk {}'.format(sym))
assert
'<eos>'
not
in
sym
assert
hasattr
(
self
,
'unk_idx'
)
return
self
.
sym2idx
.
get
(
sym
,
self
.
unk_idx
)
def
get_symbols
(
self
,
indices
):
return
[
self
.
get_sym
(
idx
)
for
idx
in
indices
]
def
get_indices
(
self
,
symbols
):
return
[
self
.
get_idx
(
sym
)
for
sym
in
symbols
]
def
convert_to_tensor
(
self
,
symbols
):
return
torch
.
LongTensor
(
self
.
get_indices
(
symbols
))
def
convert_to_sent
(
self
,
indices
,
exclude
=
None
):
if
exclude
is
None
:
return
' '
.
join
([
self
.
get_sym
(
idx
)
for
idx
in
indices
])
else
:
return
' '
.
join
([
self
.
get_sym
(
idx
)
for
idx
in
indices
if
idx
not
in
exclude
])
def
__len__
(
self
):
return
len
(
self
.
idx2sym
)
fmoe/__init__.py
0 → 100644
View file @
0f091a1d
r
"""
The fmoe package contains MoE Layers only.
"""
from
.layers
import
FMoE
from
.linear
import
FMoELinear
from
.transformer
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
fmoe/balance.py
0 → 100644
View file @
0f091a1d
import
torch
import
torch.nn.functional
as
F
metrics
=
{
"coefficient-variation"
:
lambda
c_e
:
torch
.
std
(
c_e
)
/
torch
.
mean
(
c_e
),
"Lmax-over-Lmin"
:
lambda
c_e
:
(
torch
.
max
(
c_e
)
+
1
)
/
(
torch
.
min
(
c_e
)
+
1
),
"Lmax-over-Lmean"
:
lambda
c_e
:
torch
.
max
(
c_e
)
/
torch
.
mean
(
c_e
),
}
def
reset_balance_profile
(
balance_dict
,
num_layers
,
balance_strategy
):
for
key
in
metrics
:
balance_dict
[
key
]
=
[
None
for
_
in
range
(
num_layers
)]
if
balance_strategy
:
balance_dict
[
f
"
{
balance_strategy
}
_loss"
]
=
[
None
for
_
in
range
(
num_layers
)]
def
update_balance_profile
(
balance_dict
,
gate_top_k_idx
,
_gate_score_top_k
,
gate_context
,
layer_idx
,
num_expert
,
balance_strategy
,
):
# Fill in this function to conduct balance related jobs
pass
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment