Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ade05a5d
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "d7352bef2c380b0238babb6fa434b52f5263c5b9"
Unverified
Commit
ade05a5d
authored
Apr 03, 2022
by
YuliangLiu0306
Committed by
GitHub
Apr 03, 2022
Browse files
[refactor] pipeline, put runtime schedule into engine. (#627)
parent
e5d615ae
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
68 additions
and
49 deletions
+68
-49
colossalai/engine/_base_engine.py
colossalai/engine/_base_engine.py
+34
-3
colossalai/engine/schedule/_base_schedule.py
colossalai/engine/schedule/_base_schedule.py
+2
-3
colossalai/engine/schedule/_non_pipeline_schedule.py
colossalai/engine/schedule/_non_pipeline_schedule.py
+1
-2
colossalai/initialize.py
colossalai/initialize.py
+17
-1
colossalai/trainer/_trainer.py
colossalai/trainer/_trainer.py
+6
-30
tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py
...e_tensor_parallel/test_cifar_with_data_pipeline_tensor.py
+4
-5
tests/test_trainer/test_pipeline/resnet_config.py
tests/test_trainer/test_pipeline/resnet_config.py
+1
-0
tests/test_trainer/test_pipeline/test_pipeline_schedule.py
tests/test_trainer/test_pipeline/test_pipeline_schedule.py
+1
-2
tests/test_trainer/test_trainer_with_pipe_schedule.py
tests/test_trainer/test_trainer_with_pipe_schedule.py
+2
-3
No files found.
colossalai/engine/_base_engine.py
View file @
ade05a5d
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
from
asyncio.log
import
logger
from
asyncio.log
import
logger
from
typing
import
List
from
typing
import
List
,
Iterable
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.nn.modules.loss
import
_Loss
from
torch.nn.modules.loss
import
_Loss
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
...
@@ -10,6 +10,7 @@ from torch.optim import Optimizer
...
@@ -10,6 +10,7 @@ from torch.optim import Optimizer
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
torch
import
Tensor
from
torch
import
Tensor
from
colossalai.engine.ophooks
import
register_ophooks_recursively
,
BaseOpHook
from
colossalai.engine.ophooks
import
register_ophooks_recursively
,
BaseOpHook
from
colossalai.engine.schedule
import
BaseSchedule
,
NonPipelineSchedule
,
PipelineSchedule
,
InterleavedPipelineSchedule
from
typing
import
Optional
,
Type
from
typing
import
Optional
,
Type
from
colossalai.engine.gradient_handler
import
BaseGradientHandler
from
colossalai.engine.gradient_handler
import
BaseGradientHandler
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
...
@@ -27,6 +28,7 @@ class Engine:
...
@@ -27,6 +28,7 @@ class Engine:
clip_grad_norm (float, optional): The norm of gradient clipping.
clip_grad_norm (float, optional): The norm of gradient clipping.
ophook_list (list): List of ophook.
ophook_list (list): List of ophook.
verbose (bool): whether to display log info.
verbose (bool): whether to display log info.
schedule (''BaseSchedule''): Runtime schedule.
Examples:
Examples:
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
...
@@ -59,7 +61,8 @@ class Engine:
...
@@ -59,7 +61,8 @@ class Engine:
gradient_handlers
:
Optional
[
List
[
BaseGradientHandler
]]
=
None
,
gradient_handlers
:
Optional
[
List
[
BaseGradientHandler
]]
=
None
,
clip_grad_norm
:
float
=
0.0
,
clip_grad_norm
:
float
=
0.0
,
ophook_list
:
Optional
[
List
[
BaseOpHook
]]
=
None
,
ophook_list
:
Optional
[
List
[
BaseOpHook
]]
=
None
,
verbose
:
bool
=
True
):
verbose
:
bool
=
True
,
schedule
:
Optional
[
BaseSchedule
]
=
None
):
self
.
_model
=
model
self
.
_model
=
model
self
.
_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
self
.
_criterion
=
criterion
self
.
_criterion
=
criterion
...
@@ -80,6 +83,14 @@ class Engine:
...
@@ -80,6 +83,14 @@ class Engine:
self
.
_ophook_list
=
[]
self
.
_ophook_list
=
[]
else
:
else
:
self
.
_ophook_list
=
ophook_list
self
.
_ophook_list
=
ophook_list
# build schedule
if
schedule
:
self
.
_schedule
=
schedule
else
:
self
.
_schedule
=
NonPipelineSchedule
()
if
self
.
uses_pipeline
:
self
.
_schedule
.
pre_processing
(
self
)
register_ophooks_recursively
(
self
.
_model
,
self
.
_ophook_list
)
register_ophooks_recursively
(
self
.
_model
,
self
.
_ophook_list
)
@
property
@
property
...
@@ -102,6 +113,16 @@ class Engine:
...
@@ -102,6 +113,16 @@ class Engine:
"""Criterion attached to the engine"""
"""Criterion attached to the engine"""
return
self
.
_criterion
return
self
.
_criterion
@
property
def
schedule
(
self
):
"""Schedule attached to the engine"""
return
self
.
_schedule
@
property
def
uses_pipeline
(
self
):
"""show the pipeline parallel used or not"""
return
isinstance
(
self
.
_schedule
,
(
PipelineSchedule
,
InterleavedPipelineSchedule
))
def
add_hook
(
self
,
ophook
:
Type
[
BaseOpHook
])
->
None
:
def
add_hook
(
self
,
ophook
:
Type
[
BaseOpHook
])
->
None
:
"""add necessary hook"""
"""add necessary hook"""
# whether this hook exist
# whether this hook exist
...
@@ -165,6 +186,16 @@ class Engine:
...
@@ -165,6 +186,16 @@ class Engine:
"""
"""
for
handler
in
self
.
_gradient_handlers
:
for
handler
in
self
.
_gradient_handlers
:
handler
.
handle_gradient
()
handler
.
handle_gradient
()
def
execute_schedule
(
self
,
data_iter
:
Iterable
,
**
kwargs
):
"""Run the forward, loss computation, and backward for the model.
Returns a tuple of (output, label, loss).
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
"""
output
,
label
,
loss
=
self
.
_schedule
.
forward_backward_step
(
self
,
data_iter
,
**
kwargs
)
return
output
,
label
,
loss
def
train
(
self
):
def
train
(
self
):
"""Sets the model to training mode.
"""Sets the model to training mode.
...
@@ -176,4 +207,4 @@ class Engine:
...
@@ -176,4 +207,4 @@ class Engine:
"""Sets the model to evaluation mode.
"""Sets the model to evaluation mode.
"""
"""
self
.
training
=
False
self
.
training
=
False
self
.
_model
.
eval
()
self
.
_model
.
eval
()
\ No newline at end of file
colossalai/engine/schedule/_base_schedule.py
View file @
ade05a5d
...
@@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
...
@@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
import
torch
import
torch
from
typing
import
Iterable
,
Callable
from
typing
import
Iterable
,
Callable
from
.._base_engine
import
Engine
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
...
@@ -75,14 +74,14 @@ class BaseSchedule(ABC):
...
@@ -75,14 +74,14 @@ class BaseSchedule(ABC):
return
self
.
_move_to_device
(
data
),
self
.
_move_to_device
(
label
)
return
self
.
_move_to_device
(
data
),
self
.
_move_to_device
(
label
)
return
data
,
label
return
data
,
label
def
pre_processing
(
self
,
engine
:
Engine
):
def
pre_processing
(
self
,
engine
):
"""To perform actions before running the schedule.
"""To perform actions before running the schedule.
"""
"""
pass
pass
@
abstractmethod
@
abstractmethod
def
forward_backward_step
(
self
,
def
forward_backward_step
(
self
,
engine
:
Engine
,
engine
,
data_iter
:
Iterable
,
data_iter
:
Iterable
,
forward_only
:
bool
,
forward_only
:
bool
,
return_loss
:
bool
=
True
,
return_loss
:
bool
=
True
,
...
...
colossalai/engine/schedule/_non_pipeline_schedule.py
View file @
ade05a5d
...
@@ -5,7 +5,6 @@ from typing import Iterable
...
@@ -5,7 +5,6 @@ from typing import Iterable
import
torch
import
torch
from
colossalai.engine
import
Engine
from
._base_schedule
import
BaseSchedule
from
._base_schedule
import
BaseSchedule
from
colossalai.utils
import
conditional_context
from
colossalai.utils
import
conditional_context
...
@@ -22,7 +21,7 @@ class NonPipelineSchedule(BaseSchedule):
...
@@ -22,7 +21,7 @@ class NonPipelineSchedule(BaseSchedule):
"""
"""
def
forward_backward_step
(
self
,
def
forward_backward_step
(
self
,
engine
:
Engine
,
engine
,
data_iter
:
Iterable
,
data_iter
:
Iterable
,
forward_only
:
bool
=
False
,
forward_only
:
bool
=
False
,
return_loss
:
bool
=
True
,
return_loss
:
bool
=
True
,
...
...
colossalai/initialize.py
View file @
ade05a5d
...
@@ -20,6 +20,7 @@ from colossalai.amp.naive_amp import NaiveAMPModel
...
@@ -20,6 +20,7 @@ from colossalai.amp.naive_amp import NaiveAMPModel
from
colossalai.builder.builder
import
build_gradient_handler
from
colossalai.builder.builder
import
build_gradient_handler
from
colossalai.context
import
Config
,
ConfigException
,
ParallelMode
from
colossalai.context
import
Config
,
ConfigException
,
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.engine.schedule
import
NonPipelineSchedule
,
PipelineSchedule
,
InterleavedPipelineSchedule
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.engine
import
Engine
from
colossalai.engine
import
Engine
...
@@ -388,6 +389,20 @@ def initialize(model: nn.Module,
...
@@ -388,6 +389,20 @@ def initialize(model: nn.Module,
if
isinstance
(
model
,
DDP
)
and
isinstance
(
model
.
module
,
NaiveAMPModel
):
if
isinstance
(
model
,
DDP
)
and
isinstance
(
model
.
module
,
NaiveAMPModel
):
model
.
module
.
sync_buffer
=
False
model
.
module
.
sync_buffer
=
False
# initialize schedule for engine
if
is_using_pp
():
tensor_shape
=
getattr
(
gpc
.
config
,
'TENSOR_SHAPE'
,
None
)
use_interleaved
=
hasattr
(
gpc
.
config
,
'model'
)
and
hasattr
(
gpc
.
config
.
model
,
'num_chunks'
)
if
use_interleaved
:
schedule
=
InterleavedPipelineSchedule
(
gpc
.
config
.
NUM_MICRO_BATCHES
,
gpc
.
config
.
model
.
num_chunks
,
tensor_shape
=
tensor_shape
,
scatter_gather_tensors
=
True
)
else
:
schedule
=
PipelineSchedule
(
gpc
.
config
.
NUM_MICRO_BATCHES
,
tensor_shape
=
tensor_shape
,
scatter_gather_tensors
=
True
)
else
:
schedule
=
NonPipelineSchedule
()
if
gradient_handler_cfg
is
None
:
if
gradient_handler_cfg
is
None
:
gradient_handlers
=
None
gradient_handlers
=
None
if
verbose
and
not
isinstance
(
model
,
DDP
):
if
verbose
and
not
isinstance
(
model
,
DDP
):
...
@@ -418,6 +433,7 @@ def initialize(model: nn.Module,
...
@@ -418,6 +433,7 @@ def initialize(model: nn.Module,
criterion
=
criterion
,
criterion
=
criterion
,
gradient_handlers
=
gradient_handlers
,
gradient_handlers
=
gradient_handlers
,
clip_grad_norm
=
clip_grad_norm
,
clip_grad_norm
=
clip_grad_norm
,
ophook_list
=
ophooks
)
ophook_list
=
ophooks
,
schedule
=
schedule
)
return
engine
,
train_dataloader
,
test_dataloader
,
lr_scheduler
return
engine
,
train_dataloader
,
test_dataloader
,
lr_scheduler
colossalai/trainer/_trainer.py
View file @
ade05a5d
...
@@ -9,7 +9,6 @@ from tqdm import tqdm
...
@@ -9,7 +9,6 @@ from tqdm import tqdm
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.engine
import
Engine
from
colossalai.engine
import
Engine
from
colossalai.engine.schedule
import
NonPipelineSchedule
,
BaseSchedule
from
colossalai.logging
import
DistributedLogger
from
colossalai.logging
import
DistributedLogger
from
colossalai.utils
import
MultiTimer
from
colossalai.utils
import
MultiTimer
from
colossalai.utils
import
is_dp_rank_0
,
is_tp_rank_0
,
is_no_pp_or_last_stage
from
colossalai.utils
import
is_dp_rank_0
,
is_tp_rank_0
,
is_no_pp_or_last_stage
...
@@ -23,13 +22,9 @@ class Trainer:
...
@@ -23,13 +22,9 @@ class Trainer:
Args:
Args:
engine (:class:`Engine`): Engine responsible for the process function.
engine (:class:`Engine`): Engine responsible for the process function.
schedule (:class:`BaseSchedule`, optional): Schedule responsible for forward and backward steps.
timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training.
timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training.
logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log.
logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log.
Note:
when `schedule` is None, the ``NonPipelineSchedule`` would be used. If you would like to use pipeline,
you should choose ``PipelineSchedule`` or ``InterleavedPipelineSchedule`` for the `schedule`
Examples:
Examples:
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
...
@@ -42,7 +37,7 @@ class Trainer:
...
@@ -42,7 +37,7 @@ class Trainer:
>>> # Beginning training progress
>>> # Beginning training progress
>>> timier = ...
>>> timier = ...
>>> logger = ...
>>> logger = ...
>>> trainer = Trainer(engine=engine, logger=logger,
schedule=schedule,
timer=timier)
>>> trainer = Trainer(engine=engine, logger=logger, timer=timier)
>>> # add hooks you would like to use here.
>>> # add hooks you would like to use here.
>>> hook_list = []
>>> hook_list = []
>>> trainer.fit(
>>> trainer.fit(
...
@@ -61,7 +56,6 @@ class Trainer:
...
@@ -61,7 +56,6 @@ class Trainer:
def
__init__
(
def
__init__
(
self
,
self
,
engine
:
Engine
,
engine
:
Engine
,
schedule
:
BaseSchedule
=
None
,
timer
:
MultiTimer
=
None
,
timer
:
MultiTimer
=
None
,
logger
:
DistributedLogger
=
None
,
logger
:
DistributedLogger
=
None
,
):
):
...
@@ -86,17 +80,6 @@ class Trainer:
...
@@ -86,17 +80,6 @@ class Trainer:
# multi-timer for time benchmarking
# multi-timer for time benchmarking
self
.
_timer
=
timer
self
.
_timer
=
timer
# set schedule which specifies the training iteration for the engine
if
schedule
is
None
:
schedule
=
NonPipelineSchedule
()
if
(
gpc
.
is_initialized
(
ParallelMode
.
PIPELINE
)
and
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
>
1
):
assert
not
isinstance
(
schedule
,
NonPipelineSchedule
),
"NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead."
self
.
_schedule
=
schedule
self
.
_schedule
.
pre_processing
(
engine
)
@
property
@
property
def
cur_epoch
(
self
):
def
cur_epoch
(
self
):
"""Returns the index of the current epoch."""
"""Returns the index of the current epoch."""
...
@@ -129,10 +112,6 @@ class Trainer:
...
@@ -129,10 +112,6 @@ class Trainer:
def
engine
(
self
):
def
engine
(
self
):
return
self
.
_engine
return
self
.
_engine
@
property
def
schedule
(
self
):
return
self
.
_schedule
def
_set_current_step
(
self
,
epoch
:
int
):
def
_set_current_step
(
self
,
epoch
:
int
):
"""Sets current step number.
"""Sets current step number.
...
@@ -203,8 +182,7 @@ class Trainer:
...
@@ -203,8 +182,7 @@ class Trainer:
# run 1 training step
# run 1 training step
self
.
engine
.
zero_grad
()
self
.
engine
.
zero_grad
()
logits
,
label
,
loss
=
self
.
schedule
.
forward_backward_step
(
logits
,
label
,
loss
=
self
.
engine
.
execute_schedule
(
self
.
engine
,
data_iter
,
data_iter
,
forward_only
=
False
,
forward_only
=
False
,
return_loss
=
True
,
return_loss
=
True
,
...
@@ -260,8 +238,7 @@ class Trainer:
...
@@ -260,8 +238,7 @@ class Trainer:
for
_
in
progress
:
for
_
in
progress
:
self
.
_call_hooks
(
"before_test_iter"
)
self
.
_call_hooks
(
"before_test_iter"
)
self
.
_call_timer
(
action
=
"start"
,
item
=
"Test-step"
)
self
.
_call_timer
(
action
=
"start"
,
item
=
"Test-step"
)
logits
,
label
,
loss
=
self
.
schedule
.
forward_backward_step
(
logits
,
label
,
loss
=
self
.
engine
.
execute_schedule
(
self
.
engine
,
data_iter
,
data_iter
,
forward_only
=
True
,
forward_only
=
True
,
return_loss
=
True
,
return_loss
=
True
,
...
@@ -449,8 +426,7 @@ class Trainer:
...
@@ -449,8 +426,7 @@ class Trainer:
# for compatibility with schedule
# for compatibility with schedule
simple_dataloader
=
[(
data
,
None
)]
simple_dataloader
=
[(
data
,
None
)]
data_iter
=
iter
(
simple_dataloader
)
data_iter
=
iter
(
simple_dataloader
)
output
,
_
,
_
=
self
.
schedule
.
forward_backward_step
(
self
.
engine
,
output
,
_
,
_
=
self
.
engine
.
execute_schedule
(
data_iter
,
data_iter
,
forward_only
=
True
,
forward_only
=
True
,
return_loss
=
False
)
return_loss
=
False
)
return
output
return
output
tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py
View file @
ade05a5d
...
@@ -23,9 +23,9 @@ from torchvision.datasets import CIFAR10
...
@@ -23,9 +23,9 @@ from torchvision.datasets import CIFAR10
BATCH_SIZE
=
4
BATCH_SIZE
=
4
NUM_EPOCHS
=
60
NUM_EPOCHS
=
60
WARMUP_EPOCHS
=
5
WARMUP_EPOCHS
=
5
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
2
,
tensor
=
dict
(
size
=
2
,
mode
=
'1d'
)),
CONFIG
=
dict
(
NUM_MICRO_BATCHES
=
2
,
parallel
=
dict
(
pipeline
=
2
,
tensor
=
dict
(
size
=
2
,
mode
=
'1d'
)),
fp16
=
dict
(
mode
=
AMP_TYPE
.
NAIVE
),
fp16
=
dict
(
mode
=
AMP_TYPE
.
NAIVE
),
gradient_accumulation
=
2
)
gradient_accumulation
=
2
)
def
run_trainer
(
rank
,
world_size
,
port
):
def
run_trainer
(
rank
,
world_size
,
port
):
...
@@ -63,10 +63,9 @@ def run_trainer(rank, world_size, port):
...
@@ -63,10 +63,9 @@ def run_trainer(rank, world_size, port):
train_dataloader
,
train_dataloader
,
lr_scheduler
=
lr_scheduler
)
lr_scheduler
=
lr_scheduler
)
schedule
=
PipelineSchedule
(
num_microbatches
=
2
)
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
trainer
=
Trainer
(
engine
=
engine
,
logger
=
logger
,
schedule
=
schedule
)
trainer
=
Trainer
(
engine
=
engine
,
logger
=
logger
)
hook_list
=
[
hook_list
=
[
hooks
.
LRSchedulerHook
(
lr_scheduler
=
lr_scheduler
,
by_epoch
=
False
),
hooks
.
LRSchedulerHook
(
lr_scheduler
=
lr_scheduler
,
by_epoch
=
False
),
...
...
tests/test_trainer/test_pipeline/resnet_config.py
View file @
ade05a5d
...
@@ -7,6 +7,7 @@ IMG_SIZE = 224
...
@@ -7,6 +7,7 @@ IMG_SIZE = 224
DIM
=
768
DIM
=
768
NUM_CLASSES
=
10
NUM_CLASSES
=
10
NUM_ATTN_HEADS
=
12
NUM_ATTN_HEADS
=
12
NUM_MICRO_BATCHES
=
2
# resnet 18
# resnet 18
model
=
dict
(
type
=
'VanillaResNet'
,
model
=
dict
(
type
=
'VanillaResNet'
,
...
...
tests/test_trainer/test_pipeline/test_pipeline_schedule.py
View file @
ade05a5d
...
@@ -19,7 +19,6 @@ from torchvision import transforms
...
@@ -19,7 +19,6 @@ from torchvision import transforms
from
torchvision.datasets
import
CIFAR10
from
torchvision.datasets
import
CIFAR10
BATCH_SIZE
=
4
BATCH_SIZE
=
4
NUM_MICRO
=
2
DIR_PATH
=
osp
.
dirname
(
osp
.
realpath
(
__file__
))
DIR_PATH
=
osp
.
dirname
(
osp
.
realpath
(
__file__
))
CONFIG_PATH
=
osp
.
join
(
DIR_PATH
,
'./resnet_config.py'
)
CONFIG_PATH
=
osp
.
join
(
DIR_PATH
,
'./resnet_config.py'
)
...
@@ -57,7 +56,7 @@ def run_schedule(rank, world_size, port):
...
@@ -57,7 +56,7 @@ def run_schedule(rank, world_size, port):
engine
,
train_dataloader
,
_
,
_
=
colossalai
.
initialize
(
model
,
optimizer
,
criterion
,
train_dataloader
)
engine
,
train_dataloader
,
_
,
_
=
colossalai
.
initialize
(
model
,
optimizer
,
criterion
,
train_dataloader
)
# build pipeline schedule
# build pipeline schedule
schedule
=
Pipel
ine
S
chedule
(
num_microbatches
=
NUM_MICRO
)
schedule
=
eng
ine
.
s
chedule
# run schedule
# run schedule
data_iter
=
iter
(
train_dataloader
)
data_iter
=
iter
(
train_dataloader
)
...
...
tests/test_trainer/test_trainer_with_pipe_schedule.py
View file @
ade05a5d
...
@@ -23,7 +23,7 @@ BATCH_SIZE = 4
...
@@ -23,7 +23,7 @@ BATCH_SIZE = 4
IMG_SIZE
=
32
IMG_SIZE
=
32
NUM_EPOCHS
=
200
NUM_EPOCHS
=
200
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
2
),)
CONFIG
=
dict
(
NUM_MICRO_BATCHES
=
2
,
parallel
=
dict
(
pipeline
=
2
),)
def
run_trainer_with_pipeline
(
rank
,
world_size
,
port
):
def
run_trainer_with_pipeline
(
rank
,
world_size
,
port
):
...
@@ -69,9 +69,8 @@ def run_trainer_with_pipeline(rank, world_size, port):
...
@@ -69,9 +69,8 @@ def run_trainer_with_pipeline(rank, world_size, port):
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
logger
.
info
(
"engine is built"
,
ranks
=
[
0
])
logger
.
info
(
"engine is built"
,
ranks
=
[
0
])
pipe_schedule
=
PipelineSchedule
(
num_microbatches
=
2
)
timer
=
MultiTimer
()
timer
=
MultiTimer
()
trainer
=
Trainer
(
engine
=
engine
,
schedule
=
pipe_schedule
,
logger
=
logger
,
timer
=
timer
)
trainer
=
Trainer
(
engine
=
engine
,
logger
=
logger
,
timer
=
timer
)
logger
.
info
(
"trainer is built"
,
ranks
=
[
0
])
logger
.
info
(
"trainer is built"
,
ranks
=
[
0
])
logger
.
info
(
"start training"
,
ranks
=
[
0
])
logger
.
info
(
"start training"
,
ranks
=
[
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment