Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
1ea0bcf4
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "acd6d2c42f0fa4fade262e8814279748a544b0ce"
Commit
1ea0bcf4
authored
Feb 18, 2019
by
Zihao Ye
Committed by
Minjie Wang
Feb 17, 2019
Browse files
[Model] fix link & beam search (#394)
parent
ae1806f6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
21 deletions
+25
-21
examples/pytorch/transformer/dataset/__init__.py
examples/pytorch/transformer/dataset/__init__.py
+1
-1
examples/pytorch/transformer/dataset/utils.py
examples/pytorch/transformer/dataset/utils.py
+2
-2
examples/pytorch/transformer/modules/models.py
examples/pytorch/transformer/modules/models.py
+20
-16
examples/pytorch/transformer/translation_test.py
examples/pytorch/transformer/translation_test.py
+1
-1
examples/pytorch/transformer/translation_train.py
examples/pytorch/transformer/translation_train.py
+1
-1
No files found.
examples/pytorch/transformer/dataset/__init__.py
View file @
1ea0bcf4
...
@@ -176,7 +176,7 @@ def get_dataset(dataset):
...
@@ -176,7 +176,7 @@ def get_dataset(dataset):
(
'en'
,
'de'
),
(
'en'
,
'de'
),
train
=
'train.tok.clean.bpe.32000'
,
train
=
'train.tok.clean.bpe.32000'
,
valid
=
'newstest2013.tok.bpe.32000'
,
valid
=
'newstest2013.tok.bpe.32000'
,
test
=
'newstest2014.tok.bpe.32000'
,
test
=
'newstest2014.tok.bpe.32000
.ende
'
,
vocab
=
'vocab.bpe.32000'
)
vocab
=
'vocab.bpe.32000'
)
else
:
else
:
raise
KeyError
()
raise
KeyError
()
examples/pytorch/transformer/dataset/utils.py
View file @
1ea0bcf4
...
@@ -4,7 +4,7 @@ import os
...
@@ -4,7 +4,7 @@ import os
from
dgl.data.utils
import
*
from
dgl.data.utils
import
*
_urls
=
{
_urls
=
{
'wmt'
:
'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/wmt1
6_en_de.tar.gz
'
,
'wmt'
:
'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/wmt1
4bpe_de_en.zip
'
,
'scripts'
:
'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/transformer_scripts.zip'
,
'scripts'
:
'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/transformer_scripts.zip'
,
}
}
...
@@ -23,7 +23,7 @@ def prepare_dataset(dataset_name):
...
@@ -23,7 +23,7 @@ def prepare_dataset(dataset_name):
if
dataset_name
==
'multi30k'
:
if
dataset_name
==
'multi30k'
:
os
.
system
(
'bash scripts/prepare-multi30k.sh'
)
os
.
system
(
'bash scripts/prepare-multi30k.sh'
)
elif
dataset_name
==
'wmt14'
:
elif
dataset_name
==
'wmt14'
:
download
(
_urls
[
'wmt'
],
path
=
'wmt1
6_en_de.tar.gz
'
)
download
(
_urls
[
'wmt'
],
path
=
'wmt1
4.zip
'
)
os
.
system
(
'bash scripts/prepare-wmt14.sh'
)
os
.
system
(
'bash scripts/prepare-wmt14.sh'
)
elif
dataset_name
==
'copy'
or
dataset_name
==
'tiny_copy'
:
elif
dataset_name
==
'copy'
or
dataset_name
==
'tiny_copy'
:
train_size
=
9000
train_size
=
9000
...
...
examples/pytorch/transformer/modules/models.py
View file @
1ea0bcf4
...
@@ -128,16 +128,15 @@ class Transformer(nn.Module):
...
@@ -128,16 +128,15 @@ class Transformer(nn.Module):
"""
"""
return
self
.
generator
(
g
.
ndata
[
'x'
][
nids
[
'dec'
]])
return
self
.
generator
(
g
.
ndata
[
'x'
][
nids
[
'dec'
]])
def
infer
(
self
,
graph
,
max_len
,
eos_id
,
k
,
alpha
=
1.0
):
def
infer
(
self
,
graph
,
max_len
,
eos_id
,
k
):
'''
'''
This function implements Beam Search in DGL, which is required in inference phase.
This function implements Beam Search in DGL, which is required in inference phase.
Length normalization is given by (5 + len) ^ alpha / 6 ^ alpha. Please refer to https://arxiv.org/pdf/1609.08144.pdf.
args:
args:
graph: a `Graph` object defined in `dgl.contrib.transformer.graph`.
graph: a `Graph` object defined in `dgl.contrib.transformer.graph`.
max_len: the maximum length of decoding.
max_len: the maximum length of decoding.
eos_id: the index of end-of-sequence symbol.
eos_id: the index of end-of-sequence symbol.
k: beam size
k: beam size
return:
return:
ret: a list of index array correspond to the input sequence specified by `graph``.
ret: a list of index array correspond to the input sequence specified by `graph``.
'''
'''
...
@@ -187,30 +186,35 @@ class Transformer(nn.Module):
...
@@ -187,30 +186,35 @@ class Transformer(nn.Module):
out
=
self
.
generator
(
g
.
ndata
[
'x'
][
frontiers
])
out
=
self
.
generator
(
g
.
ndata
[
'x'
][
frontiers
])
batch_size
=
frontiers
.
shape
[
0
]
//
k
batch_size
=
frontiers
.
shape
[
0
]
//
k
vocab_size
=
out
.
shape
[
-
1
]
vocab_size
=
out
.
shape
[
-
1
]
# Mask output for complete sequence
one_hot
=
th
.
zeros
(
vocab_size
).
fill_
(
-
1e9
).
to
(
device
)
one_hot
[
eos_id
]
=
0
mask
=
g
.
ndata
[
'mask'
][
frontiers
].
unsqueeze
(
-
1
).
float
()
out
=
out
*
(
1
-
mask
)
+
one_hot
.
unsqueeze
(
0
)
*
mask
if
log_prob
is
None
:
if
log_prob
is
None
:
log_prob
,
pos
=
out
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
topk
(
k
,
dim
=-
1
)
log_prob
,
pos
=
out
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
topk
(
k
,
dim
=-
1
)
eos
=
th
.
zeros
(
batch_size
).
byte
()
eos
=
th
.
zeros
(
batch_size
,
k
).
byte
()
else
:
else
:
log_prob
,
pos
=
(
out
.
view
(
batch_size
,
k
,
-
1
)
+
log_prob
.
unsqueeze
(
-
1
)).
view
(
batch_size
,
-
1
).
topk
(
k
,
dim
=-
1
)
norm_old
=
eos
.
float
().
to
(
device
)
+
(
1
-
eos
.
float
().
to
(
device
))
*
np
.
power
((
4.
+
step
)
/
6
,
alpha
)
norm_new
=
eos
.
float
().
to
(
device
)
+
(
1
-
eos
.
float
().
to
(
device
))
*
np
.
power
((
5.
+
step
)
/
6
,
alpha
)
log_prob
,
pos
=
((
out
.
view
(
batch_size
,
k
,
-
1
)
+
(
log_prob
*
norm_old
).
unsqueeze
(
-
1
))
/
norm_new
.
unsqueeze
(
-
1
)).
view
(
batch_size
,
-
1
).
topk
(
k
,
dim
=-
1
)
_y
=
y
.
view
(
batch_size
*
k
,
-
1
)
_y
=
y
.
view
(
batch_size
*
k
,
-
1
)
y
=
th
.
zeros_like
(
_y
)
y
=
th
.
zeros_like
(
_y
)
_eos
=
eos
.
clone
()
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
if
not
eos
[
i
]:
for
j
in
range
(
k
):
for
j
in
range
(
k
):
_j
=
pos
[
i
,
j
].
item
()
//
vocab_size
_j
=
pos
[
i
,
j
].
item
()
//
vocab_size
token
=
pos
[
i
,
j
].
item
()
%
vocab_size
token
=
pos
[
i
,
j
].
item
()
%
vocab_size
y
[
i
*
k
+
j
,
:]
=
_y
[
i
*
k
+
_j
,
:]
y
[
i
*
k
+
j
,
:]
=
_y
[
i
*
k
+
_j
,
:]
y
[
i
*
k
+
j
,
step
]
=
token
y
[
i
*
k
+
j
,
step
]
=
token
eos
[
i
,
j
]
=
_eos
[
i
,
_j
]
|
(
token
==
eos_id
)
if
j
==
0
:
eos
[
i
]
=
eos
[
i
]
|
(
token
==
eos_id
)
else
:
y
[
i
*
k
:(
i
+
1
)
*
k
,
:]
=
_y
[
i
*
k
:(
i
+
1
)
*
k
,
:]
if
eos
.
all
():
if
eos
.
all
():
break
break
else
:
else
:
g
.
ndata
[
'mask'
][
nids
[
'dec'
]]
=
eos
.
unsqueeze
(
-
1
).
repeat
(
1
,
k
*
max_len
).
view
(
-
1
).
to
(
device
)
g
.
ndata
[
'mask'
][
nids
[
'dec'
]]
=
eos
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
max_len
).
view
(
-
1
).
to
(
device
)
return
y
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
tolist
()
return
y
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
tolist
()
def
_register_att_map
(
self
,
g
,
enc_ids
,
dec_ids
):
def
_register_att_map
(
self
,
g
,
enc_ids
,
dec_ids
):
...
...
examples/pytorch/transformer/translation_test.py
View file @
1ea0bcf4
...
@@ -38,7 +38,7 @@ if __name__ == '__main__':
...
@@ -38,7 +38,7 @@ if __name__ == '__main__':
test_iter
=
dataset
(
graph_pool
,
mode
=
'test'
,
batch_size
=
args
.
batch
,
devices
=
[
device
],
k
=
k
)
test_iter
=
dataset
(
graph_pool
,
mode
=
'test'
,
batch_size
=
args
.
batch
,
devices
=
[
device
],
k
=
k
)
for
i
,
g
in
enumerate
(
test_iter
):
for
i
,
g
in
enumerate
(
test_iter
):
with
th
.
no_grad
():
with
th
.
no_grad
():
output
=
model
.
infer
(
g
,
dataset
.
MAX_LENGTH
,
dataset
.
eos_id
,
k
)
output
=
model
.
infer
(
g
,
dataset
.
MAX_LENGTH
,
dataset
.
eos_id
,
k
,
alpha
=
0.6
)
for
line
in
dataset
.
get_sequence
(
output
):
for
line
in
dataset
.
get_sequence
(
output
):
if
args
.
print
:
if
args
.
print
:
print
(
line
)
print
(
line
)
...
...
examples/pytorch/transformer/translation_train.py
View file @
1ea0bcf4
...
@@ -83,7 +83,7 @@ def main(dev_id, args):
...
@@ -83,7 +83,7 @@ def main(dev_id, args):
param
.
data
/=
ndev
param
.
data
/=
ndev
# Optimizer
# Optimizer
model_opt
=
NoamOpt
(
dim_model
,
1
,
4000
,
model_opt
=
NoamOpt
(
dim_model
,
0.
1
,
4000
,
T
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1e-3
,
T
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1e-3
,
betas
=
(
0.9
,
0.98
),
eps
=
1e-9
))
betas
=
(
0.9
,
0.98
),
eps
=
1e-9
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment