Commit 3c57a9d8 authored by lishen's avatar lishen
Browse files

基于nvshmem的环境变量设置

parent 3954264c
...@@ -100,9 +100,16 @@ class Buffer: ...@@ -100,9 +100,16 @@ class Buffer:
root_unique_id = None root_unique_id = None
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
# Enable IBGDA # Enable IBGDA
self._setup_device_hca_mapping()
assert num_qps_per_rank > 0 assert num_qps_per_rank > 0
os.environ["NVSHMEM_DISABLE_P2P"] = "0" if allow_nvlink_for_low_latency_mode else "1" os.environ["NVSHMEM_DISABLE_P2P"] = "0" if allow_nvlink_for_low_latency_mode else "1"
os.environ["NVSHMEM_IB_ENABLE_IBGDA"] = "1" # os.environ["NVSHMEM_IB_ENABLE_IBGDA"] = "1"
os.environ["NVSHMEM_IB_ENABLE_IBGDA"] = "0" # force_use_ibrc
os.environ["NVSHMEM_IBGDA_NIC_HANDLER"] = "gpu"
os.environ["NVSHMEM_IB_DISABLE_DMABUF"] = "1"
os.environ["NVSHMEM_ENABLE_NIC_PE_MAPPING"] = "1"
os.environ["NVSHMEM_IBGDA_NUM_RC_PER_PE"] = f"{num_qps_per_rank}" os.environ["NVSHMEM_IBGDA_NUM_RC_PER_PE"] = f"{num_qps_per_rank}"
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
os.environ["NVSHMEM_QP_DEPTH"] = os.environ.get("NVSHMEM_QP_DEPTH", "1024") os.environ["NVSHMEM_QP_DEPTH"] = os.environ.get("NVSHMEM_QP_DEPTH", "1024")
...@@ -136,6 +143,36 @@ class Buffer: ...@@ -136,6 +143,36 @@ class Buffer:
self.runtime.sync(device_ids, ipc_handles, root_unique_id) self.runtime.sync(device_ids, ipc_handles, root_unique_id)
assert self.runtime.is_available() assert self.runtime.is_available()
def _setup_device_hca_mapping(self):
"""
Set up device to NIC mapping using DEEP_EP_DEVICE_TO_HCA_MAPPING environment variable.
The mapping format is: "0:mlx5_0:1,1:mlx5_1:1,..." where each entry maps a CUDA device ID
to an HCA name separated by colon. HCA name can include additional suffixes like ":1".
"""
if 'DEEP_EP_DEVICE_TO_HCA_MAPPING' in os.environ:
device_mapping = {}
mapping_str = os.environ['DEEP_EP_DEVICE_TO_HCA_MAPPING']
# Parse mapping string like "0:mlx5_0:1,1:mlx5_1:1,..."
for mapping in mapping_str.split(','):
assert ':' in mapping, f"Invalid mapping format '{mapping}' in DEEP_EP_DEVICE_TO_HCA_MAPPING. Expected format: '<device_id>:<hca_name>'"
parts = mapping.split(':', 1) # Split only on first colon
device_id = int(parts[0])
hca_name = parts[1] # Keep the rest as HCA name (including :1)
device_mapping[device_id] = hca_name
# Get current device and set appropriate HCA
current_device = torch.cuda.current_device()
# Translate CUDA_VISIBLE_DEVICES
if 'CUDA_VISIBLE_DEVICES' in os.environ:
visible_devices = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
assert len(visible_devices) > current_device, f"CUDA_VISIBLE_DEVICES has {len(visible_devices)} entries which is fewer than the current device {current_device}"
assert visible_devices[current_device].isdigit(), f"DEEP_EP_DEVICE_TO_HCA_MAPPING requires CUDA_VISIBLE_DEVICES to contain integer indices"
current_device = int(visible_devices[current_device])
assert current_device in device_mapping, f"Current CUDA device {current_device} not found in DEEP_EP_DEVICE_TO_HCA_MAPPING"
os.environ['NVSHMEM_ENABLE_PE_MAPPING'] = '1'
os.environ['NVSHMEM_HCA_LIST'] = device_mapping[current_device]
def destroy(self): def destroy(self):
""" """
Destroy the cpp runtime and release resources. Destroy the cpp runtime and release resources.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment