Unverified Commit 71c1617b authored by Yuting Jiang's avatar Yuting Jiang Committed by GitHub
Browse files

Utils: Code Revision - Update network common utils (#118)


Update network common utils. Add get_ib_devices in network common utils and move get_free_port from test utils to network common utils
parent 9c984c7e
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
'create_sb_output_dir', 'create_sb_output_dir',
'get_sb_config', 'get_sb_config',
'logger', 'logger',
'network',
'nv_helper', 'nv_helper',
'rotate_dir', 'rotate_dir',
] ]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Network Utility."""
import socket
import re
from pathlib import Path
def get_free_port():
"""Get a free port in current system.
Return:
port (int): a free port in current system.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
except OSError:
return None
finally:
s.close()
def get_ib_devices():
"""Get available IB devices with available ports in the system and filter ethernet devices.
Return:
ib_devices_port (list): IB devices with available ports in current system.
"""
devices = list(p.name for p in Path('/sys/class/infiniband').glob('*'))
ib_devices_port_dict = {}
for device in devices:
ports = list(p.name for p in (Path('/sys/class/infiniband') / device / 'ports').glob('*'))
ports.sort(key=lambda s: [int(ch) if ch.isdigit() else ch for ch in re.split(r'(\d+)', s)])
for port in ports:
with (Path('/sys/class/infiniband') / device / 'ports' / port / 'link_layer').open('r') as f:
# Filter 'InfiniBand' devices by link_layer
if f.read().strip() == 'InfiniBand':
if device not in ib_devices_port_dict:
ib_devices_port_dict[device] = [port]
else:
ib_devices_port_dict[device].append(port)
ib_devices = list(ib_devices_port_dict.keys())
ib_devices.sort(key=lambda s: [int(ch) if ch.isdigit() else ch for ch in re.split(r'(\d+)', s)])
ib_devices_port = []
for device in ib_devices:
ib_devices_port.append(device + ':' + ','.join(ib_devices_port_dict[device]))
return ib_devices_port
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
from superbench.benchmarks import BenchmarkRegistry, Framework, BenchmarkType, ReturnCode from superbench.benchmarks import BenchmarkRegistry, Framework, BenchmarkType, ReturnCode
from superbench.benchmarks.micro_benchmarks.computation_communication_overlap \ from superbench.benchmarks.micro_benchmarks.computation_communication_overlap \
import ComputationCommunicationOverlap, ComputationKernelType import ComputationCommunicationOverlap, ComputationKernelType
from superbench.common.utils import network
# TODO - replace unittest.skip("no multiple GPUs") to decorator of skipIfNoMultiGPUS # TODO - replace unittest.skip("no multiple GPUs") to decorator of skipIfNoMultiGPUS
...@@ -26,6 +27,7 @@ def test_pytorch_computation_communication_overlap_normal(): ...@@ -26,6 +27,7 @@ def test_pytorch_computation_communication_overlap_normal():
world_size = 2 world_size = 2
assert (BenchmarkRegistry.is_benchmark_context_valid(context)) assert (BenchmarkRegistry.is_benchmark_context_valid(context))
results = utils.simulated_ddp_distributed_benchmark(context, world_size) results = utils.simulated_ddp_distributed_benchmark(context, world_size)
assert (results)
for benchmark in results: for benchmark in results:
# Check basic information. # Check basic information.
assert (benchmark) assert (benchmark)
...@@ -56,7 +58,9 @@ def test_pytorch_computation_communication_overlap_fake_distributed(): ...@@ -56,7 +58,9 @@ def test_pytorch_computation_communication_overlap_fake_distributed():
parameters='--num_warmup 5 --num_steps 10 --ratio 5', parameters='--num_warmup 5 --num_steps 10 --ratio 5',
framework=Framework.PYTORCH framework=Framework.PYTORCH
) )
utils.setup_simulated_ddp_distributed_env(1, 0, utils.get_free_port()) port = network.get_free_port()
assert (port)
utils.setup_simulated_ddp_distributed_env(1, 0, port)
benchmark = BenchmarkRegistry.launch_benchmark(context) benchmark = BenchmarkRegistry.launch_benchmark(context)
# Check basic information. # Check basic information.
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
from tests.helper import decorator from tests.helper import decorator
from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode
from superbench.benchmarks.micro_benchmarks.sharding_matmul import ShardingMatmul, ShardingMode from superbench.benchmarks.micro_benchmarks.sharding_matmul import ShardingMatmul, ShardingMode
from superbench.common.utils import network
@decorator.cuda_test @decorator.cuda_test
...@@ -22,7 +23,9 @@ def test_pytorch_sharding_matmul(): ...@@ -22,7 +23,9 @@ def test_pytorch_sharding_matmul():
assert (BenchmarkRegistry.is_benchmark_context_valid(context)) assert (BenchmarkRegistry.is_benchmark_context_valid(context))
utils.setup_simulated_ddp_distributed_env(1, 0, utils.get_free_port()) port = network.get_free_port()
assert (port)
utils.setup_simulated_ddp_distributed_env(1, 0, port)
benchmark = BenchmarkRegistry.launch_benchmark(context) benchmark = BenchmarkRegistry.launch_benchmark(context)
# Check basic information. # Check basic information.
......
...@@ -4,12 +4,11 @@ ...@@ -4,12 +4,11 @@
"""Utilities for benchmark tests.""" """Utilities for benchmark tests."""
import os import os
import socket
from contextlib import closing
import multiprocessing as multiprocessing import multiprocessing as multiprocessing
from multiprocessing import Process from multiprocessing import Process
from superbench.benchmarks import BenchmarkRegistry from superbench.benchmarks import BenchmarkRegistry
from superbench.common.utils import network
def clean_simulated_ddp_distributed_env(): def clean_simulated_ddp_distributed_env():
...@@ -21,18 +20,6 @@ def clean_simulated_ddp_distributed_env(): ...@@ -21,18 +20,6 @@ def clean_simulated_ddp_distributed_env():
os.environ.pop('MASTER_PORT') os.environ.pop('MASTER_PORT')
def get_free_port():
"""Get a free port in current system.
Return:
port (int): a free port in current system.
"""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
def setup_simulated_ddp_distributed_env(world_size, local_rank, port): def setup_simulated_ddp_distributed_env(world_size, local_rank, port):
"""Function to setup the simulated DDP distributed envionment variables.""" """Function to setup the simulated DDP distributed envionment variables."""
os.environ['WORLD_SIZE'] = str(world_size) os.environ['WORLD_SIZE'] = str(world_size)
...@@ -58,7 +45,9 @@ def simulated_ddp_distributed_benchmark(context, world_size): ...@@ -58,7 +45,9 @@ def simulated_ddp_distributed_benchmark(context, world_size):
Return: Return:
results (list): list of benchmark results from #world_size number of processes. results (list): list of benchmark results from #world_size number of processes.
""" """
port = get_free_port() port = network.get_free_port()
if not port:
return None
process_list = [] process_list = []
multiprocessing.set_start_method('spawn') multiprocessing.set_start_method('spawn')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment