Commit deb8370c authored by hepj's avatar hepj
Browse files

Initial commit

parents
Pipeline #2198 canceled with stages
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron grad scaler."""
from abc import ABC, abstractmethod
from typing import Dict
import torch
class MegatronGradScaler(ABC):
def __init__(self, initial_scale: float):
"""Initialize scale value with the input initial scale."""
assert initial_scale > 0.0
self._scale = torch.tensor([initial_scale], dtype=torch.float, device='cuda')
@property
def scale(self):
return self._scale
@property
def inv_scale(self):
return self._scale.double().reciprocal().float()
@abstractmethod
def update(self, found_inf: bool):
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict: Dict):
pass
class ConstantGradScaler(MegatronGradScaler):
"""
Constant grad scaler (loss scale is never adjusted regardless of NaNs seen in gradients).
"""
def update(self, found_inf: bool):
pass
def state_dict(self):
return dict()
def load_state_dict(self, state_dict):
pass
class DynamicGradScaler(MegatronGradScaler):
"""
Grad scaler with dynamic scale that gets adjusted during training.
Reduces loss scale by `backoff_factor` if `hysteresis` number of NaNs are seen in a row. Increases
loss scale by `growth_factor` if NaNs are not seen for `growth_interval` iterations.
"""
def __init__(
self,
initial_scale: float,
min_scale: float,
growth_factor: float,
backoff_factor: float,
growth_interval: int,
hysteresis: int,
):
"""
Grad scaler with dynamic scale that gets adjusted during training.
Args:
initial_scale (float): Initial loss scale value.
min_scale (float): Minimum loss scale value.
growth_factor (float): Factor to grow loss scale by if NaNs are not seen in `growth_interval`
training iterations. Must be greater than 1.
backoff_factor (float): Factor to decrease loss scale by if NaNs are seen in `hysteresis`
consecutive training iterations. Must be between 0 and 1.
growth_interval (int): Number of training iterations of no NaNs before loss scale is increased.
hysteresis (int): Number of training iterations of consecutive NaNs before loss scale is decreased.
"""
super(DynamicGradScaler, self).__init__(initial_scale)
# Lower bound on the scale.
assert min_scale > 0.0
assert min_scale <= initial_scale
self.min_scale = torch.tensor([min_scale], dtype=torch.float, device='cuda')
# Growth and backoff factors for the scale.
assert growth_factor > 1.0
self.growth_factor = torch.tensor([growth_factor], dtype=torch.float, device='cuda')
assert backoff_factor < 1.0
assert backoff_factor > 0.0
self.backoff_factor = torch.tensor([backoff_factor], dtype=torch.float, device='cuda')
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert growth_interval > 0
self.growth_interval = growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert hysteresis > 0
self.hysteresis = hysteresis
# Trackers.
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
def update(self, found_inf: bool):
"""
Updates internal state in grad scaler based on whether NaNs are seen in grads or not.
"""
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if found_inf:
self._growth_tracker = 0
self._hysteresis_tracker -= 1
# Now if we are out of hysteresis count, scale down the loss.
if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale)
else:
# If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1
# If we have had enough consequitive intervals with no nan/inf:
if self._growth_tracker == self.growth_interval:
# Reset the tracker and hysteresis trackers,
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
# and scale up the loss scale.
self._scale = self._scale * self.growth_factor
def state_dict(self):
state_dict = {}
state_dict['scale'] = self._scale
state_dict['growth_tracker'] = self._growth_tracker
state_dict['hysteresis_tracker'] = self._hysteresis_tracker
return state_dict
def load_state_dict(self, state_dict: Dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._growth_tracker = state_dict['growth_tracker']
self._hysteresis_tracker = state_dict['hysteresis_tracker']
from .cpu_adam import CPUAdam
from .hybrid_adam import HybridAdam
__all__ = [
'CPUAdam',
'HybridAdam'
]
\ No newline at end of file
# Copyright (c) 2024 Alibaba PAI, ColossalAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List
import os
import subprocess
import re
def check_cuda_availability():
"""
Check if CUDA is available on the system.
Returns:
A boolean value. True if CUDA is available and False otherwise.
"""
import torch
return torch.cuda.is_available()
def set_cuda_arch_list(cuda_dir):
"""
This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation.
Ahead-of-time compilation occurs when BUILD_EXT=1 is set when running 'pip install'.
"""
cuda_available = check_cuda_availability()
# we only need to set this when CUDA is not available for cross-compilation
if not cuda_available:
warnings.warn(
"\n[extension] PyTorch did not find available GPUs on this system.\n"
"If your intention is to cross-compile, this is not an error.\n"
"By default, Colossal-AI will cross-compile for \n"
"1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
"2. Volta (compute capability 7.0)\n"
"3. Turing (compute capability 7.5),\n"
"4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n"
"\nIf you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n'
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"]
if int(bare_metal_major) == 11:
if int(bare_metal_minor) == 0:
arch_list.append("8.0")
else:
arch_list.append("8.0")
arch_list.append("8.6")
arch_list_str = ";".join(arch_list)
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str
return False
return True
def check_pytorch_version(min_major_version, min_minor_version) -> bool:
"""
Compare the current PyTorch version with the minium required version.
Args:
min_major_version (int): the minimum major version of PyTorch required
min_minor_version (int): the minimum minor version of PyTorch required
Returns:
A boolean value. The value is True if the current pytorch version is acceptable and False otherwise.
"""
# get pytorch version
torch_major, torch_minor, _ = get_pytorch_version()
# if the
if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version):
raise RuntimeError(
f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n"
"The latest stable release can be obtained from https://pytorch.org/get-started/locally/"
)
def get_cuda_version_in_pytorch() -> List[int]:
"""
This function returns the CUDA version in the PyTorch build.
Returns:
The CUDA version required by PyTorch, in the form of tuple (major, minor).
"""
import torch
try:
torch_cuda_major = torch.version.cuda.split(".")[0]
torch_cuda_minor = torch.version.cuda.split(".")[1]
except:
raise ValueError(
"[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda"
)
return torch_cuda_major, torch_cuda_minor
def check_system_pytorch_cuda_match(cuda_dir):
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch()
if bare_metal_major != torch_cuda_major:
raise Exception(
f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) "
f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})."
"Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ."
)
if bare_metal_minor != torch_cuda_minor:
warnings.warn(
f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
"The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. "
"If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions"
)
return True
def get_pytorch_version() -> List[int]:
"""
This functions finds the PyTorch version.
Returns:
A tuple of integers in the form of (major, minor, patch).
"""
import torch
torch_version = torch.__version__.split("+")[0]
TORCH_MAJOR = int(torch_version.split(".")[0])
TORCH_MINOR = int(torch_version.split(".")[1])
TORCH_PATCH = int(torch_version.split(".")[2], 16)
return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH
def get_cuda_bare_metal_version(cuda_dir) -> List[int]:
"""
Get the System CUDA version from nvcc.
Args:
cuda_dir (str): the directory for CUDA Toolkit.
Returns:
The CUDA version required by PyTorch, in the form of tuple (major, minor).
"""
nvcc_path = os.path.join(cuda_dir, "bin/nvcc")
if cuda_dir is None:
raise ValueError(
f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly."
)
# check for nvcc path
if not os.path.exists(nvcc_path):
raise FileNotFoundError(
f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME."
)
# parse the nvcc -v output to obtain the system cuda version
try:
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
except:
raise ValueError(
f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}"
)
return bare_metal_major, bare_metal_minor
def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:
"""
This function appends the threads flag to your nvcc args.
Returns:
The nvcc compilation flags including the threads flag.
"""
from torch.utils.cpp_extension import CUDA_HOME
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
def get_cuda_cc_flag() -> List[str]:
"""
This function produces the cc flags for your GPU arch
Returns:
The CUDA cc flags for compilation.
"""
# only import torch when needed
# this is to avoid importing torch when building on a machine without torch pre-installed
# one case is to build wheel for pypi release
import torch
cc_flag = []
max_arch = "".join(str(i) for i in torch.cuda.get_device_capability())
for arch in torch.cuda.get_arch_list():
res = re.search(r"sm_(\d+)", arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch):
cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"])
return cc_flag
# Copyright (c) 2024 Alibaba PAI, ColossalAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .kernel_loader import CPUAdamLoader
from .nvme_optimizer import NVMeOptimizer
class CPUAdam(NVMeOptimizer):
"""
Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depending on the device of parameters.
But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
`CPUAdam` requires CUDA extensions which can be built during installation or runtime.
This version of CPU Adam accelerates parameters updating on CPU with SIMD.
Support of AVX2 or AVX512 is required.
The GPU part is implemented in an naive way.
CPU Adam also supports the hybrid precision calculation, eg. fp32 parameters and fp16 gradients.
:class:`colossalai.nn.optimizer.CPUAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adamw_mode=False``
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
model_params (iterable): iterable of parameters of dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False)
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
If it's ``None``, a random temporary directory will be used. Defaults to None.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
num_fp32_shards_per_param = 4
def __init__(
self,
model_params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
adamw_mode=True,
nvme_offload_fraction: float = 0.0,
nvme_offload_dir: Optional[str] = None,
):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
cpu_adam = CPUAdamLoader().load()
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
def torch_adam_update(
self,
data,
grad,
exp_avg,
exp_avg_sq,
lr,
beta1,
beta2,
eps,
weight_decay,
bias_correction1,
bias_correction2,
use_adamw=False,
):
grad = grad.to(data.dtype)
if weight_decay != 0:
if use_adamw:
data.mul_(1 - lr * weight_decay)
else:
grad = grad.add(data, alpha=weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# TODO(jiaruifang) dose not support amsgrad
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
data.addcdiv_(exp_avg, denom, value=-step_size)
@torch.no_grad()
def step(self, closure=None, div_scale: float = -1):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
self._pre_step("exp_avg", "exp_avg_sq")
for _, group in enumerate(self.param_groups):
for _, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
target_device = p.device
if len(state) == 0:
state["step"] = 0
# gradient momentums
state["exp_avg"] = torch.zeros_like(p, device=target_device)
# gradient variances
state["exp_avg_sq"] = torch.zeros_like(p, device=target_device)
self._post_state_init(p)
state["step"] += 1
beta1, beta2 = group["betas"]
if target_device.type == "cpu":
assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size"
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
self.torch_adam_update(
p.data,
p.grad.data,
state["exp_avg"],
state["exp_avg_sq"],
group["lr"],
beta1,
beta2,
group["eps"],
group["weight_decay"],
bias_correction1,
bias_correction2,
self.adamw_mode,
)
else:
self.cpu_adam_op.step(
state["step"],
group["lr"],
beta1,
beta2,
group["eps"],
group["weight_decay"],
group["bias_correction"],
p.data,
p.grad.data,
state["exp_avg"],
state["exp_avg_sq"],
div_scale,
)
self._post_update(p, "exp_avg", "exp_avg_sq")
elif target_device.type == "cuda":
assert div_scale == -1, "div_scale should remain default"
assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda"
assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda"
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
# adam on cuda
self.torch_adam_update(
p.data,
p.grad.data,
state["exp_avg"],
state["exp_avg_sq"],
group["lr"],
beta1,
beta2,
group["eps"],
group["weight_decay"],
bias_correction1,
bias_correction2,
self.adamw_mode,
)
else:
raise RuntimeError
self._post_step()
return loss
/*
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#include "cpu_adam.h"
#include <math.h>
#include <omp.h>
#include <string.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
// C++ interface
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
AVX_Data weight_decay_4;
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4;
this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4);
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
}
AVX_Data momentum_4;
this->simd_load(momentum_half_precision, _exp_avg + i,
momentum_cast_h + i, momentum_4);
AVX_Data variance_4;
this->simd_load(variance_half_precision, _exp_avg_sq + i,
variance_cast_h + i, variance_4);
AVX_Data param_4;
this->simd_load(param_half_precision, _params + i, params_cast_h + i,
param_4);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data);
}
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
momentum_4.data =
SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
variance_4.data =
SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
grad_4.data = SIMD_SQRT(variance_4.data);
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
if (_weight_decay > 0 && _adamw_mode) {
param_4.data =
SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data);
}
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
this->simd_store(param_half_precision, _params + i, params_cast_h + i,
param_4);
this->simd_store(momentum_half_precision, _exp_avg + i,
momentum_cast_h + i, momentum_4);
this->simd_store(variance_half_precision, _exp_avg_sq + i,
variance_cast_h + i, variance_4);
}
}
#endif
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k];
if (loss_scale > 0) {
grad /= loss_scale;
}
float param =
param_half_precision ? (float)params_cast_h[k] : _params[k];
float momentum =
momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k];
float variance = variance_half_precision ? (float)variance_cast_h[k]
: _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) {
param += w_decay * param;
}
param = grad * step_size + param;
if (param_half_precision)
params_cast_h[k] = (__half)param;
else
_params[k] = param;
if (momentum_half_precision)
momentum_cast_h[k] = (__half)(momentum);
else
_exp_avg[k] = momentum;
if (variance_half_precision)
variance_cast_h[k] = (__half)(variance);
else
_exp_avg_sq[k] = variance;
}
}
}
}
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
float betta2_minus1 = 1 - _betta2;
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay_4;
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
AVX_Data grad_4[4];
AVX_Data momentum_4[4];
AVX_Data variance_4[4];
AVX_Data param_4[4];
#pragma unroll 4
for (int j = 0; j < 4; j++) {
this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_load(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
}
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
momentum_4[j].data =
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
variance_4[j].data =
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
}
param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_store(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
}
}
}
#endif
if (_param_size > rounded_size)
Step_1((param_half_precision ? (float *)(params_cast_h + rounded_size)
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
float betta2_minus1 = 1 - _betta2;
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay_4;
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
AVX_Data grad_4[8];
AVX_Data momentum_4[8];
AVX_Data variance_4[8];
AVX_Data param_4[8];
#pragma unroll 8
for (int j = 0; j < 8; j++) {
this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_load(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
}
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
momentum_4[j].data =
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
variance_4[j].data =
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
}
param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_store(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
}
}
}
#endif
if (_param_size > rounded_size)
Step_4((param_half_precision ? (float *)(params_cast_h + rounded_size)
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay,
bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale) {
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float *params_ptr = (float *)params_c.data_ptr();
float *grads_ptr = (float *)grads_c.data_ptr();
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
this->IncrementStep(step, beta1, beta2);
this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
params_c.numel(), (params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf),
(exp_avg.options().dtype() == at::kHalf),
(exp_avg_sq.options().dtype() == at::kHalf), loss_scale);
}
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<Adam_Optimizer>(m, "CPUAdamOptimizer")
.def(py::init<float, float, float, float, float, bool>())
.def("step", &Adam_Optimizer::step);
}
/*
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#pragma once
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <torch/extension.h>
#if (__x86_64__ || __i386__)
#include <cpuid.h>
#include <x86intrin.h>
#endif
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
#if defined(__AVX512__)
#define SIMD_WIDTH 16
#define INTV __m256i
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
#define SIMD_SET(x) _mm512_set1_ps(x)
#define SIMD_ADD(x, y) _mm512_add_ps(x, y)
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__)
#define SIMD_WIDTH 8
#define INTV __m128i
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
#define SIMD_SET(x) _mm256_set1_ps(x)
#define SIMD_ADD(x, y) _mm256_add_ps(x, y)
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))
#endif
union AVX_Data {
#if defined(__AVX512__)
__m512 data;
#elif defined(__AVX256__) or defined(__AVX2__)
__m256 data;
#endif
// float data_f[16];
};
#endif
#define STEP(SPAN) \
void Step_##SPAN( \
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
size_t _param_size, bool param_half_precision = false, \
bool grad_half_precision = false, bool momentum_half_precision = false, \
bool variance_half_precision = false, float loss_scale = -1);
class Adam_Optimizer {
public:
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {}
~Adam_Optimizer() {}
STEP(1)
STEP(4)
STEP(8)
inline void IncrementStep(size_t step, float beta1, float beta2) {
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
}
inline void update_state(float lr, float epsilon, float weight_decay,
bool bias_correction) {
_alpha = lr;
_eps = epsilon;
_weight_decay = weight_decay;
_bias_correction1 = 1.0f;
_bias_correction2 = 1.0f;
if (bias_correction == 1) {
_bias_correction1 = 1 - _betta1_t;
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
}
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
data.data = SIMD_LOAD_HALF(h_ptr);
} else {
data.data = SIMD_LOAD(ptr);
}
}
inline void simd_store(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
SIMD_STORE_HALF(h_ptr, data.data);
} else {
SIMD_STORE(ptr, data.data);
}
}
#endif
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale);
private:
float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;
float _betta1_t;
float _betta2_t;
size_t _step;
float _bias_correction1;
float _bias_correction2;
bool _adamw_mode;
};
#include "cpu_adam_arm.h"
void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4 = vdivq_f32(grad_4, loss_scale_vec);
}
float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i);
float32x4_t variance_4 =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i);
float32x4_t param_4 = simd_load_offset(_params, param_dtype, i);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4);
}
momentum_4 = vmulq_f32(momentum_4, betta1_4);
momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4);
variance_4 = vmulq_f32(variance_4, betta2_4);
grad_4 = vmulq_f32(grad_4, grad_4);
variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4);
grad_4 = vsqrtq_f32(variance_4);
grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt);
grad_4 = vdivq_f32(momentum_4, grad_4);
if (_weight_decay > 0 && _adamw_mode) {
param_4 = vfmaq_f32(param_4, param_4, weight_decay_4);
}
param_4 = vfmaq_f32(param_4, grad_4, step_size_4);
simd_store_offset(_params, param_dtype, param_4, i);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i);
}
}
#endif
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = scalar_load_offset(grads, grad_dtype, k);
if (loss_scale > 0) {
grad /= loss_scale;
}
float param = scalar_load_offset(_params, param_dtype, k);
float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k);
float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k);
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) {
param += w_decay * param;
}
param = grad * step_size + param;
scalar_store_offset(_params, param_dtype, param, k);
scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k);
scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k);
}
}
}
}
void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
float32x4_t grad_4[4];
float32x4_t momentum_4[4];
float32x4_t variance_4[4];
float32x4_t param_4[4];
#pragma unroll 4
for (int j = 0; j < 4; j++) {
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
}
momentum_4[j] =
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
variance_4[j] =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
}
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
grad_4[j] = vsqrtq_f32(variance_4[j]);
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
}
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
i + SIMD_WIDTH * j);
}
}
}
#endif
if (_param_size > rounded_size) {
Step_1(scalar_seek_offset(_params, param_dtype, rounded_size),
scalar_seek_offset(grads, grad_dtype, rounded_size),
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
exp_avg_sq_dtype, loss_scale);
}
}
void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
float32x4_t grad_4[8];
float32x4_t momentum_4[8];
float32x4_t variance_4[8];
float32x4_t param_4[8];
#pragma unroll 4
for (int j = 0; j < 8; j++) {
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
}
momentum_4[j] =
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
variance_4[j] =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
}
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
grad_4[j] = vsqrtq_f32(variance_4[j]);
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
}
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
i + SIMD_WIDTH * j);
}
}
}
#endif
if (_param_size > rounded_size) {
Step_4(scalar_seek_offset(_params, param_dtype, rounded_size),
scalar_seek_offset(grads, grad_dtype, rounded_size),
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
exp_avg_sq_dtype, loss_scale);
}
}
void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay,
bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale) {
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
this->IncrementStep(step, beta1, beta2);
this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(),
exp_avg_sq_c.data_ptr(), params_c.numel(),
params_c.scalar_type(), grads_c.scalar_type(),
exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale);
}
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<AdamOptimizer>(m, "CPUAdamOptimizer")
.def(py::init<float, float, float, float, float, bool>())
.def("step", &AdamOptimizer::step);
}
#pragma once
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <cmath>
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#define TILE (128 * 1024 * 1024)
#if defined(__aarch64__)
#include <arm_neon.h>
#define SIMD_WIDTH 4
inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float: {
auto ptr_f = reinterpret_cast<const float32_t *>(ptr);
return vld1q_f32(ptr_f + offset);
}
case at::ScalarType::Half: {
auto ptr_h = reinterpret_cast<const float16_t *>(ptr);
return vcvt_f32_f16(vld1_f16(ptr_h + offset));
}
// case at::ScalarType::BFloat16: {
// auto ptr_b = reinterpret_cast<const bfloat16_t *>(ptr);
// return vcvt_f32_bf16(vld1_bf16(ptr_b + offset));
// }
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) {
return simd_load_offset(ptr, dtype, 0);
}
inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float: {
auto ptr_f = reinterpret_cast<float32_t *>(ptr);
vst1q_f32(ptr_f + offset, data);
break;
}
case at::ScalarType::Half: {
auto ptr_h = reinterpret_cast<float16_t *>(ptr);
vst1_f16(ptr_h + offset, vcvt_f16_f32(data));
break;
}
// case at::ScalarType::BFloat16: {
// auto ptr_b = reinterpret_cast<bfloat16_t *>(ptr);
// vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data));
// break;
// }
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) {
return simd_store_offset(ptr, dtype, data, 0);
}
inline float32x4_t simd_set(float value) {
auto val = static_cast<float32_t>(value);
return vdupq_n_f32(val);
}
#endif
inline float scalar_load_offset(const void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
return *(reinterpret_cast<const float *>(ptr) + offset);
case at::ScalarType::Half:
return static_cast<float>(
*(reinterpret_cast<const at::Half *>(ptr) + offset));
// case at::ScalarType::BFloat16:
// return static_cast<float>(
// *(reinterpret_cast<const at::BFloat16 *>(ptr) + offset));
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
*(reinterpret_cast<float *>(ptr) + offset) = data;
break;
case at::ScalarType::Half:
*(reinterpret_cast<at::Half *>(ptr) + offset) = data;
break;
// case at::ScalarType::BFloat16:
// *(reinterpret_cast<at::BFloat16 *>(ptr) + offset) = data;
break;
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
return reinterpret_cast<float *>(ptr) + offset;
case at::ScalarType::Half:
return reinterpret_cast<at::Half *>(ptr) + offset;
// case at::ScalarType::BFloat16:
// return reinterpret_cast<at::BFloat16 *>(ptr) + offset;
default:
AT_ERROR("Unsupported dtype");
break;
}
}
#define STEP(SPAN) \
void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \
void *_exp_avg_sq, size_t _param_size, \
at::ScalarType param_dtype, at::ScalarType grad_dtype, \
at::ScalarType exp_avg_dtype, \
at::ScalarType exp_avg_sq_dtype, float loss_scale = -1);
class AdamOptimizer {
private:
float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;
float _betta1_t;
float _betta2_t;
size_t _step;
float _bias_correction1;
float _bias_correction2;
bool _adamw_mode;
public:
AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {}
~AdamOptimizer() {}
STEP(1)
STEP(4)
STEP(8)
inline void IncrementStep(size_t step, float beta1, float beta2) {
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
}
inline void update_state(float lr, float epsilon, float weight_decay,
bool bias_correction) {
_alpha = lr;
_eps = epsilon;
_weight_decay = weight_decay;
_bias_correction1 = 1.0f;
_bias_correction2 = 1.0f;
if (bias_correction == 1) {
_bias_correction1 = 1 - _betta1_t;
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
}
}
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale);
};
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#pragma once
#include <ATen/ATen.h>
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
if (HIGH_PRECISION) { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
} else { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: { \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Float && \
PTYPE == at::ScalarType::Half) { \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Half && \
PTYPE == at::ScalarType::Float) { \
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Float && \
PTYPE == at::ScalarType::BFloat16) { \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::BFloat16 && \
PTYPE == at::ScalarType::Float) { \
using g_scalar_t_##LEVEL = at::BFloat16; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::BFloat16 && \
PTYPE == at::ScalarType::BFloat16) { \
using g_scalar_t_##LEVEL = at::BFloat16; \
using p_scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
} else { \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
"'"); \
}
#if defined(COLOSSAL_WITH_CUDA)
#define HOST __host__
#define DEVICE __device__
#define HOSTDEVICE __host__ __device__
#else
#define HOST
#define DEVICE
#define HOSTDEVICE
#endif
\ No newline at end of file
// Copyright (c) 2024 Alibaba PAI, ColossalAI and Nvidia Megatron-LM Team.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "micros.h"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum {
ADAM_MODE_0 = 0, // L2 regularization mode
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
} adamMode_t;
using MATH_T = float;
template <typename T_g, typename T_p>
struct AdamFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta1_correction,
const float beta2_correction, const float epsilon, const float lr,
adamMode_t mode, const float decay, const float div_scale) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T_g *g = (T_g *)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
T_p *p = (T_p *)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
T_p *m = (T_p *)tl.addresses[2][tensor_loc];
m += chunk_idx * chunk_size;
T_p *v = (T_p *)tl.addresses[3][tensor_loc];
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (div_scale > 0) r_g[ii] /= div_scale;
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction, const float weight_decay,
const float div_scale) {
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
DISPATCH_FLOAT_AND_HALF_FOR_G_P(
tensor_lists[0][0].scalar_type(), tensor_lists[1][0].scalar_type(), 0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay, div_scale);)
AT_CUDA_CHECK(cudaGetLastError());
}
\ No newline at end of file
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <c10/cuda/CUDAGuard.h>
#include "micros.h"
// #include <iostream>
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template <int n>
struct TensorListMetadata {
void *addresses[n][depth_to_max_tensors[n - 1]];
int sizes[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a
// full int.
int start_tensor_this_launch;
};
template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int chunk_size,
volatile int *noop_flag, T tl,
U callable, ArgTypes... args) {
// Hand the chunk information to the user-supplied functor to process however
// it likes.
callable(chunk_size, noop_flag, tl, args...);
}
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
int block_size, int chunk_size, const at::Tensor &noop_flag,
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
ArgTypes... args) {
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size();
l++) // No range-based for because I need indices
{
TORCH_CHECK(tensor_lists[l].size() == len0,
"Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory =
(contiguous_memory ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
"A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(),
"Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for (int t = 0; t < ntensors; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
int chunks_this_tensor =
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) {
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...);
AT_CUDA_CHECK(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
if (chunk == chunks_this_tensor - 1) {
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3
// << std::endl;
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
} else {
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3
// << std::endl;
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
for (int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}
\ No newline at end of file
// Copyright (c) 2024 Alibaba PAI, ColossalAI and Nvidia Megatron-LM Team.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/hpcaitech/ColossalAI/blob/main/extensions/pybind/optimizer/optimizer.cpp
#include <torch/extension.h>
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction, const float weight_decay,
const float div_scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
}
\ No newline at end of file
# Copyright (c) 2024 Alibaba PAI, ColossalAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import torch
from .kernel_loader import FusedOptimizerLoader
try:
from transformer_engine.pytorch.optimizers import multi_tensor_applier
except ImportError:
try:
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
from megatron.core.utils import local_multi_tensor_applier
multi_tensor_applier = local_multi_tensor_applier
from .cpu_adam import CPUAdam
class HybridAdam(CPUAdam):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depending on the device of parameters.
But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
`HybridAdam` requires CUDA extensions which can be built during installation or runtime.
This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam.
* For parameters updating on CPU, it uses CPUAdam.
* For parameters updating on GPU, it uses FusedAdam.
* Hybrid precision calculation of fp16 and fp32 is supported, eg fp32 parameters and fp16 gradients.
:class:`colossalai.nn.optimizer.HybridAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adamw_mode=False``
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
model_params (iterable): iterable of parameters of dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False)
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
If it's ``None``, a random temporary directory will be used. Defaults to None.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
num_fp32_shards_per_param = 4
def __init__(
self,
model_params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
adamw_mode=True,
nvme_offload_fraction: float = 0.0,
nvme_offload_dir: Optional[str] = None,
**defaults: Any,
):
super().__init__(
model_params,
lr,
bias_correction,
betas,
eps,
weight_decay,
adamw_mode,
nvme_offload_fraction,
nvme_offload_dir,
)
if torch.cuda.is_available():
fused_optim = FusedOptimizerLoader().load()
self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=torch.cuda.current_device())
@torch.no_grad()
def step(self, closure=None, div_scale: float = -1):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
self._pre_step("exp_avg", "exp_avg_sq")
for _, group in enumerate(self.param_groups):
g_l, p_l, m_l, v_l = [], [], [], []
group_step = 0
for _, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
target_device = p.device
if len(state) == 0:
state["step"] = 0
# gradient momentums
state["exp_avg"] = torch.zeros_like(p, device=target_device)
# gradient variances
state["exp_avg_sq"] = torch.zeros_like(p, device=target_device)
self._post_state_init(p)
state["step"] += 1
group_step = state["step"]
beta1, beta2 = group["betas"]
if target_device.type == "cpu" or target_device.type == "npu":
assert state["exp_avg"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
if p.grad.dtype is torch.bfloat16 or p.grad.device.type == "npu":
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
self.torch_adam_update(
p.data,
p.grad.data,
state["exp_avg"],
state["exp_avg_sq"],
group["lr"],
beta1,
beta2,
group["eps"],
group["weight_decay"],
bias_correction1,
bias_correction2,
self.adamw_mode,
)
else:
self.cpu_adam_op.step(
state["step"],
group["lr"],
beta1,
beta2,
group["eps"],
group["weight_decay"],
group["bias_correction"],
p.data,
p.grad.data,
state["exp_avg"],
state["exp_avg_sq"],
div_scale,
)
self._post_update(p, "exp_avg", "exp_avg_sq")
elif target_device.type == "cuda":
assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda"
assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda"
# record the state by group and update at once
g_l.append(p.grad.data)
p_l.append(p.data)
m_l.append(state["exp_avg"])
v_l.append(state["exp_avg_sq"])
else:
raise RuntimeError
if len(g_l) > 0:
adamw_mode = 1 if self.adamw_mode else 0
bias_correction = 1 if group["bias_correction"] else 0
multi_tensor_applier(
self.gpu_adam_op,
self._dummy_overflow_buf,
[g_l, p_l, m_l, v_l],
group["lr"],
group["betas"][0],
group["betas"][1],
group["eps"],
group_step,
adamw_mode,
bias_correction,
group["weight_decay"],
div_scale,
)
self._post_step()
return loss
\ No newline at end of file
# Copyright (c) 2024 Alibaba PAI, ColossalAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import importlib
import os
import platform
import time
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, List, Union
from torch.utils.cpp_extension import CppExtension, CUDA_HOME, CUDAExtension
from ._utils import *
class _Extension(ABC):
def __init__(self, name: str, support_aot: bool, support_jit: bool, priority: int = 1):
self._name = name
self._support_aot = support_aot
self._support_jit = support_jit
self.priority = priority
@property
def name(self):
return self._name
@property
def support_aot(self):
return self._support_aot
@property
def support_jit(self):
return self._support_jit
@staticmethod
def get_jit_extension_folder_path(name):
"""
Kernels which are compiled during runtime will be stored in the same cache folder for reuse.
The folder is in the path ~/.cache/megatron_patch/torch_extensions/<cache-folder>.
The name of the <cache-folder> follows a common format:
torch<torch_version_major>.<torch_version_minor>_<device_name><device_version>-<hash>
The <hash> suffix is the hash value of the path of the `megatron_patch` file.
"""
import megatron_patch
import torch
from torch.version import cuda
assert name in ["cpu", "cuda"], f"the argument `name` should be `cpu` or `cuda`!"
# get torch version
torch_version_major = torch.__version__.split(".")[0]
torch_version_minor = torch.__version__.split(".")[1]
# get device version
device_name = name
device_version = cuda if name == 'cuda' else ''
# use colossalai's file path as hash
hash_suffix = hashlib.sha256(megatron_patch.__file__.encode()).hexdigest()
# concat
home_directory = os.path.expanduser("~")
extension_directory = f".cache/megatron_patch/torch_extensions/torch{torch_version_major}.{torch_version_minor}_{device_name}-{device_version}-{hash_suffix}"
cache_directory = os.path.join(home_directory, extension_directory)
return cache_directory
@abstractmethod
def is_available(self) -> bool:
"""
Check if the hardware required by the kernel is available.
"""
@abstractmethod
def assert_compatible(self) -> None:
"""
Check if the hardware required by the kernel is compatible.
"""
@abstractmethod
def build_aot(self) -> Union["CppExtension", "CUDAExtension"]:
pass
@abstractmethod
def build_jit(self) -> Callable:
pass
@abstractmethod
def load(self) -> Callable:
pass
__all__ = [
"CPUAdamLoader",
]
# Some constants for installation checks
MIN_PYTORCH_VERSION_MAJOR = 1
MIN_PYTORCH_VERSION_MINOR = 10
class _CppExtension(_Extension):
def __init__(self, name: str, priority: int = 1):
super().__init__(name, support_aot=True, support_jit=True, priority=priority)
# we store the op as an attribute to avoid repeated building and loading
self.cached_op = None
# build-related variables
self.prebuilt_module_path = "megatron_patch._C"
self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}"
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
def csrc_abs_path(self, path):
return os.path.join(self.relative_to_abs_path("csrc"), path)
def relative_to_abs_path(self, code_path: str) -> str:
"""
This function takes in a path relative to the root directory and return the absolute path.
"""
# get the current file path
# iteratively check the parent directory
# if the parent directory is "hybrid_adam", then the current file path is the root directory
# otherwise, the current file path is inside the root directory
current_file_path = Path(__file__)
while True:
if current_file_path.name == "hybrid_adam":
break
else:
current_file_path = current_file_path.parent
extension_module_path = current_file_path
code_abs_path = extension_module_path.joinpath(code_path)
return str(code_abs_path)
# functions must be overrided over
def strip_empty_entries(self, args):
"""
Drop any empty strings from the list of compile and link flags
"""
return [x for x in args if len(x) > 0]
def import_op(self):
"""
This function will import the op module by its string name.
"""
return importlib.import_module(self.prebuilt_import_path)
def build_aot(self) -> "CppExtension":
return CppExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
)
def build_jit(self) -> None:
from torch.utils.cpp_extension import load
build_directory = _Extension.get_jit_extension_folder_path("cpu")
build_directory = Path(build_directory)
build_directory.mkdir(parents=True, exist_ok=True)
# check if the kernel has been built
compiled_before = False
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
if kernel_file_path.exists():
compiled_before = True
# load the kernel
if compiled_before:
print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
else:
print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
build_start = time.time()
op_kernel = load(
name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_ldflags=[],
build_directory=str(build_directory),
)
build_duration = time.time() - build_start
if compiled_before:
print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
else:
print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
return op_kernel
# functions must be overrided begin
@abstractmethod
def sources_files(self) -> List[str]:
"""
This function should return a list of source files for extensions.
"""
@abstractmethod
def include_dirs(self) -> List[str]:
"""
This function should return a list of include files for extensions.
"""
return [self.csrc_abs_path("")]
@abstractmethod
def cxx_flags(self) -> List[str]:
"""
This function should return a list of cxx compilation flags for extensions.
"""
def load(self):
try:
op_kernel = self.import_op()
except (ImportError, ModuleNotFoundError):
# if import error occurs, it means that the kernel is not pre-built
# so we build it jit
op_kernel = self.build_jit()
return op_kernel
class _CudaExtension(_CppExtension):
@abstractmethod
def nvcc_flags(self) -> List[str]:
"""
This function should return a list of nvcc compilation flags for extensions.
"""
return ["-DCOLOSSAL_WITH_CUDA"]
def is_available(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
# torch.cuda.is_available requires a device to exist, allow building with cuda extension on build nodes without a device
# but where cuda is actually available.
cuda_available = torch.cuda.is_available() or bool(os.environ.get("FORCE_CUDA", 0))
except:
cuda_available = False
return cuda_available
def assert_compatible(self) -> None:
from torch.utils.cpp_extension import CUDA_HOME
if not CUDA_HOME:
raise AssertionError(
"[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions"
)
check_system_pytorch_cuda_match(CUDA_HOME)
check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR)
def get_cuda_home_include(self):
"""
return include path inside the cuda home.
"""
from torch.utils.cpp_extension import CUDA_HOME
if CUDA_HOME is None:
raise RuntimeError(
"CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI."
)
cuda_include = os.path.join(CUDA_HOME, "include")
return cuda_include
def include_dirs(self) -> List[str]:
"""
This function should return a list of include files for extensions.
"""
return super().include_dirs() + [self.get_cuda_home_include()]
def build_jit(self) -> None:
from torch.utils.cpp_extension import CUDA_HOME, load
set_cuda_arch_list(CUDA_HOME)
# get build dir
build_directory = _Extension.get_jit_extension_folder_path("cuda")
build_directory = Path(build_directory)
build_directory.mkdir(parents=True, exist_ok=True)
# check if the kernel has been built
compiled_before = False
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
if kernel_file_path.exists():
compiled_before = True
# load the kernel
if compiled_before:
print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
else:
print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
build_start = time.time()
op_kernel = load(
name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_cuda_cflags=self.nvcc_flags(),
extra_ldflags=[],
build_directory=str(build_directory),
)
build_duration = time.time() - build_start
if compiled_before:
print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
else:
print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
return op_kernel
def build_aot(self) -> "CUDAExtension":
set_cuda_arch_list(CUDA_HOME)
return CUDAExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args={
"cxx": self.strip_empty_entries(self.cxx_flags()),
"nvcc": self.strip_empty_entries(self.nvcc_flags()),
},
)
class CpuAdamX86Extension(_CudaExtension):
def __init__(self):
super().__init__(name="cpu_adam_x86")
def is_available(self) -> bool:
return platform.machine() == "x86_64" and super().is_available()
def assert_compatible(self) -> None:
arch = platform.machine()
assert (
arch == "x86_64"
), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}"
super().assert_compatible()
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam.cpp"),
]
return ret
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-lcudart",
"-lcublas",
"-g",
"-Wno-reorder",
"-fopenmp",
"-march=native",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
extra_cuda_flags = [
"-std=c++14",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
]
ret = (
["-O3", "--use_fast_math"]
+ self.version_dependent_macros
+ extra_cuda_flags
+ super().nvcc_flags()
)
return append_nvcc_threads(ret)
class CpuAdamArmExtension(_CppExtension):
def __init__(self):
super().__init__(name="cpu_adam_arm")
def is_available(self) -> bool:
# only arm allowed
return platform.machine() == "aarch64"
def assert_compatible(self) -> None:
arch = platform.machine()
assert (
arch == "aarch64"
), f"[extension] The {self.name} kernel requires the CPU architecture to be aarch64 but got {arch}"
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam_arm.cpp"),
]
return ret
def include_dirs(self) -> List[str]:
return super().include_dirs()
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-g",
"-Wno-reorder",
"-fopenmp",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
return []
class FusedOptimizerCudaExtension(_CudaExtension):
def __init__(self):
super().__init__(name="fused_optim_cuda")
def sources_files(self):
ret = [
self.csrc_abs_path(
"cuda/multi_tensor_adam_kernel.cu",
),
self.csrc_abs_path("optimizer.cpp"),
]
return ret
def cxx_flags(self):
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
return ["-O3"] + version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = ["-lineinfo"]
extra_cuda_flags.extend(get_cuda_cc_flag())
return ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags()
class KernelLoader:
"""
An abstract class which offers encapsulation to the kernel loading process.
Usage:
kernel_loader = KernelLoader()
kernel = kernel_loader.load()
"""
REGISTRY: List[_Extension] = []
@classmethod
def register_extension(cls, extension: _Extension):
"""
This classmethod is an extension point which allows users to register their customized
kernel implementations to the loader.
Args:
extension (_Extension): the extension to be registered.
"""
cls.REGISTRY.append(extension)
def load(self, ext_name: str = None):
"""
Load the kernel according to the current machine.
Args:
ext_name (str): the name of the extension to be loaded. If not specified, the loader
will try to look for an kernel available on the current machine.
"""
exts = [ext_cls() for ext_cls in self.__class__.REGISTRY]
# look for exts which can be built/loaded on the current machine
if ext_name:
usable_exts = list(filter(lambda ext: ext.name == ext_name, exts))
else:
usable_exts = []
for ext in exts:
if ext.is_available():
# make sure the machine is compatible during kernel loading
ext.assert_compatible()
usable_exts.append(ext)
assert (
len(usable_exts) != 0
), f"No usable kernel found for {self.__class__.__name__} on the current machine."
if len(usable_exts) > 1:
# if more than one usable kernel is found, we will try to load the kernel with the highest priority
usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True)
warnings.warn(
f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}"
)
return usable_exts[0].load()
class CPUAdamLoader(KernelLoader):
REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension]
class FusedOptimizerLoader(KernelLoader):
REGISTRY = [FusedOptimizerCudaExtension]
# Copyright (c) 2024 Alibaba PAI, ColossalAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import tempfile
from typing import Callable, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
class NVMeOptimizer(torch.optim.Optimizer):
"""A base class for offloading optimizer states.
Args:
params: parameters
defaults (dict): default dict
nvme_offload_fraction (float, optional): Fraction of params to be offloaded to NVMe. Defaults to 0.0.
offload_dir (Optional[str], optional): Directory to save NVMe offload files.
If it's ``None``, a random temporary directory will be used. Defaults to None.
Raises:
ImportError: Raise if ``tensornvme`` is not installed.
"""
def __init__(
self, params, defaults: dict, nvme_offload_fraction: float = 0.0, offload_dir: Optional[str] = None
) -> None:
assert 0.0 <= nvme_offload_fraction <= 1.0
super().__init__(params, defaults)
self.nvme_offload_fraction = float(nvme_offload_fraction)
if self.nvme_offload_fraction > 0.0:
try:
from tensornvme import DiskOffloader
from tensornvme._C import get_backends
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
self.offload_dir = offload_dir or tempfile.mkdtemp()
backend = "uring" if "uring" in get_backends() else "aio"
self.offloader = DiskOffloader(self.offload_dir, 8, backend=backend)
else:
self.offload_dir = None
self.offloader = None
self.is_on_nvme: Dict[Parameter, bool] = {}
self.offloaded_numel: int = 0
# As param may be not materialized here, these attributes are initialized when the first step
self.total_numel: Optional[int] = None
self.can_offload_numel: Optional[int] = None
self.prefetch_params: List[Parameter] = []
self.param_to_prefetch_idx: Dict[Parameter, int] = {}
def _get_numel(self) -> int:
numel = 0
for group in self.param_groups:
for p in group["params"]:
numel += p.storage().size()
return numel
def _post_state_init(self, param: Parameter) -> None:
numel = param.storage().size()
if (
self.offloader is not None
and param.device.type == "cpu"
and numel + self.offloaded_numel <= self.can_offload_numel
):
self.is_on_nvme[param] = True
self.offloaded_numel += numel
else:
self.is_on_nvme[param] = False
def _setup_prefetch_params(self) -> List[Parameter]:
if self.offloader is None:
return
assert len(self.prefetch_params) == 0 and len(self.param_to_prefetch_idx) == 0
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
if len(self.state[p]) > 0 and self.is_on_nvme[p]:
assert p.device.type == "cpu"
self.param_to_prefetch_idx[p] = len(self.prefetch_params)
self.prefetch_params.append(p)
def _pre_step(self, *state_keys: str) -> None:
if self.total_numel is None:
self.total_numel = self._get_numel()
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
self._setup_prefetch_params()
if self.offloader is None or len(self.prefetch_params) == 0:
return
state = self.state[self.prefetch_params[0]]
for key in state_keys:
self.offloader.async_read(state[key])
def _pre_update(self, param: Parameter, *state_keys: str) -> None:
if self.offloader is None or param not in self.param_to_prefetch_idx:
return
self.offloader.sync_read_events()
idx = self.param_to_prefetch_idx[param]
if idx + 1 < len(self.prefetch_params):
state = self.state[self.prefetch_params[idx + 1]]
for key in state_keys:
self.offloader.async_read(state[key])
def _post_update(self, param: Parameter, *state_keys: str) -> None:
if self.offloader is None:
return
self.offloader.sync_write_events()
if self.is_on_nvme[param]:
state = self.state[param]
for key in state_keys:
self.offloader.async_write(state[key])
def _post_step(self) -> None:
if self.offloader is not None:
self.offloader.synchronize()
self.prefetch_params.clear()
self.param_to_prefetch_idx.clear()
def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]:
"""Performs a single optimization step (parameter update).
Example:
>>> self._pre_step('exp_avg', 'exp_avg_sq')
>>> for group in self.param_groups:
>>> for p in group['params']:
>>> if p.grad is None:
>>> continue
>>> state = self.state[p]
>>> if len(state) == 0:
>>> state['exp_avg'] = ...
>>> state['exp_avg_sq'] = ...
>>> self._post_state_init(p)
>>> if p.device.type == 'cpu':
>>> self._pre_update(p, 'exp_avg', 'exp_avg_sq')
>>> adam()
>>> self._post_update(p, 'exp_avg', 'exp_avg_sq')
>>> else:
>>> ...
>>> self._post_step()
Args:
closure (Optional[Callable[[], float]], optional): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
"""
raise NotImplementedError
def state_dict(self) -> dict:
# TODO(ver217): design a new method to save state_dict. When using NVMe offload, this method may lead to OOM.
if self.offloader is not None:
raise NotImplementedError
return super().state_dict()
def load_state_dict(self, state_dict: dict) -> None:
# TODO(ver217): design a new method to load state_dict. When using NVMe offload, whole state_dict may not be able to fit in memory.
if self.offloader is not None:
raise NotImplementedError
super().load_state_dict(state_dict)
def __del__(self) -> None:
if getattr(self, "offloader", None) is not None:
del self.offloader
if os.path.exists(self.offload_dir):
try:
os.rmdir(self.offload_dir)
except OSError:
pass
from typing import List
from megatron.core.optimizer.optimizer import MegatronOptimizer
import torch
import math
from .optimizer import ChainedOptimizer, multi_tensor_applier, multi_tensor_scale_impl
from .offload_distrib_optimizer import OffloadDistributedOptimizer
class ChainedOffloadOptimizer(ChainedOptimizer):
def __init__(self, chained_optimizers: List[MegatronOptimizer]):
for optimizer in chained_optimizers:
if not isinstance(optimizer, OffloadDistributedOptimizer):
raise ValueError(
"ChainedOffloadOptimizer should only be used with OffloadDistributedOptimizer!"
)
self.chained_optimizers: List[OffloadDistributedOptimizer] = chained_optimizers
@torch.no_grad()
def prepare_grads(self, mem_stats) -> bool:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
found_inf_flag = False
for optimizer in self.chained_optimizers:
optimizer._mem_stats = mem_stats
found_inf_flag |= optimizer.prepare_grads()
return found_inf_flag
@torch.no_grad()
def step(self, mem_stats=None):
"""ChainedOptimizer will step all optimizers one by one."""
found_inf_flag = self.prepare_grads(mem_stats)
if found_inf_flag:
return False, None, None
# Get grad norm.
grad_norms = []
for optimizer in self.chained_optimizers:
_grad_norm = optimizer.get_grad_norm()
grad_norms += [_grad_norm if _grad_norm else 0.0]
grad_norm = math.sqrt(sum([x**2 for x in grad_norms]))
# Clip gradients.
for optimizer in self.chained_optimizers:
if optimizer.config.clip_grad > 0.0:
grads = []
for g in optimizer._main_grads:
assert g.type() == 'torch.cuda.FloatTensor'
grads.append(g.detach())
# Scale.
clip_coeff = optimizer.config.clip_grad / (grad_norm + 1.0e-6)
if clip_coeff < 1.0:
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
multi_tensor_applier(
multi_tensor_scale_impl, dummy_overflow_buf, [grads, grads], clip_coeff
)
# Count the zeros in the grads.
num_zeros_in_grad = 0
for optimizer in self.chained_optimizers:
num_zeros_in_grad += (
optimizer.count_zeros() if optimizer.config.log_num_zeros_in_grad else 0
)
update_successful = self.step_with_ready_grads()
return update_successful, grad_norm, num_zeros_in_grad
\ No newline at end of file
# Copyright (c) 2024 Alibaba PAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from typing import *
from megatron.training.memory_tracer import MemStats
from .. import tensor_parallel
from ..distributed import ParamAndGradBuffer
from ..transformer.module import param_is_not_shared
from .chunk import ChunkManager
from .chunk.manager import get_rank
from .clip_grads import get_grad_norm_fp32
from .distrib_optimizer import DistributedOptimizer
from .grad_scaler import MegatronGradScaler
from .hybrid_adam import CPUAdam
from .optimizer_config import OptimizerConfig
__all__ = ['OffloadDistributedOptimizer']
class OffloadDistributedOptimizer(DistributedOptimizer):
def _build_model_and_main_param_groups(self, *args, **kwargs):
"""
This function overrides DO._build_model_and_main_param_groups
"""
return None, None, None, None, None
def _build_model_and_main_param_groups_actual(
self,
gbuf_ranges: List[Dict],
param_gbuf_map: Dict[torch.nn.Parameter, Tuple],
opt_group_ranges: List,
):
"""
Create main parameter groups needed for the optimizer step.
These groups encompass both: 1) groups used by this class, for
reducing/gather, and 2) groups used by the inner optimizer for the
parameter update. Given that the conceptual grad buffer partitioning
(created in earlier method) doesn't respect parameter boundaries,
the optimizer operates on shards of the model parameters, rather than
the full parameters.
"""
# Parameter groups:
# model_float16_groups: original float16 parameters
# model_fp32_groups: original fp32 parameters
# shard_float16_groups: shards of original float16 parameters
# shard_fp32_groups: shards of original fp32 parameters
# shard_fp32_from_float16_groups: fp32 copy of float16 parameters
model_float16_groups = []
model_fp32_groups = []
shard_float16_groups = []
shard_fp32_groups = []
shard_fp32_from_float16_groups = []
shard_fp32_from_float32_groups = []
# Allocate (or slice) each group's param shard.
for group_range in opt_group_ranges:
# Params of this group.
model_float16_params_this_group = []
model_fp32_params_this_group = []
shard_float16_params_this_group = []
shard_fp32_params_this_group = []
shard_fp32_from_float16_params_this_group = []
shard_fp32_from_float32_params_this_group = []
model_float16_groups.append(model_float16_params_this_group)
model_fp32_groups.append(model_fp32_params_this_group)
# Views of each sharded parameters
shard_float16_groups.append(shard_float16_params_this_group)
shard_fp32_groups.append(shard_fp32_params_this_group)
# Hybrid FP32 copies of sharded parameters
shard_fp32_from_float16_groups.append(shard_fp32_from_float16_params_this_group)
shard_fp32_from_float32_groups.append(shard_fp32_from_float32_params_this_group)
for model_param in group_range["params"]:
assert model_param.requires_grad
gbuf_index, dtype, bucket_index = param_gbuf_map[model_param]
gbuf_range = gbuf_ranges[gbuf_index][dtype][bucket_index]
param_range = gbuf_range["param_map"][model_param]["param"]
# fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
# Clone model -> main.
shard_model_param = model_param.detach().view(-1)[
param_range.start : param_range.end
]
shard_main_param = shard_model_param.clone().float()
self.chunk_manager.register_tensor(
shard_main_param, 'shard_fp32_from_float16_params'
)
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param
)
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param
)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
shard_main_param.shared = model_param.shared
# Add to group.
model_float16_params_this_group.append(model_param)
shard_float16_params_this_group.append(shard_model_param)
# NOTE: view of shard params, possible on CPU or CUDA
shard_fp32_from_float16_params_this_group.append(shard_main_param)
# fp32 params.
elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
shard_main_param = shard_model_param.clone()
self.chunk_manager.register_tensor(
shard_main_param.clone().float(), 'shard_fp32_from_float16_params'
)
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param
)
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param
)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
shard_main_param.shared = model_param.shared
model_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param)
shard_fp32_from_float32_params_this_group.append(shard_main_param)
else:
raise TypeError(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(model_param.type())
)
# Update optimizer's params. [Hybrid]
group_range["orig_group"]["params"] = [
*shard_fp32_from_float32_params_this_group,
*shard_fp32_from_float16_params_this_group,
]
return (
model_float16_groups,
model_fp32_groups,
shard_float16_groups,
shard_fp32_groups,
shard_fp32_from_float16_groups,
shard_fp32_from_float32_groups,
)
def collect_shard_param_numel(
self,
gbuf_ranges: List[Dict],
param_gbuf_map: Dict[torch.nn.Parameter, Tuple],
opt_group_ranges: List,
):
numels = np.zeros([sum(len(group_range["params"]) for group_range in opt_group_ranges)])
ptr = 0
for group_range in opt_group_ranges:
for model_param in group_range["params"]:
assert model_param.requires_grad
gbuf_index, dtype, bucket_index = param_gbuf_map[model_param]
gbuf_range = gbuf_ranges[gbuf_index][dtype][bucket_index]
param_range = gbuf_range["param_map"][model_param]["param"]
numels[ptr] = param_range.end - param_range.start
ptr += 1
return numels
def __init__(
self,
optimizer: torch.optim.Optimizer,
config: OptimizerConfig,
grad_scaler: MegatronGradScaler,
init_state_fn: Optional[Callable],
per_model_buffers: Dict[int, List[ParamAndGradBuffer]],
data_parallel_group: torch.distributed.ProcessGroup,
data_parallel_group_gloo: torch.distributed.ProcessGroup,
data_parallel_group_idx: int,
):
assert (
config.optimizer_offload_auto_threshold % (1024**2) == 0
and config.optimizer_offload_auto_threshold > 0
), "auto offload threshold should be divided by 2**20"
assert 0 <= config.optimizer_offload_fraction <= 1, "Offload fraction should be in [0, 1] !"
assert config.optimizer_offload_policy in [
'static',
'auto',
], "Only support static or auto placement policy!"
self.optimizer_offload_fraction = config.optimizer_offload_fraction
self.optimizer_offload_auto_threshold: int = config.optimizer_offload_auto_threshold
self.policy = config.optimizer_offload_policy
assert isinstance(
optimizer, CPUAdam
), "Only CPUAdam currently supported, due to checkpointing requirements."
super().__init__(
optimizer,
config,
grad_scaler,
init_state_fn,
per_model_buffers,
data_parallel_group,
data_parallel_group_gloo,
data_parallel_group_idx,
)
# In bf16 model training
self.grad_dtype_in_buffer = None
for _, buffers in per_model_buffers.items():
for buffer in buffers:
if self.grad_dtype_in_buffer is not None:
assert (
buffer.grad_dtype == self.grad_dtype_in_buffer
), "Currently only support consistent grad dtype!"
self.grad_dtype_in_buffer = buffer.grad_dtype
self.chunk_manager = ChunkManager(
chunk_size=(
config.optimizer_offload_chunk_size
if config.optimizer_offload_chunk_size > 0
else ChunkManager.find_best_chunk_size(
self.collect_shard_param_numel(
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges
),
512, # NOTE: search chunk size in [32MB, 544MB]
)
),
init_device='cpu',
is_fp32_grad=self.grad_dtype_in_buffer == torch.float32,
)
# NOTE: Allocate main param shards, all buffer will be on cpu.
(
self.model_float16_groups,
self.model_fp32_groups,
self.shard_float16_groups,
self.shard_fp32_groups,
self.shard_fp32_from_float16_groups,
self.shard_fp32_from_float32_groups,
) = self._build_model_and_main_param_groups_actual(
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges
)
self.chunk_manager.close_all_groups()
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
self.optimizer.load_state_dict(self.optimizer.state_dict())
# NOTE: alloc grad buffer for each parameter
self.chunk_manager.create_grads()
# NOTE: also alloc Adam states for each parameter
exp_avg = self.chunk_manager.alloc_paired_tensors(torch.float32)
exp_avg_sq = self.chunk_manager.alloc_paired_tensors(torch.float32)
for t, chunk_list in self.chunk_manager.paired_chunk_map.items():
assert len(chunk_list) == 2
for group in self.optimizer.param_groups:
for _, p in enumerate(group["params"]):
state = self.state[p]
assert len(state) == 0
state["step"] = 0
# gradient momentums
state["exp_avg"] = exp_avg[p]
# gradient variances
state["exp_avg_sq"] = exp_avg_sq[p]
self.optimizer._post_state_init(p)
if self.policy == 'static':
# NOTE: select partial chunks to GPU
total_memory = self.chunk_manager.total_mem['cpu']
budget = round((1 - self.optimizer_offload_fraction) * total_memory)
if budget > 0:
for _, chunks in self.chunk_manager.chunk_groups.items():
for chunk in chunks:
self.chunk_manager.move_chunk(chunk, torch.cuda.current_device(), True)
if self.chunk_manager.total_mem['cuda'] >= budget:
break
if self.chunk_manager.total_mem['cuda'] >= budget:
break
# Total: (2 + 4 + 4) = 10M or (2 + 4 + 4 + 4) = 14M [if an extra fp32 grad chunk is required]
print('After initialization, parameter chunks use mem: ', self.chunk_manager.total_mem)
def zero_grad(self, set_to_none=True):
"""
Zeroes grads for the model related parameters, i.e., model_float16_groups
and model_fp32_groups. We additionally zero the remaining groups as a
memory optimization to reduce fragmentation; in the case of
set_to_none==True, the space used by this field can be safely deallocated.
Args:
set_to_none (bool): if true, set grads to None.
"""
from .optimizer import (
_zero_grad_group_helper,
)
for groups in (
self.model_float16_groups,
self.model_fp32_groups,
self.shard_float16_groups, # grad empty/unused here?
self.shard_fp32_groups, # throws grad-access warning
self.shard_fp32_from_float16_groups,
self.shard_fp32_from_float32_groups,
):
for group in groups:
_zero_grad_group_helper(group, set_to_none=set_to_none)
# If overlapping param all-gather with forward compute, launch all-gather
# for first accessed bucket here before forward compute is initiated.
# The all-gather for the next bucket will be launched in the forward
# pre-hook when this all-gather finishes (to ensure that the communication
# kernels don't head-of-line block the compute kernels since we run with
# CUDA_DEVICE_MAX_CONNECTIONS=1 to support sequence parallelism).
if self.overlap_param_gather:
self._dispatch_gather_model_params(all_gather_handle_index=0)
def _get_model_and_main_params_data_float32(self):
"""
Get aligned list of model and main params.
"""
model_data = []
main_data = []
for model_group, main_group in zip(
self.shard_float16_groups, self.shard_fp32_from_float32_groups
):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _collect_grads(self):
shard_main_param_id_to_shard_main_grad_mapping = {}
shard_main_grads = []
# Utility method for copying group grads.
def collect_group_grads(model_groups, shard_main_groups):
for model_group, shard_main_group in zip(model_groups, shard_main_groups):
for model_param, shard_main_param in zip(model_group, shard_main_group):
param_range_map = self._get_model_param_range_map(model_param)
param_range = param_range_map["param"]
assert param_range.size == shard_main_param.nelement()
model_grad = model_param.main_grad
shard_model_grad = model_grad.view(-1)[param_range.start : param_range.end]
shard_main_grads.append(shard_model_grad.float())
shard_main_param_id_to_shard_main_grad_mapping[id(shard_main_param)] = (
shard_main_grads[-1]
)
# Copy model groups to shard groups.
collect_group_grads(self.model_float16_groups, self.shard_fp32_from_float16_groups)
collect_group_grads(self.model_fp32_groups, self.shard_fp32_from_float32_groups)
return shard_main_grads, shard_main_param_id_to_shard_main_grad_mapping
def _dispatch_grads(self, params, main_param_id_to_main_grad_mapping):
if params is None:
params = self.get_parameters()
for param in params:
if id(param) in main_param_id_to_main_grad_mapping:
if param.grad is None:
param.grad = main_param_id_to_main_grad_mapping[id(param)].to(
param.device, non_blocking=True
)
else:
param.grad.data.copy_(main_param_id_to_main_grad_mapping[id(param)])
def _copy_main_params_to_model_params(self):
"""
Copy main params to model params.
Since this step is followed by an all-gather through the DDP's grad
buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer.
"""
# Utility method for copying group params.
def copy_group_params(shard_main_groups, model_groups):
for shard_main_group, model_group in zip(shard_main_groups, model_groups):
for shard_main_param, model_param in zip(shard_main_group, model_group):
param_range_map = self._get_model_param_range_map(model_param)
world_range = param_range_map["gbuf_world_in_bucket"]
assert world_range.size == shard_main_param.nelement()
gbuf_index, _, bucket_id = self.model_param_gbuf_map[model_param]
model_param_buffer = self.buffers[gbuf_index].buckets[bucket_id].param_data
shard_model_param = model_param_buffer.view(-1)[
world_range.start : world_range.end
]
shard_model_param.data.copy_(shard_main_param)
# Copy shard groups to model groups.
copy_group_params(self.shard_fp32_from_float16_groups, self.model_float16_groups)
copy_group_params(self.shard_fp32_from_float32_groups, self.model_fp32_groups)
def _copy_model_params_to_main_params(self):
"""
Copy model params to main params.
During finetuning, this method is used to reload the main params from
the model params. This copy does not make use of the grad buffer as
an intermediary.
"""
# Utility method for copying group params.
def copy_group_params(model_groups, shard_main_groups):
for model_group, shard_main_group in zip(model_groups, shard_main_groups):
for model_param, shard_main_param in zip(model_group, shard_main_group):
param_range_map = self._get_model_param_range_map(model_param)
param_range = param_range_map["param"]
assert param_range.size == shard_main_param.nelement()
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
shard_main_param.data.copy_(shard_model_param)
# Copy model groups to shard groups.
copy_group_params(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_params(self.model_fp32_groups, self.shard_fp32_from_float32_groups)
@torch.no_grad()
def step_with_ready_grads(self) -> bool:
"""Step the optimizer with ready gradients, return successful.
Under the hood, either launch synchronous param all-gathers or get ready to launch
asynchorous all-gathers that get overlapped with the next forward pass.
"""
self.update_successful = super().step_with_ready_grads()
timers = self.config.timers
if timers is not None:
timers('params-all-gather', log_level=1).start(barrier=self.config.barrier_with_L1_time)
# If not overlapping all-gather for parameters, launch synchronous all-gather
# communication calls here. If overlapping all-gather for parameters, the following
# call to _gather_all_model_params is a no-op: the first all-gather is launched
# asynchronously in the next optimizer.zero_grad() call and subsequent all-gathers
# are launched in the forward pre-hook.
self._reset_metadata_and_sync_gather_all_model_params(force_sync=False)
if timers is not None:
timers('params-all-gather').stop()
return self.update_successful
def update_layout(self, mem_stats: MemStats = None, threshold: int = None):
if mem_stats is None:
return
if threshold is None:
threshold = self.optimizer_offload_auto_threshold
# NOTE: assume in optimizer.step(), we need less non-model data
# than forward-backward step, therefore make
# [chunk mem in CUDA] + threshold <= available space
model_data = mem_stats._prev_md_cuda
chunk_mem = self.chunk_manager.total_mem['cuda']
non_model_data = mem_stats.max_non_model_data('cuda')
current_usage = torch.cuda.memory_reserved() - model_data - non_model_data
available_space = torch.cuda.mem_get_info()[0] + current_usage - threshold
# NOTE: small chunks are preferred to being moved.
# We find this strategy is more stable than random select,
if available_space < 0:
for _, chunk_group in self.chunk_manager.chunk_groups.items():
for chunk in chunk_group:
if chunk.device_type == 'cpu':
continue
released_mem = self.chunk_manager.calc_size_in_device(chunk, 'cuda')
self.chunk_manager.move_chunk(chunk, 'cpu', async_move=False)
available_space += released_mem
if available_space >= 0:
break
if available_space >= 0:
break
# otherwise try to move chunk to CUDA without violating memory constraints
chunk_and_its_size = []
for _, chunk_group in self.chunk_manager.chunk_groups.items():
for chunk in chunk_group:
if chunk.device_type == 'cuda':
continue
required_mem = self.chunk_manager.calc_size_in_device(
chunk, 'cuda')
chunk_and_its_size.append(
(chunk, required_mem)
)
chunk_and_its_size.sort(key=lambda x: x[1])
for chunk, required_mem in chunk_and_its_size:
if required_mem < available_space:
self.chunk_manager.move_chunk(
chunk, torch.cuda.current_device()
)
available_space -= required_mem
def prepare_grads(self) -> bool:
timers = self.config.timers
if timers is not None:
timers('optimizer-update-layout', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
if self.policy == 'auto':
self.update_layout(self._mem_stats)
self._mem_stats = None
if timers is not None:
timers('optimizer-update-layout').stop()
(
self._main_grads,
self._main_param_id_to_main_grad_mapping
) = self._collect_grads()
# 2. unscale / check inf
# Reset found inf.
if self.grad_scaler:
if timers is not None:
timers('optimizer-unscale-and-check-inf', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
self._main_grads, self.found_inf, self.grad_scaler.inv_scale
)
# Update across all model parallel instances.
torch.distributed.all_reduce(
self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=self.get_model_parallel_group(),
)
# Check for nan.
found_inf_flag = self.found_inf.item() > 0
if timers is not None:
timers('optimizer-unscale-and-check-inf').stop()
if found_inf_flag:
self._main_grads = None
self._main_param_id_to_main_grad_mapping = None
return found_inf_flag
return False
def get_main_grads_for_grad_norm(self):
main_param_id_to_main_grad_mapping = \
self._main_param_id_to_main_grad_mapping
params = self.get_parameters()
grads_for_norm = []
for param in params:
# O(n) to O(n^2)
if id(param) not in main_param_id_to_main_grad_mapping:
continue
grad = main_param_id_to_main_grad_mapping[id(param)]
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
return grads_for_norm
def clip_grad_norm(self, clip_grad: float) -> float:
grads_for_norm = self.get_main_grads_for_grad_norm()
total_norm = get_grad_norm_fp32(
grads_for_norm, model_parallel_group=self.get_model_parallel_group()
)
from .optimizer import (
multi_tensor_applier,
multi_tensor_scale_impl,
)
# Grads.
grads = []
for g in self._main_grads:
assert g.type() == 'torch.cuda.FloatTensor'
grads.append(g.detach())
# Scale.
clip_coeff = clip_grad / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
multi_tensor_applier(
multi_tensor_scale_impl, dummy_overflow_buf, [grads, grads], clip_coeff
)
return total_norm
def count_zeros(self) -> float:
main_param_id_to_main_grad_mapping = \
self._main_param_id_to_main_grad_mapping
params = self.get_parameters()
total_num_zeros = torch.tensor([0.0], dtype=torch.float, device='cuda')
for param in params:
# O(n) to O(n^2)
if id(param) not in main_param_id_to_main_grad_mapping:
continue
grad = main_param_id_to_main_grad_mapping[id(param)]
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
grad = grad.detach()
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(
total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=self.get_model_parallel_group()
)
total_num_zeros = total_num_zeros.item()
return total_num_zeros
def step_with_ready_grads(self):
timers = self.config.timers
if timers is not None:
timers('optimizer-copy-grad-to-cpu-and-gpu', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
# 4. move these grads to CPU
self.chunk_manager.attach_grad()
params = self.get_parameters()
self._dispatch_grads(
params, self._main_param_id_to_main_grad_mapping
)
self._main_param_id_to_main_grad_mapping = None
self._main_grads = None
if timers is not None:
timers('optimizer-copy-grad-to-cpu-and-gpu').stop()
return super().step_with_ready_grads()
def step(self, mem_stats=None):
self._mem_stats = mem_stats
return super().step()
\ No newline at end of file
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron optimizer."""
import copy
import math
from abc import ABC, abstractmethod
from itertools import chain
from logging import getLogger
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
try:
from transformer_engine.pytorch.optimizers import multi_tensor_applier, multi_tensor_scale
multi_tensor_scale_impl = multi_tensor_scale
except ImportError:
try:
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
from megatron.core.utils import local_multi_tensor_applier
multi_tensor_applier = local_multi_tensor_applier
try:
import amp_C
l2_norm_impl = amp_C.multi_tensor_l2norm
multi_tensor_scale_impl = amp_C.multi_tensor_scale
except ImportError:
from megatron.core.utils import local_multi_tensor_l2_norm, local_multi_tensor_scale
l2_norm_impl = local_multi_tensor_l2_norm
multi_tensor_scale_impl = local_multi_tensor_scale
from .. import parallel_state, tensor_parallel
from ..dist_checkpointing.mapping import ShardedStateDict
from ..dist_checkpointing.optimizer import (
get_param_id_to_sharded_param_map,
make_sharded_optimizer_tensor,
optim_state_to_sharding_state,
)
from ..dist_checkpointing.utils import add_prefix_for_sharding
from ..transformer.module import param_is_not_shared
from .clip_grads import clip_grad_by_total_norm_fp32, count_zeros_fp32, get_grad_norm_fp32
from .grad_scaler import MegatronGradScaler
from .optimizer_config import OptimizerConfig
logger = getLogger(__name__)
def _zero_grad_group_helper(group: List[torch.nn.Parameter], set_to_none: bool):
"""
Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer.
"""
for param in group:
if param.grad is not None:
if set_to_none:
param.grad = None
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
def _multi_tensor_copy_this_to_that(
this: List[torch.Tensor], that: List[torch.Tensor], overflow_buf: Optional[torch.Tensor] = None
):
"""
Use multi-tensor-applier to copy values from one list to another.
We don't have a bfloat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16.
"""
if overflow_buf:
overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(multi_tensor_scale_impl, overflow_buf, [this, that], 1.0)
else:
for this_, that_ in zip(this, that):
that_.copy_(this_)
class MegatronOptimizer(ABC):
"""
Base class for all Megatron optimizers.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
config: OptimizerConfig,
init_state_fn: Callable = lambda x: None,
):
"""Input optimizer is the base optimizer (e.g., Adam)."""
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
self.config = config
self.init_state_fn = init_state_fn
def get_parameters(self) -> List[torch.nn.Parameter]:
"""
Get list of parameters wrapped in optimizer.
"""
params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
return params
def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]:
"""
Get main_grads that should be taken into account to compute the grad norm.
Filter parameters based on:
- grad should not be None.
- parameter should not be shared (i.e., grads shouldn't be double counted while
computing norms).
- should not be a replica due to tensor model parallelism.
"""
params = self.get_parameters()
grads_for_norm = []
for param in params:
grad = param.grad
grad_not_none = grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
return grads_for_norm
def get_model_parallel_group(self) -> torch.distributed.ProcessGroup:
"""Default returned here, but the distributed optimizer overrides this."""
if hasattr(self, 'model_parallel_group'):
return self.model_parallel_group
return parallel_state.get_model_parallel_group()
@abstractmethod
def prepare_grads(self) -> bool:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
return False
@abstractmethod
def step_with_ready_grads(self) -> bool:
"""Step the optimizer with ready gradients, return successful."""
return True
@torch.no_grad()
def get_grad_norm(self):
grads_for_norm = self.get_main_grads_for_grad_norm()
total_norm = get_grad_norm_fp32(
grads_for_norm,
model_parallel_group=self.get_model_parallel_group(),
)
return total_norm
def clip_grad_norm(self, clip_grad: float) -> float:
"""Compute grad norm."""
params = self.get_parameters()
grads_for_norm = self.get_main_grads_for_grad_norm()
grad_norm = get_grad_norm_fp32(
grads_for_norm, model_parallel_group=self.get_model_parallel_group()
)
clip_grad_by_total_norm_fp32(params, clip_grad, grad_norm)
return grad_norm
def count_zeros(self) -> float:
"""Count number of zeros in model's gradients."""
params = self.get_parameters()
return count_zeros_fp32(params, model_parallel_group=self.get_model_parallel_group())
@abstractmethod
def zero_grad(self, set_to_none: bool = True):
pass
@abstractmethod
def get_loss_scale(self) -> torch.Tensor:
"""
Get current loss scale factor.
NOTE: The output should be a CUDA tensor of size 1.
"""
pass
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
"""Simple scaling."""
return self.get_loss_scale() * loss
def finish_param_sync(self, model_index: int):
"""
Finish parameter synchronization for all optimizers.
This is a no-op for all non-distributed optimizers.
"""
pass
@abstractmethod
def reload_model_params(self):
"""Refreshes any internal state from the current model parameters.
Call whenever the parameters are changed outside of the optimizer.
For example, when we load a model from a checkpoint without loading
the optimizer, the model parameters are updated but for fp16 optimizer
with main parameters, the main parameters need to also be updated."""
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
@abstractmethod
def step(self):
"""Step the optimizer."""
pass
@abstractmethod
def sharded_state_dict(
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False
) -> ShardedStateDict:
"""Builds sharded state dict for the optimizer, based on model's sharded state dict.
Args:
model_sharded_state_dict (ShardedStateDict): sharded state dict of the model
is_loading (bool, optional): flag indicating whether the state dict will be used to save or load the optimizer state.
Defaults to False.
Returns: optimizer sharded state dict
"""
@staticmethod
def _extract_common_per_param_step(state_dict) -> Union[int, torch.Tensor]:
common_step = None
for param_idx, param_state in state_dict['state'].items():
param_step = param_state.get('step', None)
if param_step is not None:
if common_step is None:
common_step = param_step
elif common_step != param_step:
raise ValueError(
"The optimizer step differs per parameter. Mcore only supports "
"optimizers whose step is shared across all parameters."
)
return common_step
@staticmethod
def _restore_common_per_param_step(state_dict: Dict, step: Union[int, torch.Tensor]):
for param_idx, param_state in state_dict['state'].items():
param_state['step'] = copy.deepcopy(step)
class MixedPrecisionOptimizer(MegatronOptimizer):
"""Base class for both the float-16 and the distributed optimizer.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
grad_scaler (MegatronGradScaler): used for scaling gradients. Note that
this can be None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constant gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
config: OptimizerConfig,
grad_scaler: Optional[MegatronGradScaler],
init_state_fn: Callable,
):
super().__init__(
optimizer,
config,
init_state_fn,
)
self.grad_scaler = grad_scaler
# None grad scaler is only supported for bf16.
if self.grad_scaler is None:
assert not self.config.fp16, 'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if self.grad_scaler:
self.found_inf = torch.tensor([0.0], dtype=torch.float, device='cuda')
# Dummy tensor needed for apex multi-apply tensor.
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if self.config.bf16:
self._dummy_overflow_buf = None
else:
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
# In case grad scaler is not passed, define the unity scale.
if self.grad_scaler is None:
self._scale_one = torch.tensor([1.0], dtype=torch.float, device='cuda')
def get_loss_scale(self):
if self.grad_scaler is None:
return self._scale_one
return self.grad_scaler.scale
def reload_model_params(self):
self._copy_model_params_to_main_params()
def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads.
main_grads = self._collect_main_grad_data_for_unscaling()
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale
)
# Update across all model parallel instances.
torch.distributed.all_reduce(
self.found_inf, op=torch.distributed.ReduceOp.MAX, group=self.get_model_parallel_group()
)
# Check for nan.
found_inf_flag = self.found_inf.item() > 0
return found_inf_flag
@torch.no_grad()
def prepare_grads(self) -> bool:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
timers = self.config.timers
# Copy gradients from model params to main params.
if timers is not None:
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self._copy_model_grads_to_main_grads()
if timers is not None:
timers('optimizer-copy-to-main-grad').stop()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
# Unscale and check for inf/nan.
if timers is not None:
timers('optimizer-unscale-and-check-inf', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
if timers is not None:
timers('optimizer-unscale-and-check-inf').stop()
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
return found_inf_flag
return False
@torch.no_grad()
def step_with_ready_grads(self) -> bool:
"""Step the optimizer with ready gradients, return successful."""
timers = self.config.timers
# Step the optimizer.
if timers is not None:
timers('optimizer-inner-step', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self.optimizer.step()
if timers is not None:
timers('optimizer-inner-step').stop()
# Update params from main params.
if timers is not None:
timers('optimizer-copy-main-to-model-params', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self._copy_main_params_to_model_params()
if timers is not None:
timers('optimizer-copy-main-to-model-params').stop()
return True
@torch.no_grad()
def step(self):
timers = self.config.timers
found_inf_flag = self.prepare_grads()
if found_inf_flag:
return False, None, None
# Clip the main gradients.
if timers is not None:
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
grad_norm = None
if self.config.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.config.clip_grad)
if timers is not None:
timers('optimizer-clip-main-grad').stop()
# Count the zeros in the grads.
if timers is not None:
timers('optimizer-count-zeros', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else None
if timers is not None:
timers('optimizer-count-zeros').stop()
success = self.step_with_ready_grads()
# Successful update.
return success, grad_norm, num_zeros_in_grad
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
grad_scaler (MegatronGradScaler): used for scaling gradients. Note that
this can be None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constant gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
config: OptimizerConfig,
grad_scaler: MegatronGradScaler,
init_state_fn: Callable,
):
super().__init__(
optimizer,
config,
grad_scaler,
init_state_fn,
)
# Handle main parameters.
# Three groups of parameters:
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self.float16_groups = []
self.fp32_from_float16_groups = []
self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups:
float16_params_this_group = []
fp32_params_this_group = []
fp32_from_float16_params_this_group = []
# For all the parameters in this group:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
# float16 params:
if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
float16_params_this_group.append(param)
# Create a copy
main_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
tensor_parallel.copy_tensor_model_parallel_attributes(main_param, param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
self.optimizer.state[main_param] = self.optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type())
)
self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups. We additionally zero
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none)
def _collect_main_grad_data_for_unscaling(self):
main_grads = []
# fp32 params from float16 ones.
for main_group in self.fp32_from_float16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
return main_grads
def _get_model_and_main_params_data_float16(self):
model_data = []
main_data = []
for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
if hasattr(model_param, 'main_grad'):
main_param.grad = model_param.main_grad.float()
else:
if model_param.grad is not None:
main_param.grad = model_param.grad.float()
# Safe to deallocate model's grad/main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
model_param.grad = None
# For fp32 grads, we need to reset the grads to main grad.
for model_group in self.fp32_from_fp32_groups:
for model_param in model_group:
model_param.grad = model_param.main_grad
def _copy_main_params_to_model_params(self):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(
this=main_data, that=model_data, overflow_buf=self._dummy_overflow_buf
)
def _copy_model_params_to_main_params(self):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(
this=model_data, that=main_data, overflow_buf=self._dummy_overflow_buf
)
def state_dict(self):
state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
return state_dict
def sharded_state_dict(
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False
):
if is_loading:
self.init_state_fn(self.optimizer)
state_dict = self.state_dict()
id_to_sharded_param_map = get_param_id_to_sharded_param_map(
model_sharded_state_dict, chain.from_iterable(g for g in self.float16_groups)
)
# Convert fp32_from_fp16_params
assert len(state_dict['fp32_from_fp16_params']) == len(
state_dict['optimizer']['param_groups']
)
state_dict['fp32_from_fp16_params'] = [
[
make_sharded_optimizer_tensor(
id_to_sharded_param_map[param_id],
fp32_param,
prefix=f'optimizer.state.fp32_param',
)
for param_id, fp32_param in zip(state_group['params'], fp32_group)
]
for fp32_group, state_group in zip(
state_dict['fp32_from_fp16_params'], state_dict['optimizer']['param_groups']
)
]
step = self._extract_common_per_param_step(state_dict['optimizer'])
# Convert regular optimizer state
# all optimizer parameters passed to optim_state_to_sharding_state are
# expected to have the same shape as the model parameters,
# so we save the step separately and ignore it here
optim_state_to_sharding_state(
state_dict['optimizer'], id_to_sharded_param_map, exclude_keys="step"
)
# save step as a shared step among all parameters. Separate per-parameter
# steps are not supported
state_dict['optimizer']['state']['common_step'] = step
return state_dict
def load_state_dict(self, state_dict):
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
# Optimizer.
optimizer_key = 'optimizer'
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
logger.info('***WARNING*** loading optimizer from ' 'an old checkpoint ...')
if 'common_step' in state_dict[optimizer_key]['state']:
common_step = state_dict[optimizer_key]['state'].pop('common_step')
self._restore_common_per_param_step(state_dict[optimizer_key], common_step)
self.optimizer.load_state_dict(state_dict[optimizer_key])
# Grad scaler.
if 'grad_scaler' not in state_dict:
if self.config.fp16:
logger.info(
'***WARNING*** found an old checkpoint, will not ' 'load grad scaler ...'
)
else:
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
logger.info(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
fp32_from_float16_params_key = 'fp32_from_fp16_params'
if fp32_from_float16_params_key not in state_dict:
fp32_from_float16_params_key = 'fp32_from_fp16'
for current_group, saved_group in zip(
self.fp32_from_float16_groups, state_dict[fp32_from_float16_params_key]
):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
class FP32Optimizer(MegatronOptimizer):
"""Float32 optimizer.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
config: OptimizerConfig,
init_state_fn: Callable,
):
super(FP32Optimizer, self).__init__(
optimizer,
config,
init_state_fn,
)
self._scale = torch.tensor([1.0], dtype=torch.float, device='cuda')
def zero_grad(self, set_to_none=True):
"""Copied from torch.optim.optimizer"""
for group in self.optimizer.param_groups:
_zero_grad_group_helper(group['params'], set_to_none)
def get_loss_scale(self):
"""FP32 optimizer does not do any scaling."""
return self._scale
@torch.no_grad()
def prepare_grads(self) -> bool:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
timers = self.config.timers
# Copy main_grads to grads.
if timers is not None:
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
param.grad = param.main_grad
if timers is not None:
timers('optimizer-copy-to-main-grad').stop()
return False
@torch.no_grad()
def step_with_ready_grads(self) -> bool:
"""Step the optimizer with ready gradients, return successful."""
timers = self.config.timers
# Update parameters.
if timers is not None:
timers('optimizer-inner-step', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self.optimizer.step()
if timers is not None:
timers('optimizer-inner-step').stop()
return True
@torch.no_grad()
def step(self):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
timers = self.config.timers
found_inf_flag = self.prepare_grads()
if found_inf_flag:
return False, None, None
# Clip gradients.
if timers is not None:
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
grad_norm = None
if self.config.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.config.clip_grad)
if timers is not None:
timers('optimizer-clip-main-grad').stop()
# Count the zeros in the grads.
if timers is not None:
timers('optimizer-count-zeros', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else None
if timers is not None:
timers('optimizer-count-zeros').stop()
success = self.step_with_ready_grads()
# No overflow for FP32 optimizer.
return success, grad_norm, num_zeros_in_grad
def reload_model_params(self):
pass
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
if 'common_step' in state_dict['state']:
common_step = state_dict['state'].pop('common_step')
self._restore_common_per_param_step(state_dict, common_step)
self.optimizer.load_state_dict(state_dict)
def sharded_state_dict(
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False
):
if is_loading:
self.init_state_fn(self.optimizer)
state_dict = self.state_dict()
id_to_sharded_param_map = get_param_id_to_sharded_param_map(
model_sharded_state_dict, self.get_parameters()
)
step = self._extract_common_per_param_step(state_dict)
# all optimizer parameters passed to optim_state_to_sharding_state are
# expected to have the same shape as the model parameters,
# so we save the step separately and ignore it here
optim_state_to_sharding_state(state_dict, id_to_sharded_param_map, exclude_keys="step")
# save step as a shared step among all parameters. Separate per-parameter
# steps are not supported
state_dict['state']['common_step'] = step
return state_dict
class ProxyDict:
"""
A dictionary-like object that proxies to a list of dictionaries.
e.g., ProxyDict([{'a': 1}, {'b': 2}]) behaves like:
{
(0, 'a'): 1,
(1, 'b'): 2,
}
We use tuples as keys to avoid ambiguity with the keys of the inner dicts.
"""
def __init__(self, inner_dicts: List[dict]):
self._inner_dicts = inner_dicts
def __getitem__(self, key: Tuple[int, str]):
idx, inner_key = key
return self._inner_dicts[idx].get(inner_key)
def __setitem__(self, key: Tuple[int, str], value: Any):
idx, inner_key = key
self._inner_dicts[idx][inner_key] = value
def __len__(self) -> int:
return sum([len(inner_dict) for inner_dict in self._inner_dicts])
def __iter__(self):
for idx, inner_dict in enumerate(self._inner_dicts):
for inner_key in inner_dict:
yield (idx, inner_key)
def items(self):
for idx, inner_dict in enumerate(self._inner_dicts):
for inner_key, value in inner_dict.items():
yield (idx, inner_key), value
class ChainedOptimizer(MegatronOptimizer):
"""ChainedOptimizer is designed for a collection of optimizers.
These optimizers are responsible for different parts of multiple models for
a training task and will be executed one-by-one when the model is updated.
Args:
chained_optimizers: a list of optimizers.
"""
def __init__(self, chained_optimizers: List[MegatronOptimizer]):
self.chained_optimizers = chained_optimizers
@property
def param_groups(self) -> List[dict]:
param_groups = []
for optimizer in self.chained_optimizers:
param_groups += optimizer.param_groups
return param_groups
@property
def state(self) -> ProxyDict:
"""
Return optimizer state with tuple keys, where the first element is the
index of the optimizer in the list of chained optimizers.
"""
return ProxyDict([opt.state for opt in self.chained_optimizers])
def zero_grad(self, set_to_none=True):
for optimizer in self.chained_optimizers:
optimizer.zero_grad(set_to_none)
def get_loss_scale(self):
return self.chained_optimizers[0].get_loss_scale()
def reload_model_params(self):
for optimizer in self.chained_optimizers:
optimizer.reload_model_params()
def state_dict(self):
return [optimizer.state_dict() for optimizer in self.chained_optimizers]
def sharded_state_dict(
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False, **kwargs
):
sharded_state_dict = {}
for optimizer_idx, optimizer in enumerate(self.chained_optimizers):
optim_state_dict = optimizer.sharded_state_dict(
model_sharded_state_dict, is_loading, **kwargs
)
add_prefix_for_sharding(optim_state_dict, f'chained_{optimizer_idx}.')
sharded_state_dict[optimizer_idx] = optim_state_dict
return sharded_state_dict
def load_state_dict(self, state_dict):
if len(self.chained_optimizers) != len(state_dict):
raise RuntimeError(
f'Expected {len(self.chained_optimizers)} entries'
f' in state dict, but got {len(state_dict)}.'
)
if isinstance(state_dict, dict):
state_dict = (v for k, v in sorted(state_dict.items()))
for optimizer, state in zip(self.chained_optimizers, state_dict):
optimizer.load_state_dict(state)
@torch.no_grad()
def prepare_grads(self) -> bool:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
found_inf_flag = False
for optimizer in self.chained_optimizers:
found_inf_flag |= optimizer.prepare_grads()
return found_inf_flag
@torch.no_grad()
def step_with_ready_grads(self) -> bool:
"""Step the optimizer with ready gradients, return successful."""
success = True
for optimizer in self.chained_optimizers:
success &= optimizer.step_with_ready_grads()
return success
def disable_pre_hook(self):
for optimizer in self.chained_optimizers:
if (
not optimizer.config.use_distributed_optimizer
or not optimizer.config.overlap_param_gather
):
raise ValueError(
"disable_pre_hook should only be called with 'use_distributed_optimizer' "
"and 'overlap_param_gather' both enabled."
)
optimizer.disable_pre_hook()
def enable_pre_hook(self):
for optimizer in self.chained_optimizers:
if (
not optimizer.config.use_distributed_optimizer
or not optimizer.config.overlap_param_gather
):
raise ValueError(
"enable_pre_hook should only be called with 'use_distributed_optimizer' "
"and 'overlap_param_gather' both enabled."
)
optimizer.enable_pre_hook()
@torch.no_grad()
def step(self):
"""ChainedOptimizer will step all optimizers one by one."""
found_inf_flag = self.prepare_grads()
if found_inf_flag:
return False, None, None
# Get grad norm.
grad_norms = []
for optimizer in self.chained_optimizers:
_grad_norm = optimizer.get_grad_norm()
grad_norms += [_grad_norm if _grad_norm else 0.0]
grad_norm = math.sqrt(sum([x**2 for x in grad_norms]))
# Clip gradients.
for optimizer in self.chained_optimizers:
if optimizer.config.clip_grad > 0.0:
clip_grad_by_total_norm_fp32(
optimizer.get_parameters(),
max_norm=optimizer.config.clip_grad,
total_norm=grad_norm,
)
# Count the zeros in the grads.
num_zeros_in_grad = 0
for optimizer in self.chained_optimizers:
num_zeros_in_grad += (
optimizer.count_zeros() if optimizer.config.log_num_zeros_in_grad else 0
)
update_successful = self.step_with_ready_grads()
return update_successful, grad_norm, num_zeros_in_grad
def save_parameter_state(self, filename: str):
"""Save the distributed parameter states of all optimizers to a file.
Args:
filename (str): path to save parameter state to.
"""
save_states = False
states = []
for optimizer in self.chained_optimizers:
if hasattr(optimizer, 'get_parameter_state_dp_zero'):
state_dict = optimizer.get_parameter_state_dp_zero()
# Save checkpoint economically, only when DP rank = 0, state dict
# needs to be saved.
if torch.distributed.get_rank(optimizer.data_parallel_group) == 0:
states.append(state_dict)
save_states = True
else:
states.append(None)
else:
states.append(None)
if save_states:
torch.save(states, filename)
def load_parameter_state(self, filename: str):
"""Load the distributed parameter states of all optimizers from a file.
Args:
filename (str): path to load parameter state from.
"""
states = None
for idx, optimizer in enumerate(self.chained_optimizers):
if not hasattr(optimizer, 'load_parameter_state_from_dp_zero'):
continue
# Lazy loading checkpoint, state dict is needed only when DP rank = 0.
if torch.distributed.get_rank(optimizer.data_parallel_group) == 0 and states is None:
states = torch.load(filename)
state_dict = states[idx] if states else None
optimizer.load_parameter_state_from_dp_zero(state_dict)
def finish_param_sync(self, model_index: int):
"""Finish parameter synchronization for all optimizers."""
for optimizer in self.chained_optimizers:
optimizer.finish_param_sync(model_index)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Callable, Optional
import torch
@dataclass
class OptimizerConfig:
"""Configuration for optimizer."""
##############
# General
##############
optimizer: str = 'adam'
"""Optimizer to use (one of Adam or SGD)."""
lr: Optional[float] = None
"""Initial learning rate. Depending on decay style and initial warmup, the learning rate at each
iteration would be different.
"""
min_lr: Optional[float] = None
"""Minumum value for learning rate. The scheduler clip values below this threshold."""
decoupled_lr: Optional[float] = None
"""Separate learning rate for the input and output layer."""
decoupled_min_lr: Optional[float] = None
"""Minimum value for learning rate for the input and output layer. The scheduler clip values
below this threshold.
"""
weight_decay: float = 0.01
"""Weight decay coefficient for L2 regularization."""
##############
# Precision
##############
fp16: bool = False
"""If true, train with fp16 mixed precision training. Defaults to False."""
bf16: bool = False
"""If true, train with bf16 mixed precision training. Defaults to False."""
params_dtype: torch.dtype = torch.float32
"""dtype used when intializing the weights. Defaults to torch.float32."""
###############
# Loss scaling
###############
loss_scale: Optional[float] = None
"""Static loss scaling, positive power of 2 values can improve fp16 convergence. If None,
dynamic loss scaling is used.
"""
initial_loss_scale: float = 2**32
"""Initial loss-scale for dynamic loss scaling."""
min_loss_scale: float = 1.0
"""Minimum loss scale for dynamic loss scaling."""
loss_scale_window: float = 1000
"""Window over which to raise/lower dynamic scale."""
hysteresis: int = 2
"""Hysteresis for dynamic loss scaling."""
##############
# Optimizer
##############
# Adam
adam_beta1: float = 0.9
"""First coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_beta2: float = 0.999
"""Second coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_eps: float = 1e-08
"""Term added to the denominator to improve numerical stability in Adam optimizer."""
# SGD.
sgd_momentum: float = 0.9
"""Momentum factor for SGD optimizer."""
#######################
# Distributed optimizer
#######################
use_distributed_optimizer: bool = False
"""Distribute optimizer state over data-parallel replicas."""
overlap_grad_reduce: bool = False
"""If true, overlap grad reduce-scatter with backward compute in distributed optimizer."""
overlap_param_gather: bool = False
"""If true, overlap param all-gather with forward compute in distributed optimizer."""
#######################
# Optimizer Offloading
#######################
optimizer_offload_policy: str = 'static'
"""CPU Offload Policy used by OffloadDistributedOptimizer, valid if base optimizer is HybridAdam"""
optimizer_offload_fraction: float = 0.0
"""CPU Offload Fraction used by static offload policy, valid if base optimizer is HybridAdam"""
optimizer_offload_chunk_size: int = 0
"""Chunk Size used by CPU offload Chunk Manager (bytes), automatically search if value is 0 (default)"""
optimizer_offload_auto_threshold: int = 2048 * 1024**2
"""threshold for auto optimizer offload (bytes) should be larger if OOM"""
################
# Miscellaneous
################
clip_grad: float = 1.0
"""Gradient clipping based on global L2 norm."""
log_num_zeros_in_grad: bool = False
"""If true, calculate and log the number of zeros in gradient."""
barrier_with_L1_time: bool = False
"""If true, use barrier with level 1 time measurements."""
timers: Callable = None
"""Function to get timers."""
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
MAJOR = 0
MINOR = 9
PATCH = 0
PRE_RELEASE = 'rc0'
# Use the following formatting: (major, minor, patch, pre-release)
VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
__shortversion__ = '.'.join(map(str, VERSION[:3]))
__version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:])
__package_name__ = 'megatron_core'
__contact_names__ = 'NVIDIA'
__contact_emails__ = 'nemo-toolkit@nvidia.com' # use NeMo Email
__homepage__ = (
'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/' # use NeMo homepage
)
__repository_url__ = 'https://github.com/NVIDIA/Megatron-LM/megatron/core'
__download_url__ = 'https://github.com/NVIDIA/Megatron-LM/releases'
__description__ = (
'Megatron Core - a library for efficient and scalable training of transformer based models'
)
__license__ = 'BSD-3'
__keywords__ = (
'deep learning, machine learning, gpu, NLP, NLU, language, transformer, nvidia, pytorch, torch'
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment