Commit 2fe7052f authored by zhangyue's avatar zhangyue
Browse files

issue/404: kunlun_common.h 和 kunlun_handle.h 解耦合

parent 97a3a84e
#include "kunlun_common.h"
#include "../../../utils.h"
#include <functional>
namespace device::kunlun {
infiniStatus_t Handle::Internal::useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const {
auto handle = dnn_handles.pop();
if (!handle) {
*handle = xdnn::create_context();
}
(*handle)->set_stream(stream);
CHECK_STATUS(f(*handle));
dnn_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS;
}
} // namespace device::kunlun
#include "../pool.h"
#include "kunlun_handle.h"
// #include "../pool.h"
// #include "kunlun_handle.h"
#include "../../../utils.h"
#include <xpu/runtime.h>
#include <xpu/runtime_ex.h>
#include <xpu/xdnn.h>
......@@ -11,16 +12,3 @@ typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t;
#define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
namespace device::kunlun {
class Handle::Internal {
Pool<xdnnHandle_t> dnn_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const;
};
} // namespace device::kunlun
#include "kunlun_common.h"
#include "kunlun_handle.h"
namespace device::kunlun {
......@@ -10,4 +10,20 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal;
}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::Internal::useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const {
auto handle = dnn_handles.pop();
if (!handle) {
*handle = xdnn::create_context();
}
(*handle)->set_stream(stream);
CHECK_STATUS(f(*handle));
dnn_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
} // namespace device::kunlun
......@@ -2,6 +2,8 @@
#define __INFINIOP_KUNLUN_HANDLE_H__
#include "../../handle.h"
#include "../pool.h"
#include "kunlun_common.h"
#include <memory>
namespace device::kunlun {
......@@ -19,6 +21,15 @@ public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
};
class Handle::Internal {
Pool<xdnnHandle_t> dnn_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const;
};
} // namespace device::kunlun
#endif // __INFINIOP_KUNLUN_HANDLE_H__
......@@ -3,6 +3,7 @@
#include "../../../utils.h"
#include "../../devices/kunlun/kunlun_common.h"
#include "../../devices/kunlun/kunlun_handle.h"
#include "../../devices/kunlun/kunlun_kernel_common.h"
#include "elementwise_kunlun_api.h"
......
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "causal_softmax_kunlun.h"
#include "kernel.h"
......
#include "gemm_kunlun.h"
#include "../../../../utils.h"
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
namespace op::gemm::kunlun {
......
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "kernel.h"
#include "rms_norm_kunlun.h"
......
#ifndef __ROPE_KUNLUN_KERNEL_XPU__
#define __ROPE_KUNLUN_KERNEL_XPU__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "rope_kunlun.h"
#include <memory>
template <typename T, typename Tindex>
__global__ void RoPEKernel(T *destination, const T *source,
const Tindex *pos_ids, const T *sin_table, const T *cos_table,
uint32_t seqlen, uint32_t nhead, uint32_t dhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
XPUStream stream){
//ndim = 3
__global__ void RoPEKernel(T *destination, const T *source,
const Tindex *pos_ids, const T *sin_table, const T *cos_table,
uint32_t seqlen, uint32_t nhead, uint32_t dhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
XPUStream stream) {
// ndim = 3
uint32_t other_size = seqlen * nhead;
int cid = core_id();
......@@ -41,7 +41,7 @@ __global__ void RoPEKernel(T *destination, const T *source,
int remain_dhead = dhead % buf_size;
int repeat = (dhead - remain_dhead) / buf_size;
for(int i = ind_start; i < ind_start + step; i++){
for (int i = ind_start; i < ind_start + step; i++) {
int ind_i = i;
int ind_d = 0;
int ind_s = 0;
......@@ -52,7 +52,7 @@ __global__ void RoPEKernel(T *destination, const T *source,
ind_s += (ind_i % seqlen) * x_stride_seqlen;
GM2LM(pos_ids + (ind_i % seqlen), pos_local, 1 * sizeof(Tindex));
int index = static_cast<int>(pos_local[0]) * dhead / 2;
for(int r = 0; r < repeat + (remain_dhead > 0 ? 1 : 0); r++){
for (int r = 0; r < repeat + (remain_dhead > 0 ? 1 : 0); r++) {
int read_len = (r < repeat ? buf_size : remain_dhead);
int dk = read_len / 2;
int start_d = ind_d + r * buf_size;
......@@ -61,14 +61,13 @@ __global__ void RoPEKernel(T *destination, const T *source,
GM2LM(source + start_s, x_local, read_len * sizeof(T));
GM2LM(sin_table + sin_cos_index, sin_local, dk * sizeof(T));
GM2LM(cos_table + sin_cos_index, cos_local, dk * sizeof(T));
if constexpr (xpu_std::is_same<T, float>::value || xpu_std::is_same<T, half>::value){
for(int k = 0; k < dk; k++){
if constexpr (xpu_std::is_same<T, float>::value || xpu_std::is_same<T, half>::value) {
for (int k = 0; k < dk; k++) {
y_local[2 * k] = x_local[2 * k] * cos_local[k] - x_local[2 * k + 1] * sin_local[k];
y_local[2 * k + 1] = x_local[2 * k] * sin_local[k] + x_local[2 * k + 1] * cos_local[k];
}
}
else if(xpu_std::is_same<T, bfloat16_t>::value){
for(int k = 0; k < dk; k++){
} else if (xpu_std::is_same<T, bfloat16_t>::value) {
for (int k = 0; k < dk; k++) {
float x_0 = __bfloat162float(x_local[2 * k]);
float x_1 = __bfloat162float(x_local[2 * k + 1]);
float sin_f = __bfloat162float(sin_local[k]);
......@@ -84,23 +83,22 @@ __global__ void RoPEKernel(T *destination, const T *source,
}
template <typename T, typename Tindex>
void RoPE(void *destination, const void *source,
const void *pos_ids, const void *sin_table, const void *cos_table,
uint32_t seqlen, uint32_t nhead, uint32_t dhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
XPUStream stream){
RoPEKernel<T, Tindex><<<8, 64, stream>>>((T *)destination, (T *)source,
(Tindex *)pos_ids, (T *)sin_table, (T *)cos_table,
seqlen, nhead, dhead,
x_stride_seqlen, x_stride_nhead,
y_stride_seqlen, y_stride_nhead, stream);
void RoPE(void *destination, const void *source,
const void *pos_ids, const void *sin_table, const void *cos_table,
uint32_t seqlen, uint32_t nhead, uint32_t dhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
XPUStream stream) {
RoPEKernel<T, Tindex><<<8, 64, stream>>>((T *)destination, (T *)source,
(Tindex *)pos_ids, (T *)sin_table, (T *)cos_table,
seqlen, nhead, dhead,
x_stride_seqlen, x_stride_nhead,
y_stride_seqlen, y_stride_nhead, stream);
}
#define LAUNCH_KERNEL(T, Tindex) \
#define LAUNCH_KERNEL(T, Tindex) \
RoPE<T, Tindex>(y, x, pos_ids, sin_table, cos_table, \
seqlen, nhead, dhead, \
x_stride_seqlen, x_stride_nhead, \
seqlen, nhead, dhead, \
x_stride_seqlen, x_stride_nhead, \
y_stride_seqlen, y_stride_nhead, reinterpret_cast<kunlunStream_t>(stream));
namespace op::rope::kunlun {
......
......@@ -5,9 +5,11 @@ local XTDK_DIR = path.join(KUNLUN_HOME, "xtdk")
local XDNN_DIR = path.join(KUNLUN_HOME, "xhpc", "xdnn")
-- Add include dirs
add_includedirs(path.join(XRE_DIR, "include"))
add_includedirs(path.join(XDNN_DIR, "include"))
add_includedirs(path.join(XTDK_DIR, "include"))
add_includedirs(path.join(XRE_DIR, "include"), {public = true})
add_includedirs(path.join(XDNN_DIR, "include"), {public = true})
add_includedirs(path.join(XTDK_DIR, "include"), {public = true})
-- Add link dirs
add_linkdirs(path.join(XRE_DIR, "so"))
add_linkdirs(path.join(XDNN_DIR, "so"))
add_links("xpurt", "xpuapi")
......@@ -72,7 +74,7 @@ target("infiniop-kunlun")
add_files(xpu_files, {
rule = "xpu",
includedirs = {
path.join(os.projectdir, "include"),
path.join(os.projectdir(), "include"),
path.join(XRE_DIR, "include"),
path.join(XDNN_DIR, "include"),
path.join(XTDK_DIR, "include")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment