"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3d7eaf83d721ed1137ad1838b73be83c737721d4"
Commit 2eaae45d authored by chenych's avatar chenych
Browse files

support DCU

parent 3d98a379
...@@ -129,19 +129,11 @@ class Worker(WorkerHelper): ...@@ -129,19 +129,11 @@ class Worker(WorkerHelper):
self._rank = rank self._rank = rank
self._world_size = world_size self._world_size = world_size
if "AMD" in torch.cuda.get_device_name(): # DCU support
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES") os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("HIP_VISIBLE_DEVICES")
os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK") os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
cuda_visible_devices = os.getenv("LOCAL_RANK", "0") cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
torch.cuda.set_device(int(cuda_visible_devices)) torch.cuda.set_device(int(cuda_visible_devices))
## for DCU K100_AI, 通过 torch.cuda.get_device_name() 获取 device_name
if "K500SM_AI" in torch.cuda.get_device_name():
print("Init DCU Devices")
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("HIP_VISIBLE_DEVICES")
os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
torch.cuda.set_device(int(cuda_visible_devices))
master_addr = os.getenv("MASTER_ADDR") master_addr = os.getenv("MASTER_ADDR")
master_port = os.getenv("MASTER_PORT") master_port = os.getenv("MASTER_PORT")
......
...@@ -75,7 +75,6 @@ class FSDPWorker(Worker): ...@@ -75,7 +75,6 @@ class FSDPWorker(Worker):
if not dist.is_initialized(): if not dist.is_initialized():
self.print_rank0("Initializing distributed process group...") self.print_rank0("Initializing distributed process group...")
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
print(f"!!! Rank {dist.get_rank()} initialized successfully!")
# improve numerical stability # improve numerical stability
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -267,11 +266,6 @@ class FSDPWorker(Worker): ...@@ -267,11 +266,6 @@ class FSDPWorker(Worker):
sync_module_states = False sync_module_states = False
param_init_fn = None param_init_fn = None
# rank = torch.cuda.set_device(self.rank)
# model = model.to(rank)
local_rank = int(os.environ["LOCAL_RANK"])
print(f"!!! rank={self.rank}, local_rank={local_rank}, torch.cuda.current_device()={torch.cuda.current_device()}")
print(f"self.device_mesh = {self.device_mesh}")
self.fsdp_module = FSDP( self.fsdp_module = FSDP(
model, model,
sharding_strategy=sharding_strategy, sharding_strategy=sharding_strategy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment