".github/vscode:/vscode.git/clone" did not exist on "8b805d90fe49a8452b330eabb72f6fb9b1fa81cd"
Unverified Commit c98e68be authored by goldenfox2025's avatar goldenfox2025 Committed by GitHub
Browse files

Merge branch 'main' into issue180

parents d7c12d52 125afeb5
...@@ -175,6 +175,10 @@ options: ...@@ -175,6 +175,10 @@ options:
{ {
"clangd.arguments": [ "clangd.arguments": [
"--compile-commands-dir=.vscode" "--compile-commands-dir=.vscode"
] ],
"xmake.additionalConfigArguments": [
// 在这里配置 XMAKE_CONFIG_FLAGS
"--nv-gpu=y"
],
} }
``` ```
...@@ -9,6 +9,7 @@ DECLARE_INFINIOP_TEST(gemm) ...@@ -9,6 +9,7 @@ DECLARE_INFINIOP_TEST(gemm)
DECLARE_INFINIOP_TEST(random_sample) DECLARE_INFINIOP_TEST(random_sample)
DECLARE_INFINIOP_TEST(mul) DECLARE_INFINIOP_TEST(mul)
DECLARE_INFINIOP_TEST(clip) DECLARE_INFINIOP_TEST(clip)
DECLARE_INFINIOP_TEST(swiglu)
#define REGISTER_INFINIOP_TEST(name) \ #define REGISTER_INFINIOP_TEST(name) \
{ \ { \
...@@ -28,6 +29,7 @@ DECLARE_INFINIOP_TEST(clip) ...@@ -28,6 +29,7 @@ DECLARE_INFINIOP_TEST(clip)
REGISTER_INFINIOP_TEST(random_sample) \ REGISTER_INFINIOP_TEST(random_sample) \
REGISTER_INFINIOP_TEST(mul) \ REGISTER_INFINIOP_TEST(mul) \
REGISTER_INFINIOP_TEST(clip) \ REGISTER_INFINIOP_TEST(clip) \
REGISTER_INFINIOP_TEST(swiglu) \
} }
namespace infiniop_test { namespace infiniop_test {
......
#include "ops.hpp"
#include "utils.hpp"
#include <infinirt.h>
#include <iomanip>
#include <iostream>
namespace infiniop_test::swiglu {
struct Test::Attributes {
std::shared_ptr<Tensor> a;
std::shared_ptr<Tensor> b;
std::shared_ptr<Tensor> ans;
std::shared_ptr<Tensor> c;
};
std::shared_ptr<Test> Test::build(
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
double rtol, double atol) {
auto test = std::shared_ptr<Test>(new Test(rtol, atol));
test->_attributes = new Attributes();
if (tensors.find("a") == tensors.end()
|| tensors.find("b") == tensors.end()
|| tensors.find("c") == tensors.end()
|| tensors.find("ans") == tensors.end()) {
throw std::runtime_error("Invalid Test");
}
test->_attributes->a = tensors["a"];
test->_attributes->b = tensors["b"];
test->_attributes->c = tensors["c"];
test->_attributes->ans = tensors["ans"];
return test;
}
std::shared_ptr<infiniop_test::Result> Test::run(
infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) {
infiniopSwiGLUDescriptor_t op_desc;
auto a = _attributes->a->to(device, device_id);
auto b = _attributes->b->to(device, device_id);
auto c = _attributes->c->to(device, device_id);
CHECK_OR(infiniopCreateSwiGLUDescriptor(handle, &op_desc,
c->desc(),
a->desc(),
b->desc()),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor."));
size_t workspace_size;
CHECK_OR(infiniopGetSwiGLUWorkspaceSize(op_desc, &workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size."));
void *workspace;
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace."));
CHECK_OR(infiniopSwiGLU(op_desc, workspace, workspace_size, c->data(), a->data(), b->data(), nullptr),
return TEST_FAILED(OP_CREATION_FAILED, "Failed during execution."));
try {
allClose(c, _attributes->ans, _rtol, _atol);
} catch (const std::exception &e) {
return TEST_FAILED(RESULT_INCORRECT, e.what());
}
double elapsed_time = 0.;
elapsed_time = benchmark(
[=]() {
infiniopSwiGLU(
op_desc,
workspace,
workspace_size,
c->data(),
a->data(),
b->data(),
nullptr);
},
warm_ups, iterations);
return TEST_PASSED(elapsed_time);
}
std::vector<std::string> Test::attribute_names() {
return {};
}
std::vector<std::string> Test::tensor_names() {
return {"a", "b", "c", "ans"};
}
std::string Test::toString() const {
std::ostringstream oss;
oss << op_name() << std::endl;
oss << "- a: " << _attributes->a->info() << std::endl;
oss << "- b: " << _attributes->b->info() << std::endl;
oss << "- c: " << _attributes->c->info() << std::endl;
oss << std::scientific << std::setprecision(2);
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
return oss.str();
}
Test::~Test() {
delete _attributes;
}
} // namespace infiniop_test::swiglu
...@@ -16,7 +16,7 @@ typedef XPUStream kunlunStream_t; ...@@ -16,7 +16,7 @@ typedef XPUStream kunlunStream_t;
typedef XPUEvent kunlunEvent_t; typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t; typedef xdnn::Context *xdnnHandle_t;
#define CHECK_XDNN(API) CHECK_INTERNAL(API, XPU_SUCCESS) #define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
namespace device::kunlun { namespace device::kunlun {
......
#ifndef __INFINIOP_KUNLUN_COMMON_H__ #ifndef __INFINIOP_KUNLUN_KERNEL_COMMON_H__
#define __INFINIOP_KUNLUN_COMMON_H__ #define __INFINIOP_KUNLUN_KERNEL_COMMON_H__
// This header file will only be include by .xpu file // This header file will only be include by .xpu file
#include "kunlun_kernel_dtype.h"
#include "xpu/kernel/xtdk.h" #include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h" #include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h" #include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h" #include "xpu/runtime.h"
namespace device::kunlun::kernel {
// Get mask for kunlun xpu 512bit register calculation // Get mask for kunlun xpu 512bit register calculation
// if data is not enough to 512bit, padding zero and use // if data is not enough to 512bit, padding zero and use
// mask to identify real data // mask to identify real data
...@@ -26,6 +28,37 @@ inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) { ...@@ -26,6 +28,37 @@ inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) {
} }
} }
inline __device__ size_t indexToReducedOffset(
size_t flat_index,
size_t ndim,
const _ptrdiff_t *broadcasted_strides,
const _ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i].value * target_strides[i].value;
flat_index %= broadcasted_strides[i].value;
mfence();
}
return res;
}
inline __device__ size_t indexToOffset(
size_t flat_index,
size_t ndim,
const _size_t *shape,
const _ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i].value) * strides[i].value;
flat_index /= shape[i].value;
mfence();
}
return res;
}
} // namespace device::kunlun::kernel
// TODO: atomicAddF16 // TODO: atomicAddF16
// TODO: atomicAddI8 // TODO: atomicAddI8
#endif #endif
#ifndef __INFINIOP_KUNLUN_DTYPE_H__
#define __INFINIOP_KUNLUN_DTYPE_H__
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
// kunlun ptrdiff_t* is used to save ptrdiff_t array
// copied from host
typedef struct _ptrdiff_t {
long value; // 32 bit
long padding; // 32 bit
} _ptrdiff_t;
// same as ptrdiff
typedef struct _size_t {
size_t value;
size_t padding;
} _size_t;
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_H__
#include "../../../utils.h"
#include "../../devices/kunlun/kunlun_handle.h"
#include "elementwise_kunlun_api.h"
namespace op::elementwise::kunlun {
struct DeviceImpl::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::kunlun::Handle::Internal> &internal_)
: internal(internal_) {}
template <size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
kunlunStream_t stream,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides));
Op::template launch<Tdata>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
reinterpret_cast<const void *>(d_input_contiguous),
reinterpret_cast<const void *>(d_input_broadcasted),
reinterpret_cast<const void *>(d_output_shape),
reinterpret_cast<const void *>(d_input_shapes),
reinterpret_cast<const void *>(d_output_strides),
reinterpret_cast<const void *>(d_input_strides),
output,
reinterpret_cast<const void *const *>(d_inputs_arr),
stream,
args...);
return INFINI_STATUS_SUCCESS;
}
private:
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides) const {
constexpr auto input_size = N;
const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device
CHECK_KUNLUN(xpu_memcpy(workspace, h_inputs_arr, input_arr_size, XPU_HOST_TO_DEVICE));
CHECK_KUNLUN(xpu_memcpy((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), XPU_HOST_TO_DEVICE));
// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
return INFINI_STATUS_SUCCESS;
}
};
template <typename... Args>
utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) {
auto opaque = std::make_shared<Opaque>(std::forward<Args>(args)...);
return utils::Result<DeviceImpl *>(new DeviceImpl(opaque));
}
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<N, Op, Tdata>(
info, workspace, output, inputs,
reinterpret_cast<kunlunStream_t>(stream),
std::forward<Args>(args)...);
}
} // namespace op::elementwise::kunlun
// Template for kunlun kernel interface declaration
#define LAUNCH_ELEMENTWISE_KERNEL(OpName) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args);
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#include "../elementwise.h"
namespace op::elementwise::kunlun {
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static utils::Result<DeviceImpl *> create(Args &&...args);
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
} // namespace op::elementwise::kunlun
#define CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::kunlun::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#define __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#include "../../devices/kunlun/kunlun_kernel_common.h"
using namespace device::kunlun::kernel;
/**
* @brief Computes input tile offset
*/
struct InputIndexer {
size_t idx;
size_t ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const _size_t *input_shapes;
const _ptrdiff_t *input_strides;
const _ptrdiff_t *output_strides;
__device__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
inline __device__ size_t
getOutputIndex(size_t idx,
bool is_contiguous,
size_t ndim,
const _size_t *shape,
const _ptrdiff_t *strides) {
return is_contiguous ? idx : indexToOffset(idx, ndim, shape, strides);
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__device__ void launchOp(
__global_ptr__ Tdata **typed_inputs, // gm pointer
__global_ptr__ Tdata *output, // gm pointer output
Tdata *inputs_buf, // local mem buffer
size_t *input_indexes,
size_t output_index,
Args... args) {
static_assert(N == Op::num_inputs, "template N is not equal to Op::num_inputs!\n");
#pragma unroll
// Copy inputs to buf
for (size_t i = 0; i < N; i++) {
auto gm = typed_inputs[i] + input_indexes[i];
auto lm = inputs_buf + i;
GM2LM_ASYNC(gm, lm, 1 * sizeof(Tdata));
}
mfence();
// Calculate elementwise
// Inputs save all operands
Tdata out = Op{}(inputs_buf, args...);
// Copy out to gm
LM2GM_ASYNC(&out, output + output_index, 1 * sizeof(Tdata));
mfence();
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__global__ void elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *input_contiguous_gm,
const bool *input_broadcasted_gm,
const _size_t *output_shape_gm,
const _size_t *input_shapes_gm,
const _ptrdiff_t *output_strides_gm,
const _ptrdiff_t *input_strides_gm,
Tdata *output,
const void *const *inputs,
Args... args) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
// Cast input gm pointer type
auto typed_inputs = reinterpret_cast<const __global_ptr__ Tdata *const __global_ptr__ *>(inputs);
const int BUFF_SIZE = 64;
// Input data cache
__local__ Tdata inputs_buf[N];
// Input contiguous/broadcasted flags
__local__ bool input_contiguous[N];
__local__ bool input_broadcasted[N];
// Input shape/strides
__local__ _size_t input_shapes[N * ndim];
__local__ _ptrdiff_t input_strides[N * ndim];
// Output shape/strides
__local__ _size_t output_shape[ndim];
__local__ _ptrdiff_t output_strides[ndim];
// Inputs gm ptr buf
__local__ __global_ptr__ Tdata *typed_inputs_ptr[N];
// Load from gm
GM2LM_ASYNC(input_contiguous_gm, input_contiguous, N * sizeof(bool));
GM2LM_ASYNC(input_broadcasted_gm, input_broadcasted, N * sizeof(bool));
GM2LM_ASYNC(input_shapes_gm, input_shapes, N * ndim * sizeof(_size_t));
GM2LM_ASYNC(input_strides_gm, input_strides, N * ndim * sizeof(_ptrdiff_t));
GM2LM_ASYNC(output_shape_gm, output_shape, ndim * sizeof(_size_t));
GM2LM_ASYNC(output_strides_gm, output_strides, ndim * sizeof(_ptrdiff_t));
GM2LM_ASYNC(typed_inputs, typed_inputs_ptr, N * sizeof(__global_ptr__ Tdata *));
mfence();
int len_per_loop = min(BUFF_SIZE, roundup_div(output_size, nthreads));
for (int start = thread_id * len_per_loop; start < output_size; start += nthreads * len_per_loop) {
size_t read_len = min(len_per_loop, output_size - start);
for (int idx = start; idx < start + read_len; ++idx) {
size_t out_idx = getOutputIndex(static_cast<size_t>(idx), output_contiguous,
ndim, output_shape, output_strides);
InputIndexer indexer{static_cast<size_t>(idx), ndim, input_contiguous, input_broadcasted,
input_shapes, input_strides, output_strides};
// Get index offset for every operand
size_t indexes[N];
for (size_t i = 0; i < N; i++) {
indexes[i] = indexer(i);
}
// Launch operater
launchOp<N, Op, Tdata>(&typed_inputs_ptr[0], output, inputs_buf, indexes, out_idx, args...);
}
}
sync_cluster();
}
#define LAUNCH_ELEMENTWISE_KERNEL_IMPL(OpName, Op) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args) { \
elementwiseKernel<Op::num_inputs, Op, Tdata><<<8, 64, stream>>>( \
output_size, ndim, output_contiguous, \
reinterpret_cast<const bool *>(input_contiguous), \
reinterpret_cast<const bool *>(input_broadcasted), \
reinterpret_cast<const _size_t *>(output_shape), \
reinterpret_cast<const _size_t *>(input_shapes), \
reinterpret_cast<const _ptrdiff_t *>(output_strides), \
reinterpret_cast<const _ptrdiff_t *>(input_strides), \
reinterpret_cast<Tdata *>(output), inputs, args...); \
}
#define LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(OpName, T, ...) \
template void launch##OpName##Kernel<T, ##__VA_ARGS__>( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
##__VA_ARGS__);
#endif
...@@ -62,7 +62,7 @@ infiniStatus_t calculate( ...@@ -62,7 +62,7 @@ infiniStatus_t calculate(
(kunlunStream_t)stream, (kunlunStream_t)stream,
[&](xdnnHandle_t handle) { [&](xdnnHandle_t handle) {
for (size_t i = 0; i < info.batch; i++) { for (size_t i = 0; i < info.batch; i++) {
CHECK_XDNN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>( CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
handle, handle,
(Tdata *)((char *)a + i * info.a_matrix.stride * unit), (Tdata *)((char *)a + i * info.a_matrix.stride * unit),
(Tdata *)((char *)b + i * info.b_matrix.stride * unit), (Tdata *)((char *)b + i * info.b_matrix.stride * unit),
......
#include "random_sample_cpu.h" #include "random_sample_cpu.h"
#include "../../../devices/cpu/common_cpu.h" #include "../../../devices/cpu/common_cpu.h"
#include "../../../devices/cpu/cpu_handle.h" #include "../info.h"
#include "../../../tensor.h" #include "infinicore.h"
#include <algorithm> #include <algorithm>
namespace op::random_sample::cpu { namespace op::random_sample::cpu {
...@@ -15,29 +15,14 @@ infiniStatus_t Descriptor::create( ...@@ -15,29 +15,14 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t probs_desc) { infiniopTensorDescriptor_t probs_desc) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dt_i = result_desc->dtype(); auto result = RandomSampleInfo::create(result_desc, probs_desc);
auto dt_p = probs_desc->dtype(); CHECK_RESULT(result);
CHECK_DTYPE(dt_i,
INFINI_DTYPE_U8, INFINI_DTYPE_U16, INFINI_DTYPE_U32, INFINI_DTYPE_U64,
INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64);
CHECK_DTYPE(dt_p, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_API_OR(result_desc->ndim(), 0,
return INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_API_OR(probs_desc->ndim(), 1,
return INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_API_OR(probs_desc->stride(0), 1,
return INFINI_STATUS_BAD_TENSOR_STRIDES);
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
dt_i, result.take(),
dt_p,
probs_desc->dim(0),
0, 0,
nullptr, nullptr,
handle->device, handle->device, handle->device_id);
handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -55,36 +40,42 @@ struct ComputeType<fp16_t> { ...@@ -55,36 +40,42 @@ struct ComputeType<fp16_t> {
using type = float; using type = float;
}; };
template <class Tidx, class Tval> struct Algo {
struct Scheme {
using Tcompute = typename ComputeType<Tval>::type;
static Tcompute get(void const *ptr, size_t i) { template <class Tidx, class Tval>
return utils::cast<Tcompute, Tval>(reinterpret_cast<Tval const *>(ptr)[i]); static auto get(void const *ptr, size_t i) {
return utils::cast<typename ComputeType<Tval>::type, Tval>(reinterpret_cast<Tval const *>(ptr)[i]);
} }
static void argmax( template <class Tidx, class Tval>
void *result, void const *probs, size_t n) { infiniStatus_t argmax(
void *workspace, size_t workspace_size,
void *result, void const *probs, size_t n,
void *stream) {
auto idx = reinterpret_cast<Tidx *>(result); auto idx = reinterpret_cast<Tidx *>(result);
*idx = 0; *idx = 0;
auto max_val = get(probs, 0); auto max_val = get<Tidx, Tval>(probs, 0);
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
if (auto val = get(probs, i); val > max_val) { if (auto val = get<Tidx, Tval>(probs, i); val > max_val) {
max_val = val; max_val = val;
*idx = static_cast<Tidx>(i); *idx = static_cast<Tidx>(i);
} }
} }
return INFINI_STATUS_SUCCESS;
} }
static void random( template <class Tidx, class Tval>
infiniStatus_t random(
void *workspace, size_t workspace_size,
void *result, void const *probs, size_t n, void *result, void const *probs, size_t n,
float random_val, float topp, int topk, float temperature) { float random_val, float topp, int topk, float temperature,
void *stream) {
struct KVPair { struct KVPair {
Tidx idx; Tidx idx;
Tcompute val; typename ComputeType<Tval>::type val;
bool operator<(const KVPair &other) const { bool operator<(const KVPair &other) const {
return val > other.val; return val > other.val;
...@@ -95,7 +86,7 @@ struct Scheme { ...@@ -95,7 +86,7 @@ struct Scheme {
// build & sort // build & sort
std::vector<KVPair> pairs(n); std::vector<KVPair> pairs(n);
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
pairs[i] = {static_cast<Tidx>(i), get(probs, i)}; pairs[i] = {static_cast<Tidx>(i), get<Tidx, Tval>(probs, i)};
} }
std::sort(pairs.begin(), pairs.end()); std::sort(pairs.begin(), pairs.end());
// softmax & sum // softmax & sum
...@@ -115,68 +106,10 @@ struct Scheme { ...@@ -115,68 +106,10 @@ struct Scheme {
break; break;
} }
} }
}
};
template <class Tidx, class Tval>
void switch_f(
size_t n,
void *result, const void *probs,
float random_val, float topp, int topk, float temperature) {
if (random_val == 0 || topp == 0 || topk == 1 || temperature == 0) {
Scheme<Tidx, Tval>::argmax(result, probs, n);
} else {
Scheme<Tidx, Tval>::random(result, probs, n, random_val, topp, topk, temperature);
}
}
template <class Tidx> return INFINI_STATUS_SUCCESS;
void switch_val(
infiniDtype_t dt_p, size_t n,
void *result, void const *probs,
float random_val, float topp, int topk, float temperature) {
switch (dt_p) {
case INFINI_DTYPE_F16:
switch_f<Tidx, fp16_t>(n, result, probs, random_val, topp, topk, temperature);
break;
case INFINI_DTYPE_F32:
switch_f<Tidx, float>(n, result, probs, random_val, topp, topk, temperature);
break;
case INFINI_DTYPE_F64:
switch_f<Tidx, double>(n, result, probs, random_val, topp, topk, temperature);
break;
default:
// unreachable
std::abort();
}
}
void switch_idx(
infiniDtype_t dt_i, infiniDtype_t dt_p, size_t n,
void *result, void const *probs,
float random_val, float topp, int topk, float temperature) {
#define CASE(DT_VAL, DT_TYP) \
case DT_VAL: \
switch_val<DT_TYP>(dt_p, n, result, probs, random_val, topp, topk, temperature); \
break
switch (dt_i) {
CASE(INFINI_DTYPE_I8, int8_t);
CASE(INFINI_DTYPE_I16, int16_t);
CASE(INFINI_DTYPE_I32, int32_t);
CASE(INFINI_DTYPE_I64, int64_t);
CASE(INFINI_DTYPE_U8, uint8_t);
CASE(INFINI_DTYPE_U16, uint16_t);
CASE(INFINI_DTYPE_U32, uint32_t);
CASE(INFINI_DTYPE_U64, uint64_t);
default:
// unreachable
std::abort();
} }
};
#undef CASE
}
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace, void *workspace,
...@@ -189,7 +122,11 @@ infiniStatus_t Descriptor::calculate( ...@@ -189,7 +122,11 @@ infiniStatus_t Descriptor::calculate(
float temperature, float temperature,
void *stream) const { void *stream) const {
switch_idx(_dt_i, _dt_p, _n, result, probs, random_val, topp, topk, temperature); Calculate::calculate<Algo>(
Algo{}, _info, workspace, workspace_size,
result, probs,
random_val, topp, topk, temperature,
stream);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
#include "../../../devices/cuda/cuda_handle.cuh"
#include "../info.h"
#include "random_sample_cuda.cuh"
#include "random_sample_kernel.cuh"
namespace op::random_sample::cuda {
struct Descriptor::Opaque {
std::shared_ptr<device::cuda::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t result_desc,
infiniopTensorDescriptor_t probs_desc) {
auto handle = reinterpret_cast<device::cuda::Handle *>(handle_);
auto result = RandomSampleInfo::create(result_desc, probs_desc);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size;
#define CASE_P(CASE, Tidx, Tval) \
case CASE: \
workspace_size = calculateWorkspace<Tidx, Tval>(info.n); \
break
#define CASE_I(CASE, Tidx) \
case CASE: \
switch (info.dt_p) { \
CASE_P(INFINI_DTYPE_F16, Tidx, half); \
CASE_P(INFINI_DTYPE_F32, Tidx, float); \
CASE_P(INFINI_DTYPE_F64, Tidx, double); \
default: \
abort(); \
} \
break
switch (info.dt_i) {
CASE_I(INFINI_DTYPE_I8, int8_t);
CASE_I(INFINI_DTYPE_I16, int16_t);
CASE_I(INFINI_DTYPE_I32, int32_t);
CASE_I(INFINI_DTYPE_I64, int64_t);
CASE_I(INFINI_DTYPE_U8, uint8_t);
CASE_I(INFINI_DTYPE_U16, uint16_t);
CASE_I(INFINI_DTYPE_U32, uint32_t);
CASE_I(INFINI_DTYPE_U64, uint64_t);
default:
abort();
}
#undef CASE_I
#undef CASE_P
*desc_ptr = new Descriptor(
info,
workspace_size,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
size_t Descriptor::minWorkspaceSize() const {
return _min_workspace_size;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *result,
const void *probs,
float random_val,
float topp,
int topk,
float temperature,
void *stream) const {
if (workspace_size < _min_workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto block_size = _opaque->internal->blockSizeX();
Calculate::calculate<Algo>(
Algo{block_size}, _info, workspace, workspace_size,
result, probs,
random_val, topp, topk, temperature,
stream);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::random_sample::cuda
#ifndef __RANDOM_SAMPLE_CUDA_CUH__
#define __RANDOM_SAMPLE_CUDA_CUH__
#include "../random_sample.h"
DESCRIPTOR(cuda)
#endif // __RANDOM_SAMPLE_CUDA_CUH__
#include "../../../devices/cuda/cuda_kernel_common.cuh"
#include "infinicore.h"
#include <cub/device/device_radix_sort.cuh>
#include <cub/device/device_reduce.cuh>
#include <cub/device/device_scan.cuh>
namespace op::random_sample::cuda {
// ↓↓↓ 重新封装 cub api,减少模板参数,方便调用
template <class T>
static cudaError argMax_(
cub::KeyValuePair<int, T> *kv_pair,
const T *logits,
int n,
void *workspace_ptr,
size_t &workspace_len,
cudaStream_t stream) {
return cub::DeviceReduce::ArgMax(
workspace_ptr, workspace_len,
logits, kv_pair, n,
stream);
}
template <class Tval, class Tidx>
static cudaError radixSort(
void *workspace_ptr, size_t &workspace_len,
const Tval *key_in, Tval *key_out,
const Tidx *val_in, Tidx *val_out,
int n,
cudaStream_t stream) {
return cub::DeviceRadixSort::SortPairsDescending(
workspace_ptr, workspace_len,
key_in, key_out,
val_in, val_out,
n,
0, sizeof(Tval) * 8,
stream);
}
template <class T>
static cudaError inclusiveSum(
void *workspace_ptr, size_t &workspace_len,
T *data, int n,
cudaStream_t stream) {
return cub::DeviceScan::InclusiveSum(
workspace_ptr, workspace_len,
data, data, n,
stream);
}
// ↑↑↑ 重新封装 cub api,减少模板参数,方便调用
// ↓↓↓ 计算 workspace
// 地址对齐到 256
static constexpr size_t align256(size_t size) {
return (size + 255) & (~255);
}
template <class Tidx, class Tval>
utils::Result<size_t> calculateWorkspace(size_t n_) {
const auto n = static_cast<int>(n_);
size_t argmax;
CHECK_CUDA(argMax_<Tval>(
nullptr, nullptr, n,
nullptr, argmax,
nullptr));
// 前 256 字节用于 kv pair
argmax += 256;
// indices
size_t size_random = align256(sizeof(Tidx) * n);
// sorted
size_random += align256(sizeof(Tval) * n);
// indices_out
size_random += align256(sizeof(Tidx) * n);
// cub device api
size_t size_radix_sort;
CHECK_CUDA((radixSort<Tval, Tidx>(
nullptr, size_radix_sort,
nullptr, nullptr,
nullptr, nullptr,
n,
nullptr)));
size_t size_inclusive_sum;
CHECK_CUDA(inclusiveSum<Tval>(
nullptr, size_inclusive_sum,
nullptr, n,
nullptr));
size_random += cub::Max()(size_radix_sort, size_inclusive_sum);
return utils::Result<size_t>(cub::Max()(argmax, size_random));
}
// ↑↑↑ 计算 workspace
// ↓↓↓ 通过特化将 fp16_t 转换为 half
template <class Tval>
struct CudaTval {
using Type = Tval;
};
template <>
struct CudaTval<fp16_t> {
using Type = half;
};
// ↑↑↑ 通过特化将 fp16_t 转换为 half
// ↓↓↓ 用于采样过程的小型 kernel
// cuda toolkit 11.x 带的 cub::DeviceReduce::ArgMax 只接受 cub::KeyValuePair<int, Tval> 输出。
// 这个 kernel 用于取出序号
template <class Tidx, class Tval>
static __global__ void castIdx(Tidx *result, const cub::KeyValuePair<int, Tval> *kv_pair) {
*result = kv_pair->key;
}
// 填充排序要求的序号数组
template <class Tidx>
static __global__ void fillIndices(Tidx *indices, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
indices[i] = i;
}
}
// random sample 使用的 softmax 可以简化为一个基本的线性映射
// 由于已经排序,最大值就是第一个数字
// 第一个数字需要被多个 block 读取,不能写
template <class T>
static __global__ void partialSoftmaxKernel(
T *__restrict__ data, int n,
float temperature) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (0 < i && i < n) {
float max = __ldg(data);
data[i] = (T)expf(((float)data[i] - max) / temperature);
}
}
// 将第一个数字写成 1,即 exp(0)
template <class T>
static __global__ void setSoftmaxMaxKernel(
T *__restrict__ data) {
*data = 1;
}
// 直接 for 循环遍历采样
// 这个 kernel 仅用于避免将数据拷贝到 cpu
template <class Tval, class Tidx>
static __global__ void randomSampleKernel(
Tidx *__restrict__ result,
const Tval *__restrict__ sorted,
const Tidx *__restrict__ indices_out,
size_t n,
float random, float topp, size_t topk) {
topk = cub::Min()(topk, n);
auto p = (Tval)(random * cub::Min()(topp * (float)sorted[n - 1], (float)sorted[topk - 1]));
for (size_t i = 0;; ++i) {
if ((sorted[i]) >= p) {
*result = indices_out[i];
return;
}
}
}
// ↑↑↑ 用于采样过程的小型 kernel
struct Algo {
int block_size;
template <class Tidx, class Tval_>
infiniStatus_t argmax(
void *workspace, size_t workspace_size,
void *result, const void *probs, size_t n,
void *stream_) const {
using Tval = typename CudaTval<Tval_>::Type;
auto stream = (cudaStream_t)stream_;
auto logits = (Tval *)probs;
auto kv_pair = (cub::KeyValuePair<int, Tval> *)workspace;
workspace = (void *)((char *)workspace + 256);
workspace_size -= 256;
argMax_(
kv_pair,
logits,
n,
workspace,
workspace_size, stream);
castIdx<<<1, 1, 0, stream>>>((Tidx *)result, kv_pair);
return INFINI_STATUS_SUCCESS;
}
template <class Tidx, class Tval_>
infiniStatus_t random(
void *workspace_, size_t workspace_size,
void *result_, const void *probs, size_t n,
float random_val, float topp, int topk, float temperature,
void *stream_) const {
using Tval = typename CudaTval<Tval_>::Type;
auto stream = (cudaStream_t)stream_;
auto logits = (Tval *)probs;
auto result = (Tidx *)result_;
auto workspace = reinterpret_cast<size_t>(workspace_);
auto workspace_end = workspace + workspace_size;
auto indices = reinterpret_cast<Tidx *>(workspace);
workspace += align256(sizeof(Tidx) * n);
auto sorted = reinterpret_cast<Tval *>(workspace);
workspace += align256(sizeof(Tval) * n);
auto indices_out = reinterpret_cast<Tidx *>(workspace);
workspace += align256(sizeof(Tidx) * n);
workspace_ = reinterpret_cast<void *>(workspace);
workspace_size = workspace_end - workspace;
auto block = cub::Min()((size_t)block_size, n);
auto grid = (n + block - 1) / block;
// sort
fillIndices<<<grid, block, 0, stream>>>(indices, n);
CHECK_CUDA(radixSort(
workspace_, workspace_size,
logits, sorted,
indices, indices_out,
n,
stream));
// softmax
partialSoftmaxKernel<<<grid, block, 0, stream>>>(sorted, n, temperature);
setSoftmaxMaxKernel<<<1, 1, 0, stream>>>(sorted);
// sum
CHECK_CUDA(inclusiveSum(
workspace_, workspace,
sorted, n,
stream));
// sample
randomSampleKernel<<<1, 1, 0, stream>>>(
result,
sorted, indices_out, n,
random_val, topp, topk);
return INFINI_STATUS_SUCCESS;
}
};
} // namespace op::random_sample::cuda
#ifndef __RANDOM_SAMPLE_INFO_H__
#define __RANDOM_SAMPLE_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
namespace op::random_sample {
struct RandomSampleInfo {
infiniDtype_t dt_i, dt_p;
size_t n;
static utils::Result<RandomSampleInfo> create(
infiniopTensorDescriptor_t result_desc,
infiniopTensorDescriptor_t probs_desc) {
auto dt_i = result_desc->dtype();
auto dt_p = probs_desc->dtype();
CHECK_DTYPE_ANY_INT(dt_i);
CHECK_DTYPE(dt_p, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_OR_RETURN(result_desc->ndim() == 0, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(probs_desc->ndim() == 1, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(probs_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
return utils::Result<RandomSampleInfo>({dt_i, dt_p, probs_desc->dim(0)});
}
};
} // namespace op::random_sample
#endif // __RANDOM_SAMPLE_INFO_H__
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
#include "cpu/random_sample_cpu.h" #include "cpu/random_sample_cpu.h"
#endif #endif
#ifdef ENABLE_CUDA_API
#include "cuda/random_sample_cuda.cuh"
#endif
__C infiniStatus_t infiniopCreateRandomSampleDescriptor( __C infiniStatus_t infiniopCreateRandomSampleDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -25,6 +28,9 @@ __C infiniStatus_t infiniopCreateRandomSampleDescriptor( ...@@ -25,6 +28,9 @@ __C infiniStatus_t infiniopCreateRandomSampleDescriptor(
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu); CREATE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -38,9 +44,10 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize( ...@@ -38,9 +44,10 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
size_t *size) { size_t *size) {
#define GET(CASE, NAMESPACE) \ #define GET(CASE, NAMESPACE) \
case CASE: \ case CASE: { \
using Ptr = const op::random_sample::NAMESPACE::Descriptor *; \ using Ptr = const op::random_sample::NAMESPACE::Descriptor *; \
*size = reinterpret_cast<Ptr>(desc)->minWorkspaceSize(); \ *size = reinterpret_cast<Ptr>(desc)->minWorkspaceSize(); \
} \
return INFINI_STATUS_SUCCESS return INFINI_STATUS_SUCCESS
switch (desc->device_type) { switch (desc->device_type) {
...@@ -48,6 +55,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize( ...@@ -48,6 +55,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu); GET(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -82,6 +92,9 @@ __C infiniStatus_t infiniopRandomSample( ...@@ -82,6 +92,9 @@ __C infiniStatus_t infiniopRandomSample(
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu); CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -103,6 +116,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor( ...@@ -103,6 +116,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor(
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu); DELETE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_CUDA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#ifndef __RANDOM_SAMPLE_H__ #ifndef __RANDOM_SAMPLE_H__
#define __RANDOM_SAMPLE_H__ #define __RANDOM_SAMPLE_H__
#include "../../../utils.h"
#include "../../operator.h" #include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \ #define DESCRIPTOR(NAMESPACE) \
\ \
...@@ -11,22 +11,18 @@ ...@@ -11,22 +11,18 @@
struct Opaque; \ struct Opaque; \
Opaque *_opaque; \ Opaque *_opaque; \
\ \
infiniDtype_t _dt_i, _dt_p; \ RandomSampleInfo _info; \
size_t _n, _min_workspace_size; \ size_t _min_workspace_size; \
\ \
Descriptor( \ Descriptor( \
infiniDtype_t dt_i, \ RandomSampleInfo info, \
infiniDtype_t dt_p, \
size_t n, \
size_t min_workspace_size, \ size_t min_workspace_size, \
Opaque *opaque, \ Opaque *opaque, \
infiniDevice_t device_type, \ infiniDevice_t device_type, \
int device_id) \ int device_id) \
: InfiniopDescriptor{device_type, device_id}, \ : InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \ _opaque(opaque), \
_dt_i(dt_i), \ _info(info), \
_dt_p(dt_p), \
_n(n), \
_min_workspace_size(min_workspace_size) {} \ _min_workspace_size(min_workspace_size) {} \
\ \
public: \ public: \
...@@ -53,4 +49,96 @@ ...@@ -53,4 +49,96 @@
}; \ }; \
} }
namespace op::random_sample {
struct CalculateArgs {
void *workspace;
size_t workspace_size;
void *result;
const void *probs;
float random_val, topp, temperature;
int topk;
void *stream;
};
class Calculate {
template <class Tidx, class Tval, class Algo>
static void switch_f(Algo algo, size_t n, CalculateArgs args) {
if (args.random_val == 0 || args.topp == 0 || args.topk == 1 || args.temperature == 0) {
algo.template argmax<Tidx, Tval>(
args.workspace, args.workspace_size,
args.result, args.probs, n,
args.stream);
} else {
algo.template random<Tidx, Tval>(
args.workspace, args.workspace_size,
args.result, args.probs, n,
args.random_val, args.topp, args.topk, args.temperature,
args.stream);
}
}
template <class Tidx, class Algo>
static void switch_val(
Algo algo,
infiniDtype_t dt_p, size_t n, CalculateArgs args) {
switch (dt_p) {
case INFINI_DTYPE_F16:
switch_f<Tidx, fp16_t>(algo, n, args);
break;
case INFINI_DTYPE_F32:
switch_f<Tidx, float>(algo, n, args);
break;
case INFINI_DTYPE_F64:
switch_f<Tidx, double>(algo, n, args);
break;
default:
// unreachable
std::abort();
}
}
public:
template <class Algo>
static infiniStatus_t calculate(
Algo algo,
RandomSampleInfo info,
void *workspace, size_t workspace_size,
void *result, const void *probs,
float random_val, float topp, int topk, float temperature,
void *stream) {
#define CASE(DT_VAL, DT_TYP) \
case DT_VAL: \
switch_val<DT_TYP>( \
algo, info.dt_p, info.n, \
{workspace, workspace_size, \
result, probs, \
random_val, topp, temperature, topk, \
stream}); \
break
switch (info.dt_i) {
CASE(INFINI_DTYPE_I8, int8_t);
CASE(INFINI_DTYPE_I16, int16_t);
CASE(INFINI_DTYPE_I32, int32_t);
CASE(INFINI_DTYPE_I64, int64_t);
CASE(INFINI_DTYPE_U8, uint8_t);
CASE(INFINI_DTYPE_U16, uint16_t);
CASE(INFINI_DTYPE_U32, uint32_t);
CASE(INFINI_DTYPE_U64, uint64_t);
default:
// unreachable
std::abort();
}
#undef CASE
return INFINI_STATUS_SUCCESS;
}
};
} // namespace op::random_sample
#endif // __RANDOM_SAMPLE_H__ #endif // __RANDOM_SAMPLE_H__
#ifndef __RMS_NORM_KUNLUN_KERNEL_XPU__ #ifndef __RMS_NORM_KUNLUN_KERNEL_XPU__
#define __RMS_NORM_KUNLUN_KERNEL_XPU__ #define __RMS_NORM_KUNLUN_KERNEL_XPU__
#include "../../../devices/kunlun/kunlun_common.h" #include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../reduce/kunlun/reduce_kunlun.h" #include "../../../reduce/kunlun/reduce_kunlun.h"
using namespace device::kunlun::kernel;
// Element wise mul used in x * w // Element wise mul used in x * w
static inline __device__ void elementwiseMulRms(float *x, float *w, float *y, int count, float rms) { static inline __device__ void elementwiseMulRms(float *x, float *w, float *y, int count, float rms) {
int remain = count % 16; int remain = count % 16;
......
#include "swiglu_kunlun.h"
// Op interface declare
LAUNCH_ELEMENTWISE_KERNEL(SwiGLU)
namespace op::swiglu::kunlun {
typedef struct SwiGLUOp {
static constexpr size_t num_inputs = 2;
template <typename Tdata, typename... Args>
static infiniStatus_t launch(Args... args) {
launchSwiGLUKernel<Tdata>(args...);
return INFINI_STATUS_SUCCESS;
}
} SwiGLUOp;
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &up_desc = input_desc_vec.at(0);
const auto &gate_desc = input_desc_vec.at(1);
const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F32);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create KUNLUN elementwise descriptor
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F32:
return _device_info->calculate<SwiGLUOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu::kunlun
#ifndef __SWIGLU_KUNLUN_H__
#define __SWIGLU_KUNLUN_H__
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
ELEMENTWISE_DESCRIPTOR(swiglu, kunlun)
#endif // __SWIGLU_KUNLUN_H__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment