Commit b7984a7e authored by zhuwenwen's avatar zhuwenwen
Browse files

fix socker error

parent 89683b9e
...@@ -6,6 +6,7 @@ import uuid ...@@ -6,6 +6,7 @@ import uuid
from platform import uname from platform import uname
from typing import List, Tuple, Union from typing import List, Tuple, Union
from packaging.version import parse, Version from packaging.version import parse, Version
import warnings
import psutil import psutil
import torch import torch
...@@ -170,16 +171,35 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: ...@@ -170,16 +171,35 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
def get_ip() -> str: def get_ip() -> str:
host_ip = os.environ.get("HOST_IP")
if host_ip:
return host_ip
# IP is not set, try to get it from the network interface
# try ipv4 # try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try: try:
s.connect(("dns.google", 80)) # Doesn't need to be reachable s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0] return s.getsockname()[0]
except OSError: except Exception:
pass
# try ipv6 # try ipv6
try:
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
s.connect(("dns.google", 80)) # Google's public DNS server, see
# https://developers.google.com/speed/public-dns/docs/using#addresses
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
return s.getsockname()[0] return s.getsockname()[0]
except Exception:
pass
warnings.warn(
"Failed to get the IP address, using 0.0.0.0 by default."
"The value can be set by the environment variable HOST_IP.",
stacklevel=2)
return "0.0.0.0"
def get_distributed_init_method(ip: str, port: int) -> str: def get_distributed_init_method(ip: str, port: int) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment