Unverified Commit 3defa32a authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Support TP-compatible Torch AMP and Update trainer API (#27)



* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
parent 2b05de4c
...@@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001) ...@@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss') loss = dict(type='CrossEntropyLoss')
fp16 = dict(mode=AMP_TYPE.APEX) fp16 = dict(mode=AMP_TYPE.APEX)
# set_device_func = lambda global_rank, world_size: global_rank % 4
seed = 1024
...@@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001) ...@@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss') loss = dict(type='CrossEntropyLoss')
fp16 = dict(mode=AMP_TYPE.TORCH) fp16 = dict(mode=AMP_TYPE.TORCH)
# set_device_func = lambda global_rank, world_size: global_rank % 4
seed = 1024
...@@ -38,11 +38,9 @@ parallel = dict( ...@@ -38,11 +38,9 @@ parallel = dict(
tensor=dict(size=1, mode=None) tensor=dict(size=1, mode=None)
) )
schedule = dict( engine = dict(
num_microbatches=4 schedule=dict(
num_microbatches=4
)
) )
num_pipeling_batches = 2
seed = 1024
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5)
num_epochs = 10 num_epochs = 10
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from colossalai import initialize from colossalai import initialize
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage from colossalai.utils import report_memory_usage
...@@ -24,20 +23,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_apex_am ...@@ -24,20 +23,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_apex_am
def run_no_pipeline(config): def run_no_pipeline(config):
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config) engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger() logger = get_global_dist_logger()
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
schedule=schedule)
engine.train() engine.train()
logger.info('lr = %g' % engine.get_lr()) output, label, loss = engine.step(iter(train_dataloader))
output, label, loss = engine.step()
logger.info('Rank {} returns: {}'.format(rank, loss.item())) logger.info('Rank {} returns: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy() gpc.destroy()
logger.info('Test engine finished') logger.info('Test engine finished')
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from colossalai import initialize from colossalai import initialize
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage from colossalai.utils import report_memory_usage
...@@ -26,21 +25,14 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet.py') ...@@ -26,21 +25,14 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet.py')
def test_no_pipeline(config): def test_no_pipeline(config):
print('Test no pipeline engine start') print('Test no pipeline engine start')
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config) engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger() logger = get_global_dist_logger()
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
schedule=schedule)
engine.train() engine.train()
logger.info('lr = %g' % engine.get_lr()) output, label, loss = engine.step(iter(train_dataloader))
output, label, loss = engine.step()
logger.info('Rank {} returns: {}'.format(rank, loss.item())) logger.info('Rank {} returns: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy() gpc.destroy()
logger.info('Test engine finished') logger.info('Test engine finished')
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from colossalai import initialize from colossalai import initialize
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage from colossalai.utils import report_memory_usage
...@@ -26,21 +25,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_torch_a ...@@ -26,21 +25,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_torch_a
def test_no_pipeline(config): def test_no_pipeline(config):
print('Test no pipeline engine start') print('Test no pipeline engine start')
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config) engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger() logger = get_global_dist_logger()
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
schedule=schedule)
engine.train() engine.train()
logger.info('lr = %g' % engine.get_lr()) output, label, loss = engine.step(iter(train_dataloader))
output, label, loss = engine.step()
logger.info('Rank {} returns: {}'.format(rank, loss.item())) logger.info('Rank {} returns: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy() gpc.destroy()
logger.info('Test engine finished') logger.info('Test engine finished')
......
...@@ -5,6 +5,7 @@ import os.path as osp ...@@ -5,6 +5,7 @@ import os.path as osp
import pytest import pytest
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import initialize from colossalai.initialize import initialize
from colossalai.logging import get_global_dist_logger from colossalai.logging import get_global_dist_logger
...@@ -22,13 +23,25 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py') ...@@ -22,13 +23,25 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
@pytest.mark.skip("This test should be invoked using the test.sh provided") @pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist @pytest.mark.dist
def test_schedule(): def test_schedule():
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH) engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH)
logger = get_global_dist_logger() logger = get_global_dist_logger()
schedule.zero_grad() model = engine.model
output, label, losses = schedule.forward_backward_step(forward_only=False) optimizer = engine.optimizer
schedule.step() criterion = engine.criterion
logger.info('losses: {}'.format([loss.item() for loss in losses])) schedule = engine._schedule
output, label, loss = schedule.forward_backward_step(
data_iter=iter(train_dataloader),
model=model,
optimizer=optimizer,
criterion=criterion,
forward_only=False
)
schedule.optimizer_step(model, optimizer)
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info('losses: {}'.format(loss))
gpc.destroy() gpc.destroy()
logger.info('training finished') logger.info('training finished')
......
...@@ -9,7 +9,6 @@ import torch ...@@ -9,7 +9,6 @@ import torch
from colossalai import initialize from colossalai import initialize
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger from colossalai.logging import get_global_dist_logger
NUM_BATCH = 128 NUM_BATCH = 128
...@@ -23,22 +22,14 @@ PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py') ...@@ -23,22 +22,14 @@ PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
def run_pipeline(config): def run_pipeline(config):
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config) engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger() logger = get_global_dist_logger()
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
engine.train() engine.train()
logger.info('lr = %g' % engine.get_lr()) outputs, labels, loss = engine.step(iter(train_dataloader))
outputs, labels, loss = engine.step()
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info('losses: {}'.format(rank, loss.item())) logger.info('losses: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy() gpc.destroy()
logger.info('Test engine pipeline finished') logger.info('Test engine pipeline finished')
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment