Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7bf1e98b
Unverified
Commit
7bf1e98b
authored
Jan 17, 2022
by
ver217
Committed by
GitHub
Jan 17, 2022
Browse files
pipeline last stage supports multi output (#151)
parent
1ff5be36
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
18 deletions
+35
-18
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+34
-17
tests/test_trainer/test_pipeline/model/resnet.py
tests/test_trainer/test_pipeline/model/resnet.py
+1
-1
No files found.
colossalai/engine/schedule/_pipeline_schedule.py
View file @
7bf1e98b
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
from
typing
import
List
,
Tuple
,
Union
,
Callable
from
typing
import
List
,
Tuple
,
Union
,
Callable
import
inspect
import
inspect
import
torch.cuda
import
torch.cuda
from
torch
import
Tensor
import
colossalai.communication
as
comm
import
colossalai.communication
as
comm
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
...
@@ -14,14 +13,27 @@ from colossalai.utils.cuda import get_current_device
...
@@ -14,14 +13,27 @@ from colossalai.utils.cuda import get_current_device
from
colossalai.zero
import
(
ZeroRedundancyOptimizer_Level_2
,
from
colossalai.zero
import
(
ZeroRedundancyOptimizer_Level_2
,
ZeroRedundancyOptimizer_Level_3
)
ZeroRedundancyOptimizer_Level_3
)
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.logging
import
get_dist_logger
from
._base_schedule
import
BaseSchedule
from
._base_schedule
import
BaseSchedule
def
squeeze
(
x
:
Union
[
Tensor
,
tuple
,
list
]):
def
pack_return_tensors
(
return_tensors
):
if
isinstance
(
x
,
(
tuple
,
list
)):
output
,
label
=
tuple
(
zip
(
*
return_tensors
))
return
x
[
0
]
if
isinstance
(
output
[
0
],
torch
.
Tensor
):
output
=
torch
.
cat
(
output
,
dim
=
0
)
elif
isinstance
(
output
[
0
],
(
list
,
tuple
)):
output
=
tuple
(
torch
.
cat
(
tensors
,
dim
=
0
)
for
tensors
in
zip
(
*
output
))
else
:
else
:
return
x
raise
TypeError
(
f
'Output of model must be tensor or list/tuple of tensors'
)
if
isinstance
(
label
[
0
],
torch
.
Tensor
):
label
=
torch
.
cat
(
label
,
dim
=
0
)
else
:
merged_label
=
{
k
:
[]
for
k
in
label
[
0
].
keys
()}
for
d
in
label
:
for
k
,
v
in
d
.
items
():
merged_label
[
k
].
append
(
v
)
label
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
merged_label
.
items
()}
return
output
,
label
class
PipelineSchedule
(
BaseSchedule
):
class
PipelineSchedule
(
BaseSchedule
):
...
@@ -49,6 +61,7 @@ class PipelineSchedule(BaseSchedule):
...
@@ -49,6 +61,7 @@ class PipelineSchedule(BaseSchedule):
self
.
scatter_gather_tensors
=
False
self
.
scatter_gather_tensors
=
False
if
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_1D
)
and
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
>
1
:
if
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_1D
)
and
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
>
1
:
self
.
scatter_gather_tensors
=
scatter_gather_tensors
self
.
scatter_gather_tensors
=
scatter_gather_tensors
self
.
_logger
=
get_dist_logger
()
def
load_batch
(
self
,
data_iter
):
def
load_batch
(
self
,
data_iter
):
# Pipeline schedule just puts data in memory
# Pipeline schedule just puts data in memory
...
@@ -129,7 +142,6 @@ class PipelineSchedule(BaseSchedule):
...
@@ -129,7 +142,6 @@ class PipelineSchedule(BaseSchedule):
"""
"""
data
,
label
=
self
.
load_micro_batch
()
data
,
label
=
self
.
load_micro_batch
()
output_tensor
=
self
.
_call_engine
(
engine
.
model
,
input_tensor
,
data
)
output_tensor
=
self
.
_call_engine
(
engine
.
model
,
input_tensor
,
data
)
output_tensor
=
squeeze
(
output_tensor
)
if
gpc
.
is_last_rank
(
ParallelMode
.
PIPELINE
):
if
gpc
.
is_last_rank
(
ParallelMode
.
PIPELINE
):
if
return_output_label
:
if
return_output_label
:
...
@@ -139,8 +151,13 @@ class PipelineSchedule(BaseSchedule):
...
@@ -139,8 +151,13 @@ class PipelineSchedule(BaseSchedule):
accum_loss
.
add_
(
loss_reduced
.
detach
())
accum_loss
.
add_
(
loss_reduced
.
detach
())
return
loss_reduced
return
loss_reduced
else
:
else
:
# forward only, it's useless since backward is not needed
return
output_tensor
return
output_tensor
else
:
else
:
assert
isinstance
(
output_tensor
,
torch
.
Tensor
),
'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self
.
_logger
.
debug
(
f
'Global rank
{
gpc
.
get_global_rank
()
}
, pipeline rank
{
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
}
forward output tensor
{
output_tensor
.
shape
}
, dtype
{
output_tensor
.
dtype
}
'
)
return
output_tensor
return
output_tensor
def
backward_step
(
self
,
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
def
backward_step
(
self
,
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
...
@@ -319,12 +336,10 @@ class PipelineSchedule(BaseSchedule):
...
@@ -319,12 +336,10 @@ class PipelineSchedule(BaseSchedule):
comm
.
send_backward
(
input_tensor_grad
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
comm
.
send_backward
(
input_tensor_grad
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
if
len
(
return_tensors
)
>
0
:
if
len
(
return_tensors
)
>
0
:
output
,
label
=
tuple
(
map
(
list
,
zip
(
*
return_tensors
)))
output
,
label
=
pack_return_tensors
(
return_tensors
)
return
(
torch
.
cat
(
output
,
dim
=
0
),
return
output
,
label
,
accum_loss
torch
.
cat
(
label
,
dim
=
0
),
accum_loss
)
else
:
else
:
return
tuple
((
None
,
None
,
accum_loss
))
return
None
,
None
,
accum_loss
class
InterleavedPipelineSchedule
(
PipelineSchedule
):
class
InterleavedPipelineSchedule
(
PipelineSchedule
):
...
@@ -389,7 +404,6 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -389,7 +404,6 @@ class InterleavedPipelineSchedule(PipelineSchedule):
"""
"""
data
,
label
=
self
.
load_micro_batch
(
model_chunk_id
)
data
,
label
=
self
.
load_micro_batch
(
model_chunk_id
)
output_tensor
=
self
.
_call_engine
(
engine
.
model
[
model_chunk_id
],
input_tensor
,
data
)
output_tensor
=
self
.
_call_engine
(
engine
.
model
[
model_chunk_id
],
input_tensor
,
data
)
output_tensor
=
squeeze
(
output_tensor
)
if
gpc
.
is_pipeline_last_stage
():
if
gpc
.
is_pipeline_last_stage
():
if
return_output_label
:
if
return_output_label
:
...
@@ -399,8 +413,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -399,8 +413,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
accum_loss
.
add_
(
loss_reduced
.
detach
())
accum_loss
.
add_
(
loss_reduced
.
detach
())
return
loss_reduced
return
loss_reduced
else
:
else
:
# forward only, it's useless since backward is not needed
return
output_tensor
return
output_tensor
else
:
else
:
assert
isinstance
(
output_tensor
,
torch
.
Tensor
),
'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self
.
_logger
.
debug
(
f
'Global rank
{
gpc
.
get_global_rank
()
}
, pipeline rank
{
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
}
forward output tensor
{
output_tensor
.
shape
}
, dtype
{
output_tensor
.
dtype
}
'
)
return
output_tensor
return
output_tensor
def
forward_backward_step
(
self
,
engine
,
data_iter
,
forward_only
=
False
,
return_loss
=
True
,
return_output_label
=
True
):
def
forward_backward_step
(
self
,
engine
,
data_iter
,
forward_only
=
False
,
return_loss
=
True
,
return_output_label
=
True
):
...
@@ -665,9 +684,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -665,9 +684,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
if
len
(
return_tensors
)
>
0
:
if
len
(
return_tensors
)
>
0
:
output
,
label
=
tuple
(
map
(
list
,
zip
(
*
return_tensors
)))
output
,
label
=
pack_return_tensors
(
return_tensors
)
return
(
torch
.
cat
(
output
,
dim
=
0
),
return
output
,
label
,
accum_loss
torch
.
cat
(
label
,
dim
=
0
),
accum_loss
)
else
:
else
:
return
tuple
((
None
,
None
,
accum_loss
))
return
None
,
None
,
accum_loss
tests/test_trainer/test_pipeline/model/resnet.py
View file @
7bf1e98b
...
@@ -139,7 +139,7 @@ class VanillaResNet(ModelFromConfig):
...
@@ -139,7 +139,7 @@ class VanillaResNet(ModelFromConfig):
def
forward
(
self
,
x
:
Tensor
):
def
forward
(
self
,
x
:
Tensor
):
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
layer
(
x
)
return
x
,
return
x
def
init_weights
(
self
):
def
init_weights
(
self
):
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment