Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
1ea0bcf4
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c9bd4d433845921ddf7c0b0a50be3c7bdf7a80fc"
Commit
1ea0bcf4
authored
Feb 18, 2019
by
Zihao Ye
Committed by
Minjie Wang
Feb 17, 2019
Browse files
[Model] fix link & beam search (#394)
parent
ae1806f6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
21 deletions
+25
-21
examples/pytorch/transformer/dataset/__init__.py
examples/pytorch/transformer/dataset/__init__.py
+1
-1
examples/pytorch/transformer/dataset/utils.py
examples/pytorch/transformer/dataset/utils.py
+2
-2
examples/pytorch/transformer/modules/models.py
examples/pytorch/transformer/modules/models.py
+20
-16
examples/pytorch/transformer/translation_test.py
examples/pytorch/transformer/translation_test.py
+1
-1
examples/pytorch/transformer/translation_train.py
examples/pytorch/transformer/translation_train.py
+1
-1
No files found.
examples/pytorch/transformer/dataset/__init__.py
View file @
1ea0bcf4
...
...
@@ -176,7 +176,7 @@ def get_dataset(dataset):
(
'en'
,
'de'
),
train
=
'train.tok.clean.bpe.32000'
,
valid
=
'newstest2013.tok.bpe.32000'
,
test
=
'newstest2014.tok.bpe.32000'
,
test
=
'newstest2014.tok.bpe.32000
.ende
'
,
vocab
=
'vocab.bpe.32000'
)
else
:
raise
KeyError
()
examples/pytorch/transformer/dataset/utils.py
View file @
1ea0bcf4
...
...
@@ -4,7 +4,7 @@ import os
from
dgl.data.utils
import
*
_urls
=
{
'wmt'
:
'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/wmt1
6_en_de.tar.gz
'
,
'wmt'
:
'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/wmt1
4bpe_de_en.zip
'
,
'scripts'
:
'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/transformer_scripts.zip'
,
}
...
...
@@ -23,7 +23,7 @@ def prepare_dataset(dataset_name):
if
dataset_name
==
'multi30k'
:
os
.
system
(
'bash scripts/prepare-multi30k.sh'
)
elif
dataset_name
==
'wmt14'
:
download
(
_urls
[
'wmt'
],
path
=
'wmt1
6_en_de.tar.gz
'
)
download
(
_urls
[
'wmt'
],
path
=
'wmt1
4.zip
'
)
os
.
system
(
'bash scripts/prepare-wmt14.sh'
)
elif
dataset_name
==
'copy'
or
dataset_name
==
'tiny_copy'
:
train_size
=
9000
...
...
examples/pytorch/transformer/modules/models.py
View file @
1ea0bcf4
...
...
@@ -128,16 +128,15 @@ class Transformer(nn.Module):
"""
return
self
.
generator
(
g
.
ndata
[
'x'
][
nids
[
'dec'
]])
def
infer
(
self
,
graph
,
max_len
,
eos_id
,
k
):
def
infer
(
self
,
graph
,
max_len
,
eos_id
,
k
,
alpha
=
1.0
):
'''
This function implements Beam Search in DGL, which is required in inference phase.
Length normalization is given by (5 + len) ^ alpha / 6 ^ alpha. Please refer to https://arxiv.org/pdf/1609.08144.pdf.
args:
graph: a `Graph` object defined in `dgl.contrib.transformer.graph`.
max_len: the maximum length of decoding.
eos_id: the index of end-of-sequence symbol.
k: beam size
return:
ret: a list of index array correspond to the input sequence specified by `graph``.
'''
...
...
@@ -187,30 +186,35 @@ class Transformer(nn.Module):
out
=
self
.
generator
(
g
.
ndata
[
'x'
][
frontiers
])
batch_size
=
frontiers
.
shape
[
0
]
//
k
vocab_size
=
out
.
shape
[
-
1
]
# Mask output for complete sequence
one_hot
=
th
.
zeros
(
vocab_size
).
fill_
(
-
1e9
).
to
(
device
)
one_hot
[
eos_id
]
=
0
mask
=
g
.
ndata
[
'mask'
][
frontiers
].
unsqueeze
(
-
1
).
float
()
out
=
out
*
(
1
-
mask
)
+
one_hot
.
unsqueeze
(
0
)
*
mask
if
log_prob
is
None
:
log_prob
,
pos
=
out
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
topk
(
k
,
dim
=-
1
)
eos
=
th
.
zeros
(
batch_size
).
byte
()
eos
=
th
.
zeros
(
batch_size
,
k
).
byte
()
else
:
log_prob
,
pos
=
(
out
.
view
(
batch_size
,
k
,
-
1
)
+
log_prob
.
unsqueeze
(
-
1
)).
view
(
batch_size
,
-
1
).
topk
(
k
,
dim
=-
1
)
norm_old
=
eos
.
float
().
to
(
device
)
+
(
1
-
eos
.
float
().
to
(
device
))
*
np
.
power
((
4.
+
step
)
/
6
,
alpha
)
norm_new
=
eos
.
float
().
to
(
device
)
+
(
1
-
eos
.
float
().
to
(
device
))
*
np
.
power
((
5.
+
step
)
/
6
,
alpha
)
log_prob
,
pos
=
((
out
.
view
(
batch_size
,
k
,
-
1
)
+
(
log_prob
*
norm_old
).
unsqueeze
(
-
1
))
/
norm_new
.
unsqueeze
(
-
1
)).
view
(
batch_size
,
-
1
).
topk
(
k
,
dim
=-
1
)
_y
=
y
.
view
(
batch_size
*
k
,
-
1
)
y
=
th
.
zeros_like
(
_y
)
_eos
=
eos
.
clone
()
for
i
in
range
(
batch_size
):
if
not
eos
[
i
]:
for
j
in
range
(
k
):
_j
=
pos
[
i
,
j
].
item
()
//
vocab_size
token
=
pos
[
i
,
j
].
item
()
%
vocab_size
y
[
i
*
k
+
j
,
:]
=
_y
[
i
*
k
+
_j
,
:]
y
[
i
*
k
+
j
,
step
]
=
token
if
j
==
0
:
eos
[
i
]
=
eos
[
i
]
|
(
token
==
eos_id
)
else
:
y
[
i
*
k
:(
i
+
1
)
*
k
,
:]
=
_y
[
i
*
k
:(
i
+
1
)
*
k
,
:]
for
j
in
range
(
k
):
_j
=
pos
[
i
,
j
].
item
()
//
vocab_size
token
=
pos
[
i
,
j
].
item
()
%
vocab_size
y
[
i
*
k
+
j
,
:]
=
_y
[
i
*
k
+
_j
,
:]
y
[
i
*
k
+
j
,
step
]
=
token
eos
[
i
,
j
]
=
_eos
[
i
,
_j
]
|
(
token
==
eos_id
)
if
eos
.
all
():
break
else
:
g
.
ndata
[
'mask'
][
nids
[
'dec'
]]
=
eos
.
unsqueeze
(
-
1
).
repeat
(
1
,
k
*
max_len
).
view
(
-
1
).
to
(
device
)
g
.
ndata
[
'mask'
][
nids
[
'dec'
]]
=
eos
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
max_len
).
view
(
-
1
).
to
(
device
)
return
y
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
tolist
()
def
_register_att_map
(
self
,
g
,
enc_ids
,
dec_ids
):
...
...
examples/pytorch/transformer/translation_test.py
View file @
1ea0bcf4
...
...
@@ -38,7 +38,7 @@ if __name__ == '__main__':
test_iter
=
dataset
(
graph_pool
,
mode
=
'test'
,
batch_size
=
args
.
batch
,
devices
=
[
device
],
k
=
k
)
for
i
,
g
in
enumerate
(
test_iter
):
with
th
.
no_grad
():
output
=
model
.
infer
(
g
,
dataset
.
MAX_LENGTH
,
dataset
.
eos_id
,
k
)
output
=
model
.
infer
(
g
,
dataset
.
MAX_LENGTH
,
dataset
.
eos_id
,
k
,
alpha
=
0.6
)
for
line
in
dataset
.
get_sequence
(
output
):
if
args
.
print
:
print
(
line
)
...
...
examples/pytorch/transformer/translation_train.py
View file @
1ea0bcf4
...
...
@@ -83,7 +83,7 @@ def main(dev_id, args):
param
.
data
/=
ndev
# Optimizer
model_opt
=
NoamOpt
(
dim_model
,
1
,
4000
,
model_opt
=
NoamOpt
(
dim_model
,
0.
1
,
4000
,
T
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1e-3
,
betas
=
(
0.9
,
0.98
),
eps
=
1e-9
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment