Unverified Commit 841634ca authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Check network interface name when initializing Userbuffers (#1175)



* Check if network interface name is valid and show useful warning message when initializing Userbuffers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix formatting issue in warning message.
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 7e1068b3
...@@ -138,6 +138,8 @@ def initialize_ub( ...@@ -138,6 +138,8 @@ def initialize_ub(
) )
if ifname is not None: if ifname is not None:
# Make sure the ifname found in the environment is a valid network interface
if ifname in [name for _, name in socket.if_nameindex()]:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try: try:
hostname = socket.inet_ntoa( hostname = socket.inet_ntoa(
...@@ -147,6 +149,18 @@ def initialize_ub( ...@@ -147,6 +149,18 @@ def initialize_ub(
) )
except OSError as err: except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err raise OSError(f"Invalid network interface: {ifname}") from err
finally:
s.close()
else:
ifname_warning = (
f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will"
" attempt to "
+ "detect ranks on the same node by matching 'socket.gethostname()', which is "
+ "known to fail on virtual clusters like Kubernetes. If Userbuffers "
+ "initialization fails, please set the 'NVTE_UB_SOCKET_IFNAME' variable in "
+ "your environment to the correct network interface."
)
warnings.warn(ifname_warning, UserWarning)
hostnames = [None for _ in range(world_size)] hostnames = [None for _ in range(world_size)]
torch.distributed.all_gather_object(hostnames, hostname, world_group) torch.distributed.all_gather_object(hostnames, hostname, world_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment