Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
001abe23
Unverified
Commit
001abe23
authored
Feb 20, 2020
by
Jeff Rasley
Committed by
GitHub
Feb 20, 2020
Browse files
Refactor simple model test, fix pythonpath issue (#96)
Also a fix for #94
parent
f2d75135
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
138 additions
and
88 deletions
+138
-88
deepspeed/__init__.py
deepspeed/__init__.py
+4
-0
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+8
-0
deepspeed/pt/deepspeed_run.py
deepspeed/pt/deepspeed_run.py
+4
-1
install.sh
install.sh
+1
-6
tests/unit/simple_model.py
tests/unit/simple_model.py
+46
-0
tests/unit/test_config.py
tests/unit/test_config.py
+54
-0
tests/unit/test_fp16.py
tests/unit/test_fp16.py
+21
-81
No files found.
deepspeed/__init__.py
View file @
001abe23
...
...
@@ -128,6 +128,10 @@ def _add_core_arguments(parser):
type
=
str
,
help
=
'DeepSpeed json configuration file.'
)
group
.
add_argument
(
'--deepscale_config'
,
default
=
None
,
type
=
str
,
help
=
'Deprecated DeepSpeed json configuration file.'
)
return
parser
...
...
deepspeed/pt/deepspeed_light.py
View file @
001abe23
...
...
@@ -322,6 +322,14 @@ class DeepSpeedLight(Module):
# Validate command line arguments
def
_do_args_sanity_check
(
self
,
args
):
if
hasattr
(
args
,
'deepscale_config'
)
and
args
.
deepscale_config
is
not
None
:
logging
.
warning
(
"************ --deepscale_config is deprecated, please use --deepspeed_config ************"
)
if
hasattr
(
args
,
'deepspeed_config'
):
assert
args
.
deepspeed_config
is
None
,
"Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
args
.
deepspeed_config
=
args
.
deepscale_config
assert
hasattr
(
args
,
'local_rank'
)
and
type
(
args
.
local_rank
)
==
int
,
\
'DeepSpeed requires integer command line parameter --local_rank'
...
...
deepspeed/pt/deepspeed_run.py
View file @
001abe23
...
...
@@ -306,7 +306,10 @@ def main(args=None):
num_gpus_per_node
=
None
curr_path
=
os
.
path
.
abspath
(
'.'
)
if
'PYTHONPATH'
in
env
:
env
[
'PYTHONPATH'
]
=
curr_path
+
":"
+
env
[
'PYTHONPATH'
]
else
:
env
[
'PYTHONPATH'
]
=
curr_path
exports
=
""
for
var
in
env
.
keys
():
...
...
install.sh
View file @
001abe23
...
...
@@ -109,16 +109,11 @@ if [ "$third_party_install" == "1" ]; then
sudo
-H
pip
install
third_party/apex/dist/apex
*
.whl
fi
if
[
"
$deepspeed_install
"
==
"1"
]
;
then
echo
"
Install
ing deepspeed"
echo
"
Build
ing deepspeed
wheel
"
python setup.py bdist_wheel
fi
if
[
"
$local_only
"
==
"1"
]
;
then
if
[
"
$third_party_install
"
==
"1"
]
;
then
echo
"Installing apex locally"
sudo
-H
pip uninstall
-y
apex
sudo
-H
pip
install
third_party/apex/dist/apex
*
.whl
fi
if
[
"
$deepspeed_install
"
==
"1"
]
;
then
echo
"Installing deepspeed"
sudo
-H
pip uninstall
-y
deepspeed
...
...
tests/unit/simple_model.py
0 → 100644
View file @
001abe23
import
os
import
json
import
argparse
import
torch
class
SimpleModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
empty_grad
=
False
):
super
(
SimpleModel
,
self
).
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)
if
empty_grad
:
self
.
layers2
=
torch
.
nn
.
ModuleList
([
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)])
self
.
cross_entropy_loss
=
torch
.
nn
.
CrossEntropyLoss
()
def
forward
(
self
,
x
,
y
):
hidden_dim
=
x
hidden_dim
=
self
.
linear
(
hidden_dim
)
return
self
.
cross_entropy_loss
(
hidden_dim
,
y
)
def
random_dataloader
(
model
,
total_samples
,
hidden_dim
,
device
):
batch_size
=
model
.
train_micro_batch_size_per_gpu
()
train_data
=
torch
.
randn
(
total_samples
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
half
)
train_label
=
torch
.
empty
(
total_samples
,
dtype
=
torch
.
long
,
device
=
device
).
random_
(
hidden_dim
)
train_dataset
=
torch
.
utils
.
data
.
TensorDataset
(
train_data
,
train_label
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
batch_size
)
return
train_loader
def
create_config_from_dict
(
tmpdir
,
config_dict
):
config_path
=
os
.
path
.
join
(
tmpdir
,
'temp_config.json'
)
with
open
(
config_path
,
'w'
)
as
fd
:
json
.
dump
(
config_dict
,
fd
)
return
config_path
def
args_from_dict
(
tmpdir
,
config_dict
):
config_path
=
create_config_from_dict
(
tmpdir
,
config_dict
)
parser
=
argparse
.
ArgumentParser
()
args
=
parser
.
parse_args
(
args
=
''
)
args
.
deepspeed
=
True
args
.
deepspeed_config
=
config_path
args
.
local_rank
=
0
return
args
tests/unit/test_config.py
View file @
001abe23
# A test on its own
import
torch
import
pytest
import
json
import
argparse
from
common
import
distributed_test
from
simple_model
import
SimpleModel
,
create_config_from_dict
,
random_dataloader
import
torch.distributed
as
dist
# A test on its own
...
...
@@ -100,3 +103,54 @@ def test_batch_config(num_ranks, batch, micro_batch, gas, success):
"""Run batch config test """
_test_batch_config
(
num_ranks
,
batch
,
micro_batch
,
gas
,
success
)
def
test_temp_config_json
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
1
,
}
config_path
=
create_config_from_dict
(
tmpdir
,
config_dict
)
config_json
=
json
.
load
(
open
(
config_path
,
'r'
))
assert
'train_batch_size'
in
config_json
def
test_deprecated_deepscale_config
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.00015
}
},
"fp16"
:
{
"enabled"
:
True
}
}
config_path
=
create_config_from_dict
(
tmpdir
,
config_dict
)
parser
=
argparse
.
ArgumentParser
()
args
=
parser
.
parse_args
(
args
=
''
)
args
.
deepscale_config
=
config_path
args
.
local_rank
=
0
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
)
@
distributed_test
(
world_size
=
[
1
])
def
_test_deprecated_deepscale_config
(
args
,
model
,
hidden_dim
):
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
(),
dist_init_required
=
False
)
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
5
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
model
.
step
()
_test_deprecated_deepscale_config
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
tests/unit/test_fp16.py
View file @
001abe23
...
...
@@ -5,67 +5,7 @@ import pytest
import
json
import
os
from
common
import
distributed_test
def
create_config_from_dict
(
tmpdir
,
config_dict
):
config_path
=
os
.
path
.
join
(
tmpdir
,
'temp_config.json'
)
with
open
(
config_path
,
'w'
)
as
fd
:
json
.
dump
(
config_dict
,
fd
)
return
config_path
class
SimpleModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
empty_grad
=
False
):
super
(
SimpleModel
,
self
).
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)
if
empty_grad
:
self
.
layers2
=
torch
.
nn
.
ModuleList
([
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)])
self
.
cross_entropy_loss
=
torch
.
nn
.
CrossEntropyLoss
()
def
forward
(
self
,
x
,
y
):
hidden_dim
=
x
hidden_dim
=
self
.
linear
(
hidden_dim
)
return
self
.
cross_entropy_loss
(
hidden_dim
,
y
)
def
test_temp_config_json
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
1
,
}
config_path
=
create_config_from_dict
(
tmpdir
,
config_dict
)
config_json
=
json
.
load
(
open
(
config_path
,
'r'
))
assert
'train_batch_size'
in
config_json
def
prepare_optimizer_parameters
(
model
):
param_optimizer
=
list
(
model
.
named_parameters
())
optimizer_grouped_parameters
=
[{
'params'
:
[
p
for
n
,
p
in
param_optimizer
],
'weight_decay'
:
0.0
}]
return
optimizer_grouped_parameters
def
get_data_loader
(
model
,
total_samples
,
hidden_dim
,
device
):
batch_size
=
model
.
train_micro_batch_size_per_gpu
()
train_data
=
torch
.
randn
(
total_samples
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
half
)
train_label
=
torch
.
empty
(
total_samples
,
dtype
=
torch
.
long
,
device
=
device
).
random_
(
hidden_dim
)
train_dataset
=
torch
.
utils
.
data
.
TensorDataset
(
train_data
,
train_label
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
batch_size
)
return
train_loader
def
get_args
(
tmpdir
,
config_dict
):
config_path
=
create_config_from_dict
(
tmpdir
,
config_dict
)
parser
=
argparse
.
ArgumentParser
()
args
=
parser
.
parse_args
(
args
=
''
)
args
.
deepspeed
=
True
args
.
deepspeed_config
=
config_path
args
.
local_rank
=
0
return
args
from
simple_model
import
SimpleModel
,
random_dataloader
,
args_from_dict
def
test_lamb_fp16_basic
(
tmpdir
):
...
...
@@ -83,7 +23,7 @@ def test_lamb_fp16_basic(tmpdir):
"enabled"
:
True
}
}
args
=
get_
args
(
tmpdir
,
config_dict
)
args
=
args
_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
...
...
@@ -94,7 +34,7 @@ def test_lamb_fp16_basic(tmpdir):
model
=
model
,
model_parameters
=
model
.
parameters
(),
dist_init_required
=
False
)
data_loader
=
get
_data
_
loader
(
model
=
model
,
data_loader
=
random
_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
...
...
@@ -121,7 +61,7 @@ def test_lamb_fp16_empty_grad(tmpdir):
"enabled"
:
True
}
}
args
=
get_
args
(
tmpdir
,
config_dict
)
args
=
args
_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
True
)
...
...
@@ -132,7 +72,7 @@ def test_lamb_fp16_empty_grad(tmpdir):
model
=
model
,
model_parameters
=
model
.
parameters
(),
dist_init_required
=
False
)
data_loader
=
get
_data
_
loader
(
model
=
model
,
data_loader
=
random
_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
...
...
@@ -152,7 +92,7 @@ def test_adamw_fp16_basic(tmpdir):
"enabled"
:
True
}
}
args
=
get_
args
(
tmpdir
,
config_dict
)
args
=
args
_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
...
...
@@ -164,7 +104,7 @@ def test_adamw_fp16_basic(tmpdir):
model
=
model
,
optimizer
=
optimizer
,
dist_init_required
=
False
)
data_loader
=
get
_data
_
loader
(
model
=
model
,
data_loader
=
random
_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
...
...
@@ -184,7 +124,7 @@ def test_adamw_fp16_empty_grad(tmpdir):
"enabled"
:
True
}
}
args
=
get_
args
(
tmpdir
,
config_dict
)
args
=
args
_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
True
)
...
...
@@ -196,7 +136,7 @@ def test_adamw_fp16_empty_grad(tmpdir):
model
=
model
,
optimizer
=
optimizer
,
dist_init_required
=
False
)
data_loader
=
get
_data
_
loader
(
model
=
model
,
data_loader
=
random
_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment