Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
708404d5
"container/vscode:/vscode.git/clone" did not exist on "adad2ecdd7485826d7ac926bf0e62caa958784ed"
Unverified
Commit
708404d5
authored
Jan 21, 2022
by
ver217
Committed by
GitHub
Jan 21, 2022
Browse files
fix pipeline forward return tensors (#176)
parent
6fb550ac
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
7 deletions
+8
-7
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+8
-7
No files found.
colossalai/engine/schedule/_pipeline_schedule.py
View file @
708404d5
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
typing
import
List
,
Tuple
,
Union
,
Callable
import
inspect
import
torch.cuda
from
typing
import
Callable
,
List
,
Tuple
,
Union
import
colossalai.communication
as
comm
import
torch.cuda
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
(
ZeroRedundancyOptimizer_Level_2
,
ZeroRedundancyOptimizer_Level_3
)
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.logging
import
get_dist_logger
from
._base_schedule
import
BaseSchedule
...
...
@@ -151,7 +152,7 @@ class PipelineSchedule(BaseSchedule):
if
gpc
.
is_last_rank
(
ParallelMode
.
PIPELINE
):
if
return_output_label
:
return_tensors
.
append
(
tuple
(
(
output_tensor
,
label
))
)
return_tensors
.
append
((
output_tensor
,
label
))
if
accum_loss
is
not
None
:
loss_reduced
=
self
.
_call_engine_criterion
(
engine
,
output_tensor
,
label
)
/
self
.
num_microbatches
accum_loss
.
add_
(
loss_reduced
.
detach
())
...
...
@@ -414,7 +415,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if
gpc
.
is_pipeline_last_stage
():
if
return_output_label
:
return_tensors
.
append
(
tuple
(
output_tensor
,
label
))
return_tensors
.
append
((
output_tensor
,
label
))
if
accum_loss
is
not
None
:
loss_reduced
=
self
.
_call_engine_criterion
(
engine
,
output_tensor
,
label
)
/
self
.
num_microbatches
accum_loss
.
add_
(
loss_reduced
.
detach
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment