Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7bf1e98b
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "9d3124ac8ba5e0f6a6c2e234b90c7d26bf8cb84f"
Unverified
Commit
7bf1e98b
authored
Jan 17, 2022
by
ver217
Committed by
GitHub
Jan 17, 2022
Browse files
pipeline last stage supports multi output (#151)
parent
1ff5be36
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
18 deletions
+35
-18
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+34
-17
tests/test_trainer/test_pipeline/model/resnet.py
tests/test_trainer/test_pipeline/model/resnet.py
+1
-1
No files found.
colossalai/engine/schedule/_pipeline_schedule.py
View file @
7bf1e98b
...
...
@@ -4,7 +4,6 @@
from
typing
import
List
,
Tuple
,
Union
,
Callable
import
inspect
import
torch.cuda
from
torch
import
Tensor
import
colossalai.communication
as
comm
from
colossalai.context.parallel_mode
import
ParallelMode
...
...
@@ -14,14 +13,27 @@ from colossalai.utils.cuda import get_current_device
from
colossalai.zero
import
(
ZeroRedundancyOptimizer_Level_2
,
ZeroRedundancyOptimizer_Level_3
)
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.logging
import
get_dist_logger
from
._base_schedule
import
BaseSchedule
def
squeeze
(
x
:
Union
[
Tensor
,
tuple
,
list
]):
if
isinstance
(
x
,
(
tuple
,
list
)):
return
x
[
0
]
def
pack_return_tensors
(
return_tensors
):
output
,
label
=
tuple
(
zip
(
*
return_tensors
))
if
isinstance
(
output
[
0
],
torch
.
Tensor
):
output
=
torch
.
cat
(
output
,
dim
=
0
)
elif
isinstance
(
output
[
0
],
(
list
,
tuple
)):
output
=
tuple
(
torch
.
cat
(
tensors
,
dim
=
0
)
for
tensors
in
zip
(
*
output
))
else
:
return
x
raise
TypeError
(
f
'Output of model must be tensor or list/tuple of tensors'
)
if
isinstance
(
label
[
0
],
torch
.
Tensor
):
label
=
torch
.
cat
(
label
,
dim
=
0
)
else
:
merged_label
=
{
k
:
[]
for
k
in
label
[
0
].
keys
()}
for
d
in
label
:
for
k
,
v
in
d
.
items
():
merged_label
[
k
].
append
(
v
)
label
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
merged_label
.
items
()}
return
output
,
label
class
PipelineSchedule
(
BaseSchedule
):
...
...
@@ -49,6 +61,7 @@ class PipelineSchedule(BaseSchedule):
self
.
scatter_gather_tensors
=
False
if
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_1D
)
and
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
>
1
:
self
.
scatter_gather_tensors
=
scatter_gather_tensors
self
.
_logger
=
get_dist_logger
()
def
load_batch
(
self
,
data_iter
):
# Pipeline schedule just puts data in memory
...
...
@@ -129,7 +142,6 @@ class PipelineSchedule(BaseSchedule):
"""
data
,
label
=
self
.
load_micro_batch
()
output_tensor
=
self
.
_call_engine
(
engine
.
model
,
input_tensor
,
data
)
output_tensor
=
squeeze
(
output_tensor
)
if
gpc
.
is_last_rank
(
ParallelMode
.
PIPELINE
):
if
return_output_label
:
...
...
@@ -139,8 +151,13 @@ class PipelineSchedule(BaseSchedule):
accum_loss
.
add_
(
loss_reduced
.
detach
())
return
loss_reduced
else
:
# forward only, it's useless since backward is not needed
return
output_tensor
else
:
assert
isinstance
(
output_tensor
,
torch
.
Tensor
),
'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self
.
_logger
.
debug
(
f
'Global rank
{
gpc
.
get_global_rank
()
}
, pipeline rank
{
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
}
forward output tensor
{
output_tensor
.
shape
}
, dtype
{
output_tensor
.
dtype
}
'
)
return
output_tensor
def
backward_step
(
self
,
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
...
...
@@ -319,12 +336,10 @@ class PipelineSchedule(BaseSchedule):
comm
.
send_backward
(
input_tensor_grad
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
if
len
(
return_tensors
)
>
0
:
output
,
label
=
tuple
(
map
(
list
,
zip
(
*
return_tensors
)))
return
(
torch
.
cat
(
output
,
dim
=
0
),
torch
.
cat
(
label
,
dim
=
0
),
accum_loss
)
output
,
label
=
pack_return_tensors
(
return_tensors
)
return
output
,
label
,
accum_loss
else
:
return
tuple
((
None
,
None
,
accum_loss
))
return
None
,
None
,
accum_loss
class
InterleavedPipelineSchedule
(
PipelineSchedule
):
...
...
@@ -389,7 +404,6 @@ class InterleavedPipelineSchedule(PipelineSchedule):
"""
data
,
label
=
self
.
load_micro_batch
(
model_chunk_id
)
output_tensor
=
self
.
_call_engine
(
engine
.
model
[
model_chunk_id
],
input_tensor
,
data
)
output_tensor
=
squeeze
(
output_tensor
)
if
gpc
.
is_pipeline_last_stage
():
if
return_output_label
:
...
...
@@ -399,8 +413,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
accum_loss
.
add_
(
loss_reduced
.
detach
())
return
loss_reduced
else
:
# forward only, it's useless since backward is not needed
return
output_tensor
else
:
assert
isinstance
(
output_tensor
,
torch
.
Tensor
),
'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self
.
_logger
.
debug
(
f
'Global rank
{
gpc
.
get_global_rank
()
}
, pipeline rank
{
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
}
forward output tensor
{
output_tensor
.
shape
}
, dtype
{
output_tensor
.
dtype
}
'
)
return
output_tensor
def
forward_backward_step
(
self
,
engine
,
data_iter
,
forward_only
=
False
,
return_loss
=
True
,
return_output_label
=
True
):
...
...
@@ -665,9 +684,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
if
len
(
return_tensors
)
>
0
:
output
,
label
=
tuple
(
map
(
list
,
zip
(
*
return_tensors
)))
return
(
torch
.
cat
(
output
,
dim
=
0
),
torch
.
cat
(
label
,
dim
=
0
),
accum_loss
)
output
,
label
=
pack_return_tensors
(
return_tensors
)
return
output
,
label
,
accum_loss
else
:
return
tuple
((
None
,
None
,
accum_loss
))
return
None
,
None
,
accum_loss
tests/test_trainer/test_pipeline/model/resnet.py
View file @
7bf1e98b
...
...
@@ -139,7 +139,7 @@ class VanillaResNet(ModelFromConfig):
def
forward
(
self
,
x
:
Tensor
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
return
x
,
return
x
def
init_weights
(
self
):
for
m
in
self
.
modules
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment