master.py 3.21 KB
Newer Older
Jacob Buckman's avatar
Jacob Buckman committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from builtins import str
from builtins import range
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import multiprocessing
import os, sys, time

from config import config, log_config
import util

AGENT_COUNT = config["agent_config"]["count"]
EVALUATOR_COUNT = config["evaluator_config"]["count"]
MODEL_AUGMENTED = config["model_config"] is not False
if config["resume"]:
  ROOT_PATH = "output/" + config["env"]["name"] + "/" + config["name"]
else:
  ROOT_PATH = util.create_and_wipe_directory("output/" + config["env"]["name"] + "/" + config["name"])
log_config()
import learner, agent, valuerl_learner
if MODEL_AUGMENTED: import worldmodel_learner

if __name__ == '__main__':
  all_procs = set([])
  interaction_procs = set([])

  # lock
  policy_lock = multiprocessing.Lock()
  model_lock = multiprocessing.Lock() if MODEL_AUGMENTED else None

  # queue
  policy_replay_frame_queue = multiprocessing.Queue(1)
  model_replay_frame_queue = multiprocessing.Queue(1) if MODEL_AUGMENTED else None

  # interactors
  for interact_proc_i in range(AGENT_COUNT):
    interact_proc = multiprocessing.Process(target=agent.main, args=(interact_proc_i, False, policy_replay_frame_queue, model_replay_frame_queue, policy_lock, config))
    all_procs.add(interact_proc)
    interaction_procs.add(interact_proc)

  # evaluators
  for interact_proc_i in range(EVALUATOR_COUNT):
    interact_proc = multiprocessing.Process(target=agent.main, args=(interact_proc_i, True, policy_replay_frame_queue, model_replay_frame_queue, policy_lock, config))
    all_procs.add(interact_proc)
    interaction_procs.add(interact_proc)

  # policy training
  train_policy_proc = multiprocessing.Process(target=learner.run_learner, args=(valuerl_learner.ValueRLLearner, policy_replay_frame_queue, policy_lock, config, config["env"], config["policy_config"]), kwargs={"model_lock": model_lock})
  all_procs.add(train_policy_proc)

  # model training
  if MODEL_AUGMENTED:
    train_model_proc = multiprocessing.Process(target=learner.run_learner, args=(worldmodel_learner.WorldmodelLearner, model_replay_frame_queue, model_lock, config, config["env"], config["model_config"]))
    all_procs.add(train_model_proc)

  # start all policies
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  for i, proc in enumerate(interaction_procs):
    os.environ['CUDA_VISIBLE_DEVICES'] = ''
    proc.start()

  os.environ['CUDA_VISIBLE_DEVICES'] = str(int(sys.argv[2]))
  train_policy_proc.start()

  if MODEL_AUGMENTED:
    os.environ['CUDA_VISIBLE_DEVICES'] = str(1+int(sys.argv[2]))
    train_model_proc.start()

  while True:
    try:
      pass
    except:
      for proc in all_procs: proc.join()