Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a315e568
Commit
a315e568
authored
Jan 25, 2017
by
Lukasz Kaiser
Browse files
Update to the Neural GPU.
parent
d66941ac
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
2615 additions
and
607 deletions
+2615
-607
neural_gpu/README.md
neural_gpu/README.md
+21
-12
neural_gpu/data_utils.py
neural_gpu/data_utils.py
+193
-52
neural_gpu/neural_gpu.py
neural_gpu/neural_gpu.py
+637
-217
neural_gpu/neural_gpu_trainer.py
neural_gpu/neural_gpu_trainer.py
+889
-326
neural_gpu/program_utils.py
neural_gpu/program_utils.py
+440
-0
neural_gpu/wmt_utils.py
neural_gpu/wmt_utils.py
+435
-0
No files found.
neural_gpu/README.md
View file @
a315e568
...
@@ -4,7 +4,6 @@ in [[http://arxiv.org/abs/1511.08228]].
...
@@ -4,7 +4,6 @@ in [[http://arxiv.org/abs/1511.08228]].
Requirements:
Requirements:
*
TensorFlow (see tensorflow.org for how to install)
*
TensorFlow (see tensorflow.org for how to install)
*
Matplotlib for Python (sudo apt-get install python-matplotlib)
The model can be trained on the following algorithmic tasks:
The model can be trained on the following algorithmic tasks:
...
@@ -26,17 +25,27 @@ The model can be trained on the following algorithmic tasks:
...
@@ -26,17 +25,27 @@ The model can be trained on the following algorithmic tasks:
*
`qadd`
- Long quaternary addition
*
`qadd`
- Long quaternary addition
*
`search`
- Search for symbol key in dictionary
*
`search`
- Search for symbol key in dictionary
The value range for symbols are defined by the
`niclass`
and
`noclass`
flags.
It can also be trained on the WMT English-French translation task:
In particular, the values are in the range
`min(--niclass, noclass) - 1`
.
So if you set
`--niclass=33`
and
`--noclass=33`
(the default) then
`--task=rev`
will be reversing lists of 32 symbols, and
`--task=id`
will be identity on a
list of up to 32 symbols.
*
`wmt`
- WMT English-French translation (data will be downloaded)
To train the model on the reverse task run:
The value range for symbols are defined by the
`vocab_size`
flag.
In particular, the values are in the range
`vocab_size - 1`
.
So if you set
`--vocab_size=16`
(the default) then
`--problem=rev`
will be reversing lists of 15 symbols, and
`--problem=id`
will be identity
on a list of up to 15 symbols.
To train the model on the binary multiplication task run:
```
python neural_gpu_trainer.py --problem=bmul
```
This trains the Extended Neural GPU, to train the original model run:
```
```
python neural_gpu_trainer.py --
task=rev
python neural_gpu_trainer.py --
problem=bmul --beam_size=0
```
```
While training, interim / checkpoint model parameters will be
While training, interim / checkpoint model parameters will be
...
@@ -47,16 +56,16 @@ with, hit `Ctrl-C` to stop the training process. The latest
...
@@ -47,16 +56,16 @@ with, hit `Ctrl-C` to stop the training process. The latest
model parameters will be in
`/tmp/neural_gpu/neural_gpu.ckpt-<step>`
model parameters will be in
`/tmp/neural_gpu/neural_gpu.ckpt-<step>`
and used on any subsequent run.
and used on any subsequent run.
To te
st
a trained model on how well it decodes run:
To
evalua
te a trained model on how well it decodes run:
```
```
python neural_gpu_trainer.py --
task=rev
--mode=1
python neural_gpu_trainer.py --
problem=bmul
--mode=1
```
```
To
produce an animation of the result
run:
To
interact with a model (experimental, see code)
run:
```
```
python neural_gpu_trainer.py --
task=rev --mode=1 --animate=True
python neural_gpu_trainer.py --
problem=bmul --mode=2
```
```
Maintained by Lukasz Kaiser (lukaszkaiser)
Maintained by Lukasz Kaiser (lukaszkaiser)
neural_gpu/data_utils.py
View file @
a315e568
...
@@ -12,9 +12,10 @@
...
@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""
Convolutional Gated Recurrent Networks for Algorithm Learning
."""
"""
Neural GPU -- data generation and batching utilities
."""
import
math
import
math
import
os
import
random
import
random
import
sys
import
sys
import
time
import
time
...
@@ -22,22 +23,28 @@ import time
...
@@ -22,22 +23,28 @@ import time
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.platform
import
gf
il
e
import
program_ut
il
s
FLAGS
=
tf
.
app
.
flags
.
FLAGS
FLAGS
=
tf
.
app
.
flags
.
FLAGS
bins
=
[
8
,
12
,
16
,
20
,
24
,
28
,
32
,
36
,
40
,
48
,
64
,
128
]
bins
=
[
2
+
bin_idx_i
for
bin_idx_i
in
xrange
(
256
)
]
all_tasks
=
[
"sort"
,
"kvsort"
,
"id"
,
"rev"
,
"rev2"
,
"incr"
,
"add"
,
"left"
,
all_tasks
=
[
"sort"
,
"kvsort"
,
"id"
,
"rev"
,
"rev2"
,
"incr"
,
"add"
,
"left"
,
"right"
,
"left-shift"
,
"right-shift"
,
"bmul"
,
"mul"
,
"dup"
,
"right"
,
"left-shift"
,
"right-shift"
,
"bmul"
,
"mul"
,
"dup"
,
"badd"
,
"qadd"
,
"search"
]
"badd"
,
"qadd"
,
"search"
,
"progeval"
,
"progsynth"
]
forward_max
=
128
log_filename
=
""
log_filename
=
""
vocab
,
rev_vocab
=
None
,
None
def
pad
(
l
):
def
pad
(
l
):
for
b
in
bins
:
for
b
in
bins
:
if
b
>=
l
:
return
b
if
b
>=
l
:
return
b
return
forward_max
return
bins
[
-
1
]
def
bin_for
(
l
):
for
i
,
b
in
enumerate
(
bins
):
if
b
>=
l
:
return
i
return
len
(
bins
)
-
1
train_set
=
{}
train_set
=
{}
...
@@ -50,6 +57,35 @@ for some_task in all_tasks:
...
@@ -50,6 +57,35 @@ for some_task in all_tasks:
test_set
[
some_task
].
append
([])
test_set
[
some_task
].
append
([])
def
read_tmp_file
(
name
):
"""Read from a file with the given name in our log directory or above."""
dirname
=
os
.
path
.
dirname
(
log_filename
)
fname
=
os
.
path
.
join
(
dirname
,
name
+
".txt"
)
if
not
tf
.
gfile
.
Exists
(
fname
):
print_out
(
"== not found file: "
+
fname
)
fname
=
os
.
path
.
join
(
dirname
,
"../"
+
name
+
".txt"
)
if
not
tf
.
gfile
.
Exists
(
fname
):
print_out
(
"== not found file: "
+
fname
)
fname
=
os
.
path
.
join
(
dirname
,
"../../"
+
name
+
".txt"
)
if
not
tf
.
gfile
.
Exists
(
fname
):
print_out
(
"== not found file: "
+
fname
)
return
None
print_out
(
"== found file: "
+
fname
)
res
=
[]
with
tf
.
gfile
.
GFile
(
fname
,
mode
=
"r"
)
as
f
:
for
line
in
f
:
res
.
append
(
line
.
strip
())
return
res
def
write_tmp_file
(
name
,
lines
):
dirname
=
os
.
path
.
dirname
(
log_filename
)
fname
=
os
.
path
.
join
(
dirname
,
name
+
".txt"
)
with
tf
.
gfile
.
GFile
(
fname
,
mode
=
"w"
)
as
f
:
for
line
in
lines
:
f
.
write
(
line
+
"
\n
"
)
def
add
(
n1
,
n2
,
base
=
10
):
def
add
(
n1
,
n2
,
base
=
10
):
"""Add two numbers represented as lower-endian digit lists."""
"""Add two numbers represented as lower-endian digit lists."""
k
=
max
(
len
(
n1
),
len
(
n2
))
+
1
k
=
max
(
len
(
n1
),
len
(
n2
))
+
1
...
@@ -130,6 +166,30 @@ def init_data(task, length, nbr_cases, nclass):
...
@@ -130,6 +166,30 @@ def init_data(task, length, nbr_cases, nclass):
sorted_kv
=
[(
k
,
vals
[
i
])
for
(
k
,
i
)
in
sorted
(
keys
)]
sorted_kv
=
[(
k
,
vals
[
i
])
for
(
k
,
i
)
in
sorted
(
keys
)]
return
[
x
for
p
in
kv
for
x
in
p
],
[
x
for
p
in
sorted_kv
for
x
in
p
]
return
[
x
for
p
in
kv
for
x
in
p
],
[
x
for
p
in
sorted_kv
for
x
in
p
]
def
prog_io_pair
(
prog
,
max_len
,
counter
=
0
):
try
:
ilen
=
np
.
random
.
randint
(
max_len
-
3
)
+
1
bound
=
max
(
15
-
(
counter
/
20
),
1
)
inp
=
[
random
.
choice
(
range
(
-
bound
,
bound
))
for
_
in
range
(
ilen
)]
inp_toks
=
[
program_utils
.
prog_rev_vocab
[
t
]
for
t
in
program_utils
.
tokenize
(
str
(
inp
))
if
t
!=
","
]
out
=
program_utils
.
evaluate
(
prog
,
{
"a"
:
inp
})
out_toks
=
[
program_utils
.
prog_rev_vocab
[
t
]
for
t
in
program_utils
.
tokenize
(
str
(
out
))
if
t
!=
","
]
if
counter
>
400
:
out_toks
=
[]
if
(
out_toks
and
out_toks
[
0
]
==
program_utils
.
prog_rev_vocab
[
"["
]
and
len
(
out_toks
)
!=
len
([
o
for
o
in
out
if
o
==
","
])
+
3
):
raise
ValueError
(
"generated list with too long ints"
)
if
(
out_toks
and
out_toks
[
0
]
!=
program_utils
.
prog_rev_vocab
[
"["
]
and
len
(
out_toks
)
>
1
):
raise
ValueError
(
"generated one int but tokenized it to many"
)
if
len
(
out_toks
)
>
max_len
:
raise
ValueError
(
"output too long"
)
return
(
inp_toks
,
out_toks
)
except
ValueError
:
return
prog_io_pair
(
prog
,
max_len
,
counter
+
1
)
def
spec
(
inp
):
def
spec
(
inp
):
"""Return the target given the input for some tasks."""
"""Return the target given the input for some tasks."""
if
task
==
"sort"
:
if
task
==
"sort"
:
...
@@ -164,43 +224,114 @@ def init_data(task, length, nbr_cases, nclass):
...
@@ -164,43 +224,114 @@ def init_data(task, length, nbr_cases, nclass):
l
=
length
l
=
length
cur_time
=
time
.
time
()
cur_time
=
time
.
time
()
total_time
=
0.0
total_time
=
0.0
for
case
in
xrange
(
nbr_cases
):
is_prog
=
task
in
[
"progeval"
,
"progsynth"
]
if
is_prog
:
inputs_per_prog
=
5
program_utils
.
make_vocab
()
progs
=
read_tmp_file
(
"programs_len%d"
%
(
l
/
10
))
if
not
progs
:
progs
=
program_utils
.
gen
(
l
/
10
,
1.2
*
nbr_cases
/
inputs_per_prog
)
write_tmp_file
(
"programs_len%d"
%
(
l
/
10
),
progs
)
prog_ios
=
read_tmp_file
(
"programs_len%d_io"
%
(
l
/
10
))
nbr_cases
=
min
(
nbr_cases
,
len
(
progs
)
*
inputs_per_prog
)
/
1.2
if
not
prog_ios
:
# Generate program io data.
prog_ios
=
[]
for
pidx
,
prog
in
enumerate
(
progs
):
if
pidx
%
500
==
0
:
print_out
(
"== generating io pairs for program %d"
%
pidx
)
if
pidx
*
inputs_per_prog
>
nbr_cases
*
1.2
:
break
ptoks
=
[
program_utils
.
prog_rev_vocab
[
t
]
for
t
in
program_utils
.
tokenize
(
prog
)]
ptoks
.
append
(
program_utils
.
prog_rev_vocab
[
"_EOS"
])
plen
=
len
(
ptoks
)
for
_
in
xrange
(
inputs_per_prog
):
if
task
==
"progeval"
:
inp
,
out
=
prog_io_pair
(
prog
,
plen
)
prog_ios
.
append
(
str
(
inp
)
+
"
\t
"
+
str
(
out
)
+
"
\t
"
+
prog
)
elif
task
==
"progsynth"
:
plen
=
max
(
len
(
ptoks
),
8
)
for
_
in
xrange
(
3
):
inp
,
out
=
prog_io_pair
(
prog
,
plen
/
2
)
prog_ios
.
append
(
str
(
inp
)
+
"
\t
"
+
str
(
out
)
+
"
\t
"
+
prog
)
write_tmp_file
(
"programs_len%d_io"
%
(
l
/
10
),
prog_ios
)
prog_ios_dict
=
{}
for
s
in
prog_ios
:
i
,
o
,
p
=
s
.
split
(
"
\t
"
)
i_clean
=
""
.
join
([
c
for
c
in
i
if
c
.
isdigit
()
or
c
==
" "
])
o_clean
=
""
.
join
([
c
for
c
in
o
if
c
.
isdigit
()
or
c
==
" "
])
inp
=
[
int
(
x
)
for
x
in
i_clean
.
split
()]
out
=
[
int
(
x
)
for
x
in
o_clean
.
split
()]
if
inp
and
out
:
if
p
in
prog_ios_dict
:
prog_ios_dict
[
p
].
append
([
inp
,
out
])
else
:
prog_ios_dict
[
p
]
=
[[
inp
,
out
]]
# Use prog_ios_dict to create data.
progs
=
[]
for
prog
in
prog_ios_dict
:
if
len
([
c
for
c
in
prog
if
c
==
";"
])
<=
(
l
/
10
):
progs
.
append
(
prog
)
nbr_cases
=
min
(
nbr_cases
,
len
(
progs
)
*
inputs_per_prog
)
/
1.2
print_out
(
"== %d training cases on %d progs"
%
(
nbr_cases
,
len
(
progs
)))
for
pidx
,
prog
in
enumerate
(
progs
):
if
pidx
*
inputs_per_prog
>
nbr_cases
*
1.2
:
break
ptoks
=
[
program_utils
.
prog_rev_vocab
[
t
]
for
t
in
program_utils
.
tokenize
(
prog
)]
ptoks
.
append
(
program_utils
.
prog_rev_vocab
[
"_EOS"
])
plen
=
len
(
ptoks
)
dset
=
train_set
if
pidx
<
nbr_cases
/
inputs_per_prog
else
test_set
for
_
in
xrange
(
inputs_per_prog
):
if
task
==
"progeval"
:
inp
,
out
=
prog_ios_dict
[
prog
].
pop
()
dset
[
task
][
bin_for
(
plen
)].
append
([[
ptoks
,
inp
,
[],
[]],
[
out
]])
elif
task
==
"progsynth"
:
plen
,
ilist
=
max
(
len
(
ptoks
),
8
),
[[]]
for
_
in
xrange
(
3
):
inp
,
out
=
prog_ios_dict
[
prog
].
pop
()
ilist
.
append
(
inp
+
out
)
dset
[
task
][
bin_for
(
plen
)].
append
([
ilist
,
[
ptoks
]])
for
case
in
xrange
(
0
if
is_prog
else
nbr_cases
):
total_time
+=
time
.
time
()
-
cur_time
total_time
+=
time
.
time
()
-
cur_time
cur_time
=
time
.
time
()
cur_time
=
time
.
time
()
if
l
>
10000
and
case
%
100
==
1
:
if
l
>
10000
and
case
%
100
==
1
:
print_out
(
" avg gen time %.4f s"
%
(
total_time
/
float
(
case
)))
print_out
(
" avg gen time %.4f s"
%
(
total_time
/
float
(
case
)))
if
task
in
[
"add"
,
"badd"
,
"qadd"
,
"bmul"
,
"mul"
]:
if
task
in
[
"add"
,
"badd"
,
"qadd"
,
"bmul"
,
"mul"
]:
i
,
t
=
rand_pair
(
l
,
task
)
i
,
t
=
rand_pair
(
l
,
task
)
train_set
[
task
][
len
(
i
)].
append
([
i
,
t
])
train_set
[
task
][
bin_for
(
len
(
i
)
)
].
append
([
[[],
i
,
[],
[]],
[
t
]
])
i
,
t
=
rand_pair
(
l
,
task
)
i
,
t
=
rand_pair
(
l
,
task
)
test_set
[
task
][
len
(
i
)].
append
([
i
,
t
])
test_set
[
task
][
bin_for
(
len
(
i
)
)
].
append
([
[[],
i
,
[],
[]],
[
t
]
])
elif
task
==
"dup"
:
elif
task
==
"dup"
:
i
,
t
=
rand_dup_pair
(
l
)
i
,
t
=
rand_dup_pair
(
l
)
train_set
[
task
][
len
(
i
)].
append
([
i
,
t
])
train_set
[
task
][
bin_for
(
len
(
i
)
)
].
append
([
[
i
]
,
[
t
]
])
i
,
t
=
rand_dup_pair
(
l
)
i
,
t
=
rand_dup_pair
(
l
)
test_set
[
task
][
len
(
i
)].
append
([
i
,
t
])
test_set
[
task
][
bin_for
(
len
(
i
)
)
].
append
([
[
i
]
,
[
t
]
])
elif
task
==
"rev2"
:
elif
task
==
"rev2"
:
i
,
t
=
rand_rev2_pair
(
l
)
i
,
t
=
rand_rev2_pair
(
l
)
train_set
[
task
][
len
(
i
)].
append
([
i
,
t
])
train_set
[
task
][
bin_for
(
len
(
i
)
)
].
append
([
[
i
]
,
[
t
]
])
i
,
t
=
rand_rev2_pair
(
l
)
i
,
t
=
rand_rev2_pair
(
l
)
test_set
[
task
][
len
(
i
)].
append
([
i
,
t
])
test_set
[
task
][
bin_for
(
len
(
i
)
)
].
append
([
[
i
]
,
[
t
]
])
elif
task
==
"search"
:
elif
task
==
"search"
:
i
,
t
=
rand_search_pair
(
l
)
i
,
t
=
rand_search_pair
(
l
)
train_set
[
task
][
len
(
i
)].
append
([
i
,
t
])
train_set
[
task
][
bin_for
(
len
(
i
)
)
].
append
([
[
i
]
,
[
t
]
])
i
,
t
=
rand_search_pair
(
l
)
i
,
t
=
rand_search_pair
(
l
)
test_set
[
task
][
len
(
i
)].
append
([
i
,
t
])
test_set
[
task
][
bin_for
(
len
(
i
)
)
].
append
([
[
i
]
,
[
t
]
])
elif
task
==
"kvsort"
:
elif
task
==
"kvsort"
:
i
,
t
=
rand_kvsort_pair
(
l
)
i
,
t
=
rand_kvsort_pair
(
l
)
train_set
[
task
][
len
(
i
)].
append
([
i
,
t
])
train_set
[
task
][
bin_for
(
len
(
i
)
)
].
append
([
[
i
]
,
[
t
]
])
i
,
t
=
rand_kvsort_pair
(
l
)
i
,
t
=
rand_kvsort_pair
(
l
)
test_set
[
task
][
len
(
i
)].
append
([
i
,
t
])
test_set
[
task
][
bin_for
(
len
(
i
)
)
].
append
([
[
i
]
,
[
t
]
])
el
se
:
el
if
task
not
in
[
"progeval"
,
"progsynth"
]
:
inp
=
[
np
.
random
.
randint
(
nclass
-
1
)
+
1
for
i
in
xrange
(
l
)]
inp
=
[
np
.
random
.
randint
(
nclass
-
1
)
+
1
for
i
in
xrange
(
l
)]
target
=
spec
(
inp
)
target
=
spec
(
inp
)
train_set
[
task
][
l
].
append
([
inp
,
target
])
train_set
[
task
][
bin_for
(
l
)
].
append
([
[
inp
]
,
[
target
]
]
)
inp
=
[
np
.
random
.
randint
(
nclass
-
1
)
+
1
for
i
in
xrange
(
l
)]
inp
=
[
np
.
random
.
randint
(
nclass
-
1
)
+
1
for
i
in
xrange
(
l
)]
target
=
spec
(
inp
)
target
=
spec
(
inp
)
test_set
[
task
][
l
].
append
([
inp
,
target
])
test_set
[
task
][
bin_for
(
l
)
].
append
([
[
inp
]
,
[
target
]
]
)
def
to_symbol
(
i
):
def
to_symbol
(
i
):
...
@@ -218,37 +349,31 @@ def to_id(s):
...
@@ -218,37 +349,31 @@ def to_id(s):
return
int
(
s
)
+
1
return
int
(
s
)
+
1
def
get_batch
(
max_length
,
batch_size
,
d
o_train
,
task
,
offset
=
None
,
preset
=
None
):
def
get_batch
(
bin_id
,
batch_size
,
d
ata_set
,
height
,
offset
=
None
,
preset
=
None
):
"""Get a batch of data, training or testing."""
"""Get a batch of data, training or testing."""
inputs
=
[]
inputs
,
targets
=
[],
[]
targets
=
[]
pad_length
=
bins
[
bin_id
]
length
=
max_length
if
preset
is
None
:
cur_set
=
test_set
[
task
]
if
do_train
:
cur_set
=
train_set
[
task
]
while
not
cur_set
[
length
]:
length
-=
1
pad_length
=
pad
(
length
)
for
b
in
xrange
(
batch_size
):
for
b
in
xrange
(
batch_size
):
if
preset
is
None
:
if
preset
is
None
:
elem
=
random
.
choice
(
cur
_set
[
length
])
elem
=
random
.
choice
(
data
_set
[
bin_id
])
if
offset
is
not
None
and
offset
+
b
<
len
(
cur
_set
[
length
]):
if
offset
is
not
None
and
offset
+
b
<
len
(
data
_set
[
bin_id
]):
elem
=
cur
_set
[
length
][
offset
+
b
]
elem
=
data
_set
[
bin_id
][
offset
+
b
]
else
:
else
:
elem
=
preset
elem
=
preset
inp
,
target
=
elem
[
0
],
elem
[
1
]
inpt
,
targett
,
inpl
,
targetl
=
elem
[
0
],
elem
[
1
],
[],
[]
assert
len
(
inp
)
==
length
for
inp
in
inpt
:
inputs
.
append
(
inp
+
[
0
for
l
in
xrange
(
pad_length
-
len
(
inp
))])
inpl
.
append
(
inp
+
[
0
for
_
in
xrange
(
pad_length
-
len
(
inp
))])
targets
.
append
(
target
+
[
0
for
l
in
xrange
(
pad_length
-
len
(
target
))])
if
len
(
inpl
)
==
1
:
res_input
=
[]
for
_
in
xrange
(
height
-
1
):
res_target
=
[]
inpl
.
append
([
0
for
_
in
xrange
(
pad_length
)])
for
l
in
xrange
(
pad_length
):
for
target
in
targett
:
new_input
=
np
.
array
([
inputs
[
b
][
l
]
for
b
in
xrange
(
batch_size
)],
targetl
.
append
(
target
+
[
0
for
_
in
xrange
(
pad_length
-
len
(
target
))])
dtype
=
np
.
int32
)
inputs
.
append
(
inpl
)
new_target
=
np
.
array
([
targets
[
b
][
l
]
for
b
in
xrange
(
batch_size
)],
targets
.
append
(
targetl
)
dtype
=
np
.
int32
)
res_input
=
np
.
array
(
inputs
,
dtype
=
np
.
int32
)
res_input
.
append
(
new_input
)
res_target
=
np
.
array
(
targets
,
dtype
=
np
.
int32
)
res_target
.
append
(
new_target
)
assert
list
(
res_input
.
shape
)
==
[
batch_size
,
height
,
pad_length
]
assert
list
(
res_target
.
shape
)
==
[
batch_size
,
1
,
pad_length
]
return
res_input
,
res_target
return
res_input
,
res_target
...
@@ -256,11 +381,11 @@ def print_out(s, newline=True):
...
@@ -256,11 +381,11 @@ def print_out(s, newline=True):
"""Print a message out and log it to file."""
"""Print a message out and log it to file."""
if
log_filename
:
if
log_filename
:
try
:
try
:
with
gfile
.
GFile
(
log_filename
,
mode
=
"a"
)
as
f
:
with
tf
.
gfile
.
GFile
(
log_filename
,
mode
=
"a"
)
as
f
:
f
.
write
(
s
+
(
"
\n
"
if
newline
else
""
))
f
.
write
(
s
+
(
"
\n
"
if
newline
else
""
))
# pylint: disable=bare-except
# pylint: disable=bare-except
except
:
except
:
sys
.
std
out
.
write
(
"Error appending to %s
\n
"
%
log_filename
)
sys
.
std
err
.
write
(
"Error appending to %s
\n
"
%
log_filename
)
sys
.
stdout
.
write
(
s
+
(
"
\n
"
if
newline
else
""
))
sys
.
stdout
.
write
(
s
+
(
"
\n
"
if
newline
else
""
))
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
...
@@ -269,21 +394,36 @@ def decode(output):
...
@@ -269,21 +394,36 @@ def decode(output):
return
[
np
.
argmax
(
o
,
axis
=
1
)
for
o
in
output
]
return
[
np
.
argmax
(
o
,
axis
=
1
)
for
o
in
output
]
def
accuracy
(
inpt
,
output
,
target
,
batch_size
,
nprint
):
def
accuracy
(
inpt_t
,
output
,
target_t
,
batch_size
,
nprint
,
beam_out
=
None
,
beam_scores
=
None
):
"""Calculate output accuracy given target."""
"""Calculate output accuracy given target."""
assert
nprint
<
batch_size
+
1
assert
nprint
<
batch_size
+
1
inpt
=
[]
for
h
in
xrange
(
inpt_t
.
shape
[
1
]):
inpt
.
extend
([
inpt_t
[:,
h
,
l
]
for
l
in
xrange
(
inpt_t
.
shape
[
2
])])
target
=
[
target_t
[:,
0
,
l
]
for
l
in
xrange
(
target_t
.
shape
[
2
])]
def
tok
(
i
):
if
rev_vocab
and
i
<
len
(
rev_vocab
):
return
rev_vocab
[
i
]
return
str
(
i
-
1
)
def
task_print
(
inp
,
output
,
target
):
def
task_print
(
inp
,
output
,
target
):
stop_bound
=
0
stop_bound
=
0
print_len
=
0
print_len
=
0
while
print_len
<
len
(
target
)
and
target
[
print_len
]
>
stop_bound
:
while
print_len
<
len
(
target
)
and
target
[
print_len
]
>
stop_bound
:
print_len
+=
1
print_len
+=
1
print_out
(
" i: "
+
" "
.
join
([
str
(
i
-
1
)
for
i
in
inp
if
i
>
0
]))
print_out
(
" i: "
+
" "
.
join
([
tok
(
i
)
for
i
in
inp
if
i
>
0
]))
print_out
(
" o: "
+
print_out
(
" o: "
+
" "
.
join
([
str
(
output
[
l
]
-
1
)
for
l
in
xrange
(
print_len
)]))
" "
.
join
([
tok
(
output
[
l
])
for
l
in
xrange
(
print_len
)]))
print_out
(
" t: "
+
print_out
(
" t: "
+
" "
.
join
([
str
(
target
[
l
]
-
1
)
for
l
in
xrange
(
print_len
)]))
" "
.
join
([
tok
(
target
[
l
])
for
l
in
xrange
(
print_len
)]))
decoded_target
=
target
decoded_target
=
target
decoded_output
=
decode
(
output
)
decoded_output
=
decode
(
output
)
# Use beam output if given and score is high enough.
if
beam_out
is
not
None
:
for
b
in
xrange
(
batch_size
):
if
beam_scores
[
b
]
>=
10.0
:
for
l
in
xrange
(
min
(
len
(
decoded_output
),
beam_out
.
shape
[
2
])):
decoded_output
[
l
][
b
]
=
int
(
beam_out
[
b
,
0
,
l
])
total
=
0
total
=
0
errors
=
0
errors
=
0
seq
=
[
0
for
b
in
xrange
(
batch_size
)]
seq
=
[
0
for
b
in
xrange
(
batch_size
)]
...
@@ -311,6 +451,7 @@ def accuracy(inpt, output, target, batch_size, nprint):
...
@@ -311,6 +451,7 @@ def accuracy(inpt, output, target, batch_size, nprint):
def
safe_exp
(
x
):
def
safe_exp
(
x
):
perp
=
10000
perp
=
10000
x
=
float
(
x
)
if
x
<
100
:
perp
=
math
.
exp
(
x
)
if
x
<
100
:
perp
=
math
.
exp
(
x
)
if
perp
>
10000
:
return
10000
if
perp
>
10000
:
return
10000
return
perp
return
perp
neural_gpu/neural_gpu.py
View file @
a315e568
This diff is collapsed.
Click to expand it.
neural_gpu/neural_gpu_trainer.py
View file @
a315e568
This diff is collapsed.
Click to expand it.
neural_gpu/program_utils.py
0 → 100644
View file @
a315e568
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for generating program synthesis and evaluation data."""
import
contextlib
import
sys
import
StringIO
import
random
import
os
class
ListType
(
object
):
def
__init__
(
self
,
arg
):
self
.
arg
=
arg
def
__str__
(
self
):
return
"["
+
str
(
self
.
arg
)
+
"]"
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
ListType
):
return
False
return
self
.
arg
==
other
.
arg
def
__hash__
(
self
):
return
hash
(
self
.
arg
)
class
VarType
(
object
):
def
__init__
(
self
,
arg
):
self
.
arg
=
arg
def
__str__
(
self
):
return
str
(
self
.
arg
)
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
VarType
):
return
False
return
self
.
arg
==
other
.
arg
def
__hash__
(
self
):
return
hash
(
self
.
arg
)
class
FunctionType
(
object
):
def
__init__
(
self
,
args
):
self
.
args
=
args
def
__str__
(
self
):
return
str
(
self
.
args
[
0
])
+
" -> "
+
str
(
self
.
args
[
1
])
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
FunctionType
):
return
False
return
self
.
args
==
other
.
args
def
__hash__
(
self
):
return
hash
(
tuple
(
self
.
args
))
class
Function
(
object
):
def
__init__
(
self
,
name
,
arg_types
,
output_type
,
fn_arg_types
=
None
):
self
.
name
=
name
self
.
arg_types
=
arg_types
self
.
fn_arg_types
=
fn_arg_types
or
[]
self
.
output_type
=
output_type
Null
=
100
## Functions
f_head
=
Function
(
"c_head"
,
[
ListType
(
"Int"
)],
"Int"
)
def
c_head
(
xs
):
return
xs
[
0
]
if
len
(
xs
)
>
0
else
Null
f_last
=
Function
(
"c_last"
,
[
ListType
(
"Int"
)],
"Int"
)
def
c_last
(
xs
):
return
xs
[
-
1
]
if
len
(
xs
)
>
0
else
Null
f_take
=
Function
(
"c_take"
,
[
"Int"
,
ListType
(
"Int"
)],
ListType
(
"Int"
))
def
c_take
(
n
,
xs
):
return
xs
[:
n
]
f_drop
=
Function
(
"c_drop"
,
[
"Int"
,
ListType
(
"Int"
)],
ListType
(
"Int"
))
def
c_drop
(
n
,
xs
):
return
xs
[
n
:]
f_access
=
Function
(
"c_access"
,
[
"Int"
,
ListType
(
"Int"
)],
"Int"
)
def
c_access
(
n
,
xs
):
return
xs
[
n
]
if
n
>=
0
and
len
(
xs
)
>
n
else
Null
f_max
=
Function
(
"c_max"
,
[
ListType
(
"Int"
)],
"Int"
)
def
c_max
(
xs
):
return
max
(
xs
)
if
len
(
xs
)
>
0
else
Null
f_min
=
Function
(
"c_min"
,
[
ListType
(
"Int"
)],
"Int"
)
def
c_min
(
xs
):
return
min
(
xs
)
if
len
(
xs
)
>
0
else
Null
f_reverse
=
Function
(
"c_reverse"
,
[
ListType
(
"Int"
)],
ListType
(
"Int"
))
def
c_reverse
(
xs
):
return
list
(
reversed
(
xs
))
f_sort
=
Function
(
"sorted"
,
[
ListType
(
"Int"
)],
ListType
(
"Int"
))
# def c_sort(xs): return sorted(xs)
f_sum
=
Function
(
"sum"
,
[
ListType
(
"Int"
)],
"Int"
)
# def c_sum(xs): return sum(xs)
## Lambdas
# Int -> Int
def
plus_one
(
x
):
return
x
+
1
def
minus_one
(
x
):
return
x
-
1
def
times_two
(
x
):
return
x
*
2
def
neg
(
x
):
return
x
*
(
-
1
)
def
div_two
(
x
):
return
int
(
x
/
2
)
def
sq
(
x
):
return
x
**
2
def
times_three
(
x
):
return
x
*
3
def
div_three
(
x
):
return
int
(
x
/
3
)
def
times_four
(
x
):
return
x
*
4
def
div_four
(
x
):
return
int
(
x
/
4
)
# Int -> Bool
def
pos
(
x
):
return
x
>
0
def
neg
(
x
):
return
x
<
0
def
even
(
x
):
return
x
%
2
==
0
def
odd
(
x
):
return
x
%
2
==
1
# Int -> Int -> Int
def
add
(
x
,
y
):
return
x
+
y
def
sub
(
x
,
y
):
return
x
-
y
def
mul
(
x
,
y
):
return
x
*
y
# HOFs
f_map
=
Function
(
"map"
,
[
ListType
(
"Int"
)],
ListType
(
"Int"
),
[
FunctionType
([
"Int"
,
"Int"
])])
f_filter
=
Function
(
"filter"
,
[
ListType
(
"Int"
)],
ListType
(
"Int"
),
[
FunctionType
([
"Int"
,
"Bool"
])])
f_count
=
Function
(
"c_count"
,
[
ListType
(
"Int"
)],
"Int"
,
[
FunctionType
([
"Int"
,
"Bool"
])])
def
c_count
(
f
,
xs
):
return
len
([
x
for
x
in
xs
if
f
(
x
)])
f_zipwith
=
Function
(
"c_zipwith"
,
[
ListType
(
"Int"
),
ListType
(
"Int"
)],
ListType
(
"Int"
),
[
FunctionType
([
"Int"
,
"Int"
,
"Int"
])])
#FIX
def
c_zipwith
(
f
,
xs
,
ys
):
return
[
f
(
x
,
y
)
for
(
x
,
y
)
in
zip
(
xs
,
ys
)]
f_scan
=
Function
(
"c_scan"
,
[
ListType
(
"Int"
)],
ListType
(
"Int"
),
[
FunctionType
([
"Int"
,
"Int"
,
"Int"
])])
def
c_scan
(
f
,
xs
):
out
=
xs
for
i
in
range
(
1
,
len
(
xs
)):
out
[
i
]
=
f
(
xs
[
i
],
xs
[
i
-
1
])
return
out
@
contextlib
.
contextmanager
def
stdoutIO
(
stdout
=
None
):
old
=
sys
.
stdout
if
stdout
is
None
:
stdout
=
StringIO
.
StringIO
()
sys
.
stdout
=
stdout
yield
stdout
sys
.
stdout
=
old
def
evaluate
(
program_str
,
input_names_to_vals
,
default
=
"ERROR"
):
exec_str
=
[]
for
name
,
val
in
input_names_to_vals
.
iteritems
():
exec_str
+=
name
+
" = "
+
str
(
val
)
+
"; "
exec_str
+=
program_str
if
type
(
exec_str
)
is
list
:
exec_str
=
""
.
join
(
exec_str
)
with
stdoutIO
()
as
s
:
# pylint: disable=bare-except
try
:
exec
exec_str
+
" print(out)"
return
s
.
getvalue
()[:
-
1
]
except
:
return
default
# pylint: enable=bare-except
class
Statement
(
object
):
"""Statement class."""
def
__init__
(
self
,
fn
,
output_var
,
arg_vars
,
fn_args
=
None
):
self
.
fn
=
fn
self
.
output_var
=
output_var
self
.
arg_vars
=
arg_vars
self
.
fn_args
=
fn_args
or
[]
def
__str__
(
self
):
return
"%s = %s(%s%s%s)"
%
(
self
.
output_var
,
self
.
fn
.
name
,
", "
.
join
(
self
.
fn_args
),
", "
if
self
.
fn_args
else
""
,
", "
.
join
(
self
.
arg_vars
))
def
substitute
(
self
,
env
):
self
.
output_var
=
env
.
get
(
self
.
output_var
,
self
.
output_var
)
self
.
arg_vars
=
[
env
.
get
(
v
,
v
)
for
v
in
self
.
arg_vars
]
class
ProgramGrower
(
object
):
"""Grow programs."""
def
__init__
(
self
,
functions
,
types_to_lambdas
):
self
.
functions
=
functions
self
.
types_to_lambdas
=
types_to_lambdas
def
grow_body
(
self
,
new_var_name
,
dependencies
,
types_to_vars
):
"""Grow the program body."""
choices
=
[]
for
f
in
self
.
functions
:
if
all
([
a
in
types_to_vars
.
keys
()
for
a
in
f
.
arg_types
]):
choices
.
append
(
f
)
f
=
random
.
choice
(
choices
)
args
=
[]
for
t
in
f
.
arg_types
:
possible_vars
=
random
.
choice
(
types_to_vars
[
t
])
var
=
random
.
choice
(
possible_vars
)
args
.
append
(
var
)
dependencies
.
setdefault
(
new_var_name
,
[]).
extend
(
[
var
]
+
(
dependencies
[
var
]))
fn_args
=
[
random
.
choice
(
self
.
types_to_lambdas
[
t
])
for
t
in
f
.
fn_arg_types
]
types_to_vars
.
setdefault
(
f
.
output_type
,
[]).
append
(
new_var_name
)
return
Statement
(
f
,
new_var_name
,
args
,
fn_args
)
def
grow
(
self
,
program_len
,
input_types
):
"""Grow the program."""
var_names
=
list
(
reversed
(
map
(
chr
,
range
(
97
,
123
))))
dependencies
=
dict
()
types_to_vars
=
dict
()
input_names
=
[]
for
t
in
input_types
:
var
=
var_names
.
pop
()
dependencies
[
var
]
=
[]
types_to_vars
.
setdefault
(
t
,
[]).
append
(
var
)
input_names
.
append
(
var
)
statements
=
[]
for
_
in
range
(
program_len
-
1
):
var
=
var_names
.
pop
()
statements
.
append
(
self
.
grow_body
(
var
,
dependencies
,
types_to_vars
))
statements
.
append
(
self
.
grow_body
(
"out"
,
dependencies
,
types_to_vars
))
new_var_names
=
[
c
for
c
in
map
(
chr
,
range
(
97
,
123
))
if
c
not
in
input_names
]
new_var_names
.
reverse
()
keep_statements
=
[]
env
=
dict
()
for
s
in
statements
:
if
s
.
output_var
in
dependencies
[
"out"
]:
keep_statements
.
append
(
s
)
env
[
s
.
output_var
]
=
new_var_names
.
pop
()
if
s
.
output_var
==
"out"
:
keep_statements
.
append
(
s
)
for
k
in
keep_statements
:
k
.
substitute
(
env
)
return
Program
(
input_names
,
input_types
,
";"
.
join
(
[
str
(
k
)
for
k
in
keep_statements
]))
class
Program
(
object
):
"""The program class."""
def
__init__
(
self
,
input_names
,
input_types
,
body
):
self
.
input_names
=
input_names
self
.
input_types
=
input_types
self
.
body
=
body
def
evaluate
(
self
,
inputs
):
"""Evaluate this program."""
if
len
(
inputs
)
!=
len
(
self
.
input_names
):
raise
AssertionError
(
"inputs and input_names have to"
"have the same len. inp: %s , names: %s"
%
(
str
(
inputs
),
str
(
self
.
input_names
)))
inp_str
=
""
for
(
name
,
inp
)
in
zip
(
self
.
input_names
,
inputs
):
inp_str
+=
name
+
" = "
+
str
(
inp
)
+
"; "
with
stdoutIO
()
as
s
:
# pylint: disable=exec-used
exec
inp_str
+
self
.
body
+
"; print(out)"
# pylint: enable=exec-used
return
s
.
getvalue
()[:
-
1
]
def
flat_str
(
self
):
out
=
""
for
s
in
self
.
body
.
split
(
";"
):
out
+=
s
+
";"
return
out
def
__str__
(
self
):
out
=
""
for
(
n
,
t
)
in
zip
(
self
.
input_names
,
self
.
input_types
):
out
+=
n
+
" = "
+
str
(
t
)
+
"
\n
"
for
s
in
self
.
body
.
split
(
";"
):
out
+=
s
+
"
\n
"
return
out
prog_vocab
=
[]
prog_rev_vocab
=
{}
def
tokenize
(
string
,
tokens
=
None
):
"""Tokenize the program string."""
if
tokens
is
None
:
tokens
=
prog_vocab
tokens
=
sorted
(
tokens
,
key
=
len
,
reverse
=
True
)
out
=
[]
string
=
string
.
strip
()
while
string
:
found
=
False
for
t
in
tokens
:
if
string
.
startswith
(
t
):
out
.
append
(
t
)
string
=
string
[
len
(
t
):]
found
=
True
break
if
not
found
:
raise
ValueError
(
"Couldn't tokenize this: "
+
string
)
string
=
string
.
strip
()
return
out
def
clean_up
(
output
,
max_val
=
100
):
o
=
eval
(
str
(
output
))
if
isinstance
(
o
,
bool
):
return
o
if
isinstance
(
o
,
int
):
if
o
>=
0
:
return
min
(
o
,
max_val
)
else
:
return
max
(
o
,
-
1
*
max_val
)
if
isinstance
(
o
,
list
):
return
[
clean_up
(
l
)
for
l
in
o
]
def
make_vocab
():
gen
(
2
,
0
)
def
gen
(
max_len
,
how_many
):
"""Generate some programs."""
functions
=
[
f_head
,
f_last
,
f_take
,
f_drop
,
f_access
,
f_max
,
f_min
,
f_reverse
,
f_sort
,
f_sum
,
f_map
,
f_filter
,
f_count
,
f_zipwith
,
f_scan
]
types_to_lambdas
=
{
FunctionType
([
"Int"
,
"Int"
]):
[
"plus_one"
,
"minus_one"
,
"times_two"
,
"div_two"
,
"sq"
,
"times_three"
,
"div_three"
,
"times_four"
,
"div_four"
],
FunctionType
([
"Int"
,
"Bool"
]):
[
"pos"
,
"neg"
,
"even"
,
"odd"
],
FunctionType
([
"Int"
,
"Int"
,
"Int"
]):
[
"add"
,
"sub"
,
"mul"
]
}
tokens
=
[]
for
f
in
functions
:
tokens
.
append
(
f
.
name
)
for
v
in
types_to_lambdas
.
values
():
tokens
.
extend
(
v
)
tokens
.
extend
([
"="
,
";"
,
","
,
"("
,
")"
,
"["
,
"]"
,
"Int"
,
"out"
])
tokens
.
extend
(
map
(
chr
,
range
(
97
,
123
)))
io_tokens
=
map
(
str
,
range
(
-
220
,
220
))
if
not
prog_vocab
:
prog_vocab
.
extend
([
"_PAD"
,
"_EOS"
]
+
tokens
+
io_tokens
)
for
i
,
t
in
enumerate
(
prog_vocab
):
prog_rev_vocab
[
t
]
=
i
io_tokens
+=
[
","
,
"["
,
"]"
,
")"
,
"("
,
"None"
]
grower
=
ProgramGrower
(
functions
=
functions
,
types_to_lambdas
=
types_to_lambdas
)
def
mk_inp
(
l
):
return
[
random
.
choice
(
range
(
-
5
,
5
))
for
_
in
range
(
l
)]
tar
=
[
ListType
(
"Int"
)]
inps
=
[[
mk_inp
(
3
)],
[
mk_inp
(
5
)],
[
mk_inp
(
7
)],
[
mk_inp
(
15
)]]
save_prefix
=
None
outcomes_to_programs
=
dict
()
tried
=
set
()
counter
=
0
choices
=
[
0
]
if
max_len
==
0
else
range
(
max_len
)
while
counter
<
100
*
how_many
and
len
(
outcomes_to_programs
)
<
how_many
:
counter
+=
1
length
=
random
.
choice
(
choices
)
t
=
grower
.
grow
(
length
,
tar
)
while
t
in
tried
:
length
=
random
.
choice
(
choices
)
t
=
grower
.
grow
(
length
,
tar
)
# print(t.flat_str())
tried
.
add
(
t
)
outcomes
=
[
clean_up
(
t
.
evaluate
(
i
))
for
i
in
inps
]
outcome_str
=
str
(
zip
(
inps
,
outcomes
))
if
outcome_str
in
outcomes_to_programs
:
outcomes_to_programs
[
outcome_str
]
=
min
(
[
t
.
flat_str
(),
outcomes_to_programs
[
outcome_str
]],
key
=
lambda
x
:
len
(
tokenize
(
x
,
tokens
)))
else
:
outcomes_to_programs
[
outcome_str
]
=
t
.
flat_str
()
if
counter
%
5000
==
0
:
print
"== proggen: tried: "
+
str
(
counter
)
print
"== proggen: kept: "
+
str
(
len
(
outcomes_to_programs
))
if
counter
%
250000
==
0
and
save_prefix
is
not
None
:
print
"saving..."
save_counter
=
0
progfilename
=
os
.
path
.
join
(
save_prefix
,
"prog_"
+
str
(
counter
)
+
".txt"
)
iofilename
=
os
.
path
.
join
(
save_prefix
,
"io_"
+
str
(
counter
)
+
".txt"
)
prog_token_filename
=
os
.
path
.
join
(
save_prefix
,
"prog_tokens_"
+
str
(
counter
)
+
".txt"
)
io_token_filename
=
os
.
path
.
join
(
save_prefix
,
"io_tokens_"
+
str
(
counter
)
+
".txt"
)
with
open
(
progfilename
,
"a+"
)
as
fp
,
\
open
(
iofilename
,
"a+"
)
as
fi
,
\
open
(
prog_token_filename
,
"a+"
)
as
ftp
,
\
open
(
io_token_filename
,
"a+"
)
as
fti
:
for
(
o
,
p
)
in
outcomes_to_programs
.
iteritems
():
save_counter
+=
1
if
save_counter
%
500
==
0
:
print
"saving %d of %d"
%
(
save_counter
,
len
(
outcomes_to_programs
))
fp
.
write
(
p
+
"
\n
"
)
fi
.
write
(
o
+
"
\n
"
)
ftp
.
write
(
str
(
tokenize
(
p
,
tokens
))
+
"
\n
"
)
fti
.
write
(
str
(
tokenize
(
o
,
io_tokens
))
+
"
\n
"
)
return
list
(
outcomes_to_programs
.
values
())
neural_gpu/wmt_utils.py
0 → 100644
View file @
a315e568
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for downloading data from WMT, tokenizing, vocabularies."""
import
gzip
import
os
import
re
import
tarfile
from
six.moves
import
urllib
import
tensorflow
as
tf
# Special vocabulary symbols - we always put them at the start.
_PAD
=
b
"_PAD"
_GO
=
b
"_GO"
_EOS
=
b
"_EOS"
_UNK
=
b
"_CHAR_UNK"
_SPACE
=
b
"_SPACE"
_START_VOCAB
=
[
_PAD
,
_GO
,
_EOS
,
_UNK
,
_SPACE
]
PAD_ID
=
0
GO_ID
=
1
EOS_ID
=
2
UNK_ID
=
3
SPACE_ID
=
4
# Regular expressions used to tokenize.
_CHAR_MARKER
=
"_CHAR_"
_CHAR_MARKER_LEN
=
len
(
_CHAR_MARKER
)
_SPEC_CHARS
=
""
+
chr
(
226
)
+
chr
(
153
)
+
chr
(
128
)
_PUNCTUATION
=
"][.,!?
\"
':;%$#@&*+}{|><=/^~)(_`,0123456789"
+
_SPEC_CHARS
+
"-"
_WORD_SPLIT
=
re
.
compile
(
b
"(["
+
_PUNCTUATION
+
"])"
)
_OLD_WORD_SPLIT
=
re
.
compile
(
b
"([.,!?
\"
':;)(])"
)
_DIGIT_RE
=
re
.
compile
(
br
"\d"
)
# URLs for WMT data.
_WMT_ENFR_TRAIN_URL
=
"http://www.statmt.org/wmt10/training-giga-fren.tar"
_WMT_ENFR_DEV_URL
=
"http://www.statmt.org/wmt15/dev-v2.tgz"
def
maybe_download
(
directory
,
filename
,
url
):
"""Download filename from url unless it's already in directory."""
if
not
tf
.
gfile
.
Exists
(
directory
):
print
"Creating directory %s"
%
directory
os
.
mkdir
(
directory
)
filepath
=
os
.
path
.
join
(
directory
,
filename
)
if
not
tf
.
gfile
.
Exists
(
filepath
):
print
"Downloading %s to %s"
%
(
url
,
filepath
)
filepath
,
_
=
urllib
.
request
.
urlretrieve
(
url
,
filepath
)
statinfo
=
os
.
stat
(
filepath
)
print
"Succesfully downloaded"
,
filename
,
statinfo
.
st_size
,
"bytes"
return
filepath
def
gunzip_file
(
gz_path
,
new_path
):
"""Unzips from gz_path into new_path."""
print
"Unpacking %s to %s"
%
(
gz_path
,
new_path
)
with
gzip
.
open
(
gz_path
,
"rb"
)
as
gz_file
:
with
open
(
new_path
,
"wb"
)
as
new_file
:
for
line
in
gz_file
:
new_file
.
write
(
line
)
def
get_wmt_enfr_train_set
(
directory
):
"""Download the WMT en-fr training corpus to directory unless it's there."""
train_path
=
os
.
path
.
join
(
directory
,
"giga-fren.release2.fixed"
)
if
not
(
tf
.
gfile
.
Exists
(
train_path
+
".fr"
)
and
tf
.
gfile
.
Exists
(
train_path
+
".en"
)):
corpus_file
=
maybe_download
(
directory
,
"training-giga-fren.tar"
,
_WMT_ENFR_TRAIN_URL
)
print
"Extracting tar file %s"
%
corpus_file
with
tarfile
.
open
(
corpus_file
,
"r"
)
as
corpus_tar
:
corpus_tar
.
extractall
(
directory
)
gunzip_file
(
train_path
+
".fr.gz"
,
train_path
+
".fr"
)
gunzip_file
(
train_path
+
".en.gz"
,
train_path
+
".en"
)
return
train_path
def
get_wmt_enfr_dev_set
(
directory
):
"""Download the WMT en-fr training corpus to directory unless it's there."""
dev_name
=
"newstest2013"
dev_path
=
os
.
path
.
join
(
directory
,
dev_name
)
if
not
(
tf
.
gfile
.
Exists
(
dev_path
+
".fr"
)
and
tf
.
gfile
.
Exists
(
dev_path
+
".en"
)):
dev_file
=
maybe_download
(
directory
,
"dev-v2.tgz"
,
_WMT_ENFR_DEV_URL
)
print
"Extracting tgz file %s"
%
dev_file
with
tarfile
.
open
(
dev_file
,
"r:gz"
)
as
dev_tar
:
fr_dev_file
=
dev_tar
.
getmember
(
"dev/"
+
dev_name
+
".fr"
)
en_dev_file
=
dev_tar
.
getmember
(
"dev/"
+
dev_name
+
".en"
)
fr_dev_file
.
name
=
dev_name
+
".fr"
# Extract without "dev/" prefix.
en_dev_file
.
name
=
dev_name
+
".en"
dev_tar
.
extract
(
fr_dev_file
,
directory
)
dev_tar
.
extract
(
en_dev_file
,
directory
)
return
dev_path
def
is_char
(
token
):
if
len
(
token
)
>
_CHAR_MARKER_LEN
:
if
token
[:
_CHAR_MARKER_LEN
]
==
_CHAR_MARKER
:
return
True
return
False
def
basic_detokenizer
(
tokens
):
"""Reverse the process of the basic tokenizer below."""
result
=
[]
previous_nospace
=
True
for
t
in
tokens
:
if
is_char
(
t
):
result
.
append
(
t
[
_CHAR_MARKER_LEN
:])
previous_nospace
=
True
elif
t
==
_SPACE
:
result
.
append
(
" "
)
previous_nospace
=
True
elif
previous_nospace
:
result
.
append
(
t
)
previous_nospace
=
False
else
:
result
.
extend
([
" "
,
t
])
previous_nospace
=
False
return
""
.
join
(
result
)
old_style
=
False
def
basic_tokenizer
(
sentence
):
"""Very basic tokenizer: split the sentence into a list of tokens."""
words
=
[]
if
old_style
:
for
space_separated_fragment
in
sentence
.
strip
().
split
():
words
.
extend
(
re
.
split
(
_OLD_WORD_SPLIT
,
space_separated_fragment
))
return
[
w
for
w
in
words
if
w
]
for
space_separated_fragment
in
sentence
.
strip
().
split
():
tokens
=
[
t
for
t
in
re
.
split
(
_WORD_SPLIT
,
space_separated_fragment
)
if
t
]
first_is_char
=
False
for
i
,
t
in
enumerate
(
tokens
):
if
len
(
t
)
==
1
and
t
in
_PUNCTUATION
:
tokens
[
i
]
=
_CHAR_MARKER
+
t
if
i
==
0
:
first_is_char
=
True
if
words
and
words
[
-
1
]
!=
_SPACE
and
(
first_is_char
or
is_char
(
words
[
-
1
])):
tokens
=
[
_SPACE
]
+
tokens
spaced_tokens
=
[]
for
i
,
tok
in
enumerate
(
tokens
):
spaced_tokens
.
append
(
tokens
[
i
])
if
i
<
len
(
tokens
)
-
1
:
if
tok
!=
_SPACE
and
not
(
is_char
(
tok
)
or
is_char
(
tokens
[
i
+
1
])):
spaced_tokens
.
append
(
_SPACE
)
words
.
extend
(
spaced_tokens
)
return
words
def
space_tokenizer
(
sentence
):
return
sentence
.
strip
().
split
()
def
is_pos_tag
(
token
):
"""Check if token is a part-of-speech tag."""
return
(
token
in
[
"CC"
,
"CD"
,
"DT"
,
"EX"
,
"FW"
,
"IN"
,
"JJ"
,
"JJR"
,
"JJS"
,
"LS"
,
"MD"
,
"NN"
,
"NNS"
,
"NNP"
,
"NNPS"
,
"PDT"
,
"POS"
,
"PRP"
,
"PRP$"
,
"RB"
,
"RBR"
,
"RBS"
,
"RP"
,
"SYM"
,
"TO"
,
"UH"
,
"VB"
,
"VBD"
,
"VBG"
,
"VBN"
,
"VBP"
,
"VBZ"
,
"WDT"
,
"WP"
,
"WP$"
,
"WRB"
,
"."
,
","
,
":"
,
")"
,
"-LRB-"
,
"("
,
"-RRB-"
,
"HYPH"
,
"$"
,
"``"
,
"''"
,
"ADD"
,
"AFX"
,
"QTR"
,
"BES"
,
"-DFL-"
,
"GW"
,
"HVS"
,
"NFP"
])
def
parse_constraints
(
inpt
,
res
):
ntags
=
len
(
res
)
nwords
=
len
(
inpt
)
npostags
=
len
([
x
for
x
in
res
if
is_pos_tag
(
x
)])
nclose
=
len
([
x
for
x
in
res
if
x
[
0
]
==
"/"
])
nopen
=
ntags
-
nclose
-
npostags
return
(
abs
(
npostags
-
nwords
),
abs
(
nclose
-
nopen
))
def
create_vocabulary
(
vocabulary_path
,
data_path
,
max_vocabulary_size
,
tokenizer
=
None
,
normalize_digits
=
False
):
"""Create vocabulary file (if it does not exist yet) from data file.
Data file is assumed to contain one sentence per line. Each sentence is
tokenized and digits are normalized (if normalize_digits is set).
Vocabulary contains the most-frequent tokens up to max_vocabulary_size.
We write it to vocabulary_path in a one-token-per-line format, so that later
token in the first line gets id=0, second line gets id=1, and so on.
Args:
vocabulary_path: path where the vocabulary will be created.
data_path: data file that will be used to create vocabulary.
max_vocabulary_size: limit on the size of the created vocabulary.
tokenizer: a function to use to tokenize each data sentence;
if None, basic_tokenizer will be used.
normalize_digits: Boolean; if true, all digits are replaced by 0s.
"""
if
not
tf
.
gfile
.
Exists
(
vocabulary_path
):
print
"Creating vocabulary %s from data %s"
%
(
vocabulary_path
,
data_path
)
vocab
,
chars
=
{},
{}
for
c
in
_PUNCTUATION
:
chars
[
c
]
=
1
# Read French file.
with
tf
.
gfile
.
GFile
(
data_path
+
".fr"
,
mode
=
"rb"
)
as
f
:
counter
=
0
for
line_in
in
f
:
line
=
" "
.
join
(
line_in
.
split
())
counter
+=
1
if
counter
%
100000
==
0
:
print
" processing fr line %d"
%
counter
for
c
in
line
:
if
c
in
chars
:
chars
[
c
]
+=
1
else
:
chars
[
c
]
=
1
tokens
=
tokenizer
(
line
)
if
tokenizer
else
basic_tokenizer
(
line
)
tokens
=
[
t
for
t
in
tokens
if
not
is_char
(
t
)
and
t
!=
_SPACE
]
for
w
in
tokens
:
word
=
re
.
sub
(
_DIGIT_RE
,
b
"0"
,
w
)
if
normalize_digits
else
w
if
word
in
vocab
:
vocab
[
word
]
+=
1000000000
# We want target words first.
else
:
vocab
[
word
]
=
1000000000
# Read English file.
with
tf
.
gfile
.
GFile
(
data_path
+
".en"
,
mode
=
"rb"
)
as
f
:
counter
=
0
for
line_in
in
f
:
line
=
" "
.
join
(
line_in
.
split
())
counter
+=
1
if
counter
%
100000
==
0
:
print
" processing en line %d"
%
counter
for
c
in
line
:
if
c
in
chars
:
chars
[
c
]
+=
1
else
:
chars
[
c
]
=
1
tokens
=
tokenizer
(
line
)
if
tokenizer
else
basic_tokenizer
(
line
)
tokens
=
[
t
for
t
in
tokens
if
not
is_char
(
t
)
and
t
!=
_SPACE
]
for
w
in
tokens
:
word
=
re
.
sub
(
_DIGIT_RE
,
b
"0"
,
w
)
if
normalize_digits
else
w
if
word
in
vocab
:
vocab
[
word
]
+=
1
else
:
vocab
[
word
]
=
1
sorted_vocab
=
sorted
(
vocab
,
key
=
vocab
.
get
,
reverse
=
True
)
sorted_chars
=
sorted
(
chars
,
key
=
vocab
.
get
,
reverse
=
True
)
sorted_chars
=
[
_CHAR_MARKER
+
c
for
c
in
sorted_chars
]
vocab_list
=
_START_VOCAB
+
sorted_chars
+
sorted_vocab
if
tokenizer
:
vocab_list
=
_START_VOCAB
+
sorted_vocab
if
len
(
vocab_list
)
>
max_vocabulary_size
:
vocab_list
=
vocab_list
[:
max_vocabulary_size
]
with
tf
.
gfile
.
GFile
(
vocabulary_path
,
mode
=
"wb"
)
as
vocab_file
:
for
w
in
vocab_list
:
vocab_file
.
write
(
w
+
b
"
\n
"
)
def
initialize_vocabulary
(
vocabulary_path
):
"""Initialize vocabulary from file.
We assume the vocabulary is stored one-item-per-line, so a file:
dog
cat
will result in a vocabulary {"dog": 0, "cat": 1}, and this function will
also return the reversed-vocabulary ["dog", "cat"].
Args:
vocabulary_path: path to the file containing the vocabulary.
Returns:
a pair: the vocabulary (a dictionary mapping string to integers), and
the reversed vocabulary (a list, which reverses the vocabulary mapping).
Raises:
ValueError: if the provided vocabulary_path does not exist.
"""
if
tf
.
gfile
.
Exists
(
vocabulary_path
):
rev_vocab
=
[]
with
tf
.
gfile
.
GFile
(
vocabulary_path
,
mode
=
"rb"
)
as
f
:
rev_vocab
.
extend
(
f
.
readlines
())
rev_vocab
=
[
line
.
strip
()
for
line
in
rev_vocab
]
vocab
=
dict
([(
x
,
y
)
for
(
y
,
x
)
in
enumerate
(
rev_vocab
)])
return
vocab
,
rev_vocab
else
:
raise
ValueError
(
"Vocabulary file %s not found."
,
vocabulary_path
)
def
sentence_to_token_ids_raw
(
sentence
,
vocabulary
,
tokenizer
=
None
,
normalize_digits
=
old_style
):
"""Convert a string to list of integers representing token-ids.
For example, a sentence "I have a dog" may become tokenized into
["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2,
"a": 4, "dog": 7"} this function will return [1, 2, 4, 7].
Args:
sentence: the sentence in bytes format to convert to token-ids.
vocabulary: a dictionary mapping tokens to integers.
tokenizer: a function to use to tokenize each sentence;
if None, basic_tokenizer will be used.
normalize_digits: Boolean; if true, all digits are replaced by 0s.
Returns:
a list of integers, the token-ids for the sentence.
"""
if
tokenizer
:
words
=
tokenizer
(
sentence
)
else
:
words
=
basic_tokenizer
(
sentence
)
result
=
[]
for
w
in
words
:
if
normalize_digits
:
w
=
re
.
sub
(
_DIGIT_RE
,
b
"0"
,
w
)
if
w
in
vocabulary
:
result
.
append
(
vocabulary
[
w
])
else
:
if
tokenizer
:
result
.
append
(
UNK_ID
)
else
:
result
.
append
(
SPACE_ID
)
for
c
in
w
:
result
.
append
(
vocabulary
.
get
(
_CHAR_MARKER
+
c
,
UNK_ID
))
result
.
append
(
SPACE_ID
)
while
result
and
result
[
0
]
==
SPACE_ID
:
result
=
result
[
1
:]
while
result
and
result
[
-
1
]
==
SPACE_ID
:
result
=
result
[:
-
1
]
return
result
def
sentence_to_token_ids
(
sentence
,
vocabulary
,
tokenizer
=
None
,
normalize_digits
=
old_style
):
"""Convert a string to list of integers representing token-ids, tab=0."""
tab_parts
=
sentence
.
strip
().
split
(
"
\t
"
)
toks
=
[
sentence_to_token_ids_raw
(
t
,
vocabulary
,
tokenizer
,
normalize_digits
)
for
t
in
tab_parts
]
res
=
[]
for
t
in
toks
:
res
.
extend
(
t
)
res
.
append
(
0
)
return
res
[:
-
1
]
def
data_to_token_ids
(
data_path
,
target_path
,
vocabulary_path
,
tokenizer
=
None
,
normalize_digits
=
False
):
"""Tokenize data file and turn into token-ids using given vocabulary file.
This function loads data line-by-line from data_path, calls the above
sentence_to_token_ids, and saves the result to target_path. See comment
for sentence_to_token_ids on the details of token-ids format.
Args:
data_path: path to the data file in one-sentence-per-line format.
target_path: path where the file with token-ids will be created.
vocabulary_path: path to the vocabulary file.
tokenizer: a function to use to tokenize each sentence;
if None, basic_tokenizer will be used.
normalize_digits: Boolean; if true, all digits are replaced by 0s.
"""
if
not
tf
.
gfile
.
Exists
(
target_path
):
print
"Tokenizing data in %s"
%
data_path
vocab
,
_
=
initialize_vocabulary
(
vocabulary_path
)
with
tf
.
gfile
.
GFile
(
data_path
,
mode
=
"rb"
)
as
data_file
:
with
tf
.
gfile
.
GFile
(
target_path
,
mode
=
"w"
)
as
tokens_file
:
counter
=
0
for
line
in
data_file
:
counter
+=
1
if
counter
%
100000
==
0
:
print
" tokenizing line %d"
%
counter
token_ids
=
sentence_to_token_ids
(
line
,
vocab
,
tokenizer
,
normalize_digits
)
tokens_file
.
write
(
" "
.
join
([
str
(
tok
)
for
tok
in
token_ids
])
+
"
\n
"
)
def
prepare_wmt_data
(
data_dir
,
vocabulary_size
,
tokenizer
=
None
,
normalize_digits
=
False
):
"""Get WMT data into data_dir, create vocabularies and tokenize data.
Args:
data_dir: directory in which the data sets will be stored.
vocabulary_size: size of the joint vocabulary to create and use.
tokenizer: a function to use to tokenize each data sentence;
if None, basic_tokenizer will be used.
normalize_digits: Boolean; if true, all digits are replaced by 0s.
Returns:
A tuple of 6 elements:
(1) path to the token-ids for English training data-set,
(2) path to the token-ids for French training data-set,
(3) path to the token-ids for English development data-set,
(4) path to the token-ids for French development data-set,
(5) path to the vocabulary file,
(6) path to the vocabulary file (for compatibility with non-joint vocab).
"""
# Get wmt data to the specified directory.
train_path
=
get_wmt_enfr_train_set
(
data_dir
)
dev_path
=
get_wmt_enfr_dev_set
(
data_dir
)
# Create vocabularies of the appropriate sizes.
vocab_path
=
os
.
path
.
join
(
data_dir
,
"vocab%d.txt"
%
vocabulary_size
)
create_vocabulary
(
vocab_path
,
train_path
,
vocabulary_size
,
tokenizer
=
tokenizer
,
normalize_digits
=
normalize_digits
)
# Create token ids for the training data.
fr_train_ids_path
=
train_path
+
(
".ids%d.fr"
%
vocabulary_size
)
en_train_ids_path
=
train_path
+
(
".ids%d.en"
%
vocabulary_size
)
data_to_token_ids
(
train_path
+
".fr"
,
fr_train_ids_path
,
vocab_path
,
tokenizer
=
tokenizer
,
normalize_digits
=
normalize_digits
)
data_to_token_ids
(
train_path
+
".en"
,
en_train_ids_path
,
vocab_path
,
tokenizer
=
tokenizer
,
normalize_digits
=
normalize_digits
)
# Create token ids for the development data.
fr_dev_ids_path
=
dev_path
+
(
".ids%d.fr"
%
vocabulary_size
)
en_dev_ids_path
=
dev_path
+
(
".ids%d.en"
%
vocabulary_size
)
data_to_token_ids
(
dev_path
+
".fr"
,
fr_dev_ids_path
,
vocab_path
,
tokenizer
=
tokenizer
,
normalize_digits
=
normalize_digits
)
data_to_token_ids
(
dev_path
+
".en"
,
en_dev_ids_path
,
vocab_path
,
tokenizer
=
tokenizer
,
normalize_digits
=
normalize_digits
)
return
(
en_train_ids_path
,
fr_train_ids_path
,
en_dev_ids_path
,
fr_dev_ids_path
,
vocab_path
,
vocab_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment