Unverified Commit 26aded58 authored by Yanhui Liang's avatar Yanhui Liang Committed by GitHub
Browse files

Research/minigo unit test (#4023)

* Add minigo unit test

* Add minigo unit test

* Fix gpylints and update readme
parent c86f7916
......@@ -56,7 +56,9 @@ parameterized layers respectively for the residual tower, plus an additional 2
layers for the policy head and 3 layers for the value head.
## Getting Started
Please follow the [instructions](https://github.com/tensorflow/minigo/blob/master/README.md#getting-started) in original Minigo repo to set up the environment.
This project assumes you have virtualenv, TensorFlow (>= 1.5) and two other Go-related
packages pygtp(>=0.4) and sgf (==0.5).
## Training Model
One iteration of reinforcement learning consists of the following steps:
......
......@@ -40,6 +40,9 @@ SGF 'aa' 'sa' ''
KGS 'A19' 'T19' 'pass'
pygtp (1, 19) (19, 19) (0, 0)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gtp
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for coords."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
import coords
import numpy
import utils_test
tf.logging.set_verbosity(tf.logging.ERROR)
class TestCoords(utils_test.MiniGoUnitTest):
def test_upperleft(self):
self.assertEqual(coords.from_sgf('aa'), (0, 0))
self.assertEqual(coords.from_flat(utils_test.BOARD_SIZE, 0), (0, 0))
self.assertEqual(coords.from_kgs(utils_test.BOARD_SIZE, 'A9'), (0, 0))
self.assertEqual(coords.from_pygtp(utils_test.BOARD_SIZE, (1, 9)), (0, 0))
self.assertEqual(coords.to_sgf((0, 0)), 'aa')
self.assertEqual(coords.to_flat(utils_test.BOARD_SIZE, (0, 0)), 0)
self.assertEqual(coords.to_kgs(utils_test.BOARD_SIZE, (0, 0)), 'A9')
self.assertEqual(coords.to_pygtp(utils_test.BOARD_SIZE, (0, 0)), (1, 9))
def test_topleft(self):
self.assertEqual(coords.from_sgf('ia'), (0, 8))
self.assertEqual(coords.from_flat(utils_test.BOARD_SIZE, 8), (0, 8))
self.assertEqual(coords.from_kgs(utils_test.BOARD_SIZE, 'J9'), (0, 8))
self.assertEqual(coords.from_pygtp(utils_test.BOARD_SIZE, (9, 9)), (0, 8))
self.assertEqual(coords.to_sgf((0, 8)), 'ia')
self.assertEqual(coords.to_flat(utils_test.BOARD_SIZE, (0, 8)), 8)
self.assertEqual(coords.to_kgs(utils_test.BOARD_SIZE, (0, 8)), 'J9')
self.assertEqual(coords.to_pygtp(utils_test.BOARD_SIZE, (0, 8)), (9, 9))
def test_pass(self):
self.assertEqual(coords.from_sgf(''), None)
self.assertEqual(coords.from_flat(utils_test.BOARD_SIZE, 81), None)
self.assertEqual(coords.from_kgs(utils_test.BOARD_SIZE, 'pass'), None)
self.assertEqual(coords.from_pygtp(utils_test.BOARD_SIZE, (0, 0)), None)
self.assertEqual(coords.to_sgf(None), '')
self.assertEqual(coords.to_flat(utils_test.BOARD_SIZE, None), 81)
self.assertEqual(coords.to_kgs(utils_test.BOARD_SIZE, None), 'pass')
self.assertEqual(coords.to_pygtp(utils_test.BOARD_SIZE, None), (0, 0))
def test_parsing_9x9(self):
self.assertEqual(coords.from_sgf('aa'), (0, 0))
self.assertEqual(coords.from_sgf('ac'), (2, 0))
self.assertEqual(coords.from_sgf('ca'), (0, 2))
self.assertEqual(coords.from_sgf(''), None)
self.assertEqual(coords.to_sgf(None), '')
self.assertEqual('aa', coords.to_sgf(coords.from_sgf('aa')))
self.assertEqual('sa', coords.to_sgf(coords.from_sgf('sa')))
self.assertEqual((1, 17), coords.from_sgf(coords.to_sgf((1, 17))))
self.assertEqual(coords.from_kgs(utils_test.BOARD_SIZE, 'A1'), (8, 0))
self.assertEqual(coords.from_kgs(utils_test.BOARD_SIZE, 'A9'), (0, 0))
self.assertEqual(coords.from_kgs(utils_test.BOARD_SIZE, 'C2'), (7, 2))
self.assertEqual(coords.from_kgs(utils_test.BOARD_SIZE, 'J2'), (7, 8))
self.assertEqual(coords.from_pygtp(utils_test.BOARD_SIZE, (1, 1)), (8, 0))
self.assertEqual(coords.from_pygtp(utils_test.BOARD_SIZE, (1, 9)), (0, 0))
self.assertEqual(coords.from_pygtp(utils_test.BOARD_SIZE, (3, 2)), (7, 2))
self.assertEqual(coords.to_pygtp(utils_test.BOARD_SIZE, (8, 0)), (1, 1))
self.assertEqual(coords.to_pygtp(utils_test.BOARD_SIZE, (0, 0)), (1, 9))
self.assertEqual(coords.to_pygtp(utils_test.BOARD_SIZE, (7, 2)), (3, 2))
self.assertEqual(coords.to_kgs(utils_test.BOARD_SIZE, (0, 8)), 'J9')
self.assertEqual(coords.to_kgs(utils_test.BOARD_SIZE, (8, 0)), 'A1')
def test_flatten(self):
self.assertEqual(coords.to_flat(utils_test.BOARD_SIZE, (0, 0)), 0)
self.assertEqual(coords.to_flat(utils_test.BOARD_SIZE, (0, 3)), 3)
self.assertEqual(coords.to_flat(utils_test.BOARD_SIZE, (3, 0)), 27)
self.assertEqual(coords.from_flat(utils_test.BOARD_SIZE, 27), (3, 0))
self.assertEqual(coords.from_flat(utils_test.BOARD_SIZE, 10), (1, 1))
self.assertEqual(coords.from_flat(utils_test.BOARD_SIZE, 80), (8, 8))
self.assertEqual(coords.to_flat(
utils_test.BOARD_SIZE, coords.from_flat(utils_test.BOARD_SIZE, 10)), 10)
self.assertEqual(coords.from_flat(
utils_test.BOARD_SIZE, coords.to_flat(
utils_test.BOARD_SIZE, (5, 4))), (5, 4))
def test_from_flat_ndindex_equivalence(self):
ndindices = list(numpy.ndindex(
utils_test.BOARD_SIZE, utils_test.BOARD_SIZE))
flat_coords = list(range(
utils_test.BOARD_SIZE * utils_test.BOARD_SIZE))
def _from_flat(flat_coords):
return coords.from_flat(utils_test.BOARD_SIZE, flat_coords)
self.assertEqual(
list(map(_from_flat, flat_coords)), ndindices)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dualnet and dualnet_model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import tensorflow as tf # pylint: disable=g-bad-import-order
import dualnet
import go
import model_params
import preprocessing
import utils_test
tf.logging.set_verbosity(tf.logging.ERROR)
class TestDualNet(utils_test.MiniGoUnitTest):
def test_train(self):
with tempfile.TemporaryDirectory() as working_dir, \
tempfile.NamedTemporaryFile() as tf_record:
preprocessing.make_dataset_from_sgf(
utils_test.BOARD_SIZE, 'example_game.sgf', tf_record.name)
dualnet.train(
working_dir, [tf_record.name], 1, model_params.DummyMiniGoParams())
def test_inference(self):
with tempfile.TemporaryDirectory() as working_dir, \
tempfile.TemporaryDirectory() as export_dir:
dualnet.bootstrap(working_dir, model_params.DummyMiniGoParams())
exported_model = os.path.join(export_dir, 'bootstrap-model')
dualnet.export_model(working_dir, exported_model)
n1 = dualnet.DualNetRunner(
exported_model, model_params.DummyMiniGoParams())
n1.run(go.Position(utils_test.BOARD_SIZE))
n2 = dualnet.DualNetRunner(
exported_model, model_params.DummyMiniGoParams())
n2.run(go.Position(utils_test.BOARD_SIZE))
if __name__ == '__main__':
tf.test.main()
(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]
RU[Japanese]SZ[9]KM[0.00]
PW[White]PB[Black]RE[B+4.00]
;B[de]
;W[fe]
;B[ee]
;W[fd]
;B[ff]
;W[gf]
;B[gg]
;W[fg]
;B[ef]
;W[gh]
;B[hg]
;W[hh]
;B[eg]
;W[fh]
;B[ge]
;W[hf]
;B[he]
;W[ig]
;B[fc]
;W[gd]
;B[gc]
;W[hd]
;B[ed]
;W[be]
;B[hc]
;W[ie]
;B[bc]
;W[cg]
;B[cf]
;W[bf]
;B[ch]
(;W[dg]
;B[dh]
;W[bh]
;B[eh]
;W[cc]
;B[cb])
(;W[cc]
;B[cb]
(;W[bh]
;B[dh])
(;W[dg]
;B[dh]
;W[bh]
;B[eh]
;W[dc]
;B[bd]
;W[ec]
;B[cd]
;W[fb]
;B[gb]
(;W[db])
(;W[bb]
;B[eb]
;W[db]
;B[fa]
;W[ca]
;B[ea]
;W[da]
;B[df]
;W[bg]
;B[bi]
;W[ab]
;B[ah]
;W[ci]
;B[di]
;W[ag]
;B[ae]
;W[ac]
;B[ad]
;W[ha]
;B[hb]
;W[fi]
;B[ce]
;W[ai]
;B[ci]
;W[ei]
;B[ah]
;W[ic]
;B[ib]
;W[ai]
;B[ba]
;W[aa]
;B[ah]
;W[ga]
;B[ia]
;W[ai]
;B[ga]
;W[id]
;B[ah]
;W[dd]
;B[af]TW[ba][cb][ge][he][if][gg][hg][ih][gi][hi][ii]TB[ha][fb][be][bf][ag][bg][cg][dg][bh][ai]))))
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for features."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
import features
import go
import numpy as np
import utils_test
tf.logging.set_verbosity(tf.logging.ERROR)
EMPTY_ROW = '.' * utils_test.BOARD_SIZE + '\n'
TEST_BOARD = utils_test.load_board('''
.X.....OO
X........
XXXXXXXXX
''' + EMPTY_ROW * 6)
TEST_POSITION = go.Position(
utils_test.BOARD_SIZE,
board=TEST_BOARD,
n=3,
komi=6.5,
caps=(1, 2),
ko=None,
recent=(go.PlayerMove(go.BLACK, (0, 1)),
go.PlayerMove(go.WHITE, (0, 8)),
go.PlayerMove(go.BLACK, (1, 0))),
to_play=go.BLACK,
)
TEST_BOARD2 = utils_test.load_board('''
.XOXXOO..
XO.OXOX..
XXO..X...
''' + EMPTY_ROW * 6)
TEST_POSITION2 = go.Position(
utils_test.BOARD_SIZE,
board=TEST_BOARD2,
n=0,
komi=6.5,
caps=(0, 0),
ko=None,
recent=tuple(),
to_play=go.BLACK,
)
TEST_POSITION3 = go.Position(utils_test.BOARD_SIZE)
for coord in ((0, 0), (0, 1), (0, 2), (0, 3), (1, 1)):
TEST_POSITION3.play_move(coord, mutate=True)
# resulting position should look like this:
# X.XO.....
# .X.......
# .........
class TestFeatureExtraction(utils_test.MiniGoUnitTest):
def test_stone_features(self):
f = features.stone_features(utils_test.BOARD_SIZE, TEST_POSITION3)
self.assertEqual(TEST_POSITION3.to_play, go.WHITE)
self.assertEqual(f.shape, (9, 9, 16))
self.assertEqualNPArray(f[:, :, 0], utils_test.load_board('''
...X.....
.........''' + EMPTY_ROW * 7))
self.assertEqualNPArray(f[:, :, 1], utils_test.load_board('''
X.X......
.X.......''' + EMPTY_ROW * 7))
self.assertEqualNPArray(f[:, :, 2], utils_test.load_board('''
.X.X.....
.........''' + EMPTY_ROW * 7))
self.assertEqualNPArray(f[:, :, 3], utils_test.load_board('''
X.X......
.........''' + EMPTY_ROW * 7))
self.assertEqualNPArray(f[:, :, 4], utils_test.load_board('''
.X.......
.........''' + EMPTY_ROW * 7))
self.assertEqualNPArray(f[:, :, 5], utils_test.load_board('''
X.X......
.........''' + EMPTY_ROW * 7))
for i in range(10, 16):
self.assertEqualNPArray(
f[:, :, i], np.zeros([utils_test.BOARD_SIZE, utils_test.BOARD_SIZE]))
if __name__ == '__main__':
tf.test.main()
......@@ -43,11 +43,13 @@ MISSING_GROUP_ID = -1
BLACK_NAME = 'BLACK'
WHITE_NAME = 'WHITE'
def _check_bounds(board_size, c):
return c[0] % board_size == c[0] and c[1] % board_size == c[1]
def get_neighbors_diagonals(board_size):
"""Return coordinates of neighbors and diagonals for a go board."""
all_coords = [(i, j) for i in range(board_size) for j in range(board_size)]
def check_bounds(c):
return _check_bounds(board_size, c)
......@@ -80,17 +82,13 @@ def place_stones(board, color, stones):
def replay_position(board_size, position, result):
"""Wrapper for a go.Position which replays its history.
Assumes an empty start position! (i.e. no handicap, and history must
be exhaustive.)
Result must be passed in, since a resign cannot be inferred from position
history alone.
for position_w_context in replay_position(position):
print(position_w_context.position)
"""
"""Wrapper for a go.Position which replays its history."""
# Assumes an empty start position! (i.e. no handicap, and history must
# be exhaustive.)
# Result must be passed in, since a resign cannot be inferred from position
# history alone.
# for position_w_context in replay_position(position):
# print(position_w_context.position)
if position.n != len(position.recent):
raise ValueError('Position history is incomplete!')
pos = Position(board_size=board_size, komi=position.komi)
......@@ -101,6 +99,7 @@ def replay_position(board_size, position, result):
def find_reached(board_size, board, c):
"""Find the chain to reach c."""
color = board[c]
chain = set([c])
reached = set()
......@@ -138,11 +137,12 @@ def is_eyeish(board_size, board, c):
if color is None:
return None
diagonal_faults = 0
_, diagonals = get_neighbors_diagonals[c]
_, all_diagonals = get_neighbors_diagonals(board_size)
diagonals = all_diagonals[c]
if len(diagonals) < 4:
diagonal_faults += 1
for d in diagonals:
if not board[d] in (color, EMPTY):
if board[d] not in (color, EMPTY):
diagonal_faults += 1
if diagonal_faults > 1:
return None
......@@ -151,7 +151,8 @@ def is_eyeish(board_size, board, c):
class Group(namedtuple('Group', ['id', 'stones', 'liberties', 'color'])):
"""
"""Group class.
stones: a frozenset of Coordinates belonging to this group
liberties: a frozenset of Coordinates that are empty and adjacent to
this group.
......@@ -164,6 +165,7 @@ class Group(namedtuple('Group', ['id', 'stones', 'liberties', 'color'])):
class LibertyTracker(object):
"""LibertyTracker class."""
@staticmethod
def from_board(board_size, board):
......@@ -201,15 +203,16 @@ class LibertyTracker(object):
# groups: a dict of group_id to groups
# liberty_cache: a NxN numpy array of liberty counts
self.board_size = board_size
self.group_index = group_index if group_index is not None else - \
np.ones([board_size, board_size], dtype=np.int32)
self.group_index = (group_index if group_index is not None else
-np.ones([board_size, board_size], dtype=np.int32))
self.groups = groups or {}
self.liberty_cache = liberty_cache if liberty_cache is not None else - \
np.zeros([board_size, board_size], dtype=np.uint8)
self.liberty_cache = (
liberty_cache if liberty_cache is not None
else -np.zeros([board_size, board_size], dtype=np.uint8))
self.max_group_id = max_group_id
self.neighbors, _ = get_neighbors_diagonals(board_size)
def __deepcopy__(self, memodict={}):
def __deepcopy__(self, memodict=None):
new_group_index = np.copy(self.group_index)
new_lib_cache = np.copy(self.liberty_cache)
# shallow copy
......@@ -254,7 +257,7 @@ class LibertyTracker(object):
self._handle_captures(captured_stones)
# suicide is illegal
if len(self.groups[new_group.id].liberties) == 0:
if self.groups[new_group.id].liberties is None:
raise IllegalMove('Move at {} would commit suicide!\n'.format(c))
return captured_stones
......@@ -313,24 +316,27 @@ class Position(object):
def __init__(self, board_size, board=None, n=0, komi=7.5, caps=(0, 0),
lib_tracker=None, ko=None, recent=tuple(),
board_deltas=None, to_play=BLACK):
"""Initialize position class.
Args:
board_size: the go board size.
board: a numpy array
n: an int representing moves played so far
komi: a float, representing points given to the second player.
caps: a (int, int) tuple of captures for B, W.
lib_tracker: a LibertyTracker object
ko: a Move
recent: a tuple of PlayerMoves, such that recent[-1] is the last move.
board_deltas: a np.array of shape (n, go.N, go.N) representing changes
made to the board at each move (played move and captures).
Should satisfy next_pos.board - next_pos.board_deltas[0] == pos.board
to_play: BLACK or WHITE
"""
board_size: the go board size.
board: a numpy array
n: an int representing moves played so far
komi: a float, representing points given to the second player.
caps: a (int, int) tuple of captures for B, W.
lib_tracker: a LibertyTracker object
ko: a Move
recent: a tuple of PlayerMoves, such that recent[-1] is the last move.
board_deltas: a np.array of shape (n, go.N, go.N) representing changes
made to the board at each move (played move and captures).
Should satisfy next_pos.board - next_pos.board_deltas[0] == pos.board
to_play: BLACK or WHITE
"""
assert type(recent) is tuple
if not isinstance(recent, tuple):
raise TypeError('Recent must be a tuple!')
self.board_size = board_size
self.board = board if board is not None else - \
np.zeros([board_size, board_size], dtype=np.int8)
self.board = (board if board is not None else
-np.zeros([board_size, board_size], dtype=np.int8))
self.n = n
self.komi = komi
self.caps = caps
......@@ -338,13 +344,13 @@ class Position(object):
self.board_size, self.board)
self.ko = ko
self.recent = recent
self.board_deltas = board_deltas if board_deltas is not None else - \
np.zeros([0, board_size, board_size], dtype=np.int8)
self.board_deltas = (board_deltas if board_deltas is not None else
-np.zeros([0, board_size, board_size], dtype=np.int8))
self.to_play = to_play
self.last_eight = None
self.neighbors, _ = get_neighbors_diagonals(board_size)
def __deepcopy__(self, memodict={}):
def __deepcopy__(self, memodict=None):
new_board = np.copy(self.board)
new_lib_tracker = copy.deepcopy(self.lib_tracker)
return Position(
......@@ -465,10 +471,24 @@ class Position(object):
return self.lib_tracker.liberty_cache
def play_move(self, c, color=None, mutate=False):
# Obeys CGOS Rules of Play. In short:
# No suicides
# Chinese/area scoring
# Positional superko (this is very crudely approximate at the moment.)
"""Obeys CGOS Rules of Play.
In short:
No suicides
Chinese/area scoring
Positional superko (this is very crudely approximate at the moment.)
Args:
c: the coordinate to play from.
color: the color of the player to play.
mutate:
Returns:
The position of next move.
Raises:
IllegalMove: if the input c is an illegal move.
"""
if color is None:
color = self.to_play
......@@ -481,7 +501,7 @@ class Position(object):
if not self.is_move_legal(c):
raise IllegalMove('{} move at {} is illegal: \n{}'.format(
'Black' if self.to_play == BLACK else 'White',
coords.to_kgs(c), self))
coords.to_kgs(self.board_size, c), self))
potential_ko = is_koish(self.board_size, self.board, c)
......@@ -489,7 +509,7 @@ class Position(object):
captured_stones = pos.lib_tracker.add_stone(color, c)
place_stones(pos.board, EMPTY, captured_stones)
opp_color = color * -1
opp_color = -1 * color
new_board_delta = np.zeros([self.board_size, self.board_size],
dtype=np.int8)
......@@ -532,8 +552,8 @@ class Position(object):
c = unassigned_spaces[0][0], unassigned_spaces[1][0]
territory, borders = find_reached(self.board_size, working_board, c)
border_colors = set(working_board[b] for b in borders)
X_border = BLACK in border_colors
O_border = WHITE in border_colors
X_border = BLACK in border_colors # pylint: disable=invalid-name
O_border = WHITE in border_colors # pylint: disable=invalid-name
if X_border and not O_border:
territory_color = BLACK
elif O_border and not X_border:
......
This diff is collapsed.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import itertools
import sys
import coords
import go
......@@ -28,40 +28,40 @@ import sgf_wrapper
def parse_message(message):
message = gtp.pre_engine(message).strip()
first, rest = (message.split(" ", 1) + [None])[:2]
first, rest = (message.split(' ', 1) + [None])[:2]
if first.isdigit():
message_id = int(first)
if rest is not None:
command, arguments = (rest.split(" ", 1) + [None])[:2]
command, arguments = (rest.split(' ', 1) + [None])[:2]
else:
command, arguments = None, None
else:
message_id = None
command, arguments = first, rest
command = command.replace("-", "_") # for kgs extensions.
command = command.replace('-', '_') # for kgs extensions.
return message_id, command, arguments
class KgsExtensionsMixin(gtp.Engine):
def __init__(self, game_obj, name="gtp (python, kgs-chat extensions)",
version="0.1"):
def __init__(self, game_obj, name='gtp (python, kgs-chat extensions)',
version='0.1'):
super().__init__(game_obj=game_obj, name=name, version=version)
self.known_commands += ["kgs-chat"]
self.known_commands += ['kgs-chat']
def send(self, message):
message_id, command, arguments = parse_message(message)
if command in self.known_commands:
try:
retval = getattr(self, "cmd_" + command)(arguments)
retval = getattr(self, 'cmd_' + command)(arguments)
response = gtp.format_success(message_id, retval)
sys.stderr.flush()
return response
except ValueError as exception:
return gtp.format_error(message_id, exception.args[0])
else:
return gtp.format_error(message_id, "unknown command: " + command)
return gtp.format_error(message_id, 'unknown command: ' + command)
# Nice to implement this, as KGS sends it each move.
def cmd_time_left(self, arguments):
......@@ -74,19 +74,20 @@ class KgsExtensionsMixin(gtp.Engine):
try:
arg_list = arguments.split()
msg_type, sender, text = arg_list[0], arg_list[1], arg_list[2:]
text = " ".join(text)
text = ' '.join(text)
except ValueError:
return "Unparseable message, args: %r" % arguments
return 'Unparseable message, args: %r' % arguments
return self._game.chat(msg_type, sender, text)
class RegressionsMixin(gtp.Engine):
def cmd_loadsgf(self, arguments):
args = arguments.split()
if len(args) == 2:
file_, movenum = args
movenum = int(movenum)
print("movenum =", movenum, file=sys.stderr)
print('movenum =', movenum, file=sys.stderr)
else:
file_ = args[0]
movenum = None
......@@ -95,7 +96,7 @@ class RegressionsMixin(gtp.Engine):
with open(file_, 'r') as f:
contents = f.read()
except:
raise ValueError("Unreadable file: " + file_)
raise ValueError('Unreadable file: ' + file_)
try:
# This is kinda bad, because replay_sgf is already calling
......@@ -103,7 +104,7 @@ class RegressionsMixin(gtp.Engine):
# want to advance the engine along with us rather than try to
# push in some finished Position object.
for idx, p in enumerate(sgf_wrapper.replay_sgf(contents)):
print("playing #", idx, p.next_move, file=sys.stderr)
print('playing #', idx, p.next_move, file=sys.stderr)
self._game.play_move(p.next_move)
if movenum and idx == movenum:
break
......@@ -112,23 +113,25 @@ class RegressionsMixin(gtp.Engine):
class GoGuiMixin(gtp.Engine):
""" GTP extensions of 'analysis commands' for gogui.
"""GTP extensions of 'analysis commands' for gogui.
We reach into the game_obj (an instance of the players in strategies.py),
and extract stuff from its root nodes, etc. These could be extracted into
methods on the Player object, but its a little weird to do that on a Player,
which doesn't really care about GTP commands, etc. So instead, we just
violate encapsulation a bit... Suggestions welcome :) """
violate encapsulation a bit.
"""
def __init__(self, game_obj, name="gtp (python, gogui extensions)",
version="0.1"):
def __init__(self, game_obj, name='gtp (python, gogui extensions)',
version='0.1'):
super().__init__(game_obj=game_obj, name=name, version=version)
self.known_commands += ["gogui-analyze_commands"]
self.known_commands += ['gogui-analyze_commands']
def cmd_gogui_analyze_commands(self, arguments):
return "\n".join(["var/Most Read Variation/nextplay",
"var/Think a spell/spin",
"pspairs/Visit Heatmap/visit_heatmap",
"pspairs/Q Heatmap/q_heatmap"])
return '\n'.join(['var/Most Read Variation/nextplay',
'var/Think a spell/spin',
'pspairs/Visit Heatmap/visit_heatmap',
'pspairs/Q Heatmap/q_heatmap'])
def cmd_nextplay(self, arguments):
return self._game.root.mvp_gg()
......@@ -142,27 +145,26 @@ class GoGuiMixin(gtp.Engine):
sort_order = list(range(self._game.size * self._game.size + 1))
reverse = True if self._game.root.position.to_play is go.BLACK else False
sort_order.sort(
key=lambda i: self._game.root.child_Q[i], reverse=reverse)
key=lambda i: self._game.root.child_Q[i], reverse=reverse)
return self.heatmap(sort_order, self._game.root, 'child_Q')
def heatmap(self, sort_order, node, prop):
return "\n".join(["{!s:6} {}".format(
coords.to_kgs(coords.from_flat(key)),
node.__dict__.get(prop)[key])
for key in sort_order if node.child_N[key] > 0][:20])
return '\n'.join(['{!s:6} {}'.format(
coords.to_kgs(coords.from_flat(key)), node.__dict__.get(prop)[key])
for key in sort_order if node.child_N[key] > 0][:20])
def cmd_spin(self, arguments):
for i in range(50):
for j in range(100):
for _ in range(50):
for _ in range(100):
self._game.tree_search()
moves = self.cmd_nextplay(None).lower()
moves = moves.split()
colors = "bw" if self._game.root.position.to_play is go.BLACK else "wb"
moves_cols = " ".join(['{} {}'.format(*z)
for z in zip(itertools.cycle(colors), moves)])
print("gogui-gfx: TEXT", "{:.3f} after {}".format(
self._game.root.Q, self._game.root.N), file=sys.stderr, flush=True)
print("gogui-gfx: VAR", moves_cols, file=sys.stderr, flush=True)
colors = 'bw' if self._game.root.position.to_play is go.BLACK else 'wb'
moves_cols = ' '.join(['{} {}'.format(*z)
for z in zip(itertools.cycle(colors), moves)])
print('gogui-gfx: TEXT', '{:.3f} after {}'.format(
self._game.root.Q, self._game.root.N), file=sys.stderr, flush=True)
print('gogui-gfx: VAR', moves_cols, file=sys.stderr, flush=True)
return self.cmd_nextplay(None)
......
......@@ -39,6 +39,7 @@ def translate_gtp_colors(gtp_color):
class GtpInterface(object):
def __init__(self, board_size):
self.size = 9
self.position = None
......@@ -47,9 +48,9 @@ class GtpInterface(object):
def set_size(self, n):
if n != self.board_size:
raise ValueError(
("Can't handle boardsize {n}!"
"Restart with env var BOARD_SIZE={n}").format(n=n))
raise ValueError((
'''Can't handle boardsize {n}!Restart with env var BOARD_SIZE={n}'''
).format(n=n))
def set_komi(self, komi):
self.komi = komi
......@@ -60,17 +61,17 @@ class GtpInterface(object):
try:
sgf = self.to_sgf()
with open(datetime.datetime.now().strftime(
"%Y-%m-%d-%H:%M.sgf"), 'w') as f:
'%Y-%m-%d-%H:%M.sgf'), 'w') as f:
f.write(sgf)
except NotImplementedError:
pass
except:
print("Error saving sgf", file=sys.stderr, flush=True)
print('Error saving sgf', file=sys.stderr, flush=True)
self.position = go.Position(komi=self.komi)
self.initialize_game(self.position)
def accomodate_out_of_turn(self, color):
if not translate_gtp_colors(color) == self.position.to_play:
if translate_gtp_colors(color) != self.position.to_play:
self.position.flip_playerturn(mutate=True)
def make_move(self, color, vertex):
......@@ -131,10 +132,10 @@ def make_gtp_instance(board_size, read_file, readouts_per_move=100,
gtp_engine = gtp.Engine(instance)
if cgos_mode:
instance = CGOSPlayer(board_size, n, seconds_per_move=5,
verbosity=verbosity, two_player_mode=True)
verbosity=verbosity, two_player_mode=True)
else:
instance = MCTSPlayer(board_size, n, simulations_per_move=readouts_per_move,
verbosity=verbosity, two_player_mode=True)
name = "Somebot-" + os.path.basename(read_file)
name = 'Somebot-' + os.path.basename(read_file)
gtp_engine = gtp_extensions.GTPDeluxe(instance, name=name)
return gtp_engine
......@@ -33,9 +33,12 @@ import coords
import numpy as np
# Exploration constant
c_PUCT = 1.38
c_PUCT = 1.38 # pylint: disable=invalid-name
# Dirichlet noise, as a function of board_size
def D_NOISE_ALPHA(board_size): return 0.03 * 361 / (board_size ** 2)
def D_NOISE_ALPHA(board_size): # pylint: disable=invalid-name
return 0.03 * 361 / (board_size ** 2)
class DummyNode(object):
......@@ -43,8 +46,10 @@ class DummyNode(object):
This node is intended to be a placeholder for the root node, which would
otherwise have no parent node. If all nodes have parents, code becomes
simpler."""
simpler.
"""
# pylint: disable=invalid-name
def __init__(self, board_size):
self.board_size = board_size
self.parent = None
......@@ -64,6 +69,7 @@ class MCTSNode(object):
(raw number between 0-N^2, with None a pass)
parent: A parent MCTSNode.
"""
# pylint: disable=invalid-name
def __init__(self, board_size, position, fmove=None, parent=None):
if parent is None:
......@@ -85,7 +91,7 @@ class MCTSNode(object):
self.children = {} # map of flattened moves to resulting MCTSNode
def __repr__(self):
return "<MCTSNode move=%s, N=%s, to_play=%s>" % (
return '<MCTSNode move={}, N={}, to_play={}>'.format(
self.position.recent[-1:], self.N, self.position.to_play)
@property
......@@ -124,7 +130,7 @@ class MCTSNode(object):
@property
def Q_perspective(self):
"Return value of position, from perspective of player to play."
"""Return value of position, from perspective of player to play."""
return self.Q * self.position.to_play
def select_leaf(self):
......@@ -174,22 +180,21 @@ class MCTSNode(object):
def revert_virtual_loss(self, up_to):
self.losses_applied -= 1
revert = -1 * self.position.to_play
revert = -self.position.to_play
self.W += revert
if self.parent is None or self is up_to:
return
self.parent.revert_virtual_loss(up_to)
def revert_visits(self, up_to):
"""Revert visit increments.
Sometimes, repeated calls to select_leaf return the same node.
This is rare and we're okay with the wasted computation to evaluate
the position multiple times by the dual_net. But select_leaf has the
side effect of incrementing visit counts. Since we want the value to
only count once for the repeatedly selected node, we also have to
revert the incremented visit counts.
"""
"""Revert visit increments."""
# Sometimes, repeated calls to select_leaf return the same node.
# This is rare and we're okay with the wasted computation to evaluate
# the position multiple times by the dual_net. But select_leaf has the
# side effect of incrementing visit counts. Since we want the value to
# only count once for the repeatedly selected node, we also have to
# revert the incremented visit counts.
self.N -= 1
if self.parent is None or self is up_to:
return
......@@ -231,9 +236,9 @@ class MCTSNode(object):
self.parent.backup_value(value, up_to)
def is_done(self):
'''True if the last two moves were Pass or if the position is at a move
greater than the max depth.
'''
# True if the last two moves were Pass or if the position is at a move
# greater than the max depth.
max_depth = (self.board_size ** 2) * 1.4 # 505 moves for 19x19, 113 for 9x9
return self.position.is_game_over() or self.position.n >= max_depth
......@@ -243,14 +248,14 @@ class MCTSNode(object):
self.child_prior = self.child_prior * 0.75 + dirch * 0.25
def children_as_pi(self, squash=False):
"""Returns the child visit counts as a probability distribution, pi
If squash is true, exponentiate the probabilities by a temperature
slightly larger than unity to encourage diversity in early play and
hopefully to move away from 3-3s
"""
"""Returns the child visit counts as a probability distribution, pi."""
# If squash is true, exponentiate the probabilities by a temperature
# slightly larger than unity to encourage diversity in early play and
# hopefully to move away from 3-3s
probs = self.child_N
if squash:
probs = probs ** .95
probs **= .95
return probs / np.sum(probs)
def most_visited_path(self):
......@@ -260,12 +265,13 @@ class MCTSNode(object):
next_kid = np.argmax(node.child_N)
node = node.children.get(next_kid)
if node is None:
output.append("GAME END")
output.append('GAME END')
break
output.append("%s (%d) ==> " % (
coords.to_kgs(self.board_size,
coords.from_flat(self.board_size, node.fmove)), node.N))
output.append("Q: {:.5f}\n".format(node.Q))
output.append('{} ({}) ==> '.format(
coords.to_kgs(
self.board_size,
coords.from_flat(self.board_size, node.fmove)), node.N))
output.append('Q: {:.5f}\n'.format(node.Q))
return ''.join(output)
def mvp_gg(self):
......@@ -275,8 +281,8 @@ class MCTSNode(object):
while node.children and max(node.child_N) > 1:
next_kid = np.argmax(node.child_N)
node = node.children[next_kid]
output.append("%s" % coords.to_kgs(
self.board_size, coords.from_flat(self.board_size, node.fmove)))
output.append('{}'.format(coords.to_kgs(
self.board_size, coords.from_flat(self.board_size, node.fmove))))
return ' '.join(output)
def describe(self):
......@@ -288,22 +294,25 @@ class MCTSNode(object):
p_rel = p_delta / self.child_prior
# Dump out some statistics
output = []
output.append("{q:.4f}\n".format(q=self.Q))
output.append('{q:.4f}\n'.format(q=self.Q))
output.append(self.most_visited_path())
output.append(
"move: action Q U P P-Dir N soft-N" +
" p-delta p-rel\n")
'''move: action Q U P P-Dir N soft-N
p-delta p-rel\n''')
output.append(
"\n".join(["{!s:6}: {: .3f}, {: .3f}, {:.3f}, {:.3f}, {:.3f}, {:4d} {:.4f} {: .5f} {: .2f}".format(
coords.to_kgs(self.board_size, coords.from_flat(self.board_size, key)),
self.child_action_score[key],
self.child_Q[key],
self.child_U[key],
self.child_prior[key],
self.original_prior[key],
int(self.child_N[key]),
soft_n[key],
p_delta[key],
p_rel[key])
for key in sort_order][:15]))
return "".join(output)
'\n'.join([
'''{!s:6}: {: .3f}, {: .3f}, {:.3f}, {:.3f}, {:.3f}, {:4d} {:.4f}
{: .5f} {: .2f}'''.format(
coords.to_kgs(self.board_size, coords.from_flat(
self.board_size, key)),
self.child_action_score[key],
self.child_Q[key],
self.child_U[key],
self.child_prior[key],
self.original_prior[key],
int(self.child_N[key]),
soft_n[key],
p_delta[key],
p_rel[key])
for key in sort_order][:15]))
return ''.join(output)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for mcts."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import tensorflow as tf # pylint: disable=g-bad-import-order
import coords
import go
from mcts import MCTSNode
import numpy as np
import utils_test
tf.logging.set_verbosity(tf.logging.ERROR)
ALMOST_DONE_BOARD = utils_test.load_board('''
.XO.XO.OO
X.XXOOOO.
XXXXXOOOO
XXXXXOOOO
.XXXXOOO.
XXXXXOOOO
.XXXXOOO.
XXXXXOOOO
XXXXOOOOO
''')
TEST_POSITION = go.Position(
utils_test.BOARD_SIZE,
board=ALMOST_DONE_BOARD,
n=105,
komi=2.5,
caps=(1, 4),
ko=None,
recent=(go.PlayerMove(go.BLACK, (0, 1)),
go.PlayerMove(go.WHITE, (0, 8))),
to_play=go.BLACK
)
SEND_TWO_RETURN_ONE = go.Position(
utils_test.BOARD_SIZE,
board=ALMOST_DONE_BOARD,
n=75,
komi=0.5,
caps=(0, 0),
ko=None,
recent=(
go.PlayerMove(go.BLACK, (0, 1)),
go.PlayerMove(go.WHITE, (0, 8)),
go.PlayerMove(go.BLACK, (1, 0))),
to_play=go.WHITE
)
MAX_DEPTH = (utils_test.BOARD_SIZE ** 2) * 1.4
class TestMctsNodes(utils_test.MiniGoUnitTest):
def test_action_flipping(self):
np.random.seed(1)
probs = np.array([.02] * (
utils_test.BOARD_SIZE * utils_test.BOARD_SIZE + 1))
probs += np.random.random(
[utils_test.BOARD_SIZE * utils_test.BOARD_SIZE + 1]) * 0.001
black_root = MCTSNode(
utils_test.BOARD_SIZE, go.Position(utils_test.BOARD_SIZE))
white_root = MCTSNode(utils_test.BOARD_SIZE, go.Position(
utils_test.BOARD_SIZE, to_play=go.WHITE))
black_root.select_leaf().incorporate_results(probs, 0, black_root)
white_root.select_leaf().incorporate_results(probs, 0, white_root)
# No matter who is to play, when we know nothing else, the priors
# should be respected, and the same move should be picked
black_leaf = black_root.select_leaf()
white_leaf = white_root.select_leaf()
self.assertEqual(black_leaf.fmove, white_leaf.fmove)
self.assertEqualNPArray(
black_root.child_action_score, white_root.child_action_score)
def test_select_leaf(self):
flattened = coords.to_flat(utils_test.BOARD_SIZE, coords.from_kgs(
utils_test.BOARD_SIZE, 'D9'))
probs = np.array([.02] * (
utils_test.BOARD_SIZE * utils_test.BOARD_SIZE + 1))
probs[flattened] = 0.4
root = MCTSNode(utils_test.BOARD_SIZE, SEND_TWO_RETURN_ONE)
root.select_leaf().incorporate_results(probs, 0, root)
self.assertEqual(root.position.to_play, go.WHITE)
self.assertEqual(root.select_leaf(), root.children[flattened])
def test_backup_incorporate_results(self):
probs = np.array([.02] * (
utils_test.BOARD_SIZE * utils_test.BOARD_SIZE + 1))
root = MCTSNode(utils_test.BOARD_SIZE, SEND_TWO_RETURN_ONE)
root.select_leaf().incorporate_results(probs, 0, root)
leaf = root.select_leaf()
leaf.incorporate_results(probs, -1, root) # white wins!
# Root was visited twice: first at the root, then at this child.
self.assertEqual(root.N, 2)
# Root has 0 as a prior and two visits with value 0, -1
self.assertAlmostEqual(root.Q, -1/3) # average of 0, 0, -1
# Leaf should have one visit
self.assertEqual(root.child_N[leaf.fmove], 1)
self.assertEqual(leaf.N, 1)
# And that leaf's value had its parent's Q (0) as a prior, so the Q
# should now be the average of 0, -1
self.assertAlmostEqual(root.child_Q[leaf.fmove], -0.5)
self.assertAlmostEqual(leaf.Q, -0.5)
# We're assuming that select_leaf() returns a leaf like:
# root
# \
# leaf
# \
# leaf2
# which happens in this test because root is W to play and leaf was a W win.
self.assertEqual(root.position.to_play, go.WHITE)
leaf2 = root.select_leaf()
leaf2.incorporate_results(probs, -0.2, root) # another white semi-win
self.assertEqual(root.N, 3)
# average of 0, 0, -1, -0.2
self.assertAlmostEqual(root.Q, -0.3)
self.assertEqual(leaf.N, 2)
self.assertEqual(leaf2.N, 1)
# average of 0, -1, -0.2
self.assertAlmostEqual(leaf.Q, root.child_Q[leaf.fmove])
self.assertAlmostEqual(leaf.Q, -0.4)
# average of -1, -0.2
self.assertAlmostEqual(leaf.child_Q[leaf2.fmove], -0.6)
self.assertAlmostEqual(leaf2.Q, -0.6)
def test_do_not_explore_past_finish(self):
probs = np.array([0.02] * (
utils_test.BOARD_SIZE * utils_test.BOARD_SIZE + 1), dtype=np.float32)
root = MCTSNode(utils_test.BOARD_SIZE, go.Position(utils_test.BOARD_SIZE))
root.select_leaf().incorporate_results(probs, 0, root)
first_pass = root.maybe_add_child(
coords.to_flat(utils_test.BOARD_SIZE, None))
first_pass.incorporate_results(probs, 0, root)
second_pass = first_pass.maybe_add_child(
coords.to_flat(utils_test.BOARD_SIZE, None))
with self.assertRaises(AssertionError):
second_pass.incorporate_results(probs, 0, root)
node_to_explore = second_pass.select_leaf()
# should just stop exploring at the end position.
self.assertEqual(node_to_explore, second_pass)
def test_add_child(self):
root = MCTSNode(utils_test.BOARD_SIZE, go.Position(utils_test.BOARD_SIZE))
child = root.maybe_add_child(17)
self.assertIn(17, root.children)
self.assertEqual(child.parent, root)
self.assertEqual(child.fmove, 17)
def test_add_child_idempotency(self):
root = MCTSNode(utils_test.BOARD_SIZE, go.Position(utils_test.BOARD_SIZE))
child = root.maybe_add_child(17)
current_children = copy.copy(root.children)
child2 = root.maybe_add_child(17)
self.assertEqual(child, child2)
self.assertEqual(current_children, root.children)
def test_never_select_illegal_moves(self):
probs = np.array([0.02] * (
utils_test.BOARD_SIZE * utils_test.BOARD_SIZE + 1))
# let's say the NN were to accidentally put a high weight on an illegal move
probs[1] = 0.99
root = MCTSNode(utils_test.BOARD_SIZE, SEND_TWO_RETURN_ONE)
root.incorporate_results(probs, 0, root)
# and let's say the root were visited a lot of times, which pumps up the
# action score for unvisited moves...
root.N = 100000
root.child_N[root.position.all_legal_moves()] = 10000
# this should not throw an error...
leaf = root.select_leaf()
# the returned leaf should not be the illegal move
self.assertNotEqual(leaf.fmove, 1)
# and even after injecting noise, we should still not select an illegal move
for _ in range(10):
root.inject_noise()
leaf = root.select_leaf()
self.assertNotEqual(leaf.fmove, 1)
def test_dont_pick_unexpanded_child(self):
probs = np.array([0.001] * (
utils_test.BOARD_SIZE * utils_test.BOARD_SIZE + 1))
# make one move really likely so that tree search goes down that path twice
# even with a virtual loss
probs[17] = 0.999
root = MCTSNode(utils_test.BOARD_SIZE, go.Position(utils_test.BOARD_SIZE))
root.incorporate_results(probs, 0, root)
leaf1 = root.select_leaf()
self.assertEqual(leaf1.fmove, 17)
leaf1.add_virtual_loss(up_to=root)
# the second select_leaf pick should return the same thing, since the child
# hasn't yet been sent to neural net for eval + result incorporation
leaf2 = root.select_leaf()
self.assertIs(leaf1, leaf2)
if __name__ == '__main__':
tf.test.main()
......@@ -39,11 +39,16 @@ def _one_hot(board_size, index):
def make_tf_example(features, pi, value):
"""
"""Make tf examples.
Args:
features: [N, N, FEATURE_DIM] nparray of uint8
pi: [N * N + 1] nparray of float32
value: float
Returns:
tf example.
"""
return tf.train.Example(
features=tf.train.Features(
......@@ -57,9 +62,9 @@ def make_tf_example(features, pi, value):
}))
# Write tf.Example to files
def write_tf_examples(filename, tf_examples, serialize=True):
"""
"""Write tf.Example to files.
Args:
filename: Where to write tf.records
tf_examples: An iterable of tf.Example
......@@ -76,9 +81,13 @@ def write_tf_examples(filename, tf_examples, serialize=True):
# Read tf.Example from files
def _batch_parse_tf_example(board_size, batch_size, example_batch):
"""
"""Parse tf examples.
Args:
board_size: the go board size
batch_size: the batch size
example_batch: a batch of tf.Example
Returns:
A tuple (feature_tensor, dict of output tensors)
"""
......@@ -102,21 +111,20 @@ def _batch_parse_tf_example(board_size, batch_size, example_batch):
def read_tf_records(
shuffle_buffer_size, batch_size, tf_records, num_repeats=None,
shuffle_records=True, shuffle_examples=True, filter_amount=1.0):
"""
"""Read tf records.
Args:
shuffle_buffer_size: how big of a buffer to fill before shuffling
batch_size: batch size to return
tf_records: a list of tf_record filenames
num_repeats: how many times the data should be read (default: infinite)
shuffle_records: whether to shuffle the order of files read
shuffle_examples: whether to shuffle the tf.Examples
shuffle_buffer_size: how big of a buffer to fill before shuffling.
filter_amount: what fraction of records to keep
Returns:
a tf dataset of batched tensors
"""
if shuffle_buffer_size is None:
shuffle_buffer_size = params.shuffle_buffer_size
if shuffle_records:
random.shuffle(tf_records)
record_list = tf.data.Dataset.from_tensor_slices(tf_records)
......@@ -130,8 +138,8 @@ def read_tf_records(
dataset = record_list.interleave(
lambda x: tf.data.TFRecordDataset(x, compression_type='ZLIB'),
cycle_length=64, block_length=16)
dataset = dataset.filter(lambda x: tf.less(
tf.random_uniform([1]), filter_amount)[0])
dataset = dataset.filter(
lambda x: tf.less(tf.random_uniform([1]), filter_amount)[0])
# TODO(amj): apply py_func for transforms here.
if num_repeats is not None:
dataset = dataset.repeat(num_repeats)
......@@ -146,10 +154,19 @@ def read_tf_records(
def get_input_tensors(params, batch_size, tf_records, num_repeats=None,
shuffle_records=True, shuffle_examples=True,
filter_amount=0.05):
"""Read tf.Records and prepare them for ingestion by dual_net. See
`read_tf_records` for parameter documentation.
"""Read tf.Records and prepare them for ingestion by dualnet.
Args:
params: An object of hyperparameters
batch_size: batch size to return
tf_records: a list of tf_record filenames
num_repeats: how many times the data should be read (default: infinite)
shuffle_records: whether to shuffle the order of files read
shuffle_examples: whether to shuffle the tf.Examples
filter_amount: what fraction of records to keep
Returns a dict of tensors (see return value of batch_parse_tf_example)
Returns:
A dict of tensors (see return value of batch_parse_tf_example)
"""
shuffle_buffer_size = params.shuffle_buffer_size
dataset = read_tf_records(
......@@ -170,8 +187,10 @@ def make_dataset_from_selfplay(data_extracts, params):
Args:
data_extracts: An iterable of (position, pi, result) tuples
params: An object of hyperparameters
Returns an iterable of tf.Examples.
Returns:
An iterable of tf.Examples.
"""
board_size = params.board_size
tf_examples = (make_tf_example(features_lib.extract_features(
......@@ -190,7 +209,8 @@ def make_dataset_from_sgf(board_size, sgf_filename, tf_record):
def _make_tf_example_from_pwc(board_size, position_w_context):
features = features_lib.extract_features(
board_size, position_w_context.position)
pi = _one_hot(board_size, coords.to_flat(position_w_context.next_move))
pi = _one_hot(board_size, coords.to_flat(
board_size, position_w_context.next_move))
value = position_w_context.result
return make_tf_example(features, pi, value)
......@@ -203,7 +223,7 @@ def shuffle_tf_examples(shuffle_buffer_size, gather_size, records_to_shuffle):
gather_size: The number of tf.Examples to be gathered together
records_to_shuffle: A list of filenames
Returns:
Yields:
An iterator yielding lists of bytes, which are serialized tf.Examples.
"""
dataset = read_tf_records(shuffle_buffer_size, gather_size,
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for preprocessing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import tempfile
import tensorflow as tf # pylint: disable=g-bad-import-order
import coords
import features
import go
import model_params
import numpy as np
import preprocessing
import utils_test
tf.logging.set_verbosity(tf.logging.ERROR)
TEST_SGF = '''(;CA[UTF-8]SZ[9]PB[Murakawa Daisuke]PW[Iyama Yuta]KM[6.5]
HA[0]RE[W+1.5]GM[1];B[fd];W[cf])'''
class TestPreprocessing(utils_test.MiniGoUnitTest):
def create_random_data(self, num_examples):
raw_data = []
for _ in range(num_examples):
feature = np.random.random([
utils_test.BOARD_SIZE, utils_test.BOARD_SIZE,
features.NEW_FEATURES_PLANES]).astype(np.uint8)
pi = np.random.random([utils_test.BOARD_SIZE * utils_test.BOARD_SIZE
+ 1]).astype(np.float32)
value = np.random.random()
raw_data.append((feature, pi, value))
return raw_data
def extract_data(self, tf_record, filter_amount=1):
pos_tensor, label_tensors = preprocessing.get_input_tensors(
model_params.DummyMiniGoParams(), 1, [tf_record], num_repeats=1,
shuffle_records=False, shuffle_examples=False,
filter_amount=filter_amount)
recovered_data = []
with tf.Session() as sess:
while True:
try:
pos_value, label_values = sess.run([pos_tensor, label_tensors])
recovered_data.append((
pos_value,
label_values['pi_tensor'],
label_values['value_tensor']))
except tf.errors.OutOfRangeError:
break
return recovered_data
def assertEqualData(self, data1, data2):
# Assert that two data are equal, where both are of form:
# data = List<Tuple<feature_array, pi_array, value>>
self.assertEqual(len(data1), len(data2))
for datum1, datum2 in zip(data1, data2):
# feature
self.assertEqualNPArray(datum1[0], datum2[0])
# pi
self.assertEqualNPArray(datum1[1], datum2[1])
# value
self.assertEqual(datum1[2], datum2[2])
def test_serialize_round_trip(self):
np.random.seed(1)
raw_data = self.create_random_data(10)
tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))
with tempfile.NamedTemporaryFile() as f:
preprocessing.write_tf_examples(f.name, tfexamples)
recovered_data = self.extract_data(f.name)
self.assertEqualData(raw_data, recovered_data)
def test_filter(self):
raw_data = self.create_random_data(100)
tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))
with tempfile.NamedTemporaryFile() as f:
preprocessing.write_tf_examples(f.name, tfexamples)
recovered_data = self.extract_data(f.name, filter_amount=.05)
self.assertLess(len(recovered_data), 50)
def test_serialize_round_trip_no_parse(self):
np.random.seed(1)
raw_data = self.create_random_data(10)
tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))
with tempfile.NamedTemporaryFile() as start_file, \
tempfile.NamedTemporaryFile() as rewritten_file:
preprocessing.write_tf_examples(start_file.name, tfexamples)
# We want to test that the rewritten, shuffled file contains correctly
# serialized tf.Examples.
batch_size = 4
batches = list(preprocessing.shuffle_tf_examples(
1000, batch_size, [start_file.name]))
# 2 batches of 4, 1 incomplete batch of 2.
self.assertEqual(len(batches), 3)
# concatenate list of lists into one list
all_batches = list(itertools.chain.from_iterable(batches))
for _ in batches:
preprocessing.write_tf_examples(
rewritten_file.name, all_batches, serialize=False)
original_data = self.extract_data(start_file.name)
recovered_data = self.extract_data(rewritten_file.name)
# stuff is shuffled, so sort before checking equality
def sort_key(nparray_tuple):
return nparray_tuple[2]
original_data = sorted(original_data, key=sort_key)
recovered_data = sorted(recovered_data, key=sort_key)
self.assertEqualData(original_data, recovered_data)
def test_make_dataset_from_sgf(self):
with tempfile.NamedTemporaryFile() as sgf_file, \
tempfile.NamedTemporaryFile() as record_file:
sgf_file.write(TEST_SGF.encode('utf8'))
sgf_file.seek(0)
preprocessing.make_dataset_from_sgf(
utils_test.BOARD_SIZE, sgf_file.name, record_file.name)
recovered_data = self.extract_data(record_file.name)
start_pos = go.Position(utils_test.BOARD_SIZE)
first_move = coords.from_sgf('fd')
next_pos = start_pos.play_move(first_move)
second_move = coords.from_sgf('cf')
expected_data = [
(
features.extract_features(utils_test.BOARD_SIZE, start_pos),
preprocessing._one_hot(utils_test.BOARD_SIZE, coords.to_flat(
utils_test.BOARD_SIZE, first_move)), -1
),
(
features.extract_features(utils_test.BOARD_SIZE, next_pos),
preprocessing._one_hot(utils_test.BOARD_SIZE, coords.to_flat(
utils_test.BOARD_SIZE, second_move)), -1
)
]
self.assertEqualData(expected_data, recovered_data)
if __name__ == '__main__':
tf.test.main()
......@@ -24,8 +24,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import coords
import go
from go import Position, PositionWithContext
......@@ -33,49 +31,49 @@ import numpy as np
import sgf
import utils
SGF_TEMPLATE = """(;GM[1]FF[4]CA[UTF-8]AP[Minigo_sgfgenerator]RU[{ruleset}]
SGF_TEMPLATE = '''(;GM[1]FF[4]CA[UTF-8]AP[Minigo_sgfgenerator]RU[{ruleset}]
SZ[{boardsize}]KM[{komi}]PW[{white_name}]PB[{black_name}]RE[{result}]
{game_moves})"""
{game_moves})'''
PROGRAM_IDENTIFIER = "Minigo"
PROGRAM_IDENTIFIER = 'Minigo'
def translate_sgf_move_qs(player_move, q):
return "{move}C[{q:.4f}]".format(
move=translate_sgf_move(player_move), q=q)
return '{move}C[{q:.4f}]'.format(
move=translate_sgf_move(player_move), q=q)
def translate_sgf_move(player_move, comment):
if player_move.color not in (go.BLACK, go.WHITE):
raise ValueError("Can't translate color %s to sgf" % player_move.color)
raise ValueError(
'Can\'t translate color {} to sgf'.format(player_move.color))
c = coords.to_sgf(player_move.move)
color = 'B' if player_move.color == go.BLACK else 'W'
if comment is not None:
comment = comment.replace(']', r'\]')
comment_node = "C[{}]".format(comment)
comment_node = 'C[{}]'.format(comment)
else:
comment_node = ""
return ";{color}[{coords}]{comment_node}".format(
color=color, coords=c, comment_node=comment_node)
def make_sgf(board_size,
move_history,
result_string,
ruleset="Chinese",
komi=7.5,
white_name=PROGRAM_IDENTIFIER,
black_name=PROGRAM_IDENTIFIER,
comments=[]
):
comment_node = ''
return ';{color}[{coords}]{comment_node}'.format(
color=color, coords=c, comment_node=comment_node)
# pylint: disable=unused-argument
# pylint: disable=unused-variable
def make_sgf(board_size, move_history, result_string, ruleset='Chinese',
komi=7.5, white_name=PROGRAM_IDENTIFIER,
black_name=PROGRAM_IDENTIFIER, comments=[]):
"""Turn a game into SGF.
Doesn't handle handicap games or positions with incomplete history.
Args:
move_history: iterable of PlayerMoves
board_size: the go board size.
move_history: iterable of PlayerMoves.
result_string: "B+R", "W+0.5", etc.
ruleset: the rule set of go game
komi: komi score
white_name: the name of white player
black_name: the name of black player
comments: iterable of string/None. Will be zipped with move_history.
"""
try:
......@@ -87,14 +85,14 @@ def make_sgf(board_size,
from itertools import zip_longest
boardsize = board_size
game_moves = ''.join(translate_sgf_move(*z)
for z in zip_longest(move_history, comments))
game_moves = ''.join(translate_sgf_move(*z) for z in zip_longest(
move_history, comments))
result = result_string
return SGF_TEMPLATE.format(**locals())
def sgf_prop(value_list):
'Converts raw sgf library output to sensible value'
"""Converts raw sgf library output to sensible value."""
if value_list is None:
return None
if len(value_list) == 1:
......@@ -108,78 +106,82 @@ def sgf_prop_get(props, key, default):
def handle_node(board_size, pos, node):
'A node can either add B+W stones, play as B, or play as W.'
"""A node can either add B+W stones, play as B, or play as W."""
props = node.properties
black_stones_added = [coords.from_sgf(board_size,
c) for c in props.get('AB', [])]
white_stones_added = [coords.from_sgf(board_size,
c) for c in props.get('AW', [])]
black_stones_added = [coords.from_sgf(c) for c in props.get('AB', [])]
white_stones_added = [coords.from_sgf(c) for c in props.get('AW', [])]
if black_stones_added or white_stones_added:
return add_stones(pos, black_stones_added, white_stones_added)
return add_stones(board_size, pos, black_stones_added, white_stones_added)
# If B/W props are not present, then there is no move. But if it is present
# and equal to the empty string, then the move was a pass.
elif 'B' in props:
black_move = coords.from_sgf(board_size, props.get('B', [''])[0])
black_move = coords.from_sgf(props.get('B', [''])[0])
return pos.play_move(black_move, color=go.BLACK)
elif 'W' in props:
white_move = coords.from_sgf(board_size, props.get('W', [''])[0])
white_move = coords.from_sgf(props.get('W', [''])[0])
return pos.play_move(white_move, color=go.WHITE)
else:
return pos
def add_stones(pos, black_stones_added, white_stones_added):
def add_stones(board_size, pos, black_stones_added, white_stones_added):
working_board = np.copy(pos.board)
go.place_stones(working_board, go.BLACK, black_stones_added)
go.place_stones(working_board, go.WHITE, white_stones_added)
new_position = Position(board=working_board, n=pos.n, komi=pos.komi,
caps=pos.caps, ko=pos.ko, recent=pos.recent, to_play=pos.to_play)
new_position = Position(
board_size, board=working_board, n=pos.n, komi=pos.komi,
caps=pos.caps, ko=pos.ko, recent=pos.recent, to_play=pos.to_play)
return new_position
def get_next_move(board_size, node):
def get_next_move(node):
props = node.next.properties
if 'W' in props:
return coords.from_sgf(board_size, props['W'][0])
return coords.from_sgf(props['W'][0])
else:
return coords.from_sgf(board_size, props['B'][0])
return coords.from_sgf(props['B'][0])
def maybe_correct_next(pos, next_node):
if (('B' in next_node.properties and not pos.to_play == go.BLACK) or
('W' in next_node.properties and not pos.to_play == go.WHITE)):
if (('B' in next_node.properties and pos.to_play != go.BLACK) or
('W' in next_node.properties and pos.to_play != go.WHITE)):
pos.flip_playerturn(mutate=True)
def replay_sgf(board_size, sgf_contents):
"""
Wrapper for sgf files, returning go.PositionWithContext instances.
"""Wrapper for sgf files.
It does NOT return the very final position, as there is no follow up.
To get the final position, call pwc.position.play_move(pwc.next_move)
on the last PositionWithContext returned.
Example usage:
with open(filename) as f:
for position_w_context in replay_sgf(f.read()):
print(position_w_context.position)
Args:
board_size: the go board size.
sgf_contents: the content in sgf.
Yields:
The go.PositionWithContext instances.
"""
collection = sgf.parse(sgf_contents)
game = collection.children[0]
props = game.root.properties
assert int(sgf_prop(props.get('GM', ['1']))) == 1, "Not a Go SGF!"
assert int(sgf_prop(props.get('GM', ['1']))) == 1, 'Not a Go SGF!'
komi = 0
if props.get('KM') != None:
if props.get('KM') is not None:
komi = float(sgf_prop(props.get('KM')))
result = utils.parse_game_result(sgf_prop(props.get('RE')))
pos = Position(komi=komi)
pos = Position(board_size, komi=komi)
current_node = game.root
while pos is not None and current_node.next is not None:
pos = handle_node(board_size, pos, current_node)
maybe_correct_next(pos, current_node.next)
next_move = get_next_move(board_size, current_node)
next_move = get_next_move(current_node)
yield PositionWithContext(pos, next_move, result)
current_node = current_node.next
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for sgf_wrapper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
import coords
import go
from sgf_wrapper import replay_sgf, translate_sgf_move, make_sgf
import utils_test
JAPANESE_HANDICAP_SGF = '''(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]RU[Japanese]
SZ[9]HA[2]RE[Void]KM[5.50]PW[test_white]PB[test_black]AB[gc][cg];W[ee];B[dg])'''
CHINESE_HANDICAP_SGF = '''(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]RU[Chinese]SZ[9]
HA[2]RE[Void]KM[5.50]PW[test_white]PB[test_black]RE[B+39.50];B[gc];B[cg];W[ee];
B[gg];W[eg];B[ge];W[ce];B[ec];W[cc];B[dd];W[de];B[cd];W[bd];B[bc];W[bb];B[be];
W[ac];B[bf];W[dh];B[ch];W[ci];B[bi];W[di];B[ah];W[gh];B[hh];W[fh];B[hg];W[gi];
B[fg];W[dg];B[ei];W[cf];B[ef];W[ff];B[fe];W[bg];B[bh];W[af];B[ag];W[ae];B[ad];
W[ae];B[ed];W[db];B[df];W[eb];B[fb];W[ea];B[fa])'''
NO_HANDICAP_SGF = '''(;CA[UTF-8]SZ[9]PB[Murakawa Daisuke]PW[Iyama Yuta]KM[6.5]
HA[0]RE[W+1.5]GM[1];B[fd];W[cf];B[eg];W[dd];B[dc];W[cc];B[de];W[cd];B[ed];W[he];
B[ce];W[be];B[df];W[bf];B[hd];W[ge];B[gd];W[gg];B[db];W[cb];B[cg];W[bg];B[gh];
W[fh];B[hh];W[fg];B[eh];W[ei];B[di];W[fi];B[hg];W[dh];B[ch];W[ci];B[bh];W[ff];
B[fe];W[hf];B[id];W[bi];B[ah];W[ef];B[dg];W[ee];B[di];W[ig];B[ai];W[ih];B[fb];
W[hi];B[ag];W[ab];B[bd];W[bc];B[ae];W[ad];B[af];W[bd];B[ca];W[ba];B[da];W[ie])
'''
tf.logging.set_verbosity(tf.logging.ERROR)
class TestSgfGeneration(utils_test.MiniGoUnitTest):
def test_translate_sgf_move(self):
self.assertEqual(
';B[db]',
translate_sgf_move(go.PlayerMove(go.BLACK, (1, 3)), None))
self.assertEqual(
';W[aa]',
translate_sgf_move(go.PlayerMove(go.WHITE, (0, 0)), None))
self.assertEqual(
';W[]',
translate_sgf_move(go.PlayerMove(go.WHITE, None), None))
self.assertEqual(
';B[db]C[comment]',
translate_sgf_move(go.PlayerMove(go.BLACK, (1, 3)), 'comment'))
def test_make_sgf(self):
all_pwcs = list(replay_sgf(utils_test.BOARD_SIZE, NO_HANDICAP_SGF))
second_last_position, last_move, _ = all_pwcs[-1]
last_position = second_last_position.play_move(last_move)
back_to_sgf = make_sgf(
utils_test.BOARD_SIZE,
last_position.recent,
last_position.score(),
komi=last_position.komi,
)
reconstructed_positions = list(replay_sgf(
utils_test.BOARD_SIZE, back_to_sgf))
second_last_position2, last_move2, _ = reconstructed_positions[-1]
last_position2 = second_last_position2.play_move(last_move2)
self.assertEqualPositions(last_position, last_position2)
class TestSgfWrapper(utils_test.MiniGoUnitTest):
def test_sgf_props(self):
sgf_replayer = replay_sgf(utils_test.BOARD_SIZE, CHINESE_HANDICAP_SGF)
initial = next(sgf_replayer)
self.assertEqual(initial.result, go.BLACK)
self.assertEqual(initial.position.komi, 5.5)
def test_japanese_handicap_handling(self):
intermediate_board = utils_test.load_board('''
.........
.........
......X..
.........
....O....
.........
..X......
.........
.........
''')
intermediate_position = go.Position(
utils_test.BOARD_SIZE,
intermediate_board,
n=1,
komi=5.5,
caps=(0, 0),
recent=(go.PlayerMove(go.WHITE, coords.from_kgs(
utils_test.BOARD_SIZE, 'E5')),),
to_play=go.BLACK,
)
final_board = utils_test.load_board('''
.........
.........
......X..
.........
....O....
.........
..XX.....
.........
.........
''')
final_position = go.Position(
utils_test.BOARD_SIZE,
final_board,
n=2,
komi=5.5,
caps=(0, 0),
recent=(
go.PlayerMove(go.WHITE, coords.from_kgs(
utils_test.BOARD_SIZE, 'E5')),
go.PlayerMove(go.BLACK, coords.from_kgs(
utils_test.BOARD_SIZE, 'D3')),),
to_play=go.WHITE,
)
positions_w_context = list(replay_sgf(
utils_test.BOARD_SIZE, JAPANESE_HANDICAP_SGF))
self.assertEqualPositions(
intermediate_position, positions_w_context[1].position)
final_replayed_position = positions_w_context[-1].position.play_move(
positions_w_context[-1].next_move)
self.assertEqualPositions(final_position, final_replayed_position)
def test_chinese_handicap_handling(self):
intermediate_board = utils_test.load_board('''
.........
.........
......X..
.........
.........
.........
.........
.........
.........
''')
intermediate_position = go.Position(
utils_test.BOARD_SIZE,
intermediate_board,
n=1,
komi=5.5,
caps=(0, 0),
recent=(go.PlayerMove(go.BLACK, coords.from_kgs(
utils_test.BOARD_SIZE, 'G7')),),
to_play=go.BLACK,
)
final_board = utils_test.load_board('''
....OX...
.O.OOX...
O.O.X.X..
.OXXX....
OX...XX..
.X.XXO...
X.XOOXXX.
XXXO.OOX.
.XOOX.O..
''')
final_position = go.Position(
utils_test.BOARD_SIZE,
final_board,
n=50,
komi=5.5,
caps=(7, 2),
ko=None,
recent=(
go.PlayerMove(
go.WHITE, coords.from_kgs(utils_test.BOARD_SIZE, 'E9')),
go.PlayerMove(
go.BLACK, coords.from_kgs(utils_test.BOARD_SIZE, 'F9')),),
to_play=go.WHITE
)
positions_w_context = list(replay_sgf(
utils_test.BOARD_SIZE, CHINESE_HANDICAP_SGF))
self.assertEqualPositions(
intermediate_position, positions_w_context[1].position)
self.assertEqual(
positions_w_context[1].next_move, coords.from_kgs(
utils_test.BOARD_SIZE, 'C3'))
final_replayed_position = positions_w_context[-1].position.play_move(
positions_w_context[-1].next_move)
self.assertEqualPositions(final_position, final_replayed_position)
if __name__ == '__main__':
tf.test.main()
......@@ -31,15 +31,15 @@ import sgf_wrapper
def time_recommendation(move_num, seconds_per_move=5, time_limit=15*60,
decay_factor=0.98):
"""
Given current move number and "desired" seconds per move,
return how much time should actually be used. To be used specifically
for CGOS time controls, which are absolute 15 minute time.
"""Compute the time can be used."""
The strategy is to spend the maximum time possible using seconds_per_move,
and then switch to an exponentially decaying time usage, calibrated so that
we have enough time for an infinite number of moves.
"""
# Given current move number and "desired" seconds per move,
# return how much time should actually be used. To be used specifically
# for CGOS time controls, which are absolute 15 minute time.
# The strategy is to spend the maximum time possible using seconds_per_move,
# and then switch to an exponentially decaying time usage, calibrated so that
# we have enough time for an infinite number of moves.
# divide by two since you only play half the moves in a game.
player_move_num = move_num / 2
......@@ -69,6 +69,7 @@ def _get_temperature_cutoff(board_size):
class MCTSPlayerMixin(object):
# If 'simulations_per_move' is nonzero, it will perform that many reads
# before playing. Otherwise, it uses 'seconds_per_move' of wall time'
def __init__(self, board_size, network, seconds_per_move=5,
......@@ -92,7 +93,6 @@ class MCTSPlayerMixin(object):
self.result = 0
self.result_string = None
self.resign_threshold = -abs(resign_threshold)
super(MCTSPlayerMixin, self).__init__(board_size)
def initialize_game(self, position=None):
if position is None:
......@@ -105,11 +105,10 @@ class MCTSPlayerMixin(object):
self.qs = []
def suggest_move(self, position):
""" Used for playing a single game.
""" Used for playing a single game."""
# For parallel play, use initialize_move, select_leaf,
# incorporate_results, and pick_move
For parallel play, use initialize_move, select_leaf,
incorporate_results, and pick_move
"""
start = time.time()
if self.simulations_per_move == 0:
......@@ -120,7 +119,7 @@ class MCTSPlayerMixin(object):
while self.root.N < current_readouts + self.simulations_per_move:
self.tree_search()
if self.verbosity > 0:
print("%d: Searched %d times in %s seconds\n\n" % (
print('%d: Searched %d times in %s seconds\n\n' % (
position.n, self.simulations_per_move, time.time() - start),
file=sys.stderr)
......@@ -134,13 +133,13 @@ class MCTSPlayerMixin(object):
return self.pick_move()
def play_move(self, c):
"""
Notable side effects:
- finalizes the probability distribution according to
this roots visit counts into the class' running tally, `searches_pi`
- Makes the node associated with this move the root, for future
`inject_noise` calls.
"""
"""Play a move."""
# Notable side effects:
# - finalizes the probability distribution according to
# this roots visit counts into the class' running tally, `searches_pi`
# - Makes the node associated with this move the root, for future
# `inject_noise` calls.
if not self.two_player_mode:
self.searches_pi.append(
self.root.children_as_pi(self.root.position.n < self.temp_threshold))
......@@ -155,7 +154,8 @@ class MCTSPlayerMixin(object):
"""Picks a move to play, based on MCTS readout statistics.
Highest N is most robust indicator. In the early stage of the game, pick
a move weighted by visit count; later on, pick the absolute max."""
a move weighted by visit count; later on, pick the absolute max.
"""
if self.root.position.n > self.temp_threshold:
fcoord = np.argmax(self.root.child_N)
else:
......@@ -191,20 +191,20 @@ class MCTSPlayerMixin(object):
leaf.incorporate_results(move_prob, value, up_to=self.root)
def show_path_to_root(self, node):
MAX_DEPTH = (self.board_size ** 2) * 1.4 # 505 moves for 19x19, 113 for 9x9
max_depth = (self.board_size ** 2) * 1.4 # 505 moves for 19x19, 113 for 9x9
pos = node.position
diff = node.position.n - self.root.position.n
if len(pos.recent) == 0:
if pos.recent is None:
return
def fmt(move):
return "{}-{}".format('b' if move.color == 1 else 'w',
return '{}-{}'.format('b' if move.color == 1 else 'w',
coords.to_kgs(self.board_size, move.move))
path = " ".join(fmt(move) for move in pos.recent[-diff:])
if node.position.n >= MAX_DEPTH:
path += " (depth cutoff reached) %0.1f" % node.position.score()
path = ' '.join(fmt(move) for move in pos.recent[-diff:])
if node.position.n >= max_depth:
path += ' (depth cutoff reached) %0.1f' % node.position.score()
elif node.position.is_game_over():
path += " (game over) %0.1f" % node.position.score()
path += ' (game over) %0.1f' % node.position.score()
return path
def should_resign(self):
......@@ -217,7 +217,7 @@ class MCTSPlayerMixin(object):
def set_result(self, winner, was_resign):
self.result = winner
if was_resign:
string = "B+R" if winner == go.BLACK else "W+R"
string = 'B+R' if winner == go.BLACK else 'W+R'
else:
string = self.root.position.result_string()
self.result_string = string
......@@ -227,14 +227,14 @@ class MCTSPlayerMixin(object):
pos = self.root.position
if use_comments:
comments = self.comments or ['No comments.']
comments[0] = ("Resign Threshold: %0.3f\n" %
comments[0] = ('Resign Threshold: %0.3f\n' %
self.resign_threshold) + comments[0]
else:
comments = []
return sgf_wrapper.make_sgf(
self.board_size, pos.recent, self.result_string,
white_name=os.path.basename(self.network.save_file) or "Unknown",
black_name=os.path.basename(self.network.save_file) or "Unknown",
white_name=os.path.basename(self.network.save_file) or 'Unknown',
black_name=os.path.basename(self.network.save_file) or 'Unknown',
comments=comments)
def is_done(self):
......@@ -255,8 +255,8 @@ class MCTSPlayerMixin(object):
if 'winrate' in text.lower():
wr = (abs(self.root.Q) + 1.0) / 2.0
color = "Black" if self.root.Q > 0 else "White"
return "{:s} {:.2f}%".format(color, wr * 100.0)
color = 'Black' if self.root.Q > 0 else 'White'
return '{:s} {:.2f}%'.format(color, wr * 100.0)
elif 'nextplay' in text.lower():
return "I'm thinking... " + self.root.most_visited_path()
elif 'fortune' in text.lower():
......@@ -269,6 +269,7 @@ class MCTSPlayerMixin(object):
class CGOSPlayerMixin(MCTSPlayerMixin):
def suggest_move(self, position):
self.seconds_per_move = time_recommendation(position.n)
return super().suggest_move(position)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for strategies."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order
import coords
import go
import numpy as np
from strategies import MCTSPlayerMixin, time_recommendation
import utils_test
ALMOST_DONE_BOARD = utils_test.load_board('''
.XO.XO.OO
X.XXOOOO.
XXXXXOOOO
XXXXXOOOO
.XXXXOOO.
XXXXXOOOO
.XXXXOOO.
XXXXXOOOO
XXXXOOOOO
''')
# Tromp taylor means black can win if we hit the move limit.
TT_FTW_BOARD = utils_test.load_board('''
.XXOOOOOO
X.XOO...O
.XXOO...O
X.XOO...O
.XXOO..OO
X.XOOOOOO
.XXOOOOOO
X.XXXXXXX
XXXXXXXXX
''')
SEND_TWO_RETURN_ONE = go.Position(
utils_test.BOARD_SIZE,
board=ALMOST_DONE_BOARD,
n=70,
komi=2.5,
caps=(1, 4),
ko=None,
recent=(go.PlayerMove(go.BLACK, (0, 1)),
go.PlayerMove(go.WHITE, (0, 8))),
to_play=go.BLACK
)
# 505 moves for 19x19, 113 for 9x9
MAX_DEPTH = (utils_test.BOARD_SIZE ** 2) * 1.4
class DummyNet():
def __init__(self, fake_priors=None, fake_value=0):
if fake_priors is None:
fake_priors = np.ones(
(utils_test.BOARD_SIZE ** 2) + 1) / (utils_test.BOARD_SIZE ** 2 + 1)
self.fake_priors = fake_priors
self.fake_value = fake_value
def run(self, position):
return self.fake_priors, self.fake_value
def run_many(self, positions):
if not positions:
raise ValueError(
"No positions passed! (Tensorflow would have failed here.")
return [self.fake_priors] * len(positions), [
self.fake_value] * len(positions)
def initialize_basic_player():
player = MCTSPlayerMixin(utils_test.BOARD_SIZE, DummyNet())
player.initialize_game()
first_node = player.root.select_leaf()
first_node.incorporate_results(
*player.network.run(player.root.position), up_to=player.root)
return player
def initialize_almost_done_player():
probs = np.array([.001] * (utils_test.BOARD_SIZE * utils_test.BOARD_SIZE + 1))
probs[2:5] = 0.2 # some legal moves along the top.
probs[-1] = 0.2 # passing is also ok
net = DummyNet(fake_priors=probs)
player = MCTSPlayerMixin(utils_test.BOARD_SIZE, net)
# root position is white to play with no history == white passed.
player.initialize_game(SEND_TWO_RETURN_ONE)
return player
class TestMCTSPlayerMixin(utils_test.MiniGoUnitTest):
def test_time_controls(self):
secs_per_move = 5
for time_limit in (10, 100, 1000):
# in the worst case imaginable, let's say a game goes 1000 moves long
move_numbers = range(0, 1000, 2)
total_time_spent = sum(
time_recommendation(move_num, secs_per_move,
time_limit=time_limit)
for move_num in move_numbers)
# we should not exceed available game time
self.assertLess(total_time_spent, time_limit)
# but we should have used at least 95% of our time by the end.
self.assertGreater(total_time_spent, time_limit * 0.95)
def test_inject_noise(self):
player = initialize_basic_player()
sum_priors = np.sum(player.root.child_prior)
# dummyNet should return normalized priors.
self.assertAlmostEqual(sum_priors, 1)
self.assertTrue(np.all(player.root.child_U == player.root.child_U[0]))
player.root.inject_noise()
new_sum_priors = np.sum(player.root.child_prior)
# priors should still be normalized after injecting noise
self.assertAlmostEqual(sum_priors, new_sum_priors)
# With dirichelet noise, majority of density should be in one node.
max_p = np.max(player.root.child_prior)
self.assertGreater(max_p, 3/(utils_test.BOARD_SIZE ** 2 + 1))
def test_pick_moves(self):
player = initialize_basic_player()
root = player.root
root.child_N[coords.to_flat(utils_test.BOARD_SIZE, (2, 0))] = 10
root.child_N[coords.to_flat(utils_test.BOARD_SIZE, (1, 0))] = 5
root.child_N[coords.to_flat(utils_test.BOARD_SIZE, (3, 0))] = 1
# move 81, or 361, or... Endgame.
root.position.n = utils_test.BOARD_SIZE ** 2
# Assert we're picking deterministically
self.assertTrue(root.position.n > player.temp_threshold)
move = player.pick_move()
self.assertEqual(move, (2, 0))
# But if we're in the early part of the game, pick randomly
root.position.n = 3
self.assertFalse(player.root.position.n > player.temp_threshold)
with unittest.mock.patch('random.random', lambda: .5):
move = player.pick_move()
self.assertEqual(move, (2, 0))
with unittest.mock.patch('random.random', lambda: .99):
move = player.pick_move()
self.assertEqual(move, (3, 0))
def test_dont_pass_if_losing(self):
player = initialize_almost_done_player()
# check -- white is losing.
self.assertEqual(player.root.position.score(), -0.5)
for i in range(20):
player.tree_search()
# uncomment to debug this test
# print(player.root.describe())
# Search should converge on D9 as only winning move.
flattened = coords.to_flat(utils_test.BOARD_SIZE, coords.from_kgs(
utils_test.BOARD_SIZE, 'D9'))
best_move = np.argmax(player.root.child_N)
self.assertEqual(best_move, flattened)
# D9 should have a positive value
self.assertGreater(player.root.children[flattened].Q, 0)
self.assertGreaterEqual(player.root.N, 20)
# passing should be ineffective.
self.assertLess(player.root.child_Q[-1], 0)
# no virtual losses should be pending
self.assertNoPendingVirtualLosses(player.root)
# uncomment to debug this test
# print(player.root.describe())
def test_parallel_tree_search(self):
player = initialize_almost_done_player()
# check -- white is losing.
self.assertEqual(player.root.position.score(), -0.5)
# initialize the tree so that the root node has populated children.
player.tree_search(num_parallel=1)
# virtual losses should enable multiple searches to happen simultaneously
# without throwing an error...
for i in range(5):
player.tree_search(num_parallel=4)
# uncomment to debug this test
# print(player.root.describe())
# Search should converge on D9 as only winning move.
flattened = coords.to_flat(utils_test.BOARD_SIZE, coords.from_kgs(
utils_test.BOARD_SIZE, 'D9'))
best_move = np.argmax(player.root.child_N)
self.assertEqual(best_move, flattened)
# D9 should have a positive value
self.assertGreater(player.root.children[flattened].Q, 0)
self.assertGreaterEqual(player.root.N, 20)
# passing should be ineffective.
self.assertLess(player.root.child_Q[-1], 0)
# no virtual losses should be pending
self.assertNoPendingVirtualLosses(player.root)
def test_ridiculously_parallel_tree_search(self):
player = initialize_almost_done_player()
# Test that an almost complete game
# will tree search with # parallelism > # legal moves.
for i in range(10):
player.tree_search(num_parallel=50)
self.assertNoPendingVirtualLosses(player.root)
def test_long_game_tree_search(self):
player = MCTSPlayerMixin(utils_test.BOARD_SIZE, DummyNet())
endgame = go.Position(
utils_test.BOARD_SIZE,
board=TT_FTW_BOARD,
n=MAX_DEPTH-2,
komi=2.5,
ko=None,
recent=(go.PlayerMove(go.BLACK, (0, 1)),
go.PlayerMove(go.WHITE, (0, 8))),
to_play=go.BLACK
)
player.initialize_game(endgame)
# Test that an almost complete game
for i in range(10):
player.tree_search(num_parallel=8)
self.assertNoPendingVirtualLosses(player.root)
self.assertGreater(player.root.Q, 0)
def test_cold_start_parallel_tree_search(self):
# Test that parallel tree search doesn't trip on an empty tree
player = MCTSPlayerMixin(utils_test.BOARD_SIZE, DummyNet(fake_value=0.17))
player.initialize_game()
self.assertEqual(player.root.N, 0)
self.assertFalse(player.root.is_expanded)
player.tree_search(num_parallel=4)
self.assertNoPendingVirtualLosses(player.root)
# Even though the root gets selected 4 times by tree search, its
# final visit count should just be 1.
self.assertEqual(player.root.N, 1)
# 0.085 = average(0, 0.17), since 0 is the prior on the root.
self.assertAlmostEqual(player.root.Q, 0.085)
def test_tree_search_failsafe(self):
# Test that the failsafe works correctly. It can trigger if the MCTS
# repeatedly visits a finished game state.
probs = np.array([.001] * (
utils_test.BOARD_SIZE * utils_test.BOARD_SIZE + 1))
probs[-1] = 1 # Make the dummy net always want to pass
player = MCTSPlayerMixin(utils_test.BOARD_SIZE, DummyNet(fake_priors=probs))
pass_position = go.Position(utils_test.BOARD_SIZE).pass_move()
player.initialize_game(pass_position)
player.tree_search(num_parallel=1)
self.assertNoPendingVirtualLosses(player.root)
def test_only_check_game_end_once(self):
# When presented with a situation where the last move was a pass,
# and we have to decide whether to pass, it should be the first thing
# we check, but not more than that.
white_passed_pos = go.Position(
utils_test.BOARD_SIZE,).play_move(
(3, 3) # b plays
).play_move(
(3, 4) # w plays
).play_move(
(4, 3) # b plays
).pass_move() # w passes - if B passes too, B would lose by komi.
player = MCTSPlayerMixin(utils_test.BOARD_SIZE, DummyNet())
player.initialize_game(white_passed_pos)
# initialize the root
player.tree_search()
# explore a child - should be a pass move.
player.tree_search()
pass_move = utils_test.BOARD_SIZE * utils_test.BOARD_SIZE
self.assertEqual(player.root.children[pass_move].N, 1)
self.assertEqual(player.root.child_N[pass_move], 1)
player.tree_search()
# check that we didn't visit the pass node any more times.
self.assertEqual(player.root.child_N[pass_move], 1)
def test_extract_data_normal_end(self):
player = MCTSPlayerMixin(utils_test.BOARD_SIZE, DummyNet())
player.initialize_game()
player.tree_search()
player.play_move(None)
player.tree_search()
player.play_move(None)
self.assertTrue(player.root.is_done())
player.set_result(player.root.position.result(), was_resign=False)
data = list(player.extract_data())
self.assertEqual(len(data), 2)
position, pi, result = data[0]
# White wins by komi
self.assertEqual(result, go.WHITE)
self.assertEqual(player.result_string, 'W+{}'.format(
player.root.position.komi))
def test_extract_data_resign_end(self):
player = MCTSPlayerMixin(utils_test.BOARD_SIZE, DummyNet())
player.initialize_game()
player.tree_search()
player.play_move((0, 0))
player.tree_search()
player.play_move(None)
player.tree_search()
# Black is winning on the board
self.assertEqual(player.root.position.result(), go.BLACK)
# But if Black resigns
player.set_result(go.WHITE, was_resign=True)
data = list(player.extract_data())
position, pi, result = data[0]
# Result should say White is the winner
self.assertEqual(result, go.WHITE)
self.assertEqual(player.result_string, 'W+R')
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for symmetries."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import tensorflow as tf # pylint: disable=g-bad-import-order
import coords
import numpy as np
import symmetries
import utils_test
tf.logging.set_verbosity(tf.logging.ERROR)
class TestSymmetryOperations(utils_test.MiniGoUnitTest):
def setUp(self):
np.random.seed(1)
self.feat = np.random.random(
[utils_test.BOARD_SIZE, utils_test.BOARD_SIZE, 3])
self.pi = np.random.random([utils_test.BOARD_SIZE ** 2 + 1])
super().setUp()
def test_inversions(self):
for s in symmetries.SYMMETRIES:
with self.subTest(symmetry=s):
self.assertEqualNPArray(
self.feat, symmetries.apply_symmetry_feat(
s, symmetries.apply_symmetry_feat(
symmetries.invert_symmetry(s), self.feat)))
self.assertEqualNPArray(
self.feat, symmetries.apply_symmetry_feat(
symmetries.invert_symmetry(s), symmetries.apply_symmetry_feat(
s, self.feat)))
self.assertEqualNPArray(
self.pi, symmetries.apply_symmetry_pi(
utils_test.BOARD_SIZE, s, symmetries.apply_symmetry_pi(
utils_test.BOARD_SIZE, symmetries.invert_symmetry(s),
self.pi)))
self.assertEqualNPArray(
self.pi, symmetries.apply_symmetry_pi(
utils_test.BOARD_SIZE, symmetries.invert_symmetry(s),
symmetries.apply_symmetry_pi(
utils_test.BOARD_SIZE, s, self.pi)))
def test_compositions(self):
test_cases = [
('rot90', 'rot90', 'rot180'),
('rot90', 'rot180', 'rot270'),
('identity', 'rot90', 'rot90'),
('fliprot90', 'rot90', 'fliprot180'),
('rot90', 'rot270', 'identity'),
]
for s1, s2, composed in test_cases:
with self.subTest(s1=s1, s2=s2, composed=composed):
self.assertEqualNPArray(symmetries.apply_symmetry_feat(
composed, self.feat), symmetries.apply_symmetry_feat(
s2, symmetries.apply_symmetry_feat(s1, self.feat)))
self.assertEqualNPArray(
symmetries.apply_symmetry_pi(
utils_test.BOARD_SIZE, composed, self.pi),
symmetries.apply_symmetry_pi(
utils_test.BOARD_SIZE, s2,
symmetries.apply_symmetry_pi(
utils_test.BOARD_SIZE, s1, self.pi)))
def test_uniqueness(self):
all_symmetries_f = [
symmetries.apply_symmetry_feat(
s, self.feat) for s in symmetries.SYMMETRIES
]
all_symmetries_pi = [
symmetries.apply_symmetry_pi(
utils_test.BOARD_SIZE, s, self.pi) for s in symmetries.SYMMETRIES
]
for f1, f2 in itertools.combinations(all_symmetries_f, 2):
self.assertNotEqualNPArray(f1, f2)
for pi1, pi2 in itertools.combinations(all_symmetries_pi, 2):
self.assertNotEqualNPArray(pi1, pi2)
def test_proper_move_transform(self):
# Check that the reinterpretation of 362 = 19*19 + 1 during symmetry
# application is consistent with coords.from_flat
move_array = np.arange(utils_test.BOARD_SIZE ** 2 + 1)
coord_array = np.zeros([utils_test.BOARD_SIZE, utils_test.BOARD_SIZE])
for c in range(utils_test.BOARD_SIZE ** 2):
coord_array[coords.from_flat(utils_test.BOARD_SIZE, c)] = c
for s in symmetries.SYMMETRIES:
with self.subTest(symmetry=s):
transformed_moves = symmetries.apply_symmetry_pi(
utils_test.BOARD_SIZE, s, move_array)
transformed_board = symmetries.apply_symmetry_feat(s, coord_array)
for new_coord, old_coord in enumerate(transformed_moves[:-1]):
self.assertEqual(
old_coord,
transformed_board[
coords.from_flat(utils_test.BOARD_SIZE, new_coord)])
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for utils, and base class for other unit tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import re
import tempfile
import time
import tensorflow as tf # pylint: disable=g-bad-import-order
import go
import numpy as np
import utils
tf.logging.set_verbosity(tf.logging.ERROR)
BOARD_SIZE = 9
EMPTY_BOARD = np.zeros([BOARD_SIZE, BOARD_SIZE], dtype=np.int8)
ALL_COORDS = [(i, j) for i in range(BOARD_SIZE) for j in range(BOARD_SIZE)]
def _check_bounds(c):
return c[0] % BOARD_SIZE == c[0] and c[1] % BOARD_SIZE == c[1]
NEIGHBORS = {(x, y): list(filter(_check_bounds, [
(x+1, y), (x-1, y), (x, y+1), (x, y-1)])) for x, y in ALL_COORDS}
def load_board(string):
reverse_map = {
'X': go.BLACK,
'O': go.WHITE,
'.': go.EMPTY,
'#': go.FILL,
'*': go.KO,
'?': go.UNKNOWN
}
string = re.sub(r'[^XO\.#]+', '', string)
if len(string) != BOARD_SIZE ** 2:
raise ValueError("Board to load didn't have right dimensions")
board = np.zeros([BOARD_SIZE, BOARD_SIZE], dtype=np.int8)
for ii, char in enumerate(string):
np.ravel(board)[ii] = reverse_map[char]
return board
class TestUtils(tf.test.TestCase):
def test_bootstrap_name(self):
name = utils.generate_model_name(0)
self.assertIn('bootstrap', name)
def test_generate_model_name(self):
name = utils.generate_model_name(17)
self.assertIn('000017', name)
def test_detect_name(self):
string = '000017-model.index'
detected_name = utils.detect_model_name(string)
self.assertEqual(detected_name, '000017-model')
def test_detect_num(self):
string = '000017-model.index'
detected_name = utils.detect_model_num(string)
self.assertEqual(detected_name, 17)
def test_get_models(self):
with tempfile.TemporaryDirectory() as models_dir:
model1 = '000013-model.meta'
model2 = '000017-model.meta'
f1 = open(os.path.join(models_dir, model1), 'w')
f1.close()
f2 = open(os.path.join(models_dir, model2), 'w')
f2.close()
model_nums_names = utils.get_models(models_dir)
self.assertEqual(len(model_nums_names), 2)
self.assertEqual(model_nums_names[0], (13, '000013-model'))
self.assertEqual(model_nums_names[1], (17, '000017-model'))
def test_get_latest_model(self):
with tempfile.TemporaryDirectory() as models_dir:
model1 = '000013-model.meta'
model2 = '000017-model.meta'
f1 = open(os.path.join(models_dir, model1), 'w')
f1.close()
f2 = open(os.path.join(models_dir, model2), 'w')
f2.close()
latest_model = utils.get_latest_model(models_dir)
self.assertEqual(latest_model, (17, '000017-model'))
def test_round_power_of_two(self):
self.assertEqual(utils.round_power_of_two(84), 64)
self.assertEqual(utils.round_power_of_two(120), 128)
def test_shuffler(self):
random.seed(1)
dataset = (i for i in range(10))
shuffled = list(utils.shuffler(
dataset, pool_size=5, refill_threshold=0.8))
self.assertEqual(len(shuffled), 10)
self.assertNotEqual(shuffled, list(range(10)))
def test_parse_game_result(self):
self.assertEqual(utils.parse_game_result('B+3.5'), go.BLACK)
self.assertEqual(utils.parse_game_result('W+T'), go.WHITE)
self.assertEqual(utils.parse_game_result('Void'), 0)
class MiniGoUnitTest(tf.test.TestCase):
@classmethod
def setUpClass(cls):
cls.start_time = time.time()
@classmethod
def tearDownClass(cls):
print('\n%s.%s: %.3f seconds' %
(cls.__module__, cls.__name__, time.time() - cls.start_time))
def assertEqualNPArray(self, array1, array2):
if not np.all(array1 == array2):
raise AssertionError(
'Arrays differed in one or more locations:\n%s\n%s' % (array1, array2)
)
def assertNotEqualNPArray(self, array1, array2):
if np.all(array1 == array2):
raise AssertionError('Arrays were identical:\n%s' % array1)
def assertEqualLibTracker(self, lib_tracker1, lib_tracker2):
# A lib tracker may have differently numbered groups yet still
# represent the same set of groups.
# "Sort" the group_ids to ensure they are the same.
def find_group_mapping(lib_tracker):
current_gid = 0
mapping = {}
for group_id in lib_tracker.group_index.ravel().tolist():
if group_id == go.MISSING_GROUP_ID:
continue
if group_id not in mapping:
mapping[group_id] = current_gid
current_gid += 1
return mapping
lt1_mapping = find_group_mapping(lib_tracker1)
lt2_mapping = find_group_mapping(lib_tracker2)
remapped_group_index1 = [
lt1_mapping.get(gid, go.MISSING_GROUP_ID)
for gid in lib_tracker1.group_index.ravel().tolist()]
remapped_group_index2 = [
lt2_mapping.get(gid, go.MISSING_GROUP_ID)
for gid in lib_tracker2.group_index.ravel().tolist()]
self.assertEqual(remapped_group_index1, remapped_group_index2)
remapped_groups1 = {lt1_mapping.get(
gid): group for gid, group in lib_tracker1.groups.items()}
remapped_groups2 = {lt2_mapping.get(
gid): group for gid, group in lib_tracker2.groups.items()}
self.assertEqual(remapped_groups1, remapped_groups2)
self.assertEqualNPArray(
lib_tracker1.liberty_cache, lib_tracker2.liberty_cache)
def assertEqualPositions(self, pos1, pos2):
self.assertEqualNPArray(pos1.board, pos2.board)
self.assertEqualLibTracker(pos1.lib_tracker, pos2.lib_tracker)
self.assertEqual(pos1.n, pos2.n)
self.assertEqual(pos1.caps, pos2.caps)
self.assertEqual(pos1.ko, pos2.ko)
r_len = min(len(pos1.recent), len(pos2.recent))
if r_len > 0: # if a position has no history, then don't bother testing
self.assertEqual(pos1.recent[-r_len:], pos2.recent[-r_len:])
self.assertEqual(pos1.to_play, pos2.to_play)
def assertNoPendingVirtualLosses(self, root):
"""Raise an error if any node in this subtree has vlosses pending."""
queue = [root]
while queue:
current = queue.pop()
self.assertEqual(current.losses_applied, 0)
queue.extend(current.children.values())
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment