Unverified Commit 27d83851 authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

Use torch.cuda.device_count() (#126)

parent 8ad8a262
...@@ -5,7 +5,6 @@ Copyright 2020 The Microsoft DeepSpeed Team ...@@ -5,7 +5,6 @@ Copyright 2020 The Microsoft DeepSpeed Team
import os import os
import sys import sys
import json import json
import pynvml
import shutil import shutil
import base64 import base64
import logging import logging
...@@ -13,6 +12,9 @@ import argparse ...@@ -13,6 +12,9 @@ import argparse
import subprocess import subprocess
import collections import collections
from copy import deepcopy from copy import deepcopy
import torch.cuda
from deepspeed.pt.deepspeed_constants import TORCH_DISTRIBUTED_DEFAULT_PORT from deepspeed.pt.deepspeed_constants import TORCH_DISTRIBUTED_DEFAULT_PORT
DLTS_HOSTFILE = "/job/hostfile" DLTS_HOSTFILE = "/job/hostfile"
...@@ -213,19 +215,6 @@ def parse_inclusion_exclusion(resource_pool, inclusion, exclusion): ...@@ -213,19 +215,6 @@ def parse_inclusion_exclusion(resource_pool, inclusion, exclusion):
exclude_str=exclusion) exclude_str=exclusion)
def local_gpu_count():
device_count = None
try:
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
print("device count", device_count)
return device_count
except pynvml.NVMLError:
logging.error("Unable to get GPU count information, perhaps there are "
"no GPUs on this host?")
return device_count
def encode_world_info(world_info): def encode_world_info(world_info):
world_info_json = json.dumps(world_info).encode('utf-8') world_info_json = json.dumps(world_info).encode('utf-8')
world_info_base64 = base64.urlsafe_b64encode(world_info_json).decode('utf-8') world_info_base64 = base64.urlsafe_b64encode(world_info_json).decode('utf-8')
...@@ -243,8 +232,8 @@ def main(args=None): ...@@ -243,8 +232,8 @@ def main(args=None):
resource_pool = fetch_hostfile(args.hostfile) resource_pool = fetch_hostfile(args.hostfile)
if not resource_pool: if not resource_pool:
resource_pool = {} resource_pool = {}
device_count = local_gpu_count() device_count = torch.cuda.device_count()
if device_count is None: if device_count == 0:
raise RuntimeError("Unable to proceed, no GPU resources available") raise RuntimeError("Unable to proceed, no GPU resources available")
resource_pool['localhost'] = device_count resource_pool['localhost'] = device_count
args.master_addr = "127.0.0.1" args.master_addr = "127.0.0.1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment