Commit 8b59f4fe authored by Catheriany's avatar Catheriany
Browse files

Merge remote-tracking branch 'origin/main' into issue/204

parents 16506fc0 df1c6b5d
...@@ -175,6 +175,10 @@ options: ...@@ -175,6 +175,10 @@ options:
{ {
"clangd.arguments": [ "clangd.arguments": [
"--compile-commands-dir=.vscode" "--compile-commands-dir=.vscode"
] ],
"xmake.additionalConfigArguments": [
// 在这里配置 XMAKE_CONFIG_FLAGS
"--nv-gpu=y"
],
} }
``` ```
...@@ -18,6 +18,7 @@ def run_tests(args): ...@@ -18,6 +18,7 @@ def run_tests(args):
"rms_norm.py", "rms_norm.py",
"rope.py", "rope.py",
"swiglu.py", "swiglu.py",
"attention.py",
]: ]:
result = subprocess.run( result = subprocess.run(
f"python {test} {args}", text=True, encoding="utf-8", shell=True f"python {test} {args}", text=True, encoding="utf-8", shell=True
......
...@@ -7,8 +7,9 @@ ...@@ -7,8 +7,9 @@
*/ */
DECLARE_INFINIOP_TEST(gemm) DECLARE_INFINIOP_TEST(gemm)
DECLARE_INFINIOP_TEST(random_sample) DECLARE_INFINIOP_TEST(random_sample)
DECLARE_INFINIOP_TEST(add)
DECLARE_INFINIOP_TEST(mul) DECLARE_INFINIOP_TEST(mul)
DECLARE_INFINIOP_TEST(swiglu)
DECLARE_INFINIOP_TEST(add)
#define REGISTER_INFINIOP_TEST(name) \ #define REGISTER_INFINIOP_TEST(name) \
{ \ { \
...@@ -28,7 +29,8 @@ DECLARE_INFINIOP_TEST(mul) ...@@ -28,7 +29,8 @@ DECLARE_INFINIOP_TEST(mul)
REGISTER_INFINIOP_TEST(random_sample) \ REGISTER_INFINIOP_TEST(random_sample) \
REGISTER_INFINIOP_TEST(add) \ REGISTER_INFINIOP_TEST(add) \
REGISTER_INFINIOP_TEST(mul) \ REGISTER_INFINIOP_TEST(mul) \
} REGISTER_INFINIOP_TEST(swiglu)
}
namespace infiniop_test { namespace infiniop_test {
......
#include "ops.hpp"
#include "utils.hpp"
#include <infinirt.h>
#include <iomanip>
#include <iostream>
namespace infiniop_test::swiglu {
struct Test::Attributes {
std::shared_ptr<Tensor> a;
std::shared_ptr<Tensor> b;
std::shared_ptr<Tensor> ans;
std::shared_ptr<Tensor> c;
};
std::shared_ptr<Test> Test::build(
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
double rtol, double atol) {
auto test = std::shared_ptr<Test>(new Test(rtol, atol));
test->_attributes = new Attributes();
if (tensors.find("a") == tensors.end()
|| tensors.find("b") == tensors.end()
|| tensors.find("c") == tensors.end()
|| tensors.find("ans") == tensors.end()) {
throw std::runtime_error("Invalid Test");
}
test->_attributes->a = tensors["a"];
test->_attributes->b = tensors["b"];
test->_attributes->c = tensors["c"];
test->_attributes->ans = tensors["ans"];
return test;
}
std::shared_ptr<infiniop_test::Result> Test::run(
infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) {
infiniopSwiGLUDescriptor_t op_desc;
auto a = _attributes->a->to(device, device_id);
auto b = _attributes->b->to(device, device_id);
auto c = _attributes->c->to(device, device_id);
CHECK_OR(infiniopCreateSwiGLUDescriptor(handle, &op_desc,
c->desc(),
a->desc(),
b->desc()),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor."));
size_t workspace_size;
CHECK_OR(infiniopGetSwiGLUWorkspaceSize(op_desc, &workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size."));
void *workspace;
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace."));
CHECK_OR(infiniopSwiGLU(op_desc, workspace, workspace_size, c->data(), a->data(), b->data(), nullptr),
return TEST_FAILED(OP_CREATION_FAILED, "Failed during execution."));
try {
allClose(c, _attributes->ans, _rtol, _atol);
} catch (const std::exception &e) {
return TEST_FAILED(RESULT_INCORRECT, e.what());
}
double elapsed_time = 0.;
elapsed_time = benchmark(
[=]() {
infiniopSwiGLU(
op_desc,
workspace,
workspace_size,
c->data(),
a->data(),
b->data(),
nullptr);
},
warm_ups, iterations);
return TEST_PASSED(elapsed_time);
}
std::vector<std::string> Test::attribute_names() {
return {};
}
std::vector<std::string> Test::tensor_names() {
return {"a", "b", "c", "ans"};
}
std::string Test::toString() const {
std::ostringstream oss;
oss << op_name() << std::endl;
oss << "- a: " << _attributes->a->info() << std::endl;
oss << "- b: " << _attributes->b->info() << std::endl;
oss << "- c: " << _attributes->c->info() << std::endl;
oss << std::scientific << std::setprecision(2);
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
return oss.str();
}
Test::~Test() {
delete _attributes;
}
} // namespace infiniop_test::swiglu
...@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.16.0) ...@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.16.0)
# project information # project information
project(Ascend_C) project(Ascend_C)
set(SOC_VERSION "Ascend910B3" CACHE STRING "system on chip type") set(SOC_VERSION "Ascend910B3" CACHE STRING "system on chip type")
set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME} CACHE PATH "ASCEND CANN package installation directory") set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_TOOLKIT_HOME} CACHE PATH "ASCEND CANN package installation directory")
set(RUN_MODE "npu" CACHE STRING "run mode: npu") set(RUN_MODE "npu" CACHE STRING "run mode: npu")
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type Release/Debug (default Debug)" FORCE) set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type Release/Debug (default Debug)" FORCE)
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/out" CACHE STRING "path for install()" FORCE) set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/out" CACHE STRING "path for install()" FORCE)
...@@ -19,10 +19,14 @@ else() ...@@ -19,10 +19,14 @@ else()
endif() endif()
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
include_directories(
${CMAKE_SOURCE_DIR}/../../../../include/infiniop/
)
ascendc_library(ascend_kernels STATIC ascendc_library(ascend_kernels STATIC
../../ops/swiglu/ascend/swiglu_kernel.cpp ../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp
../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp # ../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp
../../ops/random_sample/ascend/random_sample_kernel.cpp # ../../ops/random_sample/ascend/random_sample_kernel.cpp
) )
#ifndef __INFINIOP_ASCEND_KERNEL_COMMON_H__
#define __INFINIOP_ASCEND_KERNEL_COMMON_H__
#include "../../../../include/infinicore.h"
#include "kernel_operator.h"
constexpr int32_t BLOCK_NUM = 8;
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t BYTE_ALIGN = 32;
#endif
#include "common_ascend.h" #include "common_ascend.h"
std::vector<int64_t> inferStorageShape(std::vector<int64_t> shape, std::vector<int64_t> strides) { std::vector<int64_t> inferStorageShape(std::vector<int64_t> shape, std::vector<int64_t> strides) {
auto index = std::max_element(strides.begin(), strides.end()); if (shape.size() != strides.size()) {
uint64_t max_stride_index = std::distance(strides.begin(), index); throw std::invalid_argument("Shape and strides must have the same length.");
auto storageShape = std::vector<int64_t>({shape[max_stride_index] * strides[max_stride_index]}); }
int64_t max_offset = 0;
for (size_t i = 0; i < shape.size(); ++i) {
max_offset += (shape[i] - 1) * strides[i];
}
return storageShape; // storage shape is 1D buffer that must cover all accessed elements
return {max_offset + 1};
} }
size_t aclnnTensorDescriptor::numel() const { size_t aclnnTensorDescriptor::numel() const {
...@@ -18,7 +24,7 @@ aclnnTensorDescriptor::aclnnTensorDescriptor(infiniopTensorDescriptor_t desc, vo ...@@ -18,7 +24,7 @@ aclnnTensorDescriptor::aclnnTensorDescriptor(infiniopTensorDescriptor_t desc, vo
this->strides = std::vector<int64_t>(ndim); this->strides = std::vector<int64_t>(ndim);
for (uint64_t i = 0; i < ndim; ++i) { for (uint64_t i = 0; i < ndim; ++i) {
this->shape[i] = static_cast<int64_t>(desc->dim(i)); this->shape[i] = static_cast<int64_t>(desc->dim(i));
this->strides[i] = desc->stride(i); this->strides[i] = static_cast<int64_t>(desc->stride(i));
} }
this->storageShape = inferStorageShape(this->shape, this->strides); this->storageShape = inferStorageShape(this->shape, this->strides);
this->dataType = toAclDataType(desc->dtype()); this->dataType = toAclDataType(desc->dtype());
......
...@@ -16,7 +16,7 @@ typedef XPUStream kunlunStream_t; ...@@ -16,7 +16,7 @@ typedef XPUStream kunlunStream_t;
typedef XPUEvent kunlunEvent_t; typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t; typedef xdnn::Context *xdnnHandle_t;
#define CHECK_XDNN(API) CHECK_INTERNAL(API, XPU_SUCCESS) #define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
namespace device::kunlun { namespace device::kunlun {
......
#ifndef __INFINIOP_KUNLUN_COMMON_H__ #ifndef __INFINIOP_KUNLUN_KERNEL_COMMON_H__
#define __INFINIOP_KUNLUN_COMMON_H__ #define __INFINIOP_KUNLUN_KERNEL_COMMON_H__
// This header file will only be include by .xpu file // This header file will only be include by .xpu file
#include "kunlun_kernel_dtype.h"
#include "xpu/kernel/xtdk.h" #include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h" #include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h" #include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h" #include "xpu/runtime.h"
namespace device::kunlun::kernel {
// Get mask for kunlun xpu 512bit register calculation // Get mask for kunlun xpu 512bit register calculation
// if data is not enough to 512bit, padding zero and use // if data is not enough to 512bit, padding zero and use
// mask to identify real data // mask to identify real data
...@@ -26,6 +28,37 @@ inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) { ...@@ -26,6 +28,37 @@ inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) {
} }
} }
inline __device__ size_t indexToReducedOffset(
size_t flat_index,
size_t ndim,
const _ptrdiff_t *broadcasted_strides,
const _ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i].value * target_strides[i].value;
flat_index %= broadcasted_strides[i].value;
mfence();
}
return res;
}
inline __device__ size_t indexToOffset(
size_t flat_index,
size_t ndim,
const _size_t *shape,
const _ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i].value) * strides[i].value;
flat_index /= shape[i].value;
mfence();
}
return res;
}
} // namespace device::kunlun::kernel
// TODO: atomicAddF16 // TODO: atomicAddF16
// TODO: atomicAddI8 // TODO: atomicAddI8
#endif #endif
#ifndef __INFINIOP_KUNLUN_DTYPE_H__
#define __INFINIOP_KUNLUN_DTYPE_H__
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
// kunlun ptrdiff_t* is used to save ptrdiff_t array
// copied from host
typedef struct _ptrdiff_t {
long value; // 32 bit
long padding; // 32 bit
} _ptrdiff_t;
// same as ptrdiff
typedef struct _size_t {
size_t value;
size_t padding;
} _size_t;
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_H__
#include "../../../utils.h"
#include "../../devices/kunlun/kunlun_handle.h"
#include "elementwise_kunlun_api.h"
namespace op::elementwise::kunlun {
struct DeviceImpl::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::kunlun::Handle::Internal> &internal_)
: internal(internal_) {}
template <size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
kunlunStream_t stream,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides));
Op::template launch<Tdata>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
reinterpret_cast<const void *>(d_input_contiguous),
reinterpret_cast<const void *>(d_input_broadcasted),
reinterpret_cast<const void *>(d_output_shape),
reinterpret_cast<const void *>(d_input_shapes),
reinterpret_cast<const void *>(d_output_strides),
reinterpret_cast<const void *>(d_input_strides),
output,
reinterpret_cast<const void *const *>(d_inputs_arr),
stream,
args...);
return INFINI_STATUS_SUCCESS;
}
private:
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides) const {
constexpr auto input_size = N;
const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device
CHECK_KUNLUN(xpu_memcpy(workspace, h_inputs_arr, input_arr_size, XPU_HOST_TO_DEVICE));
CHECK_KUNLUN(xpu_memcpy((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), XPU_HOST_TO_DEVICE));
// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
return INFINI_STATUS_SUCCESS;
}
};
template <typename... Args>
utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) {
auto opaque = std::make_shared<Opaque>(std::forward<Args>(args)...);
return utils::Result<DeviceImpl *>(new DeviceImpl(opaque));
}
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<N, Op, Tdata>(
info, workspace, output, inputs,
reinterpret_cast<kunlunStream_t>(stream),
std::forward<Args>(args)...);
}
} // namespace op::elementwise::kunlun
// Template for kunlun kernel interface declaration
#define LAUNCH_ELEMENTWISE_KERNEL(OpName) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args);
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#include "../elementwise.h"
namespace op::elementwise::kunlun {
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static utils::Result<DeviceImpl *> create(Args &&...args);
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
} // namespace op::elementwise::kunlun
#define CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::kunlun::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#define __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#include "../../devices/kunlun/kunlun_kernel_common.h"
using namespace device::kunlun::kernel;
/**
* @brief Computes input tile offset
*/
struct InputIndexer {
size_t idx;
size_t ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const _size_t *input_shapes;
const _ptrdiff_t *input_strides;
const _ptrdiff_t *output_strides;
__device__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
inline __device__ size_t
getOutputIndex(size_t idx,
bool is_contiguous,
size_t ndim,
const _size_t *shape,
const _ptrdiff_t *strides) {
return is_contiguous ? idx : indexToOffset(idx, ndim, shape, strides);
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__device__ void launchOp(
__global_ptr__ Tdata **typed_inputs, // gm pointer
__global_ptr__ Tdata *output, // gm pointer output
Tdata *inputs_buf, // local mem buffer
size_t *input_indexes,
size_t output_index,
Args... args) {
static_assert(N == Op::num_inputs, "template N is not equal to Op::num_inputs!\n");
#pragma unroll
// Copy inputs to buf
for (size_t i = 0; i < N; i++) {
auto gm = typed_inputs[i] + input_indexes[i];
auto lm = inputs_buf + i;
GM2LM_ASYNC(gm, lm, 1 * sizeof(Tdata));
}
mfence();
// Calculate elementwise
// Inputs save all operands
Tdata out = Op{}(inputs_buf, args...);
// Copy out to gm
LM2GM_ASYNC(&out, output + output_index, 1 * sizeof(Tdata));
mfence();
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__global__ void elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *input_contiguous_gm,
const bool *input_broadcasted_gm,
const _size_t *output_shape_gm,
const _size_t *input_shapes_gm,
const _ptrdiff_t *output_strides_gm,
const _ptrdiff_t *input_strides_gm,
Tdata *output,
const void *const *inputs,
Args... args) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
// Cast input gm pointer type
auto typed_inputs = reinterpret_cast<const __global_ptr__ Tdata *const __global_ptr__ *>(inputs);
const int BUFF_SIZE = 64;
// Input data cache
__local__ Tdata inputs_buf[N];
// Input contiguous/broadcasted flags
__local__ bool input_contiguous[N];
__local__ bool input_broadcasted[N];
// Input shape/strides
__local__ _size_t input_shapes[N * ndim];
__local__ _ptrdiff_t input_strides[N * ndim];
// Output shape/strides
__local__ _size_t output_shape[ndim];
__local__ _ptrdiff_t output_strides[ndim];
// Inputs gm ptr buf
__local__ __global_ptr__ Tdata *typed_inputs_ptr[N];
// Load from gm
GM2LM_ASYNC(input_contiguous_gm, input_contiguous, N * sizeof(bool));
GM2LM_ASYNC(input_broadcasted_gm, input_broadcasted, N * sizeof(bool));
GM2LM_ASYNC(input_shapes_gm, input_shapes, N * ndim * sizeof(_size_t));
GM2LM_ASYNC(input_strides_gm, input_strides, N * ndim * sizeof(_ptrdiff_t));
GM2LM_ASYNC(output_shape_gm, output_shape, ndim * sizeof(_size_t));
GM2LM_ASYNC(output_strides_gm, output_strides, ndim * sizeof(_ptrdiff_t));
GM2LM_ASYNC(typed_inputs, typed_inputs_ptr, N * sizeof(__global_ptr__ Tdata *));
mfence();
int len_per_loop = min(BUFF_SIZE, roundup_div(output_size, nthreads));
for (int start = thread_id * len_per_loop; start < output_size; start += nthreads * len_per_loop) {
size_t read_len = min(len_per_loop, output_size - start);
for (int idx = start; idx < start + read_len; ++idx) {
size_t out_idx = getOutputIndex(static_cast<size_t>(idx), output_contiguous,
ndim, output_shape, output_strides);
InputIndexer indexer{static_cast<size_t>(idx), ndim, input_contiguous, input_broadcasted,
input_shapes, input_strides, output_strides};
// Get index offset for every operand
size_t indexes[N];
for (size_t i = 0; i < N; i++) {
indexes[i] = indexer(i);
}
// Launch operater
launchOp<N, Op, Tdata>(&typed_inputs_ptr[0], output, inputs_buf, indexes, out_idx, args...);
}
}
sync_cluster();
}
#define LAUNCH_ELEMENTWISE_KERNEL_IMPL(OpName, Op) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args) { \
elementwiseKernel<Op::num_inputs, Op, Tdata><<<8, 64, stream>>>( \
output_size, ndim, output_contiguous, \
reinterpret_cast<const bool *>(input_contiguous), \
reinterpret_cast<const bool *>(input_broadcasted), \
reinterpret_cast<const _size_t *>(output_shape), \
reinterpret_cast<const _size_t *>(input_shapes), \
reinterpret_cast<const _ptrdiff_t *>(output_strides), \
reinterpret_cast<const _ptrdiff_t *>(input_strides), \
reinterpret_cast<Tdata *>(output), inputs, args...); \
}
#define LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(OpName, T, ...) \
template void launch##OpName##Kernel<T, ##__VA_ARGS__>( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
##__VA_ARGS__);
#endif
#ifndef ATTENTION_H
#define ATTENTION_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::attention::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc); \
}; \
}
#endif // ATTENTION_H
#include "../../operator.h"
#include "../../../utils.h"
#include "../../../utils/check.h"
#include "../../handle.h"
#include "../../tensor.h"
#include "infiniop/ops/attention.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/rearrange.h"
#include <cmath>
#include <cstdint>
struct InfiniopAttentionDescriptor {
InfiniopDescriptor _super;
infiniopRearrangeDescriptor_t rearrange_desc_k;
infiniopRearrangeDescriptor_t rearrange_desc_v;
infiniopRearrangeDescriptor_t rearrange_desc_q;
infiniopRearrangeDescriptor_t rearrange_desc_out;
infiniopGemmDescriptor_t matmul_desc1;
infiniopGemmDescriptor_t matmul_desc2;
infiniopCausalSoftmaxDescriptor_t softmax_desc;
uint64_t workspace_size;
uint64_t rearranged_q_size;
uint64_t matmul1_workspace_size;
uint64_t matmul1_tensor_size;
uint64_t matmul2_workspace_size;
uint64_t matmul2_tensor_size;
uint64_t softmax_workspace_size;
uint64_t k_cache_offset;
uint64_t v_cache_offset;
float qk_alpha;
};
__C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t handle,
infiniopAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
uint64_t pos) {
if (out_desc->ndim() != 3 || q_desc->ndim() != 3 || k_desc->ndim() != 3 || v_desc->ndim() != 3 || k_cache_desc->ndim() != 3 || v_cache_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (!out_desc->isContiguous(0, 2)) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (q_desc->strides()[2] != 1 || k_desc->strides()[2] != 1 || v_desc->strides()[2] != 1 || k_cache_desc->strides()[2] != 1 || v_cache_desc->strides()[2] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
uint64_t n_q_head = q_desc->shape()[0];
uint64_t seq_len = q_desc->shape()[1];
uint64_t head_dim = q_desc->shape()[2];
uint64_t hidden_size = n_q_head * head_dim;
uint64_t n_kv_head = k_desc->shape()[0];
uint64_t total_seq_len = seq_len + pos;
uint64_t n_group = n_q_head / n_kv_head;
if (out_desc->shape()[0] != seq_len || out_desc->shape()[1] != n_q_head || out_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// k: [n_kv_head, seq_len, head_dim]
if (k_desc->shape()[0] != n_kv_head || k_desc->shape()[1] != seq_len || k_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// v: [n_kv_head, seq_len, head_dim]
if (v_desc->shape()[0] != n_kv_head || v_desc->shape()[1] != seq_len || v_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// k_cache: [n_kv_head, _, head_dim]
if (k_cache_desc->shape()[0] != n_kv_head || k_cache_desc->shape()[1] < total_seq_len || k_cache_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// v_cache: [n_kv_head, _, head_dim]
if (v_cache_desc->shape()[0] != n_kv_head || v_cache_desc->shape()[1] < total_seq_len || v_cache_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// Rearrange k into k_cache
infiniopTensorDescriptor_t dst_k_desc;
CHECK_STATUS(infiniopCreateTensorDescriptor(&dst_k_desc, 3, k_desc->shape().data(), k_cache_desc->strides().data(), k_cache_desc->dtype()));
infiniopRearrangeDescriptor_t rearrange_desc_k;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_k, dst_k_desc, k_desc));
// Rearrange v into v_cache
infiniopTensorDescriptor_t dst_v_desc;
CHECK_STATUS(infiniopCreateTensorDescriptor(&dst_v_desc, 3, v_desc->shape().data(), v_cache_desc->strides().data(), v_cache_desc->dtype()));
infiniopRearrangeDescriptor_t rearrange_desc_v;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_v, dst_v_desc, v_desc));
infiniopRearrangeDescriptor_t rearrange_desc_q = nullptr;
uint64_t rearranged_q_size = 0;
infiniopTensorDescriptor_t rearranged_q_desc;
// Rearrange q into contiguous
if (!q_desc->isContiguous(0, 1)) {
CHECK_STATUS(infiniopCreateTensorDescriptor(&rearranged_q_desc, 3, q_desc->shape().data(), nullptr, q_desc->dtype()));
rearranged_q_size = rearranged_q_desc->numel() * infiniSizeOf(rearranged_q_desc->dtype());
rearrange_desc_q = new InfiniopDescriptor;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_q, rearranged_q_desc, q_desc));
}
// Matmul1: q * full_k
// q: [n_q_head, seq_len, head_dim] -> [n_kv_head, n_group *seq_len, head_dim]
infiniopTensorDescriptor_t reshaped_q_desc;
CHECK_STATUS(infiniopCreateTensorDescriptor(&reshaped_q_desc, 3, q_desc->shape().data(), nullptr, q_desc->dtype()));
TRANSFORM_TENSOR_DESC(reshaped_q_desc, dimSplit(0, {n_kv_head, n_group}));
TRANSFORM_TENSOR_DESC(reshaped_q_desc, dimMerge(1, 2));
// full_k: [n_kv_head, head_dim, total_seq_len]
infiniopTensorDescriptor_t full_k_desc;
uint64_t full_k_shape[3] = {n_kv_head, total_seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&full_k_desc, 3, full_k_shape, k_cache_desc->strides().data(), k_cache_desc->dtype()));
TRANSFORM_TENSOR_DESC(full_k_desc, dimPermute({0, 2, 1}));
// qk: [n_kv_head, n_group * seq_len, total_seq_len]
infiniopTensorDescriptor_t qk_desc;
uint64_t qk_shape[3] = {n_kv_head, n_group * seq_len, total_seq_len};
CHECK_STATUS(infiniopCreateTensorDescriptor(&qk_desc, 3, qk_shape, nullptr, q_desc->dtype()));
// matmul1_desc
// qk_alpha
float qk_alpha = 1 / sqrt(head_dim);
infiniopGemmDescriptor_t matmul1_desc;
CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul1_desc, qk_desc, reshaped_q_desc, full_k_desc));
// matmul1 workspace size
uint64_t matmul1_workspace_size;
CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul1_desc, &matmul1_workspace_size));
// matmul1 tensor size
uint64_t matmul1_tensor_size = qk_desc->numel() * infiniSizeOf(qk_desc->dtype());
// CausalSoftmax: softmax(qk)
// qk: [n_kv_head, n_group * seq_len, total_seq_len] -> [n_q_head, seq_len, total_seq_len]
TRANSFORM_TENSOR_DESC(qk_desc, dimSplit(1, {n_group, seq_len}));
TRANSFORM_TENSOR_DESC(qk_desc, dimMerge(0, 1));
infiniopCausalSoftmaxDescriptor_t softmax_desc;
CHECK_STATUS(infiniopCreateCausalSoftmaxDescriptor(handle, &softmax_desc, qk_desc, qk_desc));
// softmax workspace size
uint64_t softmax_workspace_size;
CHECK_STATUS(infiniopGetCausalSoftmaxWorkspaceSize(softmax_desc, &softmax_workspace_size));
// Matmul2: softmax(qk) * full_v
// softmax(qk): [n_q_head, seq_len, total_seq_len] -> [n_kv_head, n_group * seq_len, total_seq_len]
// full_v: [n_kv_head, total_seq_len, head_dim]
TRANSFORM_TENSOR_DESC(qk_desc, dimSplit(0, {n_kv_head, n_group}));
TRANSFORM_TENSOR_DESC(qk_desc, dimMerge(1, 2));
infiniopTensorDescriptor_t full_v_desc;
uint64_t full_v_shape[3] = {n_kv_head, total_seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&full_v_desc, 3, full_v_shape, v_cache_desc->strides().data(), v_cache_desc->dtype()));
// temp_out: [n_kv_head, n_group * seq_len, head_dim]
infiniopTensorDescriptor_t temp_out_desc;
uint64_t temp_out_shape[3] = {n_kv_head, n_group * seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&temp_out_desc, 3, temp_out_shape, nullptr, q_desc->dtype()));
// matmul2_desc
infiniopGemmDescriptor_t matmul2_desc;
CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul2_desc, temp_out_desc, qk_desc, full_v_desc));
// matmul2 workspace size
uint64_t matmul2_workspace_size;
CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul2_desc, &matmul2_workspace_size));
// matmul2 tensor size
uint64_t matmul2_tensor_size = temp_out_desc->numel() * infiniSizeOf(temp_out_desc->dtype());
// Rearrange temp_out into out
// out: [seq_len, n_q_head, head_dim]
// temp_out: [n_kv_head, n_group * seq_len, head_dim] -> [n_q_head, seq_len, head_dim] -> [seq_len, n_q_head, head_dim]
TRANSFORM_TENSOR_DESC(temp_out_desc, dimSplit(1, {n_group, seq_len}));
TRANSFORM_TENSOR_DESC(temp_out_desc, dimMerge(0, 1));
TRANSFORM_TENSOR_DESC(temp_out_desc, dimPermute({1, 0, 2}));
infiniopRearrangeDescriptor_t rearrange_desc_out;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_out, out_desc, temp_out_desc));
// workspace size
uint64_t workspace_size = rearranged_q_size + std::max(std::max(matmul1_workspace_size + matmul1_tensor_size, matmul1_tensor_size + softmax_workspace_size), matmul1_tensor_size + matmul2_workspace_size + matmul2_tensor_size);
// k_cache_offset
uint64_t k_cache_offset = 0;
if (pos > 0) {
k_cache_offset = pos * k_cache_desc->getByteStrides()[1];
}
// v_cache_offset
uint64_t v_cache_offset = 0;
if (pos > 0) {
v_cache_offset = pos * v_cache_desc->getByteStrides()[1];
}
// create attention descriptor
*(InfiniopAttentionDescriptor **)desc_ptr = new InfiniopAttentionDescriptor{
{handle->device, handle->device_id},
rearrange_desc_k,
rearrange_desc_v,
rearrange_desc_q,
rearrange_desc_out,
matmul1_desc,
matmul2_desc,
softmax_desc,
workspace_size,
rearranged_q_size,
matmul1_workspace_size,
matmul1_tensor_size,
matmul2_workspace_size,
matmul2_tensor_size,
softmax_workspace_size,
k_cache_offset,
v_cache_offset,
1.f / std::sqrt(float(head_dim)),
};
return INFINI_STATUS_SUCCESS;
}
__C __export infiniStatus_t infiniopGetAttentionWorkspaceSize(infiniopAttentionDescriptor_t desc, uint64_t *size) {
*size = ((InfiniopAttentionDescriptor *)desc)->workspace_size;
return INFINI_STATUS_SUCCESS;
}
__C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc_,
void *workspace,
uint64_t workspace_size,
void *out,
void const *q,
void const *k,
void const *v,
void *k_cache,
void *v_cache,
void *stream) {
auto desc = (InfiniopAttentionDescriptor *)desc_;
void *workspace_ = workspace;
if (workspace_size < desc->workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; // STATUS_MEMORY_NOT_ALLOCATED
}
// concat k and v to k_cache and v_cache
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_k,
(char *)k_cache + desc->k_cache_offset, k, stream));
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_v,
(char *)v_cache + desc->v_cache_offset, v, stream));
// rearrange q into contiguous
void const *_q = q;
if (desc->rearrange_desc_q) {
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_q, (char *)workspace_, q, stream));
_q = workspace_;
workspace_ = (char *)workspace_ + desc->rearranged_q_size;
}
// matmul1: q * full_k
CHECK_STATUS(infiniopGemm(desc->matmul_desc1,
(char *)workspace_ + desc->matmul1_tensor_size, workspace_size - desc->matmul1_tensor_size,
workspace_, _q, k_cache, desc->qk_alpha, 0.0, stream));
// softmax(qk)
CHECK_STATUS(infiniopCausalSoftmax(desc->softmax_desc,
(char *)workspace_ + desc->matmul1_tensor_size, workspace_size - desc->matmul1_tensor_size,
workspace_, workspace_, stream));
// matmul2: softmax(qk) * full_v
CHECK_STATUS(infiniopGemm(desc->matmul_desc2,
(char *)workspace_ + desc->matmul1_tensor_size + desc->matmul2_tensor_size,
workspace_size - desc->matmul1_tensor_size - desc->matmul2_tensor_size,
(char *)workspace_ + desc->matmul1_tensor_size, workspace_, v_cache, 1.0, 0.0, stream));
// rearrange out
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_out, out, (char *)workspace_ + desc->matmul1_tensor_size, stream));
return INFINI_STATUS_SUCCESS;
}
__C __export infiniStatus_t infiniopDestroyAttentionDescriptor(infiniopAttentionDescriptor_t desc_) {
auto desc = (InfiniopAttentionDescriptor *)desc_;
if (desc->rearrange_desc_q) {
CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_q));
}
CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_k));
CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_v));
CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_out));
CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->matmul_desc1));
CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->matmul_desc2));
CHECK_STATUS(infiniopDestroyCausalSoftmaxDescriptor(desc->softmax_desc));
delete desc;
return INFINI_STATUS_SUCCESS;
}
#include "causal_softmax_aclnn.h" #include "causal_softmax_ascend.h"
#include "../../../devices/ascend/common_ascend.h" #include "../../../devices/ascend/common_ascend.h"
#include <aclnnop/aclnn_masked_fill_tensor.h> #include <aclnnop/aclnn_masked_fill_tensor.h>
#include <aclnnop/aclnn_softmax.h> #include <aclnnop/aclnn_softmax.h>
...@@ -12,6 +12,8 @@ struct Descriptor::Opaque { ...@@ -12,6 +12,8 @@ struct Descriptor::Opaque {
aclnnTensorDescriptor_t value; aclnnTensorDescriptor_t value;
void *mask_addr; void *mask_addr;
void *value_addr; void *value_addr;
uint64_t workspacesize;
aclOpExecutor *executor;
~Opaque() { ~Opaque() {
delete x; delete x;
...@@ -21,6 +23,9 @@ struct Descriptor::Opaque { ...@@ -21,6 +23,9 @@ struct Descriptor::Opaque {
aclrtFree(mask_addr); aclrtFree(mask_addr);
aclrtFree(value_addr); aclrtFree(value_addr);
// Delete useless executor
aclDestroyAclOpExecutor(executor);
} }
}; };
...@@ -92,18 +97,18 @@ infiniStatus_t Descriptor::create( ...@@ -92,18 +97,18 @@ infiniStatus_t Descriptor::create(
aclTensor *tvalue = value->tensor; aclTensor *tvalue = value->tensor;
CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor)); CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor));
int64_t dim = 2;
int64_t dim = 2;
CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(tx, dim, ty, &workspacesize_softmax, &executor)); CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(tx, dim, ty, &workspacesize_softmax, &executor));
// set executor reusable
aclSetAclOpExecutorRepeatable(executor);
// Create the descriptor // Create the descripto
size_t all_workspacesize = workspacesize_softmax + workspacesize_mask; size_t all_workspacesize = std::max(workspacesize_softmax, workspacesize_mask);
*desc_ptr = new Descriptor(new Opaque{x, mask, y, value, mask_addr, value_addr},
std::move(info), all_workspacesize, handle_ascend->device, handle_ascend->device_id);
// Delete useless executor *desc_ptr = new Descriptor(new Opaque{x, mask, y, value, mask_addr, value_addr,
aclDestroyAclOpExecutor(executor); workspacesize_softmax, executor},
aclDestroyAclOpExecutor(mask_executor); std::move(info), all_workspacesize, handle_ascend->device, handle_ascend->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -116,23 +121,18 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, voi ...@@ -116,23 +121,18 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, voi
auto ty = _opaque->y->tensor; auto ty = _opaque->y->tensor;
auto tmask = _opaque->mask->tensor; auto tmask = _opaque->mask->tensor;
auto tvalue = _opaque->value->tensor; auto tvalue = _opaque->value->tensor;
aclOpExecutor *executor = nullptr;
aclOpExecutor *mask_executor = nullptr; aclOpExecutor *mask_executor = nullptr;
size_t workspacesize_softmax = 0;
size_t workspacesize_mask = 0; size_t workspacesize_mask = 0;
int64_t dim = 2;
AclSetTensorAddr(mask_executor, 0, tx, (void *)x); AclSetTensorAddr(mask_executor, 0, tx, (void *)x);
AclSetTensorAddr(mask_executor, 1, tmask, _opaque->mask_addr); AclSetTensorAddr(mask_executor, 1, tmask, _opaque->mask_addr);
AclSetTensorAddr(mask_executor, 2, tvalue, _opaque->value_addr); AclSetTensorAddr(mask_executor, 2, tvalue, _opaque->value_addr);
CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor)); CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor));
CHECK_ACL(aclnnInplaceMaskedFillTensor(workspace, workspacesize_mask, mask_executor, stream)); CHECK_ACL(aclnnInplaceMaskedFillTensor(workspace, workspacesize_mask, mask_executor, stream));
CHECK_ACL(aclrtSynchronizeStream(stream));
AclSetTensorAddr(executor, 0, tx, (void *)x); AclSetTensorAddr(_opaque->executor, 0, tx, (void *)x);
AclSetTensorAddr(executor, 1, ty, y); AclSetTensorAddr(_opaque->executor, 1, ty, y);
CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(tx, dim, ty, &workspacesize_softmax, &executor)); CHECK_ACL(aclnnSoftmax(workspace, _opaque->workspacesize, _opaque->executor, stream));
CHECK_ACL(aclnnSoftmax(workspace, workspacesize_softmax, executor, stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
...@@ -48,7 +48,7 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *y, const T *x) { ...@@ -48,7 +48,7 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *y, const T *x) {
if constexpr (std::is_same<T, fp16_t>::value) { if constexpr (std::is_same<T, fp16_t>::value) {
y_[j * info->y_stride_j] = utils::cast<fp16_t>(utils::cast<float>(y_[j * info->y_stride_j]) / sum); y_[j * info->y_stride_j] = utils::cast<fp16_t>(utils::cast<float>(y_[j * info->y_stride_j]) / sum);
} else { } else {
y_[j * info->y_stride_j] = y_[y_offset + j * info->y_stride_j] / sum; y_[j * info->y_stride_j] = y_[j * info->y_stride_j] / sum;
} }
} }
} }
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "cuda/causal_softmax_cuda.cuh" #include "cuda/causal_softmax_cuda.cuh"
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
#include "ascend/causal_softmax_aclnn.h" #include "ascend/causal_softmax_ascend.h"
#endif #endif
__C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor( __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
......
...@@ -62,7 +62,7 @@ infiniStatus_t calculate( ...@@ -62,7 +62,7 @@ infiniStatus_t calculate(
(kunlunStream_t)stream, (kunlunStream_t)stream,
[&](xdnnHandle_t handle) { [&](xdnnHandle_t handle) {
for (size_t i = 0; i < info.batch; i++) { for (size_t i = 0; i < info.batch; i++) {
CHECK_XDNN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>( CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
handle, handle,
(Tdata *)((char *)a + i * info.a_matrix.stride * unit), (Tdata *)((char *)a + i * info.a_matrix.stride * unit),
(Tdata *)((char *)b + i * info.b_matrix.stride * unit), (Tdata *)((char *)b + i * info.b_matrix.stride * unit),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment