Unverified Commit e5ce4c8e authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support
parent 8d56c9c3
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from torch import Tensor from torch import Tensor
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.device import get_current_device
__all__ = ["BaseGradScaler"] __all__ = ["BaseGradScaler"]
...@@ -22,7 +23,7 @@ class BaseGradScaler(ABC): ...@@ -22,7 +23,7 @@ class BaseGradScaler(ABC):
def __init__(self, initial_scale: float, verbose: bool): def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0 assert initial_scale > 0
self._scale = torch.cuda.FloatTensor([initial_scale]) self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float)
self._verbose = verbose self._verbose = verbose
if self._verbose: if self._verbose:
......
...@@ -5,6 +5,8 @@ from typing import Optional ...@@ -5,6 +5,8 @@ from typing import Optional
import torch import torch
from colossalai.utils.device import get_current_device
from .base_grad_scaler import BaseGradScaler from .base_grad_scaler import BaseGradScaler
__all__ = ["DynamicGradScaler"] __all__ = ["DynamicGradScaler"]
...@@ -37,12 +39,12 @@ class DynamicGradScaler(BaseGradScaler): ...@@ -37,12 +39,12 @@ class DynamicGradScaler(BaseGradScaler):
): ):
super().__init__(initial_scale, verbose) super().__init__(initial_scale, verbose)
if min_scale: if min_scale:
self._min_scale = torch.cuda.FloatTensor([min_scale]) self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float)
else: else:
self._min_scale = None self._min_scale = None
if max_scale: if max_scale:
self._max_scale = torch.cuda.FloatTensor([max_scale]) self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float)
else: else:
self._max_scale = None self._max_scale = None
...@@ -115,7 +117,7 @@ class DynamicGradScaler(BaseGradScaler): ...@@ -115,7 +117,7 @@ class DynamicGradScaler(BaseGradScaler):
return state_dict return state_dict
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self._scale = state_dict["scale"].cuda(torch.cuda.current_device()) self._scale = state_dict["scale"].to(get_current_device())
self._growth_factor = state_dict["growth_factor"] self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"] self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"] self._hysteresis = state_dict["hysteresis"]
...@@ -11,7 +11,7 @@ except: ...@@ -11,7 +11,7 @@ except:
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from .region import Region from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
......
...@@ -25,6 +25,7 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh ...@@ -25,6 +25,7 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
...@@ -37,6 +38,7 @@ PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} ...@@ -37,6 +38,7 @@ PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
def get_param_info(optim: Optimizer): def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes: # Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape. # 1. A mapping from integer param_id to param32 shape.
...@@ -53,6 +55,8 @@ def get_param_info(optim: Optimizer): ...@@ -53,6 +55,8 @@ def get_param_info(optim: Optimizer):
start_index += len(group["params"]) start_index += len(group["params"])
return param_info return param_info
class GeminiCheckpointIO(GeneralCheckpointIO): class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -359,6 +363,8 @@ class GeminiPlugin(DPPluginBase): ...@@ -359,6 +363,8 @@ class GeminiPlugin(DPPluginBase):
) -> None: ) -> None:
super().__init__() super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if IS_NPU_AVAILABLE:
assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict( self.gemini_config = dict(
chunk_config_dict=chunk_config_dict, chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()), chunk_init_device=(chunk_init_device or get_current_device()),
...@@ -437,7 +443,7 @@ class GeminiPlugin(DPPluginBase): ...@@ -437,7 +443,7 @@ class GeminiPlugin(DPPluginBase):
return True return True
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
return ["cuda"] return ["cuda", "npu"]
def configure( def configure(
self, self,
...@@ -485,4 +491,4 @@ class GeminiPlugin(DPPluginBase): ...@@ -485,4 +491,4 @@ class GeminiPlugin(DPPluginBase):
return GeminiCheckpointIO() return GeminiCheckpointIO()
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError raise NotImplementedError
\ No newline at end of file
...@@ -306,7 +306,7 @@ class LowLevelZeroPlugin(DPPluginBase): ...@@ -306,7 +306,7 @@ class LowLevelZeroPlugin(DPPluginBase):
return True return True
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
return ["cuda"] return ["cuda", "npu"]
def configure( def configure(
self, self,
......
...@@ -11,7 +11,7 @@ import torch.distributed as dist ...@@ -11,7 +11,7 @@ import torch.distributed as dist
from colossalai.context import Config from colossalai.context import Config
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import set_device, set_seed from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed
def launch( def launch(
...@@ -47,12 +47,15 @@ def launch( ...@@ -47,12 +47,15 @@ def launch(
if rank == 0: if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.") warnings.warn("`config` is deprecated and will be removed soon.")
if IS_NPU_AVAILABLE and backend == "nccl":
backend = "hccl"
# init default process group # init default process group
init_method = f"tcp://[{host}]:{port}" init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device # set cuda device
if torch.cuda.is_available(): if torch.cuda.is_available() or IS_NPU_AVAILABLE:
# if local rank is not given, calculate automatically # if local rank is not given, calculate automatically
set_device(local_rank) set_device(local_rank)
......
...@@ -142,6 +142,7 @@ class Adam_Optimizer { ...@@ -142,6 +142,7 @@ class Adam_Optimizer {
} }
} }
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
inline void simd_load(bool is_half, float *ptr, __half *h_ptr, inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) { AVX_Data &data) {
if (is_half) { if (is_half) {
...@@ -159,6 +160,7 @@ class Adam_Optimizer { ...@@ -159,6 +160,7 @@ class Adam_Optimizer {
SIMD_STORE(ptr, data.data); SIMD_STORE(ptr, data.data);
} }
} }
#endif
void step(size_t step, float lr, float beta1, float beta2, float epsilon, void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params, float weight_decay, bool bias_correction, torch::Tensor &params,
......
#include "cpu_adam_arm.h"
void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4 = vdivq_f32(grad_4, loss_scale_vec);
}
float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i);
float32x4_t variance_4 =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i);
float32x4_t param_4 = simd_load_offset(_params, param_dtype, i);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4);
}
momentum_4 = vmulq_f32(momentum_4, betta1_4);
momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4);
variance_4 = vmulq_f32(variance_4, betta2_4);
grad_4 = vmulq_f32(grad_4, grad_4);
variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4);
grad_4 = vsqrtq_f32(variance_4);
grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt);
grad_4 = vdivq_f32(momentum_4, grad_4);
if (_weight_decay > 0 && _adamw_mode) {
param_4 = vfmaq_f32(param_4, param_4, weight_decay_4);
}
param_4 = vfmaq_f32(param_4, grad_4, step_size_4);
simd_store_offset(_params, param_dtype, param_4, i);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i);
}
}
#endif
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = scalar_load_offset(grads, grad_dtype, k);
if (loss_scale > 0) {
grad /= loss_scale;
}
float param = scalar_load_offset(_params, param_dtype, k);
float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k);
float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k);
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) {
param += w_decay * param;
}
param = grad * step_size + param;
scalar_store_offset(_params, param_dtype, param, k);
scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k);
scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k);
}
}
}
}
void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
float32x4_t grad_4[4];
float32x4_t momentum_4[4];
float32x4_t variance_4[4];
float32x4_t param_4[4];
#pragma unroll 4
for (int j = 0; j < 4; j++) {
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
}
momentum_4[j] =
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
variance_4[j] =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
}
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
grad_4[j] = vsqrtq_f32(variance_4[j]);
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
}
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
i + SIMD_WIDTH * j);
}
}
}
#endif
if (_param_size > rounded_size) {
Step_1(scalar_seek_offset(_params, param_dtype, rounded_size),
scalar_seek_offset(grads, grad_dtype, rounded_size),
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
exp_avg_sq_dtype, loss_scale);
}
}
void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
float32x4_t grad_4[8];
float32x4_t momentum_4[8];
float32x4_t variance_4[8];
float32x4_t param_4[8];
#pragma unroll 4
for (int j = 0; j < 8; j++) {
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
}
momentum_4[j] =
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
variance_4[j] =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
}
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
grad_4[j] = vsqrtq_f32(variance_4[j]);
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
}
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
i + SIMD_WIDTH * j);
}
}
}
#endif
if (_param_size > rounded_size) {
Step_4(scalar_seek_offset(_params, param_dtype, rounded_size),
scalar_seek_offset(grads, grad_dtype, rounded_size),
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
exp_avg_sq_dtype, loss_scale);
}
}
void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay,
bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale) {
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
this->IncrementStep(step, beta1, beta2);
this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(),
exp_avg_sq_c.data_ptr(), params_c.numel(),
params_c.scalar_type(), grads_c.scalar_type(),
exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale);
}
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<AdamOptimizer>(m, "CPUAdamOptimizer")
.def(py::init<float, float, float, float, float, bool>())
.def("step", &AdamOptimizer::step);
}
#pragma once
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <cmath>
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#define TILE (128 * 1024 * 1024)
#if defined(__aarch64__)
#include <arm_neon.h>
#define SIMD_WIDTH 4
inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float: {
auto ptr_f = reinterpret_cast<const float32_t *>(ptr);
return vld1q_f32(ptr_f + offset);
}
case at::ScalarType::Half: {
auto ptr_h = reinterpret_cast<const float16_t *>(ptr);
return vcvt_f32_f16(vld1_f16(ptr_h + offset));
}
// case at::ScalarType::BFloat16: {
// auto ptr_b = reinterpret_cast<const bfloat16_t *>(ptr);
// return vcvt_f32_bf16(vld1_bf16(ptr_b + offset));
// }
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) {
return simd_load_offset(ptr, dtype, 0);
}
inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float: {
auto ptr_f = reinterpret_cast<float32_t *>(ptr);
vst1q_f32(ptr_f + offset, data);
break;
}
case at::ScalarType::Half: {
auto ptr_h = reinterpret_cast<float16_t *>(ptr);
vst1_f16(ptr_h + offset, vcvt_f16_f32(data));
break;
}
// case at::ScalarType::BFloat16: {
// auto ptr_b = reinterpret_cast<bfloat16_t *>(ptr);
// vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data));
// break;
// }
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) {
return simd_store_offset(ptr, dtype, data, 0);
}
inline float32x4_t simd_set(float value) {
auto val = static_cast<float32_t>(value);
return vdupq_n_f32(val);
}
#endif
inline float scalar_load_offset(const void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
return *(reinterpret_cast<const float *>(ptr) + offset);
case at::ScalarType::Half:
return static_cast<float>(
*(reinterpret_cast<const at::Half *>(ptr) + offset));
// case at::ScalarType::BFloat16:
// return static_cast<float>(
// *(reinterpret_cast<const at::BFloat16 *>(ptr) + offset));
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
*(reinterpret_cast<float *>(ptr) + offset) = data;
break;
case at::ScalarType::Half:
*(reinterpret_cast<at::Half *>(ptr) + offset) = data;
break;
// case at::ScalarType::BFloat16:
// *(reinterpret_cast<at::BFloat16 *>(ptr) + offset) = data;
break;
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
return reinterpret_cast<float *>(ptr) + offset;
case at::ScalarType::Half:
return reinterpret_cast<at::Half *>(ptr) + offset;
// case at::ScalarType::BFloat16:
// return reinterpret_cast<at::BFloat16 *>(ptr) + offset;
default:
AT_ERROR("Unsupported dtype");
break;
}
}
#define STEP(SPAN) \
void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \
void *_exp_avg_sq, size_t _param_size, \
at::ScalarType param_dtype, at::ScalarType grad_dtype, \
at::ScalarType exp_avg_dtype, \
at::ScalarType exp_avg_sq_dtype, float loss_scale = -1);
class AdamOptimizer {
private:
float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;
float _betta1_t;
float _betta2_t;
size_t _step;
float _bias_correction1;
float _bias_correction2;
bool _adamw_mode;
public:
AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {}
~AdamOptimizer() {}
STEP(1)
STEP(4)
STEP(8)
inline void IncrementStep(size_t step, float beta1, float beta2) {
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
}
inline void update_state(float lr, float epsilon, float weight_decay,
bool bias_correction) {
_alpha = lr;
_eps = epsilon;
_weight_decay = weight_decay;
_bias_correction1 = 1.0f;
_bias_correction2 = 1.0f;
if (bias_correction == 1) {
_bias_correction1 = 1 - _betta1_t;
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
}
}
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale);
};
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
class Unpad(torch.autograd.Function): class Unpad(torch.autograd.Function):
......
...@@ -12,7 +12,7 @@ from colossalai.legacy.context.parallel_mode import ParallelMode ...@@ -12,7 +12,7 @@ from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
......
...@@ -9,7 +9,7 @@ import colossalai.legacy.communication.p2p_v2 as comm ...@@ -9,7 +9,7 @@ import colossalai.legacy.communication.p2p_v2 as comm
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.engine import Engine from colossalai.legacy.engine import Engine
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ._pipeline_schedule import PipelineSchedule from ._pipeline_schedule import PipelineSchedule
......
...@@ -22,7 +22,7 @@ from colossalai.legacy.utils.checkpointing import ( ...@@ -22,7 +22,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule from ..colossalai_layer._utils import ColossalaiModule
......
...@@ -18,7 +18,7 @@ from colossalai.legacy.utils.checkpointing import ( ...@@ -18,7 +18,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
......
...@@ -19,7 +19,7 @@ from colossalai.legacy.utils.checkpointing import ( ...@@ -19,7 +19,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
......
...@@ -27,7 +27,7 @@ from colossalai.legacy.utils.checkpointing import ( ...@@ -27,7 +27,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import ( from ._operation import (
......
...@@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter ...@@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter
from colossalai.legacy.context import seed from colossalai.legacy.context import seed
from colossalai.legacy.registry import LAYERS from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ..utils import to_2tuple from ..utils import to_2tuple
......
...@@ -3,7 +3,7 @@ import types ...@@ -3,7 +3,7 @@ import types
from time import time from time import time
from typing import List from typing import List
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from .stateful_tensor import StatefulTensor, TensorState from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy from .tensor_placement_policy import TensorPlacementPolicy
......
import math import math
import platform
from typing import Optional from typing import Optional
import torch import torch
from colossalai.kernel.op_builder import CPUAdamBuilder from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder
from .nvme_optimizer import NVMeOptimizer from .nvme_optimizer import NVMeOptimizer
...@@ -77,7 +78,7 @@ class CPUAdam(NVMeOptimizer): ...@@ -77,7 +78,7 @@ class CPUAdam(NVMeOptimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode self.adamw_mode = adamw_mode
cpu_adam = CPUAdamBuilder().load() cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load()
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
......
...@@ -84,9 +84,10 @@ class HybridAdam(CPUAdam): ...@@ -84,9 +84,10 @@ class HybridAdam(CPUAdam):
nvme_offload_fraction, nvme_offload_fraction,
nvme_offload_dir, nvme_offload_dir,
) )
fused_optim = FusedOptimBuilder().load() if torch.cuda.is_available():
self.gpu_adam_op = fused_optim.multi_tensor_adam fused_optim = FusedOptimBuilder().load()
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
@torch.no_grad() @torch.no_grad()
def step(self, closure=None, div_scale: float = -1): def step(self, closure=None, div_scale: float = -1):
...@@ -118,11 +119,11 @@ class HybridAdam(CPUAdam): ...@@ -118,11 +119,11 @@ class HybridAdam(CPUAdam):
group_step = state["step"] group_step = state["step"]
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
if target_device.type == "cpu": if target_device.type == "cpu" or target_device.type == "npu":
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq") self._pre_update(p, "exp_avg", "exp_avg_sq")
if p.grad.dtype is torch.bfloat16: if p.grad.dtype is torch.bfloat16 or p.grad.device.type == "npu":
# cpu adam kernel does not support bf16 now # cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment