Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
b652395e
Unverified
Commit
b652395e
authored
May 29, 2020
by
Chunyang Wen
Committed by
GitHub
May 28, 2020
Browse files
fix: typo (#238)
* fix: typo in code docs * more pythonic code
parent
6fe0edb8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
16 deletions
+11
-16
deepspeed/pt/deepspeed_checkpointing.py
deepspeed/pt/deepspeed_checkpointing.py
+3
-6
deepspeed/pt/deepspeed_config_utils.py
deepspeed/pt/deepspeed_config_utils.py
+8
-10
No files found.
deepspeed/pt/deepspeed_checkpointing.py
View file @
b652395e
...
@@ -14,6 +14,7 @@ b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
...
@@ -14,6 +14,7 @@ b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
# Parts of the code here are adapted from PyTorch
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
# repo: https://github.com/pytorch/pytorch
import
contextlib
import
contextlib
import
copy
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch
import
torch
from
torch
import
_C
from
torch
import
_C
...
@@ -108,7 +109,7 @@ def detach_variable(inputs, device=None):
...
@@ -108,7 +109,7 @@ def detach_variable(inputs, device=None):
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
"""Sets the random number generator state of the current GPU.
"""Sets the random number generator state of the current GPU.
Arguments
s
:
Arguments:
new_state (torch.ByteTensor): The desired state
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
with a single change: the input state is not cloned. Cloning caused
...
@@ -160,10 +161,7 @@ class CudaRNGStatesTracker:
...
@@ -160,10 +161,7 @@ class CudaRNGStatesTracker:
def
get_states
(
self
):
def
get_states
(
self
):
"""Get rng states. Copy the dictionary so we have direct
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
pointers to the states, not just a pointer to the dictionary."""
states
=
{}
return
copy
.
copy
(
self
.
states_
)
for
name
in
self
.
states_
:
states
[
name
]
=
self
.
states_
[
name
]
return
states
def
set_states
(
self
,
states
):
def
set_states
(
self
,
states
):
"""Set the rng states. For efficiency purposes, we do not check
"""Set the rng states. For efficiency purposes, we do not check
...
@@ -720,5 +718,4 @@ def is_configured():
...
@@ -720,5 +718,4 @@ def is_configured():
Return:
Return:
True of configured, else False
True of configured, else False
"""
"""
global
deepspeed_checkpointing_enabled
return
deepspeed_checkpointing_enabled
return
deepspeed_checkpointing_enabled
deepspeed/pt/deepspeed_config_utils.py
View file @
b652395e
...
@@ -6,20 +6,18 @@ Licensed under the MIT license.
...
@@ -6,20 +6,18 @@ Licensed under the MIT license.
Collection of DeepSpeed configuration utilities
Collection of DeepSpeed configuration utilities
"""
"""
from
collections
import
Counter
def
get_scalar_param
(
param_dict
,
param_name
,
param_default_value
):
def
get_scalar_param
(
param_dict
,
param_name
,
param_default_value
):
if
param_name
in
param_dict
.
keys
():
return
param_dict
.
get
(
param_name
,
param_default_value
)
return
param_dict
[
param_name
]
else
:
return
param_default_value
def
dict_raise_error_on_duplicate_keys
(
ordered_pairs
):
def
dict_raise_error_on_duplicate_keys
(
ordered_pairs
):
"""Reject duplicate keys."""
"""Reject duplicate keys."""
d
=
{}
d
=
dict
((
k
,
v
)
for
k
,
v
in
ordered_pairs
)
for
k
,
v
in
ordered_pairs
:
if
len
(
d
)
!=
len
(
ordered_pairs
):
if
k
in
d
:
counter
=
Counter
([
pair
[
0
]
for
pair
in
ordered_pairs
])
raise
ValueError
(
"Duplicate key in DeepSpeed config: %r"
%
(
k
,
))
keys
=
[
key
for
key
,
value
in
counter
.
items
()
if
value
>
1
]
else
:
raise
ValueError
(
"Duplicate keys in DeepSpeed config: {}"
.
format
(
keys
))
d
[
k
]
=
v
return
d
return
d
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment