Commit 9d4f7ef8 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Surpass type checks

parent b56f7c2c
import inspect
import numpy as np
import os
import sys
import numpy as np
import torch
import torch.distributed as dist
from typing import Optional
......@@ -23,7 +23,8 @@ def init_dist(local_rank: int, num_local_ranks: int):
'rank': node_rank * num_local_ranks + local_rank,
}
if 'device_id' in sig.parameters:
params['device_id'] = torch.device(f"cuda:{local_rank}")
# noinspection PyTypeChecker
params['device_id'] = torch.device(f'cuda:{local_rank}')
dist.init_process_group(**params)
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cuda')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment