Unverified Commit 3defa32a authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Support TP-compatible Torch AMP and Update trainer API (#27)



* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
parent 2b05de4c
......@@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss')
fp16 = dict(mode=AMP_TYPE.APEX)
# set_device_func = lambda global_rank, world_size: global_rank % 4
seed = 1024
......@@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss')
fp16 = dict(mode=AMP_TYPE.TORCH)
# set_device_func = lambda global_rank, world_size: global_rank % 4
seed = 1024
......@@ -38,11 +38,9 @@ parallel = dict(
tensor=dict(size=1, mode=None)
)
schedule = dict(
num_microbatches=4
engine = dict(
schedule=dict(
num_microbatches=4
)
)
num_pipeling_batches = 2
seed = 1024
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5)
num_epochs = 10
......@@ -8,7 +8,6 @@ import torch
from colossalai import initialize
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage
......@@ -24,20 +23,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_apex_am
def run_no_pipeline(config):
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
output, label, loss = engine.step()
output, label, loss = engine.step(iter(train_dataloader))
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine finished')
......
......@@ -8,7 +8,6 @@ import torch
from colossalai import initialize
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage
......@@ -26,21 +25,14 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet.py')
def test_no_pipeline(config):
print('Test no pipeline engine start')
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
output, label, loss = engine.step()
output, label, loss = engine.step(iter(train_dataloader))
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine finished')
......
......@@ -8,7 +8,6 @@ import torch
from colossalai import initialize
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage
......@@ -26,21 +25,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_torch_a
def test_no_pipeline(config):
print('Test no pipeline engine start')
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
output, label, loss = engine.step()
output, label, loss = engine.step(iter(train_dataloader))
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine finished')
......
......@@ -5,6 +5,7 @@ import os.path as osp
import pytest
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import initialize
from colossalai.logging import get_global_dist_logger
......@@ -22,13 +23,25 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_schedule():
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH)
engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH)
logger = get_global_dist_logger()
schedule.zero_grad()
output, label, losses = schedule.forward_backward_step(forward_only=False)
schedule.step()
logger.info('losses: {}'.format([loss.item() for loss in losses]))
model = engine.model
optimizer = engine.optimizer
criterion = engine.criterion
schedule = engine._schedule
output, label, loss = schedule.forward_backward_step(
data_iter=iter(train_dataloader),
model=model,
optimizer=optimizer,
criterion=criterion,
forward_only=False
)
schedule.optimizer_step(model, optimizer)
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info('losses: {}'.format(loss))
gpc.destroy()
logger.info('training finished')
......
......@@ -9,7 +9,6 @@ import torch
from colossalai import initialize
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
NUM_BATCH = 128
......@@ -23,22 +22,14 @@ PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
def run_pipeline(config):
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
outputs, labels, loss = engine.step()
outputs, labels, loss = engine.step(iter(train_dataloader))
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info('losses: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine pipeline finished')
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment