Unverified Commit 71c1617b authored by Yuting Jiang's avatar Yuting Jiang Committed by GitHub
Browse files

Utils: Code Revision - Update network common utils (#118)


Update network common utils. Add get_ib_devices in network common utils and move get_free_port from test utils to network common utils
parent 9c984c7e
......@@ -15,6 +15,7 @@
'create_sb_output_dir',
'get_sb_config',
'logger',
'network',
'nv_helper',
'rotate_dir',
]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Network Utility."""
import socket
import re
from pathlib import Path
def get_free_port():
"""Get a free port in current system.
Return:
port (int): a free port in current system.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
except OSError:
return None
finally:
s.close()
def get_ib_devices():
"""Get available IB devices with available ports in the system and filter ethernet devices.
Return:
ib_devices_port (list): IB devices with available ports in current system.
"""
devices = list(p.name for p in Path('/sys/class/infiniband').glob('*'))
ib_devices_port_dict = {}
for device in devices:
ports = list(p.name for p in (Path('/sys/class/infiniband') / device / 'ports').glob('*'))
ports.sort(key=lambda s: [int(ch) if ch.isdigit() else ch for ch in re.split(r'(\d+)', s)])
for port in ports:
with (Path('/sys/class/infiniband') / device / 'ports' / port / 'link_layer').open('r') as f:
# Filter 'InfiniBand' devices by link_layer
if f.read().strip() == 'InfiniBand':
if device not in ib_devices_port_dict:
ib_devices_port_dict[device] = [port]
else:
ib_devices_port_dict[device].append(port)
ib_devices = list(ib_devices_port_dict.keys())
ib_devices.sort(key=lambda s: [int(ch) if ch.isdigit() else ch for ch in re.split(r'(\d+)', s)])
ib_devices_port = []
for device in ib_devices:
ib_devices_port.append(device + ':' + ','.join(ib_devices_port_dict[device]))
return ib_devices_port
......@@ -10,6 +10,7 @@
from superbench.benchmarks import BenchmarkRegistry, Framework, BenchmarkType, ReturnCode
from superbench.benchmarks.micro_benchmarks.computation_communication_overlap \
import ComputationCommunicationOverlap, ComputationKernelType
from superbench.common.utils import network
# TODO - replace unittest.skip("no multiple GPUs") to decorator of skipIfNoMultiGPUS
......@@ -26,6 +27,7 @@ def test_pytorch_computation_communication_overlap_normal():
world_size = 2
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
results = utils.simulated_ddp_distributed_benchmark(context, world_size)
assert (results)
for benchmark in results:
# Check basic information.
assert (benchmark)
......@@ -56,7 +58,9 @@ def test_pytorch_computation_communication_overlap_fake_distributed():
parameters='--num_warmup 5 --num_steps 10 --ratio 5',
framework=Framework.PYTORCH
)
utils.setup_simulated_ddp_distributed_env(1, 0, utils.get_free_port())
port = network.get_free_port()
assert (port)
utils.setup_simulated_ddp_distributed_env(1, 0, port)
benchmark = BenchmarkRegistry.launch_benchmark(context)
# Check basic information.
......
......@@ -7,6 +7,7 @@
from tests.helper import decorator
from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode
from superbench.benchmarks.micro_benchmarks.sharding_matmul import ShardingMatmul, ShardingMode
from superbench.common.utils import network
@decorator.cuda_test
......@@ -22,7 +23,9 @@ def test_pytorch_sharding_matmul():
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
utils.setup_simulated_ddp_distributed_env(1, 0, utils.get_free_port())
port = network.get_free_port()
assert (port)
utils.setup_simulated_ddp_distributed_env(1, 0, port)
benchmark = BenchmarkRegistry.launch_benchmark(context)
# Check basic information.
......
......@@ -4,12 +4,11 @@
"""Utilities for benchmark tests."""
import os
import socket
from contextlib import closing
import multiprocessing as multiprocessing
from multiprocessing import Process
from superbench.benchmarks import BenchmarkRegistry
from superbench.common.utils import network
def clean_simulated_ddp_distributed_env():
......@@ -21,18 +20,6 @@ def clean_simulated_ddp_distributed_env():
os.environ.pop('MASTER_PORT')
def get_free_port():
"""Get a free port in current system.
Return:
port (int): a free port in current system.
"""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
def setup_simulated_ddp_distributed_env(world_size, local_rank, port):
"""Function to setup the simulated DDP distributed envionment variables."""
os.environ['WORLD_SIZE'] = str(world_size)
......@@ -58,7 +45,9 @@ def simulated_ddp_distributed_benchmark(context, world_size):
Return:
results (list): list of benchmark results from #world_size number of processes.
"""
port = get_free_port()
port = network.get_free_port()
if not port:
return None
process_list = []
multiprocessing.set_start_method('spawn')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment