Unverified Commit 5e689720 authored by Yuting Jiang's avatar Yuting Jiang Committed by GitHub
Browse files

Benchmarks: Revise Test - Revise benchmark test util to support pytorch multi-GPU test (#54)

* Superbenchmark: Revise tests - revise benchmark test util to support multi gpu test

* modify test_sharding_matmul.py to match the tests util
parent cb33c99c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Module for tests."""
...@@ -22,7 +22,7 @@ def test_pytorch_sharding_matmul(): ...@@ -22,7 +22,7 @@ def test_pytorch_sharding_matmul():
assert (BenchmarkRegistry.is_benchmark_context_valid(context)) assert (BenchmarkRegistry.is_benchmark_context_valid(context))
utils.setup_simulated_ddp_distributed_env() utils.setup_simulated_ddp_distributed_env(1, 0, utils.get_free_port())
benchmark = BenchmarkRegistry.launch_benchmark(context) benchmark = BenchmarkRegistry.launch_benchmark(context)
# Check basic information. # Check basic information.
......
...@@ -4,15 +4,12 @@ ...@@ -4,15 +4,12 @@
"""Utilities for benchmark tests.""" """Utilities for benchmark tests."""
import os import os
import socket
from contextlib import closing
import multiprocessing as multiprocessing
from multiprocessing import Process
from superbench.benchmarks import BenchmarkRegistry
def setup_simulated_ddp_distributed_env():
"""Function to setup the simulated DDP distributed envionment variables."""
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'
os.environ['LOCAL_RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
def clean_simulated_ddp_distributed_env(): def clean_simulated_ddp_distributed_env():
...@@ -22,3 +19,57 @@ def clean_simulated_ddp_distributed_env(): ...@@ -22,3 +19,57 @@ def clean_simulated_ddp_distributed_env():
os.environ.pop('LOCAL_RANK') os.environ.pop('LOCAL_RANK')
os.environ.pop('MASTER_ADDR') os.environ.pop('MASTER_ADDR')
os.environ.pop('MASTER_PORT') os.environ.pop('MASTER_PORT')
def get_free_port():
"""Get a free port in current system.
Return:
port (int): a free port in current system.
"""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
def setup_simulated_ddp_distributed_env(world_size, local_rank, port):
"""Function to setup the simulated DDP distributed envionment variables."""
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(local_rank)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port)
def benchmark_in_one_process(context, world_size, local_rank, port, queue):
"""Function to setup env for DDP initialization and run the benchmark in each single process."""
setup_simulated_ddp_distributed_env(world_size, local_rank, port)
benchmark = BenchmarkRegistry.launch_benchmark(context)
# parser object must be removed becaues it can not be serialized.
benchmark._parser = None
queue.put(benchmark)
clean_simulated_ddp_distributed_env()
def simulated_ddp_distributed_benchmark(context, world_size):
"""Function to run the benchmark on #world_size number of processes.
Return:
results (list): list of benchmark results from #world_size number of processes.
"""
port = get_free_port()
process_list = []
multiprocessing.set_start_method('spawn')
queue = multiprocessing.Queue()
for rank in range(world_size):
process = Process(target=benchmark_in_one_process, args=(context, world_size, rank, port, queue))
process.start()
process_list.append(process)
for process in process_list:
process.join()
results = [queue.get(1) for p in process_list]
return results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment