Unverified Commit 27d83851 authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

Use torch.cuda.device_count() (#126)

parent 8ad8a262
......@@ -5,7 +5,6 @@ Copyright 2020 The Microsoft DeepSpeed Team
import os
import sys
import json
import pynvml
import shutil
import base64
import logging
......@@ -13,6 +12,9 @@ import argparse
import subprocess
import collections
from copy import deepcopy
import torch.cuda
from deepspeed.pt.deepspeed_constants import TORCH_DISTRIBUTED_DEFAULT_PORT
DLTS_HOSTFILE = "/job/hostfile"
......@@ -213,19 +215,6 @@ def parse_inclusion_exclusion(resource_pool, inclusion, exclusion):
exclude_str=exclusion)
def local_gpu_count():
device_count = None
try:
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
print("device count", device_count)
return device_count
except pynvml.NVMLError:
logging.error("Unable to get GPU count information, perhaps there are "
"no GPUs on this host?")
return device_count
def encode_world_info(world_info):
world_info_json = json.dumps(world_info).encode('utf-8')
world_info_base64 = base64.urlsafe_b64encode(world_info_json).decode('utf-8')
......@@ -243,8 +232,8 @@ def main(args=None):
resource_pool = fetch_hostfile(args.hostfile)
if not resource_pool:
resource_pool = {}
device_count = local_gpu_count()
if device_count is None:
device_count = torch.cuda.device_count()
if device_count == 0:
raise RuntimeError("Unable to proceed, no GPU resources available")
resource_pool['localhost'] = device_count
args.master_addr = "127.0.0.1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment