"build/CMakeFiles/ContinuousStart.dir/DependInfo.cmake" did not exist on "395d2ce606314a6729939084e5f492f37cd2ff13"
Unverified Commit 2391ec99 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #136 from InfiniTensor/issue/111-rmsnorm-kunlun

issue/111: 添加rmsnorm以及算子编译流程
parents 1d77c986 ff484bc7
#ifndef __INFINIOP_KUNLUN_COMMON_H__
#define __INFINIOP_KUNLUN_COMMON_H__
// This header file will only be include by .xpu file
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
// Get mask for kunlun xpu 512bit register calculation
// if data is not enough to 512bit, padding zero and use
// mask to identify real data
// 0 - i bit 1, others 0
inline __device__ float lowerBitMask(int i) {
return (1 << (i + 1)) - 1;
}
// Atomic add for reduce
inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) {
int success = 1;
while (success) {
// SM2REG read 32bit data to register
float a = SM2REG_atomic(ptr);
a = a + value;
success = REG2SM_atomic(ptr, a);
}
}
// TODO: atomicAddF16
// TODO: atomicAddI8
#endif
#ifndef __RMS_NORM_KUNLUN_KERNEL_XPU__
#define __RMS_NORM_KUNLUN_KERNEL_XPU__
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../reduce/kunlun/reduce_kunlun.h"
// Element wise mul used in x * w
static inline __device__ void elementwiseMulRms(float *x, float *w, float *y, int count, float rms) {
int remain = count % 16;
int offset_last = count - remain;
// y[i] = w[i] * x[i] * rms for remainder
for (int i = offset_last; i < count; i++) {
*(y + i) = *(w + i) * *(x + i) * rms;
}
mfence();
float32x16_t v_x;
float32x16_t v_w;
// Do x * w * rms
for (int i = 0; i < offset_last; i += 16) {
v_x = vload_lm_float32x16_mz(x + i);
v_w = vload_lm_float32x16_mz(w + i);
v_x = vvmul_float32x16(v_x, v_w);
v_x = svmul_float32x16(rms, v_x);
vstore_lm_float32x16((y + i), v_x);
mfence();
}
}
// RmsNorm main kernel func
// kunlun2 has 8 cluster and 64 core
// Call it by rmsnorm<<<8, 32, stream>>>()
__global__ void rmsNormKernelF32(float *y, long stride_y, const float *x, long stride_x, const float *w, int m, int n, float epsilon) {
// ncores in a cluster
int ncores = core_num();
// get cid of current core
int cid = core_id();
if (cid >= ncores) {
return;
}
// Divide m rows into all clusters equally
// if m % cluster_num() != 0, cluster_id < m % cluster_num() do 1 row more
// [m_start, m_end) is the range of m dim in current cluster
int m_start = m / cluster_num() * cluster_id() + min(m % cluster_num(), cluster_id());
int m_end = m_start + (m / cluster_num()) + (cluster_id() < (m % cluster_num()));
// max_nn is the max number of elements calculated on one core
const int max_nn = 1024;
// max_mm is the max number of rows calculated on one cluster
const int max_mm = 1024;
// LM cache for reduce
__local__ float x_local[max_nn];
// sm_output is shared mem cache for reduce
__shared__ float sm_output[max_mm];
// LM cache for elementwise mul
__local__ float y_local[max_nn];
__local__ float w_local[max_nn];
while (m_start < m_end) {
// init sm_output
for (int i = cid; i < m_end - m_start; i += ncores) {
sm_output[i] = 0.0f;
}
mfence();
sync_cluster();
// mm is the number of rows on current cluster
int mm = min(max_mm, m_end - m_start);
// each row will be devided to several blocks
// total_block is the number of blocks calculated on current cluster
// curr_block is the block calculated on current core
int total_block = mm * roundup_div(n, max_nn);
for (int curr_block = cid; curr_block < total_block; curr_block += ncores) {
// curr_m is the row of curr_block;
// curr_n_start is the first element of current row
// curr_nn is the number of elements of curr_block
int curr_m = curr_block % mm + m_start;
int curr_n_start = (curr_block / mm) * max_nn;
int curr_nn = min(max_nn, n - curr_n_start);
auto x_ptr = x + curr_m * stride_x + curr_n_start;
GM2LM(x_ptr, x_local, curr_nn * sizeof(float));
// do reduce
float ss = op::common_kunlun::reduce_op::sumSquaredF32(x_local, curr_nn);
atomicAddF32(&sm_output[curr_m - m_start], ss);
}
mfence();
sync_cluster();
// do elementwise mul for every line
for (int blk = cid; blk < total_block; blk += ncores) {
int m = blk % mm + m_start;
int n_start = (blk / mm) * max_nn;
int nn = min(max_nn, n - n_start);
auto x_ptr = x + m * stride_x + n_start;
auto w_ptr = w + n_start;
GM2LM(x_ptr, x_local, nn * sizeof(float));
GM2LM(w_ptr, w_local, nn * sizeof(float));
float ss = SM2REG_atomic(sm_output + m - m_start);
float rms = 1.0f / sqrt(ss / n + epsilon);
elementwiseMulRms(x_local, w_local, y_local, nn, rms);
mfence();
auto y_ptr = y + m * stride_y + n_start;
LM2GM(y_local, y_ptr, nn * sizeof(float));
}
mfence();
sync_cluster();
m_start += max_mm;
}
}
void rmsNormF32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream) {
rmsNormKernelF32<<<8, 32, stream>>>((float *)y, stride_y, (const float *)x, stride_x, (const float *)w, m, n, epsilon);
}
#endif
#include "rms_norm_kunlun.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include <memory>
#include <stdint.h>
void rmsNormF32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream);
namespace op::rms_norm::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (info.ndim() != 2) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
*desc_ptr = new Descriptor(
new Descriptor::Opaque{static_cast<device::kunlun::Handle *>(handle)->internal()},
info,
0,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t launchKernel(
int m, int n,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
const void *x, ptrdiff_t stride_x,
const void *w, infiniDtype_t wtype,
float epsilon,
kunlunStream_t stream) {
if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
rmsNormF32(y, static_cast<long>(stride_y), x, static_cast<long>(stride_x), w, m, n, epsilon, stream);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y, const void *x, const void *w, void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stride_x = _info.x_strides[0];
auto stride_y = _info.y_strides[0];
int n = static_cast<int>(_info.dim());
int m = static_cast<int>(_info.shape[0]);
launchKernel(m, n, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, reinterpret_cast<kunlunStream_t>(stream));
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::kunlun
#ifndef __RMS_NORM_KUNLUN_H__
#define __RMS_NORM_KUNLUN_H__
#include "../rms_norm.h"
DESCRIPTOR(kunlun)
#endif
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
#include "ascend/rms_norm_aclnn.h" #include "ascend/rms_norm_aclnn.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/rms_norm_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateRMSNormDescriptor( __C infiniStatus_t infiniopCreateRMSNormDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -37,6 +40,9 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor( ...@@ -37,6 +40,9 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda) CREATE(INFINI_DEVICE_NVIDIA, cuda)
#endif #endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangCreateRMSNormDescriptor((BangHandle_t)handle, (RMSNormBangDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon); return bangCreateRMSNormDescriptor((BangHandle_t)handle, (RMSNormBangDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon);
...@@ -76,6 +82,9 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d ...@@ -76,6 +82,9 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda) GET(INFINI_DEVICE_NVIDIA, cuda)
#endif #endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangGetRMSNormWorkspaceSize((RMSNormBangDescriptor_t)desc, size); return bangGetRMSNormWorkspaceSize((RMSNormBangDescriptor_t)desc, size);
...@@ -116,6 +125,9 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works ...@@ -116,6 +125,9 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda) CALCULATE(INFINI_DEVICE_NVIDIA, cuda)
#endif #endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangRMSNorm((RMSNormBangDescriptor_t)desc, workspace, workspace_size, y, x, w, stream); return bangRMSNorm((RMSNormBangDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
...@@ -155,6 +167,9 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t ...@@ -155,6 +167,9 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
DESTROY(INFINI_DEVICE_NVIDIA, cuda) DESTROY(INFINI_DEVICE_NVIDIA, cuda)
#endif #endif
#ifdef ENABLE_KUNLUN_API
DESTROY(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangDestroyRMSNormDescriptor((RMSNormBangDescriptor_t)desc); return bangDestroyRMSNormDescriptor((RMSNormBangDescriptor_t)desc);
......
#ifndef __INFINIOP_REDUCE_KUNLUN_H__
#define __INFINIOP_REDUCE_KUNLUN_H__
#include "../../devices/kunlun/kunlun_common.h"
namespace op::common_kunlun::reduce_op {
// Use 16 floats instruction to calculate reduce
// data_ptr is the pointer of LM
static inline __device__ float sumSquaredF32(float *data_ptr, int count) {
__local__ float acc_buf[16];
int remain = count % 16;
int offset_last = count - remain;
int mask = lowerBitMask(remain - 1);
// Load last 16 data
float32x16_t v_last = vload_lm_float32x16_mz((data_ptr + offset_last), mask);
// Do v_last * v_last
v_last = vvmul_float32x16(v_last, v_last);
// for every 16 float data
for (int i = 0; i < offset_last; i += 16) {
float32x16_t v_0 = vload_lm_float32x16_mz(data_ptr + i);
// Do v_0 * v_0
v_0 = vvmul_float32x16(v_0, v_0);
// Add to v_last
v_last = vvadd_float32x16(v_last, v_0);
}
vstore_lm_float32x16_mz(acc_buf, v_last);
mfence();
float res = 0.0f;
for (int i = 0; i < 16; ++i) {
res += acc_buf[i];
}
return res;
}
} // namespace op::common_kunlun::reduce_op
#endif
add_defines("ENABLE_KUNLUN_API") add_defines("ENABLE_KUNLUN_API")
local KUNLUN_HOME = os.getenv("KUNLUN_HOME") local KUNLUN_HOME = os.getenv("KUNLUN_HOME")
local XTDK_DIR = path.join(KUNLUN_HOME, "XTDK")
-- Add include dirs -- Add include dirs
add_includedirs(path.join(KUNLUN_HOME, "include"), {public=true}) add_includedirs(path.join(KUNLUN_HOME, "include"), {public=true})
...@@ -7,6 +8,55 @@ add_linkdirs(path.join(KUNLUN_HOME, "lib64")) ...@@ -7,6 +8,55 @@ add_linkdirs(path.join(KUNLUN_HOME, "lib64"))
add_links("xpurt") add_links("xpurt")
add_links("xpuapi") add_links("xpuapi")
rule("xpu")
set_extensions(".xpu")
on_load(function (target)
target:add("includedirs", path.join(os.projectdir(), "include"))
end)
on_build_file(function (target, sourcefile)
local objectfile = target:objectfile(sourcefile)
local basename = objectfile:gsub("%.o$", "")
os.mkdir(path.directory(objectfile))
local cc = path.join(XTDK_DIR, "bin/clang++")
local includedirs = table.concat(target:get("includedirs"), " ")
local arch_map = {
["x86_64"] = "x86_64-linux-gnu",
["arm64"] = "aarch64-linux-gnu"
}
local args = {
"--sysroot=/",
"--target=" .. arch_map[os.arch()],
"-fPIC",
"-pie",
"--xpu-arch=xpu2",
"--basename", basename,
"-std=c++11",
"-O2",
"-fno-builtin",
"-g",
"-c", sourcefile,
"-v"
}
for _, includedir in ipairs(target:get("includedirs")) do
table.insert(args, "-I" .. includedir)
end
-- print(args)
os.execv(cc, args)
table.insert(target:objectfiles(), objectfile)
table.insert(target:objectfiles(), basename .. ".device.bin.o")
print(target:objectfiles())
end)
rule_end()
local src_dir = path.join(os.projectdir(), "src", "infiniop")
target("infiniop-kunlun") target("infiniop-kunlun")
set_kind("static") set_kind("static")
add_deps("infini-utils") add_deps("infini-utils")
...@@ -17,6 +67,11 @@ target("infiniop-kunlun") ...@@ -17,6 +67,11 @@ target("infiniop-kunlun")
set_languages("cxx17") set_languages("cxx17")
add_files("$(projectdir)/src/infiniop/devices/kunlun/*.cc", "$(projectdir)/src/infiniop/ops/*/kunlun/*.cc") add_files("$(projectdir)/src/infiniop/devices/kunlun/*.cc", "$(projectdir)/src/infiniop/ops/*/kunlun/*.cc")
-- compile handwriting kernel
local xpu_files = os.files(src_dir .. "/ops/*/kunlun/*.xpu")
if #xpu_files > 0 then
add_files(xpu_files, {rule = "xpu"})
end
target_end() target_end()
target("infinirt-kunlun") target("infinirt-kunlun")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment