Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
513d7d8e
Commit
513d7d8e
authored
Apr 19, 2023
by
Abhinav Khattar
Browse files
add enable autocast
Signed-off-by:
Abhinav Khattar
<
aklife97@gmail.com
>
parent
2699f93e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
10 deletions
+20
-10
megatron/core/pipeline_parallel/schedules.py
megatron/core/pipeline_parallel/schedules.py
+20
-10
No files found.
megatron/core/pipeline_parallel/schedules.py
View file @
513d7d8e
...
@@ -90,6 +90,9 @@ def get_forward_backward_func():
...
@@ -90,6 +90,9 @@ def get_forward_backward_func():
collect_non_loss_data: TODO
collect_non_loss_data: TODO
enable_autocat (optional, default=False): If True, runs the
forward_step_func call inside torch.autocast context
"""
"""
pipeline_model_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_model_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
if
pipeline_model_parallel_size
>
1
:
if
pipeline_model_parallel_size
>
1
:
...
@@ -166,7 +169,8 @@ def forward_step(forward_step_func,
...
@@ -166,7 +169,8 @@ def forward_step(forward_step_func,
input_tensor
,
input_tensor
,
forward_data_store
,
forward_data_store
,
timers
,
timers
,
collect_non_loss_data
=
False
):
collect_non_loss_data
=
False
,
enable_autocast
=
False
):
"""Forward step for passed-in model.
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
If first stage, input tensor is obtained from data_iterator, otherwise
...
@@ -184,7 +188,7 @@ def forward_step(forward_step_func,
...
@@ -184,7 +188,7 @@ def forward_step(forward_step_func,
set_input_tensor
=
get_attr_wrapped_model
(
model
,
"set_input_tensor"
)
set_input_tensor
=
get_attr_wrapped_model
(
model
,
"set_input_tensor"
)
set_input_tensor
(
input_tensor
)
set_input_tensor
(
input_tensor
)
context_manager
=
torch
.
autocast
(
"cuda"
)
if
torch
.
is
_autocast
_enabled
()
else
nullcontext
()
context_manager
=
torch
.
autocast
(
"cuda"
)
if
enable
_autocast
else
nullcontext
()
with
context_manager
:
with
context_manager
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
...
@@ -296,7 +300,8 @@ def forward_backward_no_pipelining(*,
...
@@ -296,7 +300,8 @@ def forward_backward_no_pipelining(*,
sequence_parallel
:
bool
=
False
,
# unused
sequence_parallel
:
bool
=
False
,
# unused
forward_only
:
bool
=
False
,
forward_only
:
bool
=
False
,
timers
:
Callable
=
None
,
timers
:
Callable
=
None
,
collect_non_loss_data
:
bool
=
False
):
collect_non_loss_data
:
bool
=
False
,
enable_autocast
:
bool
=
False
):
"""Run forward and backward passes with no pipeline parallelism
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
(no inter-stage communication).
...
@@ -320,7 +325,7 @@ def forward_backward_no_pipelining(*,
...
@@ -320,7 +325,7 @@ def forward_backward_no_pipelining(*,
for
i
in
range
(
num_microbatches
-
1
):
for
i
in
range
(
num_microbatches
-
1
):
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
timers
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
,
enable_autocast
)
if
not
forward_only
:
if
not
forward_only
:
backward_step
(
grad_scaler
,
input_tensor
,
output_tensor
,
backward_step
(
grad_scaler
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
timers
)
output_tensor_grad
,
model_type
,
timers
)
...
@@ -329,7 +334,7 @@ def forward_backward_no_pipelining(*,
...
@@ -329,7 +334,7 @@ def forward_backward_no_pipelining(*,
# synchronize gradients).
# synchronize gradients).
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
timers
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
,
enable_autocast
)
if
not
forward_only
:
if
not
forward_only
:
backward_step
(
grad_scaler
,
input_tensor
,
output_tensor
,
backward_step
(
grad_scaler
,
input_tensor
,
output_tensor
,
...
@@ -350,7 +355,8 @@ def forward_backward_pipelining_with_interleaving(*,
...
@@ -350,7 +355,8 @@ def forward_backward_pipelining_with_interleaving(*,
sequence_parallel
:
bool
=
False
,
sequence_parallel
:
bool
=
False
,
forward_only
:
bool
=
False
,
forward_only
:
bool
=
False
,
timers
:
Callable
=
None
,
timers
:
Callable
=
None
,
collect_non_loss_data
:
bool
=
False
):
collect_non_loss_data
:
bool
=
False
,
enable_autocast
:
bool
=
False
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
communication between pipeline stages as needed.
...
@@ -440,7 +446,8 @@ def forward_backward_pipelining_with_interleaving(*,
...
@@ -440,7 +446,8 @@ def forward_backward_pipelining_with_interleaving(*,
input_tensor
,
input_tensor
,
forward_data_store
,
forward_data_store
,
timers
,
timers
,
collect_non_loss_data
)
collect_non_loss_data
,
enable_autocast
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
# if forward-only, no need to save tensors for a backward pass
...
@@ -731,7 +738,8 @@ def forward_backward_pipelining_without_interleaving(*,
...
@@ -731,7 +738,8 @@ def forward_backward_pipelining_without_interleaving(*,
sequence_parallel
:
bool
=
False
,
sequence_parallel
:
bool
=
False
,
forward_only
:
bool
=
False
,
forward_only
:
bool
=
False
,
timers
:
Callable
=
None
,
timers
:
Callable
=
None
,
collect_non_loss_data
:
bool
=
False
):
collect_non_loss_data
:
bool
=
False
,
enable_autocast
:
bool
=
False
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
stages.
...
@@ -775,7 +783,9 @@ def forward_backward_pipelining_without_interleaving(*,
...
@@ -775,7 +783,9 @@ def forward_backward_pipelining_without_interleaving(*,
# Run warmup forward passes.
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
dtype
,
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
dtype
,
timers
=
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
timers
,
collect_non_loss_data
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
timers
,
collect_non_loss_data
,
enable_autocast
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
if
not
forward_only
:
...
@@ -795,7 +805,7 @@ def forward_backward_pipelining_without_interleaving(*,
...
@@ -795,7 +805,7 @@ def forward_backward_pipelining_without_interleaving(*,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
input_tensor
,
forward_data_store
,
timers
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
,
enable_autocast
)
if
forward_only
:
if
forward_only
:
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment