Commit 6292da00 authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: refactor elementwise framework, complete CUDA implementation,...

issue/127: refactor elementwise framework, complete CUDA implementation, refactor swiglu using the generic elementwise framework
parent 7105d13d
...@@ -21,10 +21,43 @@ ...@@ -21,10 +21,43 @@
handle->device, \ handle->device, \
handle->device_id); handle->device_id);
DEVICE_IMPL(cpu)
namespace op::elementwise::cpu { namespace op::elementwise::cpu {
class DeviceImpl final {
struct Opaque;
std::shared_ptr<struct Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static infiniStatus_t create(
DeviceImpl **device_info,
Args &&...args);
/* Invoke elementwise operation when all inputs have the same type */
template <typename Op, typename Tdata, typename... Args>
void calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
/* Invoke elementwise operation for different input types */
template <typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
struct DeviceImpl::Opaque {}; struct DeviceImpl::Opaque {};
template <typename... Args> template <typename... Args>
...@@ -42,13 +75,13 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, ...@@ -42,13 +75,13 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output,
#pragma omp parallel for #pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) { for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data()); size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides);
auto get_input_idx = [&](size_t input_id) { auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i return info.input_contiguous[input_id] ? i
: (info.input_broadcasted[input_id] : (info.input_broadcasted[input_id]
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data()) ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id])
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data())); : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id]));
}; };
out[out_idx] = utils::cast<Tout>(Op{}(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...)); out[out_idx] = utils::cast<Tout>(Op{}(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
...@@ -60,6 +93,7 @@ template <typename Op, typename Tout, typename... Tin, typename... Args, std::en ...@@ -60,6 +93,7 @@ template <typename Op, typename Tout, typename... Tin, typename... Args, std::en
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream,
Args &&...args) { Args &&...args) {
static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch"); static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch");
...@@ -80,13 +114,13 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, ...@@ -80,13 +114,13 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
#pragma omp parallel for #pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) { for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data()); size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides);
auto get_input_idx = [&](size_t input_id) { auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i return info.input_contiguous[input_id] ? i
: (info.input_broadcasted[input_id] : (info.input_broadcasted[input_id]
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data()) ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id])
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data())); : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id]));
}; };
if constexpr (std::is_same_v<Tdata, fp16_t>) { if constexpr (std::is_same_v<Tdata, fp16_t>) {
...@@ -99,7 +133,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, ...@@ -99,7 +133,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
// Invoke elementwise operation when all inputs have the same type // Invoke elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, typename... Args> template <typename Op, typename Tdata, typename... Args>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, Args &&...args) { void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, void *stream, Args &&...args) {
constexpr size_t N = Op::num_inputs; constexpr size_t N = Op::num_inputs;
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...); calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
} }
......
#ifndef __INFINIOP_ELEMENTWISE_CUDA_API_H__
#define __INFINIOP_ELEMENTWISE_CUDA_API_H__
#include "../elementwise.h"
namespace op::elementwise::cuda {
/**
* @brief Define the methods and info needed by CUDA to perform elementwise operation
*/
class DeviceImpl final {
struct Opaque;
std::shared_ptr<struct Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static infiniStatus_t create(
DeviceImpl **device_info,
Args &&...args);
/* Invoke elementwise operation when all inputs have the same dtype */
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
void calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
/* Invoke elementwise operation for different input types */
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
} // namespace op::elementwise::cuda
/**
* @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CUDA implementation
*/
#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR \
\
op::elementwise::ElementwiseInfo elementwise_info; \
CHECK_STATUS(op::elementwise::createElementwiseInfo(elementwise_info, out_desc, input_desc)); \
\
op::elementwise::cuda::DeviceImpl *device_impl; \
CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \
\
*desc_ptr = new Descriptor( \
dtype, \
std::move(elementwise_info), \
device_impl, \
handle->device, \
handle->device_id);
#endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__
...@@ -4,47 +4,12 @@ ...@@ -4,47 +4,12 @@
#include "../operator.h" #include "../operator.h"
#include "../tensor.h" #include "../tensor.h"
#include <algorithm> #include <algorithm>
#include <cstring>
#include <iostream>
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include <vector> #include <vector>
#define DEVICE_IMPL(NAMESPACE) \
\
namespace op::elementwise::NAMESPACE { \
class DeviceImpl final { \
struct Opaque; \
std::unique_ptr<Opaque> _opaque; \
\
DeviceImpl(Opaque *opaque) : _opaque(opaque) {} \
\
public: \
~DeviceImpl() = default; \
\
template <typename... Args> \
static infiniStatus_t create( \
DeviceImpl **device_info, \
Args &&...args); \
\
/* Invoke elementwise operation when all inputs have the same type */ \
template <typename Op, typename Tdata, typename... Args> \
void calculate( \
const op::elementwise::ElementwiseInfo &info, \
void *output, \
const std::vector<const void *> &inputs, \
Args &&...args); \
\
/* Invoke elementwise operation for different input types */ \
template <typename Op, typename Tout, typename... Tin, \
typename... Args, \
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> \
void calculate( \
const op::elementwise::ElementwiseInfo &info, \
void *output, \
const std::vector<const void *> &inputs, \
Args &&...args); \
}; \
}
#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \ #define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \
\ \
namespace op::OP::NAMESPACE { \ namespace op::OP::NAMESPACE { \
...@@ -61,7 +26,7 @@ ...@@ -61,7 +26,7 @@
int device_id) \ int device_id) \
: InfiniopDescriptor{device_type, device_id}, \ : InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \ _dtype(dtype), \
_info(info), \ _info(std::move(info)), \
_device_info(device_info) {} \ _device_info(device_info) {} \
\ \
public: \ public: \
...@@ -87,12 +52,84 @@ struct ElementwiseInfo { ...@@ -87,12 +52,84 @@ struct ElementwiseInfo {
size_t output_size; size_t output_size;
size_t ndim; size_t ndim;
bool output_contiguous; bool output_contiguous;
std::vector<bool> input_contiguous; bool *input_contiguous;
std::vector<bool> input_broadcasted; bool *input_broadcasted;
std::vector<size_t> output_shape; size_t *output_shape;
std::vector<std::vector<size_t>> input_shapes; size_t **input_shapes;
std::vector<ptrdiff_t> output_strides; ptrdiff_t *output_strides;
std::vector<std::vector<ptrdiff_t>> input_strides; ptrdiff_t **input_strides;
size_t input_size;
ElementwiseInfo() = default;
// Destructor to free allocated memory
~ElementwiseInfo() {
delete[] input_contiguous;
delete[] input_broadcasted;
delete[] output_shape;
delete[] output_strides;
for (size_t i = 0; i < input_size; ++i) {
delete[] input_shapes[i];
delete[] input_strides[i];
}
delete[] input_shapes;
delete[] input_strides;
}
ElementwiseInfo(const ElementwiseInfo &other)
: output_size(other.output_size),
ndim(other.ndim),
output_contiguous(other.output_contiguous),
input_size(other.input_size) {
input_contiguous = new bool[input_size];
std::memcpy(input_contiguous, other.input_contiguous, input_size * sizeof(*input_contiguous));
input_broadcasted = new bool[input_size];
std::memcpy(input_broadcasted, other.input_broadcasted, input_size * sizeof(*input_broadcasted));
output_shape = new size_t[ndim];
std::memcpy(output_shape, other.output_shape, ndim * sizeof(*output_shape));
output_strides = new ptrdiff_t[ndim];
std::memcpy(output_strides, other.output_strides, ndim * sizeof(*output_strides));
input_shapes = new size_t *[input_size];
for (size_t i = 0; i < input_size; ++i) {
input_shapes[i] = new size_t[ndim];
std::memcpy(input_shapes[i], other.input_shapes[i], ndim * sizeof(*input_shapes[i]));
}
input_strides = new ptrdiff_t *[input_size];
for (size_t i = 0; i < input_size; ++i) {
input_strides[i] = new ptrdiff_t[ndim];
std::memcpy(input_strides[i], other.input_strides[i], ndim * sizeof(*input_strides[i]));
}
}
ElementwiseInfo(ElementwiseInfo &&other) noexcept
: output_size(other.output_size),
ndim(other.ndim),
output_contiguous(other.output_contiguous),
input_contiguous(other.input_contiguous),
input_broadcasted(other.input_broadcasted),
output_shape(other.output_shape),
input_shapes(other.input_shapes),
output_strides(other.output_strides),
input_strides(other.input_strides),
input_size(other.input_size) {
other.input_contiguous = nullptr;
other.input_broadcasted = nullptr;
other.output_shape = nullptr;
other.input_shapes = nullptr;
other.output_strides = nullptr;
other.input_strides = nullptr;
other.input_size = 0;
}
ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete;
ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete;
}; };
inline infiniStatus_t createElementwiseInfo( inline infiniStatus_t createElementwiseInfo(
...@@ -109,28 +146,37 @@ inline infiniStatus_t createElementwiseInfo( ...@@ -109,28 +146,37 @@ inline infiniStatus_t createElementwiseInfo(
return INFINI_STATUS_BAD_TENSOR_STRIDES; return INFINI_STATUS_BAD_TENSOR_STRIDES;
} }
const size_t input_size = input_descs.size(); info.input_size = input_descs.size();
const size_t out_ndim = output_desc->ndim(); info.ndim = output_desc->ndim();
// Intializing the ElementwiseInfo struct
info.output_size = output_desc->numel(); info.output_size = output_desc->numel();
info.ndim = out_ndim;
info.output_contiguous = output_desc->isContiguous(); info.output_contiguous = output_desc->isContiguous();
for (const auto &desc : input_descs) { // Allocate memory for arrays
info.input_contiguous.emplace_back(desc->isContiguous()); info.input_contiguous = new bool[info.input_size];
} info.input_broadcasted = new bool[info.input_size];
info.output_shape = new size_t[info.ndim];
for (size_t i = 0; i < input_size; ++i) { info.output_strides = new ptrdiff_t[info.ndim];
const auto &desc = input_descs[i]; info.input_shapes = new size_t *[info.input_size];
info.input_broadcasted.emplace_back(!info.input_contiguous[i] && (desc->ndim() != out_ndim || desc->hasBroadcastDim())); info.input_strides = new ptrdiff_t *[info.input_size];
}
// Fill arrays
info.output_shape = std::move(output_desc->shape()); const auto output_shape = output_desc->shape();
info.output_strides = std::move(output_desc->strides()); const auto output_strides = output_desc->strides();
for (const auto &desc : input_descs) { std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape));
info.input_shapes.emplace_back(desc->shape()); std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides));
info.input_strides.emplace_back(desc->strides());
for (size_t i = 0; i < info.input_size; ++i) {
auto &desc = input_descs[i];
info.input_contiguous[i] = desc->isContiguous();
info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim());
info.input_shapes[i] = new size_t[desc->ndim()];
const auto &in_shape = desc->shape();
std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i]));
info.input_strides[i] = new ptrdiff_t[desc->ndim()];
const auto &in_strides = desc->strides();
std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i]));
} }
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -8,11 +8,13 @@ infiniStatus_t Descriptor::create( ...@@ -8,11 +8,13 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle_, infiniopHandle_t handle_,
Descriptor **desc_ptr, Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc, infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t up_desc, std::vector<infiniopTensorDescriptor_t> input_desc) {
infiniopTensorDescriptor_t gate_desc) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dtype = out_desc->dtype(); auto dtype = out_desc->dtype();
const auto &up_desc = input_desc.at(0);
const auto &gate_desc = input_desc.at(1);
const auto &out_shape = out_desc->shape(); const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape(); const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape(); const auto &gate_shape = gate_desc->shape();
...@@ -21,35 +23,26 @@ infiniStatus_t Descriptor::create( ...@@ -21,35 +23,26 @@ infiniStatus_t Descriptor::create(
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape); CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
op::binary::BinaryInfo info; // create CPU elementwise descriptor
CHECK_STATUS(op::binary::createBinaryInfo(info, out_desc, up_desc, gate_desc)); CREATE_ELEMENTWISE_CPU_DESCRIPTOR;
// Create descriptor
*desc_ptr = new Descriptor(
dtype,
std::move(info),
nullptr,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *c, void *output,
const void *a, std::vector<const void *> inputs,
const void *b,
void *stream) const { void *stream) const {
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
op::common_cpu::binary_op::calculate<fp16_t, SwiGLUOp>(_info, c, a, b); _device_info->calculate<SwiGLUOp, fp16_t>(_info, output, inputs, stream);
break; break;
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
op::common_cpu::binary_op::calculate<float, SwiGLUOp>(_info, c, a, b); _device_info->calculate<SwiGLUOp, float>(_info, output, inputs, stream);
break; break;
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
op::common_cpu::binary_op::calculate<double, SwiGLUOp>(_info, c, a, b); _device_info->calculate<SwiGLUOp, double>(_info, output, inputs, stream);
break; break;
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
......
#ifndef __SWIGLU_CPU_H__ #ifndef __SWIGLU_CPU_H__
#define __SWIGLU_CPU_H__ #define __SWIGLU_CPU_H__
#include "../../../binary/cpu/binary_cpu.h" #include "../../../elementwise/cpu/elementwise_cpu.h"
BINARY_DESCRIPTOR(swiglu, cpu) ELEMENTWISE_DESCRIPTOR(swiglu, cpu)
struct SwiGLUOp { namespace op::swiglu::cpu {
typedef struct SwiGLUOp {
private: private:
template <typename T> template <typename T>
T sigmoid(const T &x) const { T sigmoid(const T &x) const {
...@@ -13,10 +14,12 @@ private: ...@@ -13,10 +14,12 @@ private:
} }
public: public:
static constexpr size_t num_inputs = 2;
template <typename T> template <typename T>
T operator()(const T &up, const T &gate) const { T operator()(const T &up, const T &gate) const {
return gate * sigmoid(gate) * up; return gate * sigmoid(gate) * up;
} }
}; } SwiGLUOp;
} // namespace op::swiglu::cpu
#endif // __SWIGLU_CPU_H__ #endif // __SWIGLU_CPU_H__
#include "swiglu_cuda.cuh"
#include "swiglu_cuda_internal.cuh"
namespace op::swiglu::cuda {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc) {
auto handle = reinterpret_cast<device::cuda::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &up_desc = input_desc.at(0);
const auto &gate_desc = input_desc.at(1);
const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
if (!SAME_VEC(out_shape, up_shape, gate_shape)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *output,
std::vector<const void *> inputs,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
_device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream);
break;
case INFINI_DTYPE_F32:
_device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream);
break;
case INFINI_DTYPE_F64:
_device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream);
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu::cuda
#ifndef __SWIGLU_CUDA_API_H__
#define __SWIGLU_CUDA_API_H__
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
ELEMENTWISE_DESCRIPTOR(swiglu, cuda)
#endif // __SWIGLU_CUDA_API_H__
#ifndef __SWIGLU_CUDA_H__
#define __SWIGLU_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_fp16.h>
namespace op::swiglu::cuda {
typedef struct SwiGLUOp {
private:
template <typename T>
__device__ __forceinline__ T sigmoid(const T &x) const {
if constexpr (std::is_same_v<T, half2>) {
return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x))));
} else if constexpr (std::is_same_v<T, half>) {
return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x))))));
} else if constexpr (std::is_same_v<T, float>) {
return __frcp_rd(__fadd_rd(1, __expf(-x)));
} else {
return 1 / (1 + std::exp(-x));
}
}
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__device__ __forceinline__ T operator()(const T &up, const T &gate) const {
if constexpr (std::is_same_v<T, half2>) {
return __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, float>) {
return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), up);
} else {
return gate * sigmoid(gate) * up;
}
}
template <typename Tc, typename Ta, typename Tb>
__device__ __forceinline__ Tc operator()(const Ta &up, const Tb &gate) const {
if constexpr (std::is_same_v<Ta, half2>) {
return __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<Ta, half> && std::is_same_v<Tb, float>) {
if constexpr (std::is_same_v<Tc, half>) {
return __float2half(__fmul_rd(__fmul_rd(gate, sigmoid(gate)), __half2float(up)));
} else {
return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), __half2float(up));
}
} else if constexpr (std::is_same_v<Ta, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<Ta, float>) {
return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), up);
} else {
return gate * sigmoid(gate) * up;
}
}
} SwiGLUOp;
} // namespace op::swiglu::cuda
#endif // __SWIGLU_CUDA_H__
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
#include "cpu/swiglu_cpu.h" #include "cpu/swiglu_cpu.h"
#endif #endif
#ifdef ENABLE_CUDA_API
#include "cuda/swiglu_cuda.cuh"
#endif
__C infiniStatus_t infiniopCreateSwiGLUDescriptor( __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -19,19 +22,16 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( ...@@ -19,19 +22,16 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
handle, \ handle, \
reinterpret_cast<op::swiglu::NAMESPACE::Descriptor **>(desc_ptr), \ reinterpret_cast<op::swiglu::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \ c_desc, \
a_desc, \ {a_desc, \
b_desc) b_desc})
switch (handle->device) { switch (handle->device) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu); CREATE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_CUDA_API
case DevNvGpu: CREATE(INFINI_DEVICE_NVIDIA, cuda);
return cudaCreateSwiGLUDescriptor((CudaHandle_t)handle,
(SwiGLUCudaDescriptor_t *)desc_ptr,
c_desc, a_desc, b_desc);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -76,16 +76,15 @@ __C infiniStatus_t infiniopSwiGLU( ...@@ -76,16 +76,15 @@ __C infiniStatus_t infiniopSwiGLU(
#define CALCULATE(CASE, NAMESPACE) \ #define CALCULATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc) \ return reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc) \
->calculate(c, a, b, stream) ->calculate(c, {a, b}, stream)
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu); CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_CUDA_API
case DevNvGpu: CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
return cudaSwiGLU((SwiGLUCudaDescriptor_t)desc, c, a, b, stream);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -125,9 +124,8 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { ...@@ -125,9 +124,8 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu); DELETE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_CUDA_API
case DevNvGpu: DELETE(INFINI_DEVICE_NVIDIA, cuda);
return cudaDestroySwiGLUDescriptor((SwiGLUCudaDescriptor_t)desc);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
......
...@@ -98,4 +98,6 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) { ...@@ -98,4 +98,6 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) {
} }
} }
#define CEIL_DIV(x, y) ((x + y - 1) / y)
#endif #endif
...@@ -28,6 +28,7 @@ target("infiniop-cuda") ...@@ -28,6 +28,7 @@ target("infiniop-cuda")
else else
add_cuflags("-Xcompiler=-Wall", "-Xcompiler=-Werror") add_cuflags("-Xcompiler=-Wall", "-Xcompiler=-Werror")
add_cuflags("-Xcompiler=-fPIC") add_cuflags("-Xcompiler=-fPIC")
add_cuflags("--extended-lambda")
add_culdflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC")
add_cxxflags("-fPIC") add_cxxflags("-fPIC")
end end
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment