Unverified Commit e79ea442 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[fp16] refactored fp16 optimizer (#392)

parent f8a0e7fb
import inspect
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.utils import is_no_pp_or_last_stage from colossalai.utils import is_no_pp_or_last_stage
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
from .grad_scaler import DynamicGradScaler, ConstantGradScaler
def convert_to_naive_amp(model: nn.Module, def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
optimizer: Optimizer,
amp_config):
"""A helper function to wrap training components with naive AMP modules """A helper function to wrap training components with naive AMP modules
:param model: your model object :param model: your model object
...@@ -31,7 +30,19 @@ def convert_to_naive_amp(model: nn.Module, ...@@ -31,7 +30,19 @@ def convert_to_naive_amp(model: nn.Module,
output_to_fp32 = is_no_pp_or_last_stage() output_to_fp32 = is_no_pp_or_last_stage()
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
optimizer = NaiveAMPOptimizer(optimizer, **amp_config) use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
if use_dynamic_grad_scaler:
scaler_class = DynamicGradScaler
else:
scaler_class = ConstantGradScaler
sig = inspect.signature(scaler_class.__init__)
kwargs = dict()
for param in sig.parameters.values():
if param.name in amp_config:
kwargs[param.name] = amp_config.pop(param.name)
grad_scaler = scaler_class(**kwargs)
optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
return model, optimizer return model, optimizer
......
This diff is collapsed.
from typing import List
from torch import Tensor
def has_inf_or_nan(tensor):
try:
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
# (which is true for some recent version of pytorch).
tensor_sum = float(tensor.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# tensor_sum = float(tensor.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
return True
return False
def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None:
"""
Clear the gradient of a list of tensors,
Note: copied from torch.optim.optimizer.
"""
for param in tensor_list:
if param.grad is not None:
if set_to_none:
param.grad = None
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
...@@ -28,12 +28,10 @@ class BaseGradScaler(ABC): ...@@ -28,12 +28,10 @@ class BaseGradScaler(ABC):
def inv_scale(self) -> Tensor: def inv_scale(self) -> Tensor:
return self._scale.double().reciprocal().float() return self._scale.double().reciprocal().float()
@abstractmethod
def state_dict(self) -> Dict: def state_dict(self) -> Dict:
state_dict = dict() state_dict = dict()
state_dict['scale'] = self.scale state_dict['scale'] = self.scale
@abstractmethod
def load_state_dict(self, state_dict: Dict) -> None: def load_state_dict(self, state_dict: Dict) -> None:
self._scale = state_dict['scale'] self._scale = state_dict['scale']
......
...@@ -16,11 +16,19 @@ class DynamicGradScaler(BaseGradScaler): ...@@ -16,11 +16,19 @@ class DynamicGradScaler(BaseGradScaler):
growth_interval: int = 1000, growth_interval: int = 1000,
min_scale: int = None, min_scale: int = None,
max_scale: int = None, max_scale: int = None,
hysteresis: int = None, hysteresis: int = 2,
verbose: bool = False): verbose: bool = False):
super().__init__(initial_scale, verbose) super().__init__(initial_scale, verbose)
self._min_scale = min_scale if min_scale:
self._max_scale = max_scale self._min_scale = torch.cuda.FloatTensor([min_scale])
else:
self._min_scale = None
if max_scale:
self._max_scale = torch.cuda.FloatTensor([max_scale])
else:
self._max_scale = None
self._growth_factor = growth_factor self._growth_factor = growth_factor
self._backoff_factor = backoff_factor self._backoff_factor = backoff_factor
self._growth_interval = growth_interval self._growth_interval = growth_interval
......
...@@ -26,17 +26,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer): ...@@ -26,17 +26,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
""" """
def __init__(self, optim: Optimizer, *args, **kwargs): def __init__(self, optim: Optimizer, *args, **kwargs):
optim = FP16Optimizer(optimizer=optim, *args, **kwargs) optim = FP16Optimizer(optim, *args, **kwargs)
super().__init__(optim) super().__init__(optim)
def backward(self, loss: Tensor): def backward(self, loss: Tensor):
"""Backward with gradient scaler self.optim.backward(loss)
:param loss: loss computed by a loss function
:type loss: torch.Tensor
"""
loss = self.optim.scale_loss(loss)
loss.backward()
def step(self): def step(self):
return self.optim.step() return self.optim.step()
......
...@@ -304,7 +304,7 @@ def initialize(model: nn.Module, ...@@ -304,7 +304,7 @@ def initialize(model: nn.Module,
if is_using_pp(): if is_using_pp():
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently' assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
if amp_mode == AMP_TYPE.NAIVE: if amp_mode == AMP_TYPE.NAIVE:
cfg_['clip_grad'] = clip_grad_norm cfg_['clip_grad_norm'] = clip_grad_norm
model, optimizer, criterion = convert_to_amp(model=model, model, optimizer, criterion = convert_to_amp(model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
......
from itertools import groupby
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -7,7 +6,7 @@ from torch.optim import Optimizer ...@@ -7,7 +6,7 @@ from torch.optim import Optimizer
from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor, from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor,
release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan) release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan)
...@@ -16,11 +15,8 @@ from functools import partial ...@@ -16,11 +15,8 @@ from functools import partial
class ShardedOptimizer(ColossalaiOptimizer): class ShardedOptimizer(ColossalaiOptimizer):
def __init__( def __init__(self,
self,
optimizer: Optimizer, optimizer: Optimizer,
# grad scaler config
initial_scale=2**32, initial_scale=2**32,
min_scale=1, min_scale=1,
growth_factor=2, growth_factor=2,
...@@ -28,23 +24,14 @@ class ShardedOptimizer(ColossalaiOptimizer): ...@@ -28,23 +24,14 @@ class ShardedOptimizer(ColossalaiOptimizer):
growth_interval=1000, growth_interval=1000,
hysteresis=2, hysteresis=2,
max_scale: int = 2**32, max_scale: int = 2**32,
# grad clipping
clip_grad_norm=2.0, clip_grad_norm=2.0,
verbose=False, verbose=False,
# communication
reduce_bucket_size=500000000, reduce_bucket_size=500000000,
communication_dtype=torch.float16, communication_dtype=torch.float16,
overlap_communication=False, overlap_communication=False,
# stage 2
partition_grad=False, partition_grad=False,
dp_parallel_mode=ParallelMode.DATA, dp_parallel_mode=ParallelMode.DATA,
mp_parallel_mode=ParallelMode.MODEL, mp_parallel_mode=ParallelMode.MODEL,
# cpu offload
cpu_offload=False, cpu_offload=False,
cpu_fp16_param=False, cpu_fp16_param=False,
cpu_fp16_grad=False): cpu_fp16_grad=False):
...@@ -263,6 +250,7 @@ class ShardedOptimizer(ColossalaiOptimizer): ...@@ -263,6 +250,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
# args here is not grad, but allow_unreacable and accumulate_grad # args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args): def reduce_grad_hook(*args):
reduction_func() reduction_func()
accum_grad_obj.register_hook(reduce_grad_hook) accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank) _define_and_attach(param, reduce_rank)
...@@ -444,7 +432,6 @@ class ShardedOptimizer(ColossalaiOptimizer): ...@@ -444,7 +432,6 @@ class ShardedOptimizer(ColossalaiOptimizer):
self._grad_store._averaged_gradients[group_id] = [] self._grad_store._averaged_gradients[group_id] = []
self._grad_store._averaged_gradients[group_id] = [] self._grad_store._averaged_gradients[group_id] = []
# unscale and clip grads # unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups) global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
......
...@@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Union ...@@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
......
import torch
import colossalai
import copy
import pytest
import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.utils import free_port
from functools import partial
def check_equal(a, b):
"""
This function checks if two tensors are equal within tolerance
"""
assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}'
def run_naive_amp():
"""
In this test, we compare the naive fp16 optimizer implemented in colossalai
and fp32 torch optimizer
"""
# create layer
test_models = ['repeated_computed_layers', 'nested_model']
for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name)
model_builder, train_dataloader, _, optim_builder, _ = get_component_func()
# create model
amp_model = model_builder(checkpoint=True).cuda()
torch_model = copy.deepcopy(amp_model)
# create optimizer
amp_optimizer = optim_builder(amp_model)
torch_optimizer = optim_builder(torch_model)
# inject naive amp
amp_config = dict(initial_scale=1)
amp_model, amp_optimizer = convert_to_naive_amp(amp_model, amp_optimizer, amp_config)
# create data
data_iter = iter(train_dataloader)
data, label = next(data_iter)
data = data.cuda()
# forward pass
amp_output = amp_model(data)
torch_output = torch_model(data)
assert torch.allclose(amp_output, torch_output, rtol=1e-3, atol=1e-3), f'{amp_output} vs {torch_output}'
# backward
amp_optimizer.backward(amp_output.mean())
torch_output.mean().backward()
# check grad
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
torch.allclose(amp_param.grad, torch_param.grad.half(), rtol=1e-3, atol=1e-3)
# step
amp_optimizer.step()
torch_optimizer.step()
# check updated param
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
torch.allclose(amp_param, torch_param.half(), rtol=1e-3, atol=1e-3)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
run_naive_amp()
@pytest.mark.dist
def test_naive_amp():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_naive_amp()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment