Commit 6292da00 authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: refactor elementwise framework, complete CUDA implementation,...

issue/127: refactor elementwise framework, complete CUDA implementation, refactor swiglu using the generic elementwise framework
parent 7105d13d
......@@ -21,10 +21,43 @@
handle->device, \
handle->device_id);
DEVICE_IMPL(cpu)
namespace op::elementwise::cpu {
class DeviceImpl final {
struct Opaque;
std::shared_ptr<struct Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static infiniStatus_t create(
DeviceImpl **device_info,
Args &&...args);
/* Invoke elementwise operation when all inputs have the same type */
template <typename Op, typename Tdata, typename... Args>
void calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
/* Invoke elementwise operation for different input types */
template <typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
struct DeviceImpl::Opaque {};
template <typename... Args>
......@@ -42,13 +75,13 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output,
#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data());
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides);
auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i
: (info.input_broadcasted[input_id]
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data())
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data()));
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id])
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id]));
};
out[out_idx] = utils::cast<Tout>(Op{}(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
......@@ -60,6 +93,7 @@ template <typename Op, typename Tout, typename... Tin, typename... Args, std::en
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch");
......@@ -80,13 +114,13 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data());
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides);
auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i
: (info.input_broadcasted[input_id]
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data())
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data()));
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id])
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id]));
};
if constexpr (std::is_same_v<Tdata, fp16_t>) {
......@@ -99,7 +133,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
// Invoke elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, typename... Args>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, Args &&...args) {
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, void *stream, Args &&...args) {
constexpr size_t N = Op::num_inputs;
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
}
......
// #ifndef __INFINIOP_ELEMENTWISE_CUDA_H__
// #define __INFINIOP_ELEMENTWISE_CUDA_H__
// #include "../../devices/cuda/cuda_common.cuh"
// #include "../elementwise.h"
// #define ELEMENTWISE_CUDA_OPAQUE(OP) \
// \
// namespace op::OP::cuda { \
// struct Descriptor::Opaque { \
// std::shared_ptr<device::cuda::Handle::Internal> internal; \
// }; \
// \
// Descriptor::~Descriptor() { \
// delete _opaque; \
// } \
// } // namespace op::elementwise::cuda
// namespace op::common_cuda::elementwise_op {
// // Perform elementwise operation when all inputs have the same type
// template <size_t BLOCK_SIZE, typename Op, typename Tdata, size_t... Is, typename... Args>
// void _calculate_impl(const op::elementwise::ElementwiseInfo &info,
// void *output,
// const std::vector<const void *> &inputs,
// std::index_sequence<Is...>,
// Args &&...args) {
// Tdata *out = reinterpret_cast<Tdata *>(output);
// std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
// const ptrdiff_t output_size = info.output_size;
// #pragma omp parallel for
// for (ptrdiff_t i = 0; i < output_size; ++i) {
// size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data());
// auto get_input_idx = [&](size_t input_id) {
// return info.input_contiguous[input_id] ? i
// : (info.input_broadcasted[input_id]
// ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data())
// : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data()));
// };
// if constexpr (std::is_same_v<Tdata, fp16_t>) {
// out[out_idx] = utils::cast<fp16_t>(Op{}(utils::cast<float>(ins[Is][get_input_idx(Is)])..., std::forward<Args>(args)...));
// } else {
// out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward<Args>(args)...);
// }
// }
// }
// template <size_t BLOCK_SIZE, typename Op, typename Tdata, size_t... Is, typename... Args>
// void calculate_impl(const op::elementwise::ElementwiseInfo &info,
// void *output,
// const std::vector<const void *> &inputs,
// std::index_sequence<Is...>,
// Args &&...args) {
// if (info.output_size == 0) {
// return;
// }
// Tdata *out = reinterpret_cast<Tdata *>(output);
// std::array<const Tdata *, sizeof...(Is)> inputs_vec = {reinterpret_cast<const Tdata *>(inputs[Is])...};
// dim3 blockDims = dim3(std::min(static_cast<uint64_t>(BLOCK_SIZE), info.output_size));
// dim3 gridDims = dim3(std::min(ROUND_UP_DIV(info.output_size, blockDims.x), desc->max_grid_size));
// uint64_t step = gridDims.x * blockDims.x;
// _calculate_impl<BLOCK_SIZE, Op, Tdata, TIdata>(info, out, inputs_vec, Is, std::forward<Args>(args)...);
// }
// // Invoke elementwise operation when all inputs have the same type
// template <size_t BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
// void calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, Args &&...args) {
// constexpr size_t N = Op::num_inputs;
// calculate_impl<BLOCK_SIZE, Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
// }
// } // namespace op::common_cuda::elementwise_op
// #endif // __INFINIOP_ELEMENTWISE_CUDA_H__
\ No newline at end of file
#ifndef __INFINIOP_ELEMENTWISE_CUDA_H__
#define __INFINIOP_ELEMENTWISE_CUDA_H__
#include "../../../utils.h"
#include "../../devices/cuda/cuda_common.cuh"
#include "elementwise_cuda_api.cuh"
namespace op::elementwise::cuda {
/**
* @brief Helper device function to expand a compile-time index sequence into individual constants
* and pass them to a lambda.
*
* @tparam Lambda Type of the lambda function to invoke.
* @tparam Is Index sequence values (automatically deduced).
* @param lambda Lambda to be called with std::integral_constant<size_t, Is>... as arguments.
*/
template <typename Lambda, size_t... Is>
__device__ __forceinline__ void call_expand(Lambda lambda, std::index_sequence<Is...>) {
lambda(std::integral_constant<size_t, Is>{}...);
}
/**
* @brief CUDA kernel for performing elementwise operations on tensors where all inputs share the same data type.
*
* @tparam Op Operator type implementing operator()(Tdata...).
* @tparam Tdata Common data type for inputs and output.
* @tparam N Number of input tensors.
* @tparam Args Additional arguments to pass to the operator.
*
* @param output_size Total number of output elements.
* @param ndim Number of dimensions in tensors.
* @param output_contiguous Whether the output tensor is contiguous in memory.
* @param input_contiguous Array indicating if each input tensor is contiguous.
* @param input_broadcasted Array indicating if each input tensor is broadcasted.
* @param output_shape Shape of the output tensor.
* @param input_shapes Shapes of the input tensors.
* @param output_strides Strides for the output tensor.
* @param input_strides Strides for each input tensor.
* @param input_size Total number of input elements (optional, may be unused).
* @param output Output buffer.
* @param inputs Array of input pointers, all of type Tdata.
* @param offset Linear offset to support partitioned execution.
* @param args Additional arguments passed to the operator.
*/
template <typename Op, typename Tdata, size_t N, typename... Args>
INFINIOP_CUDA_KERNEL elementwise_kernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape,
const size_t *__restrict__ *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ *__restrict__ input_strides,
size_t input_size,
Tdata *output,
const Tdata *const *inputs,
size_t offset,
Args... args) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx < output_size) {
size_t out_idx = output_contiguous ? idx
: device::cuda::indexToOffset(idx, ndim, output_shape, output_strides);
auto get_input_idx = [&] __device__(size_t input_id) {
return input_contiguous[input_id] ? idx
: (input_broadcasted[input_id]
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides[input_id])
: device::cuda::indexToOffset(idx, ndim, input_shapes[input_id], input_strides[input_id]));
};
// Use a helper to expand the index sequence into individual compile-time constants
auto expand_inputs = [&] __device__(auto... idxs) {
if constexpr (std::is_same_v<Tdata, fp16_t>) {
output[out_idx] = utils::cast<fp16_t>(
Op{}(utils::cast<float>(inputs[idxs.value][get_input_idx(idxs.value)])...,
std::forward<Args>(args)...));
} else {
output[out_idx] = Op{}(
inputs[idxs.value][get_input_idx(idxs.value)]...,
std::forward<Args>(args)...);
}
};
call_expand(expand_inputs, std::make_index_sequence<N>{});
}
}
/**
* @brief Casts an untyped device pointer to a typed pointer of type T.
*
* @tparam T Desired pointer type.
* @param ptr Untyped pointer.
* @return Pointer of type const T*.
*/
template <typename T>
__device__ inline const T *typed_input_ptr(const void *ptr) {
return reinterpret_cast<const T *>(ptr);
}
/**
* @brief Launches a type-safe elementwise operation on a single output element.
*
* @tparam Op Operator type implementing a templated operator() for (Tout, Tin...).
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
* @tparam Is Index sequence corresponding to each input.
*
* @param idx Linear index in the flattened output space.
* @param out_idx Actual output index (may be non-contiguous).
* @param ndim Number of dimensions in the tensors.
* @param input_contiguous Array indicating whether each input is contiguous.
* @param input_broadcasted Array indicating whether each input is broadcasted.
* @param input_shapes Shapes of the input tensors.
* @param input_strides Strides of the input tensors.
* @param inputs Raw pointers to input data.
* @param output Pointer to output data.
* @param ... Index sequence used for unpacking variadic inputs.
*/
template <typename Op, typename Tout, typename... Tin, size_t... Is>
__device__ void launch_op(
size_t idx,
size_t out_idx,
size_t ndim,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ const *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ const *__restrict__ input_strides,
const void *const *__restrict__ inputs,
Tout *output,
std::index_sequence<Is...>) {
auto get_input_idx = [&] __device__(size_t input_id) {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::cuda::indexToReducedOffset(idx, ndim, input_strides[0], input_strides[input_id])
: device::cuda::indexToOffset(idx, ndim, input_shapes[input_id], input_strides[input_id]));
};
output[out_idx] = Op{}.template operator()<Tout, Tin...>(
(typed_input_ptr<Tin>(inputs[Is])[get_input_idx(Is)])...);
}
/**
* @brief CUDA kernel for performing an elementwise operation on tensors with support
* for broadcasting and mixed data types.
*
* @tparam Op Operator type implementing a templated operator() for (Tout, Tin...).
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
*
* @param output_size Total number of output elements.
* @param ndim Number of dimensions in the tensors.
* @param output_contiguous Whether the output tensor is contiguous.
* @param input_contiguous Array indicating whether each input is contiguous.
* @param input_broadcasted Array indicating whether each input is broadcasted.
* @param output_shape Shape of the output tensor.
* @param input_shapes Shapes of the input tensors.
* @param output_strides Strides of the output tensor.
* @param input_strides Strides of the input tensors.
* @param input_size Total number of input elements (unused here, but may be used for validation).
* @param output Pointer to the output buffer.
* @param inputs Array of untyped input pointers.
* @param offset Linear offset into the output for partitioned execution.
*/
template <typename Op, typename Tout, typename... Tin>
INFINIOP_CUDA_KERNEL elementwise_kernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape,
const size_t *__restrict__ const *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ const *__restrict__ input_strides,
size_t input_size,
Tout *output,
const void *const *__restrict__ inputs,
size_t offset) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx >= output_size) {
return;
}
size_t out_idx = output_contiguous
? idx
: device::cuda::indexToOffset(idx, ndim, output_shape, output_strides);
launch_op<Op, Tout, Tin...>(
idx,
out_idx,
ndim,
input_contiguous,
input_broadcasted,
input_shapes,
input_strides,
inputs,
output,
std::index_sequence_for<Tin...>{});
}
struct DeviceImpl::Opaque {
std::shared_ptr<device::cuda::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::cuda::Handle::Internal> &internal)
: internal(internal) {}
/**
* @brief Performs elementwise operations when all inputs and the output share the same data type.
*
* @tparam BLOCK_SIZE The block size for the kernel launch.
* @tparam N The number of input tensors.
* @tparam Op The operation to perform (e.g., addition, multiplication).
* @tparam Tdata The data type of the input and output tensors.
* @tparam Args Additional arguments to be passed to the operation.
* @param info Structure containing elementwise operation information (size, shape, etc.).
* @param output Pointer to the output memory where results will be stored.
* @param inputs Vector of pointers to input tensors.
* @param stream CUDA stream used for asynchronous execution.
* @param args Additional arguments for the operation.
*/
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args, size_t... Is>
void calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
cudaStream_t stream,
Args &&...args) {
if (info.output_size == 0) {
return;
}
// casting the output and the inputs to Tdata pointers
Tdata *out = reinterpret_cast<Tdata *>(output);
const Tdata *inputs_arr[N];
const Tdata **d_inputs_arr = nullptr;
for (size_t i = 0; i < N; ++i) {
inputs_arr[i] = reinterpret_cast<const Tdata *>(inputs[i]);
}
cudaMallocAsync(&d_inputs_arr, N * sizeof(*d_inputs_arr), stream);
cudaMemcpyAsync(d_inputs_arr, inputs_arr, N * sizeof(*d_inputs_arr), cudaMemcpyHostToDevice, stream);
// create and send the info to device
const bool *d_bools = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const int8_t *d_output_shape_strides = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t **d_input_shapes = nullptr;
const ptrdiff_t **d_input_strides = nullptr;
std::vector<const size_t *> tmp_device_ptrs(info.input_size);
std::vector<const ptrdiff_t *> tmp_device_ptrs_strides(info.input_size);
infoToDevice<N>(info, d_bools, d_input_contiguous,
d_input_broadcasted, d_output_shape_strides, d_output_shape,
d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides,
d_input_strides, stream);
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < info.output_size; i += step) {
elementwise_kernel<Op, Tdata, N, Args...><<<gridDims, blockDims, 0, stream>>>(
info.output_size,
info.ndim,
info.output_contiguous,
d_input_contiguous,
d_input_broadcasted,
d_output_shape,
d_input_shapes,
d_output_strides,
d_input_strides,
info.input_size, out, d_inputs_arr, i, std::forward<Args>(args)...);
}
freeAllDevice((const void **)d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream);
}
/**
* @brief Performs elementwise operations when inputs and the outputs have mixed data types (i.e., different dtypes).
*
* @tparam BLOCK_SIZE The block size for the kernel launch.
* @tparam N The number of input tensors.
* @tparam Op The operation to perform (e.g., addition, multiplication).
* @tparam Tout The output data type.
* @tparam Tin The input data types.
* @tparam Args Additional arguments to be passed to the operation.
* @param info Structure containing elementwise operation information (size, shape, etc.).
* @param output Pointer to the output memory where results will be stored.
* @param inputs Vector of pointers to input tensors.
* @param stream CUDA stream used for asynchronous execution.
* @param args Additional arguments for the operation.
*/
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tout, typename... Tin, typename... Args, size_t... Is,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
cudaStream_t stream,
Args &&...args) {
if (info.output_size == 0) {
return;
}
Tout *out = reinterpret_cast<Tout *>(output);
// Store input pointers with the correct types
const std::tuple<const Tin *...> inputs_arr{reinterpret_cast<const Tin *>(inputs[Is])...};
const void **d_inputs_arr = nullptr;
// Create array of input pointers on host (void*) to copy to device
const void *host_input_ptrs[] = {reinterpret_cast<const void *>(std::get<Is>(inputs_arr))...};
cudaMallocAsync(&d_inputs_arr, N * sizeof(void *), stream);
cudaMemcpyAsync(d_inputs_arr, host_input_ptrs, N * sizeof(void *), cudaMemcpyHostToDevice, stream);
// Device pointers
const bool *d_bools = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const int8_t *d_output_shape_strides = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t **d_input_shapes = nullptr;
const ptrdiff_t **d_input_strides = nullptr;
std::vector<const size_t *> tmp_device_ptrs(info.input_size);
std::vector<const ptrdiff_t *> tmp_device_ptrs_strides(info.input_size);
infoToDevice<N>(info, d_bools, d_input_contiguous,
d_input_broadcasted, d_output_shape_strides, d_output_shape,
d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides,
d_input_strides, stream);
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < info.output_size; i += step) {
elementwise_kernel<Op, Tout, Tin...><<<gridDims, blockDims, 0, stream>>>(
info.output_size,
info.ndim,
info.output_contiguous,
d_input_contiguous,
d_input_broadcasted,
d_output_shape,
d_input_shapes,
d_output_strides,
d_input_strides,
info.input_size, out, reinterpret_cast<const void **>(d_inputs_arr), i);
}
freeAllDevice(d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream);
}
private:
/**
* @brief Transfers elementwise kernel metadata (shapes, strides, flags) from host to device.
*
* @tparam N Number of inputs.
* @param info Structure containing input/output metadata.
* @param d_bools Device pointer for input_contiguous and input_broadcasted flags.
* @param d_input_contiguous Device pointer to input contiguity flags.
* @param d_input_broadcasted Device pointer to input broadcasting flags.
* @param d_output_shape_strides Device buffer containing both output shape and strides.
* @param d_output_shape Device pointer to output shape.
* @param d_output_strides Device pointer to output strides.
* @param tmp_device_ptrs Temporary device pointers for input shapes.
* @param d_input_shapes Device array of pointers to input shapes.
* @param tmp_device_ptrs_strides Temporary device pointers for input strides.
* @param d_input_strides Device array of pointers to input strides.
* @param stream CUDA stream for async allocation and transfers.
* @return infiniStatus_t Status indicating success or failure.
*/
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
const bool *&d_bools,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const int8_t *&d_output_shape_strides,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
std::vector<const size_t *> &tmp_device_ptrs,
const size_t **&d_input_shapes,
std::vector<const ptrdiff_t *> &tmp_device_ptrs_strides,
const ptrdiff_t **&d_input_strides,
cudaStream_t stream) const {
cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream);
cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream);
cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream);
cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync((void *)(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)), info.output_strides, info.ndim * sizeof(*d_output_strides), cudaMemcpyHostToDevice, stream);
cudaMallocAsync(&d_input_shapes, info.input_size * sizeof(*d_input_shapes), stream);
for (size_t i = 0; i < info.input_size; ++i) {
cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream);
cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i],
info.ndim * sizeof(size_t), cudaMemcpyHostToDevice, stream);
}
cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(),
info.input_size * sizeof(size_t *), cudaMemcpyHostToDevice, stream);
cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream);
for (size_t i = 0; i < info.input_size; ++i) {
cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*&tmp_device_ptrs_strides[i]), stream);
cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i],
info.ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice, stream);
}
cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(),
info.input_size * sizeof(ptrdiff_t *), cudaMemcpyHostToDevice, stream);
d_input_contiguous = d_bools;
d_input_broadcasted = d_bools + info.input_size;
d_output_shape = reinterpret_cast<const size_t *>(d_output_shape_strides);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape_strides + info.ndim * sizeof(size_t));
return INFINI_STATUS_SUCCESS;
}
/**
* @brief Frees all device-allocated memory used for metadata in elementwise kernel execution.
*
* @param d_inputs_arr Device array of input pointers.
* @param d_bools Device memory holding input flags.
* @param d_output_shape_strides Device buffer holding output shape and strides.
* @param input_size Number of input tensors.
* @param d_input_shapes Device array of input shape pointers.
* @param d_input_strides Device array of input stride pointers.
* @param stream CUDA stream for async deallocation.
* @return infiniStatus_t Status indicating success or failure.
*/
inline infiniStatus_t freeAllDevice(const void **d_inputs_arr,
const bool *d_bools,
const int8_t *d_output_shape_strides,
const size_t input_size,
const size_t **d_input_shapes,
const ptrdiff_t **d_input_strides,
cudaStream_t stream) const {
cudaFreeAsync((void *)d_inputs_arr, stream);
cudaFreeAsync((void *)d_bools, stream);
cudaFreeAsync((void *)d_output_shape_strides, stream);
cudaFreeAsync((void *)d_input_shapes, stream);
cudaFreeAsync((void *)d_input_strides, stream);
return INFINI_STATUS_SUCCESS;
}
};
template <typename... Args>
infiniStatus_t DeviceImpl::create(DeviceImpl **device_info,
Args &&...args) {
auto opaque = std::make_shared<Opaque>(std::forward<Args>(args)...);
*device_info = new DeviceImpl(opaque);
return INFINI_STATUS_SUCCESS;
}
/**
* @brief Launches elementwise operation where input types may differ.
*
* Dispatches to templated `calculateImpl` using specified output and input types.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam Op Operation functor defining the computation.
* @tparam Tout Output data type.
* @tparam Tin... Input data types (must match Op::num_inputs).
* @tparam Args... Additional arguments passed to the operation.
* @param info Metadata describing tensor shapes, strides, etc.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args (UNUSED) Additional operation-specific arguments.
*/
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
static_assert(sizeof...(Tin) == N, "Input type count mismatch");
_opaque->calculateImpl<BLOCK_SIZE, N, Op, Tout, Tin...>(
info, output, inputs,
std::make_index_sequence<N>{},
reinterpret_cast<cudaStream_t>(stream),
std::forward<Args>(args)...);
}
/**
* @brief Launches elementwise operation where all input types are the same.
*
* Calls the corresponding templated `calculateImpl` with a unified input type.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam Op Operation functor defining the computation.
* @tparam Tdata Data type for both input and output tensors.
* @tparam Args... Additional arguments passed to the operation.
* @param info Metadata describing tensor shapes, strides, etc.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args Additional operation-specific arguments.
*/
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
_opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>(
info, output, inputs,
std::make_index_sequence<N>{},
reinterpret_cast<cudaStream_t>(stream),
std::forward<Args>(args)...);
cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream));
}
} // namespace op::elementwise::cuda
#endif // __INFINIOP_ELEMENTWISE_CUDA_H__
#ifndef __INFINIOP_ELEMENTWISE_CUDA_API_H__
#define __INFINIOP_ELEMENTWISE_CUDA_API_H__
#include "../elementwise.h"
namespace op::elementwise::cuda {
/**
* @brief Define the methods and info needed by CUDA to perform elementwise operation
*/
class DeviceImpl final {
struct Opaque;
std::shared_ptr<struct Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static infiniStatus_t create(
DeviceImpl **device_info,
Args &&...args);
/* Invoke elementwise operation when all inputs have the same dtype */
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
void calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
/* Invoke elementwise operation for different input types */
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
} // namespace op::elementwise::cuda
/**
* @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CUDA implementation
*/
#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR \
\
op::elementwise::ElementwiseInfo elementwise_info; \
CHECK_STATUS(op::elementwise::createElementwiseInfo(elementwise_info, out_desc, input_desc)); \
\
op::elementwise::cuda::DeviceImpl *device_impl; \
CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \
\
*desc_ptr = new Descriptor( \
dtype, \
std::move(elementwise_info), \
device_impl, \
handle->device, \
handle->device_id);
#endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__
......@@ -4,47 +4,12 @@
#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
#include <cstring>
#include <iostream>
#include <memory>
#include <numeric>
#include <vector>
#define DEVICE_IMPL(NAMESPACE) \
\
namespace op::elementwise::NAMESPACE { \
class DeviceImpl final { \
struct Opaque; \
std::unique_ptr<Opaque> _opaque; \
\
DeviceImpl(Opaque *opaque) : _opaque(opaque) {} \
\
public: \
~DeviceImpl() = default; \
\
template <typename... Args> \
static infiniStatus_t create( \
DeviceImpl **device_info, \
Args &&...args); \
\
/* Invoke elementwise operation when all inputs have the same type */ \
template <typename Op, typename Tdata, typename... Args> \
void calculate( \
const op::elementwise::ElementwiseInfo &info, \
void *output, \
const std::vector<const void *> &inputs, \
Args &&...args); \
\
/* Invoke elementwise operation for different input types */ \
template <typename Op, typename Tout, typename... Tin, \
typename... Args, \
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> \
void calculate( \
const op::elementwise::ElementwiseInfo &info, \
void *output, \
const std::vector<const void *> &inputs, \
Args &&...args); \
}; \
}
#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \
\
namespace op::OP::NAMESPACE { \
......@@ -61,7 +26,7 @@
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \
_info(info), \
_info(std::move(info)), \
_device_info(device_info) {} \
\
public: \
......@@ -87,12 +52,84 @@ struct ElementwiseInfo {
size_t output_size;
size_t ndim;
bool output_contiguous;
std::vector<bool> input_contiguous;
std::vector<bool> input_broadcasted;
std::vector<size_t> output_shape;
std::vector<std::vector<size_t>> input_shapes;
std::vector<ptrdiff_t> output_strides;
std::vector<std::vector<ptrdiff_t>> input_strides;
bool *input_contiguous;
bool *input_broadcasted;
size_t *output_shape;
size_t **input_shapes;
ptrdiff_t *output_strides;
ptrdiff_t **input_strides;
size_t input_size;
ElementwiseInfo() = default;
// Destructor to free allocated memory
~ElementwiseInfo() {
delete[] input_contiguous;
delete[] input_broadcasted;
delete[] output_shape;
delete[] output_strides;
for (size_t i = 0; i < input_size; ++i) {
delete[] input_shapes[i];
delete[] input_strides[i];
}
delete[] input_shapes;
delete[] input_strides;
}
ElementwiseInfo(const ElementwiseInfo &other)
: output_size(other.output_size),
ndim(other.ndim),
output_contiguous(other.output_contiguous),
input_size(other.input_size) {
input_contiguous = new bool[input_size];
std::memcpy(input_contiguous, other.input_contiguous, input_size * sizeof(*input_contiguous));
input_broadcasted = new bool[input_size];
std::memcpy(input_broadcasted, other.input_broadcasted, input_size * sizeof(*input_broadcasted));
output_shape = new size_t[ndim];
std::memcpy(output_shape, other.output_shape, ndim * sizeof(*output_shape));
output_strides = new ptrdiff_t[ndim];
std::memcpy(output_strides, other.output_strides, ndim * sizeof(*output_strides));
input_shapes = new size_t *[input_size];
for (size_t i = 0; i < input_size; ++i) {
input_shapes[i] = new size_t[ndim];
std::memcpy(input_shapes[i], other.input_shapes[i], ndim * sizeof(*input_shapes[i]));
}
input_strides = new ptrdiff_t *[input_size];
for (size_t i = 0; i < input_size; ++i) {
input_strides[i] = new ptrdiff_t[ndim];
std::memcpy(input_strides[i], other.input_strides[i], ndim * sizeof(*input_strides[i]));
}
}
ElementwiseInfo(ElementwiseInfo &&other) noexcept
: output_size(other.output_size),
ndim(other.ndim),
output_contiguous(other.output_contiguous),
input_contiguous(other.input_contiguous),
input_broadcasted(other.input_broadcasted),
output_shape(other.output_shape),
input_shapes(other.input_shapes),
output_strides(other.output_strides),
input_strides(other.input_strides),
input_size(other.input_size) {
other.input_contiguous = nullptr;
other.input_broadcasted = nullptr;
other.output_shape = nullptr;
other.input_shapes = nullptr;
other.output_strides = nullptr;
other.input_strides = nullptr;
other.input_size = 0;
}
ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete;
ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete;
};
inline infiniStatus_t createElementwiseInfo(
......@@ -109,28 +146,37 @@ inline infiniStatus_t createElementwiseInfo(
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
const size_t input_size = input_descs.size();
const size_t out_ndim = output_desc->ndim();
// Intializing the ElementwiseInfo struct
info.input_size = input_descs.size();
info.ndim = output_desc->ndim();
info.output_size = output_desc->numel();
info.ndim = out_ndim;
info.output_contiguous = output_desc->isContiguous();
for (const auto &desc : input_descs) {
info.input_contiguous.emplace_back(desc->isContiguous());
}
for (size_t i = 0; i < input_size; ++i) {
const auto &desc = input_descs[i];
info.input_broadcasted.emplace_back(!info.input_contiguous[i] && (desc->ndim() != out_ndim || desc->hasBroadcastDim()));
}
info.output_shape = std::move(output_desc->shape());
info.output_strides = std::move(output_desc->strides());
for (const auto &desc : input_descs) {
info.input_shapes.emplace_back(desc->shape());
info.input_strides.emplace_back(desc->strides());
// Allocate memory for arrays
info.input_contiguous = new bool[info.input_size];
info.input_broadcasted = new bool[info.input_size];
info.output_shape = new size_t[info.ndim];
info.output_strides = new ptrdiff_t[info.ndim];
info.input_shapes = new size_t *[info.input_size];
info.input_strides = new ptrdiff_t *[info.input_size];
// Fill arrays
const auto output_shape = output_desc->shape();
const auto output_strides = output_desc->strides();
std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape));
std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides));
for (size_t i = 0; i < info.input_size; ++i) {
auto &desc = input_descs[i];
info.input_contiguous[i] = desc->isContiguous();
info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim());
info.input_shapes[i] = new size_t[desc->ndim()];
const auto &in_shape = desc->shape();
std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i]));
info.input_strides[i] = new ptrdiff_t[desc->ndim()];
const auto &in_strides = desc->strides();
std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i]));
}
return INFINI_STATUS_SUCCESS;
......
......@@ -8,11 +8,13 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t up_desc,
infiniopTensorDescriptor_t gate_desc) {
std::vector<infiniopTensorDescriptor_t> input_desc) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &up_desc = input_desc.at(0);
const auto &gate_desc = input_desc.at(1);
const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
......@@ -21,35 +23,26 @@ infiniStatus_t Descriptor::create(
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
op::binary::BinaryInfo info;
CHECK_STATUS(op::binary::createBinaryInfo(info, out_desc, up_desc, gate_desc));
// Create descriptor
*desc_ptr = new Descriptor(
dtype,
std::move(info),
nullptr,
handle->device,
handle->device_id);
// create CPU elementwise descriptor
CREATE_ELEMENTWISE_CPU_DESCRIPTOR;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *c,
const void *a,
const void *b,
void *output,
std::vector<const void *> inputs,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
op::common_cpu::binary_op::calculate<fp16_t, SwiGLUOp>(_info, c, a, b);
_device_info->calculate<SwiGLUOp, fp16_t>(_info, output, inputs, stream);
break;
case INFINI_DTYPE_F32:
op::common_cpu::binary_op::calculate<float, SwiGLUOp>(_info, c, a, b);
_device_info->calculate<SwiGLUOp, float>(_info, output, inputs, stream);
break;
case INFINI_DTYPE_F64:
op::common_cpu::binary_op::calculate<double, SwiGLUOp>(_info, c, a, b);
_device_info->calculate<SwiGLUOp, double>(_info, output, inputs, stream);
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
......
#ifndef __SWIGLU_CPU_H__
#define __SWIGLU_CPU_H__
#include "../../../binary/cpu/binary_cpu.h"
#include "../../../elementwise/cpu/elementwise_cpu.h"
BINARY_DESCRIPTOR(swiglu, cpu)
ELEMENTWISE_DESCRIPTOR(swiglu, cpu)
struct SwiGLUOp {
namespace op::swiglu::cpu {
typedef struct SwiGLUOp {
private:
template <typename T>
T sigmoid(const T &x) const {
......@@ -13,10 +14,12 @@ private:
}
public:
static constexpr size_t num_inputs = 2;
template <typename T>
T operator()(const T &up, const T &gate) const {
return gate * sigmoid(gate) * up;
}
};
} SwiGLUOp;
} // namespace op::swiglu::cpu
#endif // __SWIGLU_CPU_H__
#include "swiglu_cuda.cuh"
#include "swiglu_cuda_internal.cuh"
namespace op::swiglu::cuda {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc) {
auto handle = reinterpret_cast<device::cuda::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &up_desc = input_desc.at(0);
const auto &gate_desc = input_desc.at(1);
const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
if (!SAME_VEC(out_shape, up_shape, gate_shape)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *output,
std::vector<const void *> inputs,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
_device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream);
break;
case INFINI_DTYPE_F32:
_device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream);
break;
case INFINI_DTYPE_F64:
_device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream);
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu::cuda
#ifndef __SWIGLU_CUDA_API_H__
#define __SWIGLU_CUDA_API_H__
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
ELEMENTWISE_DESCRIPTOR(swiglu, cuda)
#endif // __SWIGLU_CUDA_API_H__
#ifndef __SWIGLU_CUDA_H__
#define __SWIGLU_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_fp16.h>
namespace op::swiglu::cuda {
typedef struct SwiGLUOp {
private:
template <typename T>
__device__ __forceinline__ T sigmoid(const T &x) const {
if constexpr (std::is_same_v<T, half2>) {
return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x))));
} else if constexpr (std::is_same_v<T, half>) {
return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x))))));
} else if constexpr (std::is_same_v<T, float>) {
return __frcp_rd(__fadd_rd(1, __expf(-x)));
} else {
return 1 / (1 + std::exp(-x));
}
}
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__device__ __forceinline__ T operator()(const T &up, const T &gate) const {
if constexpr (std::is_same_v<T, half2>) {
return __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, float>) {
return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), up);
} else {
return gate * sigmoid(gate) * up;
}
}
template <typename Tc, typename Ta, typename Tb>
__device__ __forceinline__ Tc operator()(const Ta &up, const Tb &gate) const {
if constexpr (std::is_same_v<Ta, half2>) {
return __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<Ta, half> && std::is_same_v<Tb, float>) {
if constexpr (std::is_same_v<Tc, half>) {
return __float2half(__fmul_rd(__fmul_rd(gate, sigmoid(gate)), __half2float(up)));
} else {
return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), __half2float(up));
}
} else if constexpr (std::is_same_v<Ta, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<Ta, float>) {
return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), up);
} else {
return gate * sigmoid(gate) * up;
}
}
} SwiGLUOp;
} // namespace op::swiglu::cuda
#endif // __SWIGLU_CUDA_H__
......@@ -5,6 +5,9 @@
#ifdef ENABLE_CPU_API
#include "cpu/swiglu_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/swiglu_cuda.cuh"
#endif
__C infiniStatus_t infiniopCreateSwiGLUDescriptor(
infiniopHandle_t handle,
......@@ -19,19 +22,16 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
handle, \
reinterpret_cast<op::swiglu::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \
a_desc, \
b_desc)
{a_desc, \
b_desc})
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu:
return cudaCreateSwiGLUDescriptor((CudaHandle_t)handle,
(SwiGLUCudaDescriptor_t *)desc_ptr,
c_desc, a_desc, b_desc);
#ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
......@@ -76,16 +76,15 @@ __C infiniStatus_t infiniopSwiGLU(
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc) \
->calculate(c, a, b, stream)
->calculate(c, {a, b}, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu:
return cudaSwiGLU((SwiGLUCudaDescriptor_t)desc, c, a, b, stream);
#ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
......@@ -125,9 +124,8 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
#ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu:
return cudaDestroySwiGLUDescriptor((SwiGLUCudaDescriptor_t)desc);
#ifdef ENABLE_CUDA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
......
......@@ -98,4 +98,6 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) {
}
}
#define CEIL_DIV(x, y) ((x + y - 1) / y)
#endif
......@@ -28,6 +28,7 @@ target("infiniop-cuda")
else
add_cuflags("-Xcompiler=-Wall", "-Xcompiler=-Werror")
add_cuflags("-Xcompiler=-fPIC")
add_cuflags("--extended-lambda")
add_culdflags("-Xcompiler=-fPIC")
add_cxxflags("-fPIC")
end
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment