Commit 9d4f7ef8 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Surpass type checks

parent b56f7c2c
import inspect import inspect
import numpy as np
import os import os
import sys import sys
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from typing import Optional from typing import Optional
...@@ -23,7 +23,8 @@ def init_dist(local_rank: int, num_local_ranks: int): ...@@ -23,7 +23,8 @@ def init_dist(local_rank: int, num_local_ranks: int):
'rank': node_rank * num_local_ranks + local_rank, 'rank': node_rank * num_local_ranks + local_rank,
} }
if 'device_id' in sig.parameters: if 'device_id' in sig.parameters:
params['device_id'] = torch.device(f"cuda:{local_rank}") # noinspection PyTypeChecker
params['device_id'] = torch.device(f'cuda:{local_rank}')
dist.init_process_group(**params) dist.init_process_group(**params)
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cuda') torch.set_default_device('cuda')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment