# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. # Copyright 2019 Kakao Brain # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading import time import pytest import torch from fairscale.nn.pipe.microbatch import Batch from fairscale.nn.pipe.stream import CPUStream from fairscale.nn.pipe.worker import Task, spawn_workers class fake_device: """A test double for :class:`torch.device`. Every fake device is different with each other. """ type = "fake" index = None def test_join_running_workers(): count = 0 def counter(): nonlocal count time.sleep(0.1) count += 1 return Batch((), 0) with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues): def call_in_worker(i, f): task = Task(CPUStream, compute=f, finalize=None) in_queues[i].put(task) for i in range(10): call_in_worker(i, counter) # There's no nondeterminism because 'spawn_workers' joins all running # workers. assert count == 10 def test_join_running_workers_with_exception(): class ExpectedException(Exception): pass count = 0 def counter(): nonlocal count time.sleep(0.1) count += 1 return Batch((), 0) with pytest.raises(ExpectedException): with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues): def call_in_worker(i, f): task = Task(CPUStream, compute=f, finalize=None) in_queues[i].put(task) for i in range(10): call_in_worker(i, counter) raise ExpectedException # There's no nondeterminism because only 1 task can be placed in input # queues. assert count == 10 def test_compute_multithreading(): """Task.compute should be executed on multiple threads.""" thread_ids = set() def log_thread_id(): thread_id = threading.current_thread().ident thread_ids.add(thread_id) return Batch((), 0) with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues): for i in range(2): t = Task(CPUStream, compute=log_thread_id, finalize=None) in_queues[i].put(t) for i in range(2): out_queues[i].get() assert len(thread_ids) == 2 def test_compute_success(): """Task.compute returns (True, (task, batch)) on success.""" def _42(): return Batch(torch.tensor(42), 0) with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): t = Task(CPUStream, compute=_42, finalize=None) in_queues[0].put(t) ok, (task, batch) = out_queues[0].get() assert ok assert task is t assert isinstance(batch, Batch) assert batch[0].item() == 42 def test_compute_exception(): """Task.compute returns (False, exc_info) on failure.""" def zero_div(): 0 / 0 with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): t = Task(CPUStream, compute=zero_div, finalize=None) in_queues[0].put(t) ok, exc_info = out_queues[0].get() assert not ok assert isinstance(exc_info, tuple) assert issubclass(exc_info[0], ZeroDivisionError) @pytest.mark.parametrize("grad_mode", [True, False]) def test_grad_mode(grad_mode): def detect_grad_enabled(): x = torch.rand(1, requires_grad=torch.is_grad_enabled()) return Batch(x, 0) with torch.set_grad_enabled(grad_mode): with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): task = Task(CPUStream, compute=detect_grad_enabled, finalize=None) in_queues[0].put(task) ok, (_, batch) = out_queues[0].get() assert ok assert batch[0].requires_grad == grad_mode def test_worker_per_device(): cpu = torch.device("cpu") cpu0 = torch.device("cpu", index=0) fake1 = fake_device() fake2 = fake_device() with spawn_workers([cpu, cpu, cpu0, fake1, fake2]) as (in_queues, out_queues): assert len(in_queues) == len(out_queues) == 5 # 0: cpu, 1: cpu, 2: cpu0 assert in_queues[0] is in_queues[1] is in_queues[2] assert out_queues[0] is out_queues[1] is out_queues[2] # 3: fake1, 4: fake2 assert in_queues[3] is not in_queues[4] assert out_queues[3] is not out_queues[4]