Unverified Commit e79ea442 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[fp16] refactored fp16 optimizer (#392)

parent f8a0e7fb
import inspect
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.utils import is_no_pp_or_last_stage from colossalai.utils import is_no_pp_or_last_stage
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
from .grad_scaler import DynamicGradScaler, ConstantGradScaler
def convert_to_naive_amp(model: nn.Module, def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
optimizer: Optimizer,
amp_config):
"""A helper function to wrap training components with naive AMP modules """A helper function to wrap training components with naive AMP modules
:param model: your model object :param model: your model object
...@@ -31,7 +30,19 @@ def convert_to_naive_amp(model: nn.Module, ...@@ -31,7 +30,19 @@ def convert_to_naive_amp(model: nn.Module,
output_to_fp32 = is_no_pp_or_last_stage() output_to_fp32 = is_no_pp_or_last_stage()
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
optimizer = NaiveAMPOptimizer(optimizer, **amp_config) use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
if use_dynamic_grad_scaler:
scaler_class = DynamicGradScaler
else:
scaler_class = ConstantGradScaler
sig = inspect.signature(scaler_class.__init__)
kwargs = dict()
for param in sig.parameters.values():
if param.name in amp_config:
kwargs[param.name] = amp_config.pop(param.name)
grad_scaler = scaler_class(**kwargs)
optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
return model, optimizer return model, optimizer
......
This diff is collapsed.
from typing import List
from torch import Tensor
def has_inf_or_nan(tensor):
try:
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
# (which is true for some recent version of pytorch).
tensor_sum = float(tensor.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# tensor_sum = float(tensor.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
return True
return False
def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None:
"""
Clear the gradient of a list of tensors,
Note: copied from torch.optim.optimizer.
"""
for param in tensor_list:
if param.grad is not None:
if set_to_none:
param.grad = None
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
...@@ -28,12 +28,10 @@ class BaseGradScaler(ABC): ...@@ -28,12 +28,10 @@ class BaseGradScaler(ABC):
def inv_scale(self) -> Tensor: def inv_scale(self) -> Tensor:
return self._scale.double().reciprocal().float() return self._scale.double().reciprocal().float()
@abstractmethod
def state_dict(self) -> Dict: def state_dict(self) -> Dict:
state_dict = dict() state_dict = dict()
state_dict['scale'] = self.scale state_dict['scale'] = self.scale
@abstractmethod
def load_state_dict(self, state_dict: Dict) -> None: def load_state_dict(self, state_dict: Dict) -> None:
self._scale = state_dict['scale'] self._scale = state_dict['scale']
......
...@@ -16,11 +16,19 @@ class DynamicGradScaler(BaseGradScaler): ...@@ -16,11 +16,19 @@ class DynamicGradScaler(BaseGradScaler):
growth_interval: int = 1000, growth_interval: int = 1000,
min_scale: int = None, min_scale: int = None,
max_scale: int = None, max_scale: int = None,
hysteresis: int = None, hysteresis: int = 2,
verbose: bool = False): verbose: bool = False):
super().__init__(initial_scale, verbose) super().__init__(initial_scale, verbose)
self._min_scale = min_scale if min_scale:
self._max_scale = max_scale self._min_scale = torch.cuda.FloatTensor([min_scale])
else:
self._min_scale = None
if max_scale:
self._max_scale = torch.cuda.FloatTensor([max_scale])
else:
self._max_scale = None
self._growth_factor = growth_factor self._growth_factor = growth_factor
self._backoff_factor = backoff_factor self._backoff_factor = backoff_factor
self._growth_interval = growth_interval self._growth_interval = growth_interval
......
...@@ -26,17 +26,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer): ...@@ -26,17 +26,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
""" """
def __init__(self, optim: Optimizer, *args, **kwargs): def __init__(self, optim: Optimizer, *args, **kwargs):
optim = FP16Optimizer(optimizer=optim, *args, **kwargs) optim = FP16Optimizer(optim, *args, **kwargs)
super().__init__(optim) super().__init__(optim)
def backward(self, loss: Tensor): def backward(self, loss: Tensor):
"""Backward with gradient scaler self.optim.backward(loss)
:param loss: loss computed by a loss function
:type loss: torch.Tensor
"""
loss = self.optim.scale_loss(loss)
loss.backward()
def step(self): def step(self):
return self.optim.step() return self.optim.step()
......
...@@ -304,7 +304,7 @@ def initialize(model: nn.Module, ...@@ -304,7 +304,7 @@ def initialize(model: nn.Module,
if is_using_pp(): if is_using_pp():
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently' assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
if amp_mode == AMP_TYPE.NAIVE: if amp_mode == AMP_TYPE.NAIVE:
cfg_['clip_grad'] = clip_grad_norm cfg_['clip_grad_norm'] = clip_grad_norm
model, optimizer, criterion = convert_to_amp(model=model, model, optimizer, criterion = convert_to_amp(model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
......
from itertools import groupby
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -7,7 +6,7 @@ from torch.optim import Optimizer ...@@ -7,7 +6,7 @@ from torch.optim import Optimizer
from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor, from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor,
release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan) release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan)
...@@ -16,38 +15,26 @@ from functools import partial ...@@ -16,38 +15,26 @@ from functools import partial
class ShardedOptimizer(ColossalaiOptimizer): class ShardedOptimizer(ColossalaiOptimizer):
def __init__( def __init__(self,
self, optimizer: Optimizer,
optimizer: Optimizer, initial_scale=2**32,
min_scale=1,
# grad scaler config growth_factor=2,
initial_scale=2**32, backoff_factor=0.5,
min_scale=1, growth_interval=1000,
growth_factor=2, hysteresis=2,
backoff_factor=0.5, max_scale: int = 2**32,
growth_interval=1000, clip_grad_norm=2.0,
hysteresis=2, verbose=False,
max_scale: int = 2**32, reduce_bucket_size=500000000,
communication_dtype=torch.float16,
# grad clipping overlap_communication=False,
clip_grad_norm=2.0, partition_grad=False,
verbose=False, dp_parallel_mode=ParallelMode.DATA,
mp_parallel_mode=ParallelMode.MODEL,
# communication cpu_offload=False,
reduce_bucket_size=500000000, cpu_fp16_param=False,
communication_dtype=torch.float16, cpu_fp16_grad=False):
overlap_communication=False,
# stage 2
partition_grad=False,
dp_parallel_mode=ParallelMode.DATA,
mp_parallel_mode=ParallelMode.MODEL,
# cpu offload
cpu_offload=False,
cpu_fp16_param=False,
cpu_fp16_grad=False):
# TODO: add support for # TODO: add support for
# 1. fp16 master weights # 1. fp16 master weights
...@@ -257,12 +244,13 @@ class ShardedOptimizer(ColossalaiOptimizer): ...@@ -257,12 +244,13 @@ class ShardedOptimizer(ColossalaiOptimizer):
reduction_func = partial(self._reduce_and_remove_grads_by_bucket, reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
param=param, param=param,
reduce_rank=reduce_rank) reduce_rank=reduce_rank)
# define hook # define hook
# NOT IMPORTANT BUT GOOD TO KNOW: # NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad # args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args): def reduce_grad_hook(*args):
reduction_func() reduction_func()
accum_grad_obj.register_hook(reduce_grad_hook) accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank) _define_and_attach(param, reduce_rank)
...@@ -293,8 +281,8 @@ class ShardedOptimizer(ColossalaiOptimizer): ...@@ -293,8 +281,8 @@ class ShardedOptimizer(ColossalaiOptimizer):
def _reduce_grads_in_bucket(self, reduce_rank=None): def _reduce_grads_in_bucket(self, reduce_rank=None):
# reduce grads # reduce grads
self._reduce_grads_by_rank(reduce_rank=reduce_rank, self._reduce_grads_by_rank(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
# use communication stream if overlapping # use communication stream if overlapping
# communication with computation # communication with computation
...@@ -323,7 +311,7 @@ class ShardedOptimizer(ColossalaiOptimizer): ...@@ -323,7 +311,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
# we do not keep the gradient after reduction # we do not keep the gradient after reduction
if self._partition_grads and not self._param_store.belongs_to_current_rank(param): if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
if self._overlap_communication: if self._overlap_communication:
# we need to keep this gradient for now as reduction may # we need to keep this gradient for now as reduction may
# be completed yet since it is using a different cuda stream # be completed yet since it is using a different cuda stream
self._param_store.add_previous_reduced_param(param) self._param_store.add_previous_reduced_param(param)
else: else:
...@@ -444,7 +432,6 @@ class ShardedOptimizer(ColossalaiOptimizer): ...@@ -444,7 +432,6 @@ class ShardedOptimizer(ColossalaiOptimizer):
self._grad_store._averaged_gradients[group_id] = [] self._grad_store._averaged_gradients[group_id] = []
self._grad_store._averaged_gradients[group_id] = [] self._grad_store._averaged_gradients[group_id] = []
# unscale and clip grads # unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups) global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
...@@ -501,7 +488,7 @@ class ShardedOptimizer(ColossalaiOptimizer): ...@@ -501,7 +488,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# compute combined scale factor for this group # compute combined scale factor for this group
combined_scale = self.loss_scale combined_scale = self.loss_scale
if self._clip_grad_norm > 0.: if self._clip_grad_norm > 0.:
# norm is in fact norm*scale # norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
...@@ -562,7 +549,7 @@ class ShardedOptimizer(ColossalaiOptimizer): ...@@ -562,7 +549,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
for param in param_group: for param in param_group:
if param.grad is not None: if param.grad is not None:
self._reduce_and_remove_grads_by_bucket(param) self._reduce_and_remove_grads_by_bucket(param)
# we need to reduce the gradients # we need to reduce the gradients
# left in the communication bucket # left in the communication bucket
self._reduce_grads_in_bucket() self._reduce_grads_in_bucket()
......
...@@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Union ...@@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
......
import torch
import colossalai
import copy
import pytest
import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.utils import free_port
from functools import partial
def check_equal(a, b):
"""
This function checks if two tensors are equal within tolerance
"""
assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}'
def run_naive_amp():
"""
In this test, we compare the naive fp16 optimizer implemented in colossalai
and fp32 torch optimizer
"""
# create layer
test_models = ['repeated_computed_layers', 'nested_model']
for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name)
model_builder, train_dataloader, _, optim_builder, _ = get_component_func()
# create model
amp_model = model_builder(checkpoint=True).cuda()
torch_model = copy.deepcopy(amp_model)
# create optimizer
amp_optimizer = optim_builder(amp_model)
torch_optimizer = optim_builder(torch_model)
# inject naive amp
amp_config = dict(initial_scale=1)
amp_model, amp_optimizer = convert_to_naive_amp(amp_model, amp_optimizer, amp_config)
# create data
data_iter = iter(train_dataloader)
data, label = next(data_iter)
data = data.cuda()
# forward pass
amp_output = amp_model(data)
torch_output = torch_model(data)
assert torch.allclose(amp_output, torch_output, rtol=1e-3, atol=1e-3), f'{amp_output} vs {torch_output}'
# backward
amp_optimizer.backward(amp_output.mean())
torch_output.mean().backward()
# check grad
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
torch.allclose(amp_param.grad, torch_param.grad.half(), rtol=1e-3, atol=1e-3)
# step
amp_optimizer.step()
torch_optimizer.step()
# check updated param
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
torch.allclose(amp_param, torch_param.half(), rtol=1e-3, atol=1e-3)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
run_naive_amp()
@pytest.mark.dist
def test_naive_amp():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_naive_amp()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment