Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
20557f70
Unverified
Commit
20557f70
authored
Mar 26, 2020
by
Shaden Smith
Committed by
GitHub
Mar 26, 2020
Browse files
Fix ThroughputTimer with hybrid parallelism. (#171)
parent
a76572dc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
58 deletions
+64
-58
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+64
-58
No files found.
deepspeed/pt/deepspeed_light.py
View file @
20557f70
...
@@ -142,18 +142,18 @@ class DeepSpeedLight(Module):
...
@@ -142,18 +142,18 @@ class DeepSpeedLight(Module):
self
.
_init_distributed
(
dist_init_required
)
self
.
_init_distributed
(
dist_init_required
)
# Configure distributed model
self
.
_configure_distributed_model
(
model
)
# Throughput timer
# Throughput timer
self
.
tput_timer
=
ThroughputTimer
(
self
.
tput_timer
=
ThroughputTimer
(
batch_size
=
self
.
train_micro_batch_size_per_gpu
(),
batch_size
=
self
.
train_micro_batch_size_per_gpu
(),
num_workers
=
self
.
world_size
,
num_workers
=
self
.
dp_
world_size
,
monitor_memory
=
False
)
monitor_memory
=
False
)
self
.
training_dataloader
=
self
.
deepspeed_io
(
self
.
training_dataloader
=
self
.
deepspeed_io
(
training_data
)
if
training_data
else
None
training_data
)
if
training_data
else
None
# Configure distributed model
self
.
_configure_distributed_model
(
model
)
# Configure optimizer and scheduler
# Configure optimizer and scheduler
self
.
optimizer
=
None
self
.
optimizer
=
None
self
.
lr_scheduler
=
None
self
.
lr_scheduler
=
None
...
@@ -324,17 +324,19 @@ class DeepSpeedLight(Module):
...
@@ -324,17 +324,19 @@ class DeepSpeedLight(Module):
def
_configure_checkpointing
(
self
,
dist_init_required
):
def
_configure_checkpointing
(
self
,
dist_init_required
):
dp_rank
=
torch
.
distributed
.
get_rank
(
dp_rank
=
self
.
global_rank
)
if
self
.
mpu
is
None
else
self
.
mpu
.
get_data_parallel_rank
()
if
self
.
mpu
:
dp_rank
=
self
.
mpu
.
get_data_parallel_rank
()
#only the first data parallel process needs to store the model checkpoint
#only the first data parallel process needs to store the model checkpoint
self
.
save_non_zero_checkpoint
=
True
if
dp_rank
==
0
else
False
self
.
save_non_zero_checkpoint
=
(
dp_rank
==
0
)
if
self
.
zero_optimization
():
if
self
.
zero_optimization
():
pp_rank
=
torch
.
distributed
.
get_rank
(
group
=
self
.
optimizer
.
dp_process_group
)
pp_rank
=
torch
.
distributed
.
get_rank
(
group
=
self
.
optimizer
.
dp_process_group
)
#only the first parameter parallel process needs to store the optimizer state checkpoints for zero
# Only the first parameter parallel process needs to store the
self
.
save_zero_checkpoint
=
True
if
pp_rank
==
dp_rank
else
False
# optimizer state checkpoints for zero
self
.
save_zero_checkpoint
=
(
pp_rank
==
dp_rank
)
def
_scheduler_from_config
(
self
,
optimizer
):
def
_scheduler_from_config
(
self
,
optimizer
):
scheduler_name
=
self
.
scheduler_name
()
scheduler_name
=
self
.
scheduler_name
()
...
@@ -621,11 +623,12 @@ class DeepSpeedLight(Module):
...
@@ -621,11 +623,12 @@ class DeepSpeedLight(Module):
allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True.
allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True.
"""
"""
if
self
.
is_gradient_accumulation_boundary
()
and
self
.
tensorboard_enabled
(
# Log training Loss
)
and
torch
.
distributed
.
get_rank
(
if
self
.
tensorboard_enabled
():
)
==
0
:
# deepspeed tensorboard support for loss
if
self
.
is_gradient_accumulation_boundary
():
if
self
.
global_rank
==
0
:
self
.
sample_count
+=
(
self
.
train_micro_batch_size_per_gpu
()
*
self
.
sample_count
+=
(
self
.
train_micro_batch_size_per_gpu
()
*
torch
.
distributed
.
get
_world_size
()
*
self
.
dp
_world_size
*
self
.
gradient_accumulation_steps
())
self
.
gradient_accumulation_steps
())
self
.
summary_events
=
[
self
.
summary_events
=
[
(
f
'Train/Samples/train_loss'
,
(
f
'Train/Samples/train_loss'
,
...
@@ -712,8 +715,10 @@ class DeepSpeedLight(Module):
...
@@ -712,8 +715,10 @@ class DeepSpeedLight(Module):
self
.
tput_timer
.
stop
(
report_progress
)
self
.
tput_timer
.
stop
(
report_progress
)
if
self
.
is_gradient_accumulation_boundary
()
and
self
.
tensorboard_enabled
(
# Log learning rate
)
and
torch
.
distributed
.
get_rank
()
==
0
:
# deepspeed tensorboard support for lr
if
self
.
tensorboard_enabled
():
if
self
.
is_gradient_accumulation_boundary
():
if
self
.
global_rank
==
0
:
self
.
summary_events
=
[(
f
'Train/Samples/lr'
,
self
.
summary_events
=
[(
f
'Train/Samples/lr'
,
self
.
get_lr
()[
0
],
self
.
get_lr
()[
0
],
self
.
sample_count
)]
self
.
sample_count
)]
...
@@ -731,9 +736,10 @@ class DeepSpeedLight(Module):
...
@@ -731,9 +736,10 @@ class DeepSpeedLight(Module):
'backward_allreduce_microstep'
,
'backward_allreduce_microstep'
,
'step_microstep'
'step_microstep'
])
])
# Log timing
if
self
.
tensorboard_enabled
():
if
self
.
is_gradient_accumulation_boundary
():
if
self
.
is_gradient_accumulation_boundary
():
if
self
.
tensorboard_enabled
()
and
torch
.
distributed
.
get_rank
(
if
self
.
global_rank
==
0
:
)
==
0
:
# this is done before the log because log resets timers
self
.
summary_events
=
[(
f
'Train/Samples/elapsed_time_ms_forward'
,
self
.
timers
(
'forward'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
self
.
summary_events
=
[(
f
'Train/Samples/elapsed_time_ms_forward'
,
self
.
timers
(
'forward'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
(
f
'Train/Samples/elapsed_time_ms_backward'
,
self
.
timers
(
'backward'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
(
f
'Train/Samples/elapsed_time_ms_backward'
,
self
.
timers
(
'backward'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
(
f
'Train/Samples/elapsed_time_ms_backward_inner'
,
self
.
timers
(
'backward_inner'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
(
f
'Train/Samples/elapsed_time_ms_backward_inner'
,
self
.
timers
(
'backward_inner'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
...
@@ -870,7 +876,7 @@ class DeepSpeedLight(Module):
...
@@ -870,7 +876,7 @@ class DeepSpeedLight(Module):
return
csr
return
csr
def
csr_all_gather
(
self
,
value
):
def
csr_all_gather
(
self
,
value
):
my_size
=
torch
.
LongTensor
([
value
.
size
()[
0
]]).
cuda
(
)
my_size
=
torch
.
LongTensor
([
value
.
size
()[
0
]]).
to
(
self
.
device
)
all_sizes
=
self
.
all_gather_scalar
(
my_size
)
all_sizes
=
self
.
all_gather_scalar
(
my_size
)
max_size
=
torch
.
cat
(
all_sizes
).
max
()
max_size
=
torch
.
cat
(
all_sizes
).
max
()
fill_size
=
(
max_size
-
my_size
)
fill_size
=
(
max_size
-
my_size
)
...
@@ -879,22 +885,22 @@ class DeepSpeedLight(Module):
...
@@ -879,22 +885,22 @@ class DeepSpeedLight(Module):
if
value
.
dim
()
==
1
:
if
value
.
dim
()
==
1
:
if
fill_size
>
0
:
if
fill_size
>
0
:
value
=
torch
.
cat
([
value
,
value
.
new_zeros
(
fill_size
)])
value
=
torch
.
cat
([
value
,
value
.
new_zeros
(
fill_size
)])
tensor_list
=
[
tensor_list
=
[
value
.
new_zeros
(
max_size
)
for
_
in
range
(
self
.
dp_world_size
)]
value
.
new_zeros
(
max_size
)
for
_
in
range
(
dist
.
get_world_size
())
]
else
:
else
:
if
fill_size
>
0
:
if
fill_size
>
0
:
value
=
torch
.
cat
([
value
,
value
.
new_zeros
(
fill_size
,
value
.
size
()[
1
])])
value
=
torch
.
cat
([
value
,
value
.
new_zeros
(
fill_size
,
value
.
size
()[
1
])])
tensor_list
=
[
tensor_list
=
[
value
.
new_zeros
(
max_size
,
value
.
new_zeros
(
max_size
,
value
.
size
()[
1
])
for
_
in
range
(
dist
.
get
_world_size
()
)
value
.
size
()[
1
])
for
_
in
range
(
self
.
dp
_world_size
)
]
]
dist
.
all_gather
(
tensor_list
,
value
,
group
=
self
.
data_parallel_group
)
dist
.
all_gather
(
tensor_list
,
value
,
group
=
self
.
data_parallel_group
)
tensors
=
[]
tensors
=
[]
for
dev_idx
,
t
in
enumerate
(
tensor_list
):
for
dev_idx
,
t
in
enumerate
(
tensor_list
):
size
=
all_sizes
[
dev_idx
][
0
]
size
=
all_sizes
[
dev_idx
][
0
]
tensors
.
append
(
t
.
index_select
(
0
,
torch
.
LongTensor
(
range
(
size
)).
cuda
()))
tensors
.
append
(
t
.
index_select
(
0
,
torch
.
LongTensor
(
range
(
size
)).
to
(
self
.
device
)))
return
tensors
return
tensors
...
@@ -1036,8 +1042,8 @@ class DeepSpeedLight(Module):
...
@@ -1036,8 +1042,8 @@ class DeepSpeedLight(Module):
def
_create_checkpoint_files
(
self
,
save_dir
,
tag
):
def
_create_checkpoint_files
(
self
,
save_dir
,
tag
):
#checkpoint files are created sequentially
#checkpoint files are created sequentially
for
rank
in
range
(
dist
.
get_
world_size
()
):
for
rank
in
range
(
self
.
world_size
):
if
rank
==
dist
.
get
_rank
()
:
if
rank
==
self
.
global
_rank
:
try
:
try
:
if
self
.
save_non_zero_checkpoint
:
if
self
.
save_non_zero_checkpoint
:
checkpoint_name
=
self
.
_get_ckpt_name
(
save_dir
,
tag
)
checkpoint_name
=
self
.
_get_ckpt_name
(
save_dir
,
tag
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment