Unverified Commit e5ce4c8e authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support
parent 8d56c9c3
......@@ -8,6 +8,7 @@ import torch
from torch import Tensor
from colossalai.logging import get_dist_logger
from colossalai.utils.device import get_current_device
__all__ = ["BaseGradScaler"]
......@@ -22,7 +23,7 @@ class BaseGradScaler(ABC):
def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
self._scale = torch.cuda.FloatTensor([initial_scale])
self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float)
self._verbose = verbose
if self._verbose:
......
......@@ -5,6 +5,8 @@ from typing import Optional
import torch
from colossalai.utils.device import get_current_device
from .base_grad_scaler import BaseGradScaler
__all__ = ["DynamicGradScaler"]
......@@ -37,12 +39,12 @@ class DynamicGradScaler(BaseGradScaler):
):
super().__init__(initial_scale, verbose)
if min_scale:
self._min_scale = torch.cuda.FloatTensor([min_scale])
self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float)
else:
self._min_scale = None
if max_scale:
self._max_scale = torch.cuda.FloatTensor([max_scale])
self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float)
else:
self._max_scale = None
......@@ -115,7 +117,7 @@ class DynamicGradScaler(BaseGradScaler):
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
self._scale = state_dict["scale"].to(get_current_device())
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"]
......@@ -11,7 +11,7 @@ except:
import torch
from torch.fx.node import Node
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
......
......@@ -25,6 +25,7 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
......@@ -37,6 +38,7 @@ PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.
......@@ -53,6 +55,8 @@ def get_param_info(optim: Optimizer):
start_index += len(group["params"])
return param_info
class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
......@@ -359,6 +363,8 @@ class GeminiPlugin(DPPluginBase):
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if IS_NPU_AVAILABLE:
assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
......@@ -437,7 +443,7 @@ class GeminiPlugin(DPPluginBase):
return True
def supported_devices(self) -> List[str]:
return ["cuda"]
return ["cuda", "npu"]
def configure(
self,
......
......@@ -306,7 +306,7 @@ class LowLevelZeroPlugin(DPPluginBase):
return True
def supported_devices(self) -> List[str]:
return ["cuda"]
return ["cuda", "npu"]
def configure(
self,
......
......@@ -11,7 +11,7 @@ import torch.distributed as dist
from colossalai.context import Config
from colossalai.logging import get_dist_logger
from colossalai.utils import set_device, set_seed
from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed
def launch(
......@@ -47,12 +47,15 @@ def launch(
if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.")
if IS_NPU_AVAILABLE and backend == "nccl":
backend = "hccl"
# init default process group
init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device
if torch.cuda.is_available():
if torch.cuda.is_available() or IS_NPU_AVAILABLE:
# if local rank is not given, calculate automatically
set_device(local_rank)
......
......@@ -142,6 +142,7 @@ class Adam_Optimizer {
}
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
......@@ -159,6 +160,7 @@ class Adam_Optimizer {
SIMD_STORE(ptr, data.data);
}
}
#endif
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
......
#include "cpu_adam_arm.h"
void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4 = vdivq_f32(grad_4, loss_scale_vec);
}
float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i);
float32x4_t variance_4 =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i);
float32x4_t param_4 = simd_load_offset(_params, param_dtype, i);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4);
}
momentum_4 = vmulq_f32(momentum_4, betta1_4);
momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4);
variance_4 = vmulq_f32(variance_4, betta2_4);
grad_4 = vmulq_f32(grad_4, grad_4);
variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4);
grad_4 = vsqrtq_f32(variance_4);
grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt);
grad_4 = vdivq_f32(momentum_4, grad_4);
if (_weight_decay > 0 && _adamw_mode) {
param_4 = vfmaq_f32(param_4, param_4, weight_decay_4);
}
param_4 = vfmaq_f32(param_4, grad_4, step_size_4);
simd_store_offset(_params, param_dtype, param_4, i);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i);
}
}
#endif
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = scalar_load_offset(grads, grad_dtype, k);
if (loss_scale > 0) {
grad /= loss_scale;
}
float param = scalar_load_offset(_params, param_dtype, k);
float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k);
float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k);
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) {
param += w_decay * param;
}
param = grad * step_size + param;
scalar_store_offset(_params, param_dtype, param, k);
scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k);
scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k);
}
}
}
}
void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
float32x4_t grad_4[4];
float32x4_t momentum_4[4];
float32x4_t variance_4[4];
float32x4_t param_4[4];
#pragma unroll 4
for (int j = 0; j < 4; j++) {
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
}
momentum_4[j] =
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
variance_4[j] =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
}
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
grad_4[j] = vsqrtq_f32(variance_4[j]);
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
}
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
i + SIMD_WIDTH * j);
}
}
}
#endif
if (_param_size > rounded_size) {
Step_1(scalar_seek_offset(_params, param_dtype, rounded_size),
scalar_seek_offset(grads, grad_dtype, rounded_size),
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
exp_avg_sq_dtype, loss_scale);
}
}
void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
float32x4_t grad_4[8];
float32x4_t momentum_4[8];
float32x4_t variance_4[8];
float32x4_t param_4[8];
#pragma unroll 4
for (int j = 0; j < 8; j++) {
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
}
momentum_4[j] =
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
variance_4[j] =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
}
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
grad_4[j] = vsqrtq_f32(variance_4[j]);
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
}
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
i + SIMD_WIDTH * j);
}
}
}
#endif
if (_param_size > rounded_size) {
Step_4(scalar_seek_offset(_params, param_dtype, rounded_size),
scalar_seek_offset(grads, grad_dtype, rounded_size),
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
exp_avg_sq_dtype, loss_scale);
}
}
void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay,
bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale) {
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
this->IncrementStep(step, beta1, beta2);
this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(),
exp_avg_sq_c.data_ptr(), params_c.numel(),
params_c.scalar_type(), grads_c.scalar_type(),
exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale);
}
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<AdamOptimizer>(m, "CPUAdamOptimizer")
.def(py::init<float, float, float, float, float, bool>())
.def("step", &AdamOptimizer::step);
}
#pragma once
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <cmath>
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#define TILE (128 * 1024 * 1024)
#if defined(__aarch64__)
#include <arm_neon.h>
#define SIMD_WIDTH 4
inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float: {
auto ptr_f = reinterpret_cast<const float32_t *>(ptr);
return vld1q_f32(ptr_f + offset);
}
case at::ScalarType::Half: {
auto ptr_h = reinterpret_cast<const float16_t *>(ptr);
return vcvt_f32_f16(vld1_f16(ptr_h + offset));
}
// case at::ScalarType::BFloat16: {
// auto ptr_b = reinterpret_cast<const bfloat16_t *>(ptr);
// return vcvt_f32_bf16(vld1_bf16(ptr_b + offset));
// }
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) {
return simd_load_offset(ptr, dtype, 0);
}
inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float: {
auto ptr_f = reinterpret_cast<float32_t *>(ptr);
vst1q_f32(ptr_f + offset, data);
break;
}
case at::ScalarType::Half: {
auto ptr_h = reinterpret_cast<float16_t *>(ptr);
vst1_f16(ptr_h + offset, vcvt_f16_f32(data));
break;
}
// case at::ScalarType::BFloat16: {
// auto ptr_b = reinterpret_cast<bfloat16_t *>(ptr);
// vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data));
// break;
// }
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) {
return simd_store_offset(ptr, dtype, data, 0);
}
inline float32x4_t simd_set(float value) {
auto val = static_cast<float32_t>(value);
return vdupq_n_f32(val);
}
#endif
inline float scalar_load_offset(const void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
return *(reinterpret_cast<const float *>(ptr) + offset);
case at::ScalarType::Half:
return static_cast<float>(
*(reinterpret_cast<const at::Half *>(ptr) + offset));
// case at::ScalarType::BFloat16:
// return static_cast<float>(
// *(reinterpret_cast<const at::BFloat16 *>(ptr) + offset));
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
*(reinterpret_cast<float *>(ptr) + offset) = data;
break;
case at::ScalarType::Half:
*(reinterpret_cast<at::Half *>(ptr) + offset) = data;
break;
// case at::ScalarType::BFloat16:
// *(reinterpret_cast<at::BFloat16 *>(ptr) + offset) = data;
break;
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
return reinterpret_cast<float *>(ptr) + offset;
case at::ScalarType::Half:
return reinterpret_cast<at::Half *>(ptr) + offset;
// case at::ScalarType::BFloat16:
// return reinterpret_cast<at::BFloat16 *>(ptr) + offset;
default:
AT_ERROR("Unsupported dtype");
break;
}
}
#define STEP(SPAN) \
void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \
void *_exp_avg_sq, size_t _param_size, \
at::ScalarType param_dtype, at::ScalarType grad_dtype, \
at::ScalarType exp_avg_dtype, \
at::ScalarType exp_avg_sq_dtype, float loss_scale = -1);
class AdamOptimizer {
private:
float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;
float _betta1_t;
float _betta2_t;
size_t _step;
float _bias_correction1;
float _bias_correction2;
bool _adamw_mode;
public:
AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {}
~AdamOptimizer() {}
STEP(1)
STEP(4)
STEP(8)
inline void IncrementStep(size_t step, float beta1, float beta2) {
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
}
inline void update_state(float lr, float epsilon, float weight_decay,
bool bias_correction) {
_alpha = lr;
_eps = epsilon;
_weight_decay = weight_decay;
_bias_correction1 = 1.0f;
_bias_correction2 = 1.0f;
if (bias_correction == 1) {
_bias_correction1 = 1 - _betta1_t;
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
}
}
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale);
};
......@@ -5,7 +5,7 @@ import torch
import torch.nn.functional as F
from einops import rearrange
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
class Unpad(torch.autograd.Function):
......
......@@ -12,7 +12,7 @@ from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from ._base_schedule import BaseSchedule
......
......@@ -9,7 +9,7 @@ import colossalai.legacy.communication.p2p_v2 as comm
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.engine import Engine
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from ._pipeline_schedule import PipelineSchedule
......
......@@ -22,7 +22,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule
......
......@@ -18,7 +18,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
......
......@@ -19,7 +19,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
......
......@@ -27,7 +27,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (
......
......@@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter
from colossalai.legacy.context import seed
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from ..utils import to_2tuple
......
......@@ -3,7 +3,7 @@ import types
from time import time
from typing import List
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy
......
import math
import platform
from typing import Optional
import torch
from colossalai.kernel.op_builder import CPUAdamBuilder
from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder
from .nvme_optimizer import NVMeOptimizer
......@@ -77,7 +78,7 @@ class CPUAdam(NVMeOptimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
cpu_adam = CPUAdamBuilder().load()
cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load()
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
......
......@@ -84,6 +84,7 @@ class HybridAdam(CPUAdam):
nvme_offload_fraction,
nvme_offload_dir,
)
if torch.cuda.is_available():
fused_optim = FusedOptimBuilder().load()
self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
......@@ -118,11 +119,11 @@ class HybridAdam(CPUAdam):
group_step = state["step"]
beta1, beta2 = group["betas"]
if target_device.type == "cpu":
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
if target_device.type == "cpu" or target_device.type == "npu":
assert state["exp_avg"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
if p.grad.dtype is torch.bfloat16:
if p.grad.dtype is torch.bfloat16 or p.grad.device.type == "npu":
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment