Commit 2fe7052f authored by zhangyue's avatar zhangyue
Browse files

issue/404: kunlun_common.h 和 kunlun_handle.h 解耦合

parent 97a3a84e
#include "kunlun_common.h"
#include "../../../utils.h"
#include <functional>
namespace device::kunlun {
infiniStatus_t Handle::Internal::useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const {
auto handle = dnn_handles.pop();
if (!handle) {
*handle = xdnn::create_context();
}
(*handle)->set_stream(stream);
CHECK_STATUS(f(*handle));
dnn_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS;
}
} // namespace device::kunlun
#include "../pool.h" // #include "../pool.h"
#include "kunlun_handle.h" // #include "kunlun_handle.h"
#include "../../../utils.h"
#include <xpu/runtime.h> #include <xpu/runtime.h>
#include <xpu/runtime_ex.h> #include <xpu/runtime_ex.h>
#include <xpu/xdnn.h> #include <xpu/xdnn.h>
...@@ -11,16 +12,3 @@ typedef XPUEvent kunlunEvent_t; ...@@ -11,16 +12,3 @@ typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t; typedef xdnn::Context *xdnnHandle_t;
#define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS) #define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
namespace device::kunlun {
class Handle::Internal {
Pool<xdnnHandle_t> dnn_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const;
};
} // namespace device::kunlun
#include "kunlun_common.h" #include "kunlun_handle.h"
namespace device::kunlun { namespace device::kunlun {
...@@ -10,4 +10,20 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & { ...@@ -10,4 +10,20 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal; return _internal;
} }
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::Internal::useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const {
auto handle = dnn_handles.pop();
if (!handle) {
*handle = xdnn::create_context();
}
(*handle)->set_stream(stream);
CHECK_STATUS(f(*handle));
dnn_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
} // namespace device::kunlun } // namespace device::kunlun
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#define __INFINIOP_KUNLUN_HANDLE_H__ #define __INFINIOP_KUNLUN_HANDLE_H__
#include "../../handle.h" #include "../../handle.h"
#include "../pool.h"
#include "kunlun_common.h"
#include <memory> #include <memory>
namespace device::kunlun { namespace device::kunlun {
...@@ -19,6 +21,15 @@ public: ...@@ -19,6 +21,15 @@ public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id); static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
}; };
class Handle::Internal {
Pool<xdnnHandle_t> dnn_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const;
};
} // namespace device::kunlun } // namespace device::kunlun
#endif // __INFINIOP_KUNLUN_HANDLE_H__ #endif // __INFINIOP_KUNLUN_HANDLE_H__
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "../../../utils.h" #include "../../../utils.h"
#include "../../devices/kunlun/kunlun_common.h" #include "../../devices/kunlun/kunlun_common.h"
#include "../../devices/kunlun/kunlun_handle.h"
#include "../../devices/kunlun/kunlun_kernel_common.h" #include "../../devices/kunlun/kunlun_kernel_common.h"
#include "elementwise_kunlun_api.h" #include "elementwise_kunlun_api.h"
......
#include "../../../devices/kunlun/kunlun_common.h" #include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h" #include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "causal_softmax_kunlun.h" #include "causal_softmax_kunlun.h"
#include "kernel.h" #include "kernel.h"
......
#include "gemm_kunlun.h" #include "gemm_kunlun.h"
#include "../../../../utils.h" #include "../../../../utils.h"
#include "../../../devices/kunlun/kunlun_common.h" #include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
namespace op::gemm::kunlun { namespace op::gemm::kunlun {
......
#include "../../../devices/kunlun/kunlun_common.h" #include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h" #include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "kernel.h" #include "kernel.h"
#include "rms_norm_kunlun.h" #include "rms_norm_kunlun.h"
......
#ifndef __ROPE_KUNLUN_KERNEL_XPU__ #ifndef __ROPE_KUNLUN_KERNEL_XPU__
#define __ROPE_KUNLUN_KERNEL_XPU__ #define __ROPE_KUNLUN_KERNEL_XPU__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../devices/kunlun/kunlun_common.h" #include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h" #include "../../../devices/kunlun/kunlun_handle.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "rope_kunlun.h" #include "rope_kunlun.h"
#include <memory> #include <memory>
template <typename T, typename Tindex> template <typename T, typename Tindex>
__global__ void RoPEKernel(T *destination, const T *source, __global__ void RoPEKernel(T *destination, const T *source,
const Tindex *pos_ids, const T *sin_table, const T *cos_table, const Tindex *pos_ids, const T *sin_table, const T *cos_table,
uint32_t seqlen, uint32_t nhead, uint32_t dhead, uint32_t seqlen, uint32_t nhead, uint32_t dhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead, int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead, int32_t y_stride_seqlen, int32_t y_stride_nhead,
XPUStream stream){ XPUStream stream) {
//ndim = 3 // ndim = 3
uint32_t other_size = seqlen * nhead; uint32_t other_size = seqlen * nhead;
int cid = core_id(); int cid = core_id();
...@@ -41,7 +41,7 @@ __global__ void RoPEKernel(T *destination, const T *source, ...@@ -41,7 +41,7 @@ __global__ void RoPEKernel(T *destination, const T *source,
int remain_dhead = dhead % buf_size; int remain_dhead = dhead % buf_size;
int repeat = (dhead - remain_dhead) / buf_size; int repeat = (dhead - remain_dhead) / buf_size;
for(int i = ind_start; i < ind_start + step; i++){ for (int i = ind_start; i < ind_start + step; i++) {
int ind_i = i; int ind_i = i;
int ind_d = 0; int ind_d = 0;
int ind_s = 0; int ind_s = 0;
...@@ -52,7 +52,7 @@ __global__ void RoPEKernel(T *destination, const T *source, ...@@ -52,7 +52,7 @@ __global__ void RoPEKernel(T *destination, const T *source,
ind_s += (ind_i % seqlen) * x_stride_seqlen; ind_s += (ind_i % seqlen) * x_stride_seqlen;
GM2LM(pos_ids + (ind_i % seqlen), pos_local, 1 * sizeof(Tindex)); GM2LM(pos_ids + (ind_i % seqlen), pos_local, 1 * sizeof(Tindex));
int index = static_cast<int>(pos_local[0]) * dhead / 2; int index = static_cast<int>(pos_local[0]) * dhead / 2;
for(int r = 0; r < repeat + (remain_dhead > 0 ? 1 : 0); r++){ for (int r = 0; r < repeat + (remain_dhead > 0 ? 1 : 0); r++) {
int read_len = (r < repeat ? buf_size : remain_dhead); int read_len = (r < repeat ? buf_size : remain_dhead);
int dk = read_len / 2; int dk = read_len / 2;
int start_d = ind_d + r * buf_size; int start_d = ind_d + r * buf_size;
...@@ -61,14 +61,13 @@ __global__ void RoPEKernel(T *destination, const T *source, ...@@ -61,14 +61,13 @@ __global__ void RoPEKernel(T *destination, const T *source,
GM2LM(source + start_s, x_local, read_len * sizeof(T)); GM2LM(source + start_s, x_local, read_len * sizeof(T));
GM2LM(sin_table + sin_cos_index, sin_local, dk * sizeof(T)); GM2LM(sin_table + sin_cos_index, sin_local, dk * sizeof(T));
GM2LM(cos_table + sin_cos_index, cos_local, dk * sizeof(T)); GM2LM(cos_table + sin_cos_index, cos_local, dk * sizeof(T));
if constexpr (xpu_std::is_same<T, float>::value || xpu_std::is_same<T, half>::value){ if constexpr (xpu_std::is_same<T, float>::value || xpu_std::is_same<T, half>::value) {
for(int k = 0; k < dk; k++){ for (int k = 0; k < dk; k++) {
y_local[2 * k] = x_local[2 * k] * cos_local[k] - x_local[2 * k + 1] * sin_local[k]; y_local[2 * k] = x_local[2 * k] * cos_local[k] - x_local[2 * k + 1] * sin_local[k];
y_local[2 * k + 1] = x_local[2 * k] * sin_local[k] + x_local[2 * k + 1] * cos_local[k]; y_local[2 * k + 1] = x_local[2 * k] * sin_local[k] + x_local[2 * k + 1] * cos_local[k];
} }
} } else if (xpu_std::is_same<T, bfloat16_t>::value) {
else if(xpu_std::is_same<T, bfloat16_t>::value){ for (int k = 0; k < dk; k++) {
for(int k = 0; k < dk; k++){
float x_0 = __bfloat162float(x_local[2 * k]); float x_0 = __bfloat162float(x_local[2 * k]);
float x_1 = __bfloat162float(x_local[2 * k + 1]); float x_1 = __bfloat162float(x_local[2 * k + 1]);
float sin_f = __bfloat162float(sin_local[k]); float sin_f = __bfloat162float(sin_local[k]);
...@@ -84,23 +83,22 @@ __global__ void RoPEKernel(T *destination, const T *source, ...@@ -84,23 +83,22 @@ __global__ void RoPEKernel(T *destination, const T *source,
} }
template <typename T, typename Tindex> template <typename T, typename Tindex>
void RoPE(void *destination, const void *source, void RoPE(void *destination, const void *source,
const void *pos_ids, const void *sin_table, const void *cos_table, const void *pos_ids, const void *sin_table, const void *cos_table,
uint32_t seqlen, uint32_t nhead, uint32_t dhead, uint32_t seqlen, uint32_t nhead, uint32_t dhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead, int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead, int32_t y_stride_seqlen, int32_t y_stride_nhead,
XPUStream stream){ XPUStream stream) {
RoPEKernel<T, Tindex><<<8, 64, stream>>>((T *)destination, (T *)source, RoPEKernel<T, Tindex><<<8, 64, stream>>>((T *)destination, (T *)source,
(Tindex *)pos_ids, (T *)sin_table, (T *)cos_table, (Tindex *)pos_ids, (T *)sin_table, (T *)cos_table,
seqlen, nhead, dhead, seqlen, nhead, dhead,
x_stride_seqlen, x_stride_nhead, x_stride_seqlen, x_stride_nhead,
y_stride_seqlen, y_stride_nhead, stream); y_stride_seqlen, y_stride_nhead, stream);
} }
#define LAUNCH_KERNEL(T, Tindex) \ #define LAUNCH_KERNEL(T, Tindex) \
RoPE<T, Tindex>(y, x, pos_ids, sin_table, cos_table, \ RoPE<T, Tindex>(y, x, pos_ids, sin_table, cos_table, \
seqlen, nhead, dhead, \ seqlen, nhead, dhead, \
x_stride_seqlen, x_stride_nhead, \ x_stride_seqlen, x_stride_nhead, \
y_stride_seqlen, y_stride_nhead, reinterpret_cast<kunlunStream_t>(stream)); y_stride_seqlen, y_stride_nhead, reinterpret_cast<kunlunStream_t>(stream));
namespace op::rope::kunlun { namespace op::rope::kunlun {
......
...@@ -5,9 +5,11 @@ local XTDK_DIR = path.join(KUNLUN_HOME, "xtdk") ...@@ -5,9 +5,11 @@ local XTDK_DIR = path.join(KUNLUN_HOME, "xtdk")
local XDNN_DIR = path.join(KUNLUN_HOME, "xhpc", "xdnn") local XDNN_DIR = path.join(KUNLUN_HOME, "xhpc", "xdnn")
-- Add include dirs -- Add include dirs
add_includedirs(path.join(XRE_DIR, "include")) add_includedirs(path.join(XRE_DIR, "include"), {public = true})
add_includedirs(path.join(XDNN_DIR, "include")) add_includedirs(path.join(XDNN_DIR, "include"), {public = true})
add_includedirs(path.join(XTDK_DIR, "include")) add_includedirs(path.join(XTDK_DIR, "include"), {public = true})
-- Add link dirs
add_linkdirs(path.join(XRE_DIR, "so")) add_linkdirs(path.join(XRE_DIR, "so"))
add_linkdirs(path.join(XDNN_DIR, "so")) add_linkdirs(path.join(XDNN_DIR, "so"))
add_links("xpurt", "xpuapi") add_links("xpurt", "xpuapi")
...@@ -72,7 +74,7 @@ target("infiniop-kunlun") ...@@ -72,7 +74,7 @@ target("infiniop-kunlun")
add_files(xpu_files, { add_files(xpu_files, {
rule = "xpu", rule = "xpu",
includedirs = { includedirs = {
path.join(os.projectdir, "include"), path.join(os.projectdir(), "include"),
path.join(XRE_DIR, "include"), path.join(XRE_DIR, "include"),
path.join(XDNN_DIR, "include"), path.join(XDNN_DIR, "include"),
path.join(XTDK_DIR, "include") path.join(XTDK_DIR, "include")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment