"docs/vscode:/vscode.git/clone" did not exist on "fe2b6ca6e8cdf652e36d48f5a88c58f13c53ad8c"
Commit 3d98a379 authored by chenych's avatar chenych
Browse files

Fix devices recognize

parent 20247eb8
...@@ -135,6 +135,14 @@ class Worker(WorkerHelper): ...@@ -135,6 +135,14 @@ class Worker(WorkerHelper):
cuda_visible_devices = os.getenv("LOCAL_RANK", "0") cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
torch.cuda.set_device(int(cuda_visible_devices)) torch.cuda.set_device(int(cuda_visible_devices))
## for DCU K100_AI, 通过 torch.cuda.get_device_name() 获取 device_name
if "K500SM_AI" in torch.cuda.get_device_name():
print("Init DCU Devices")
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("HIP_VISIBLE_DEVICES")
os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
torch.cuda.set_device(int(cuda_visible_devices))
master_addr = os.getenv("MASTER_ADDR") master_addr = os.getenv("MASTER_ADDR")
master_port = os.getenv("MASTER_PORT") master_port = os.getenv("MASTER_PORT")
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
The main entry point to run the PPO algorithm The main entry point to run the PPO algorithm
""" """
import os
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import numpy as np import numpy as np
...@@ -71,7 +73,9 @@ class FSDPWorker(Worker): ...@@ -71,7 +73,9 @@ class FSDPWorker(Worker):
self.role = role self.role = role
if not dist.is_initialized(): if not dist.is_initialized():
self.print_rank0("Initializing distributed process group...")
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
print(f"!!! Rank {dist.get_rank()} initialized successfully!")
# improve numerical stability # improve numerical stability
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -265,7 +269,9 @@ class FSDPWorker(Worker): ...@@ -265,7 +269,9 @@ class FSDPWorker(Worker):
# rank = torch.cuda.set_device(self.rank) # rank = torch.cuda.set_device(self.rank)
# model = model.to(rank) # model = model.to(rank)
print(f"!!! local_rank={self.rank}, torch.cuda.current_device()={torch.cuda.current_device()}") local_rank = int(os.environ["LOCAL_RANK"])
print(f"!!! rank={self.rank}, local_rank={local_rank}, torch.cuda.current_device()={torch.cuda.current_device()}")
print(f"self.device_mesh = {self.device_mesh}")
self.fsdp_module = FSDP( self.fsdp_module = FSDP(
model, model,
sharding_strategy=sharding_strategy, sharding_strategy=sharding_strategy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment