Unverified Commit b652395e authored by Chunyang Wen's avatar Chunyang Wen Committed by GitHub
Browse files

fix: typo (#238)

* fix: typo in code docs

* more pythonic code
parent 6fe0edb8
...@@ -14,6 +14,7 @@ b886b7bb972afe72bac0f5de4f42a4a7bae8ebef ...@@ -14,6 +14,7 @@ b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
import contextlib import contextlib
import copy
import torch.distributed as dist import torch.distributed as dist
import torch import torch
from torch import _C from torch import _C
...@@ -108,7 +109,7 @@ def detach_variable(inputs, device=None): ...@@ -108,7 +109,7 @@ def detach_variable(inputs, device=None):
def _set_cuda_rng_state(new_state, device=-1): def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU. """Sets the random number generator state of the current GPU.
Argumentss: Arguments:
new_state (torch.ByteTensor): The desired state new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused with a single change: the input state is not cloned. Cloning caused
...@@ -160,10 +161,7 @@ class CudaRNGStatesTracker: ...@@ -160,10 +161,7 @@ class CudaRNGStatesTracker:
def get_states(self): def get_states(self):
"""Get rng states. Copy the dictionary so we have direct """Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary.""" pointers to the states, not just a pointer to the dictionary."""
states = {} return copy.copy(self.states_)
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states): def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check """Set the rng states. For efficiency purposes, we do not check
...@@ -720,5 +718,4 @@ def is_configured(): ...@@ -720,5 +718,4 @@ def is_configured():
Return: Return:
True of configured, else False True of configured, else False
""" """
global deepspeed_checkpointing_enabled
return deepspeed_checkpointing_enabled return deepspeed_checkpointing_enabled
...@@ -6,20 +6,18 @@ Licensed under the MIT license. ...@@ -6,20 +6,18 @@ Licensed under the MIT license.
Collection of DeepSpeed configuration utilities Collection of DeepSpeed configuration utilities
""" """
from collections import Counter
def get_scalar_param(param_dict, param_name, param_default_value): def get_scalar_param(param_dict, param_name, param_default_value):
if param_name in param_dict.keys(): return param_dict.get(param_name, param_default_value)
return param_dict[param_name]
else:
return param_default_value
def dict_raise_error_on_duplicate_keys(ordered_pairs): def dict_raise_error_on_duplicate_keys(ordered_pairs):
"""Reject duplicate keys.""" """Reject duplicate keys."""
d = {} d = dict((k, v) for k, v in ordered_pairs)
for k, v in ordered_pairs: if len(d) != len(ordered_pairs):
if k in d: counter = Counter([pair[0] for pair in ordered_pairs])
raise ValueError("Duplicate key in DeepSpeed config: %r" % (k, )) keys = [key for key, value in counter.items() if value > 1]
else: raise ValueError("Duplicate keys in DeepSpeed config: {}".format(keys))
d[k] = v
return d return d
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment