Commit 9cc0c416 authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: Refactor ElementwiseInfo, refactor elementwise to use workspace for...

issue/127: Refactor ElementwiseInfo, refactor elementwise to use workspace for storing meta, fix misc. issues
parent 40fdded5
...@@ -11,7 +11,11 @@ __C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t hand ...@@ -11,7 +11,11 @@ __C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t hand
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc); infiniopTensorDescriptor_t b_desc);
__C __export infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, __C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c, void *c,
void const *a, void const *a,
void const *b, void const *b,
......
...@@ -9,6 +9,10 @@ ...@@ -9,6 +9,10 @@
#define CUDA_BLOCK_SIZE_1024 1024 #define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512 #define CUDA_BLOCK_SIZE_512 512
#define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess)
namespace device::cuda {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor // return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t __forceinline__ __device__ __host__ size_t
indexToReducedOffset( indexToReducedOffset(
...@@ -38,6 +42,7 @@ indexToOffset( ...@@ -38,6 +42,7 @@ indexToOffset(
} }
return res; return res;
} }
} // namespace device::cuda
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
#include <cuda_fp16.h> #include <cuda_fp16.h>
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
dtype, \ dtype, \
info_result.take(), \ info_result.take(), \
nullptr, \ nullptr, \
0, \
handle->device, \ handle->device, \
handle->device_id); handle->device_id);
...@@ -103,24 +104,34 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, Args &&...args) { ...@@ -103,24 +104,34 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, Args &&...args) {
} }
// Perform elementwise operation for different input types // Perform elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args,
void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, std::index_sequence<Is...>, Args &&...args) { std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
Args &&...args) {
Tout *out = reinterpret_cast<Tout *>(output); Tout *out = reinterpret_cast<Tout *>(output);
std::tuple<const Tin *...> input_ptrs = {reinterpret_cast<const Tin *>(inputs[Is])...}; std::tuple<const Tin *...> input_ptrs = {reinterpret_cast<const Tin *>(inputs[Is])...};
ptrdiff_t output_size = info.output_size; ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for #pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) { for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides); size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
auto get_input_idx = [&](size_t input_id) { auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i return info.getInputContiguous()[input_id]
: (info.input_broadcasted[input_id] ? i
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id]) : (info.getInputBroadcasted()[input_id]
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id])); ? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
}; };
out[out_idx] = utils::cast<Tout>(Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...)); out[out_idx] = utils::cast<Tout>(
Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
} }
} }
...@@ -147,17 +158,20 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, ...@@ -147,17 +158,20 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
Tdata *out = reinterpret_cast<Tdata *>(output); Tdata *out = reinterpret_cast<Tdata *>(output);
std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...}; std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
const ptrdiff_t output_size = info.output_size; const ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for #pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) { for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides); size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
auto get_input_idx = [&](size_t input_id) { auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i return info.getInputContiguous()[input_id]
: (info.input_broadcasted[input_id] ? i
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id]) : (info.getInputBroadcasted()[input_id]
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id])); ? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
}; };
if constexpr (std::is_same_v<Tdata, fp16_t>) { if constexpr (std::is_same_v<Tdata, fp16_t>) {
...@@ -170,7 +184,11 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, ...@@ -170,7 +184,11 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
// Invoke elementwise operation when all inputs have the same type // Invoke elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, typename... Args> template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, void *stream, Args &&...args) { infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs; constexpr size_t N = Op::num_inputs;
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...); calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -31,6 +31,7 @@ public: ...@@ -31,6 +31,7 @@ public:
* @tparam Args... Additional arguments passed to the operation. * @tparam Args... Additional arguments passed to the operation.
* *
* @param info Metadata describing tensor shapes, strides, etc. * @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device. * @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory). * @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*). * @param stream CUDA stream (opaque void*).
...@@ -40,6 +41,7 @@ public: ...@@ -40,6 +41,7 @@ public:
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args> template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate( infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info, const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream, void *stream,
...@@ -56,6 +58,7 @@ public: ...@@ -56,6 +58,7 @@ public:
* @tparam Tin... Input data types (must match Op::num_inputs). * @tparam Tin... Input data types (must match Op::num_inputs).
* @tparam Args... Additional arguments passed to the operation. * @tparam Args... Additional arguments passed to the operation.
* @param info Metadata describing tensor shapes, strides, etc. * @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device. * @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory). * @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*). * @param stream CUDA stream (opaque void*).
...@@ -67,6 +70,7 @@ public: ...@@ -67,6 +70,7 @@ public:
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculate( infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info, const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream, void *stream,
...@@ -82,14 +86,17 @@ public: ...@@ -82,14 +86,17 @@ public:
\ \
auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \ auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \
CHECK_RESULT(info_result); \ CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\ \
op::elementwise::cuda::DeviceImpl *device_impl; \ op::elementwise::cuda::DeviceImpl *device_impl; \
CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \ CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \
\ \
*desc_ptr = new Descriptor( \ *desc_ptr = new Descriptor( \
dtype, \ dtype, \
std::move(info_result.take()), \ std::move(info), \
device_impl, \ device_impl, \
workspace_size, \
handle->device, \ handle->device, \
handle->device_id); handle->device_id);
......
...@@ -19,21 +19,26 @@ ...@@ -19,21 +19,26 @@
infiniDtype_t _dtype; \ infiniDtype_t _dtype; \
op::elementwise::ElementwiseInfo _info; \ op::elementwise::ElementwiseInfo _info; \
std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \ std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
size_t _workspace_size; \
\ \
Descriptor( \ Descriptor( \
infiniDtype_t dtype, \ infiniDtype_t dtype, \
op::elementwise::ElementwiseInfo info, \ op::elementwise::ElementwiseInfo info, \
op::elementwise::NAMESPACE::DeviceImpl *device_info, \ op::elementwise::NAMESPACE::DeviceImpl *device_info, \
size_t workspace_size, \
infiniDevice_t device_type, \ infiniDevice_t device_type, \
int device_id) \ int device_id) \
: InfiniopDescriptor{device_type, device_id}, \ : InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \ _dtype(dtype), \
_info(std::move(info)), \ _info(std::move(info)), \
_device_info(device_info) {} \ _device_info(device_info), \
_workspace_size(workspace_size) {} \
\ \
public: \ public: \
~Descriptor(); \ ~Descriptor(); \
\ \
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \ static infiniStatus_t create( \
infiniopHandle_t handle, \ infiniopHandle_t handle, \
Descriptor **desc_ptr, \ Descriptor **desc_ptr, \
...@@ -41,6 +46,7 @@ ...@@ -41,6 +46,7 @@
std::vector<infiniopTensorDescriptor_t> input_descs); \ std::vector<infiniopTensorDescriptor_t> input_descs); \
\ \
infiniStatus_t calculate( \ infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *output, \ void *output, \
std::vector<const void *> inputs, \ std::vector<const void *> inputs, \
void *stream) const; \ void *stream) const; \
...@@ -62,57 +68,70 @@ namespace op::elementwise { ...@@ -62,57 +68,70 @@ namespace op::elementwise {
*/ */
struct ElementwiseInfo { struct ElementwiseInfo {
private: private:
ElementwiseInfo() = default; std::vector<int8_t> _meta;
size_t _output_size;
size_t _input_size;
size_t _ndim;
bool _output_contiguous;
ElementwiseInfo(std::vector<int8_t> meta,
size_t output_size,
size_t input_size,
size_t ndim,
bool output_contiguous)
: _meta(std::move(meta)), _output_size(output_size),
_input_size(input_size), _ndim(ndim),
_output_contiguous(output_contiguous) {}
public: public:
size_t output_size; inline size_t getMetaMemSize() const {
size_t ndim; return _meta.size();
bool output_contiguous; }
bool *input_contiguous; inline const int8_t *getMetaStart() const {
bool *input_broadcasted; return _meta.data();
size_t *output_shape; }
size_t **input_shapes; inline size_t getOutputSize() const {
ptrdiff_t *output_strides; return _output_size;
ptrdiff_t **input_strides; }
size_t input_size; inline size_t getInputSize() const {
return _input_size;
~ElementwiseInfo() { }
delete[] input_contiguous; inline size_t getNdim() const {
delete[] input_broadcasted; return _ndim;
delete[] output_shape; }
delete[] output_strides; inline bool isOutputContiguous() const {
return _output_contiguous;
for (size_t i = 0; i < input_size; ++i) { }
delete[] input_shapes[i]; inline const size_t *getOutputShape() const {
delete[] input_strides[i]; return reinterpret_cast<const size_t *>(_meta.data());
} }
delete[] input_shapes; inline const ptrdiff_t *getOutputStrides() const {
delete[] input_strides; return reinterpret_cast<const ptrdiff_t *>(getOutputShape() + _ndim);
} }
inline const size_t *getAllInputShapes() const {
ElementwiseInfo(ElementwiseInfo &&other) noexcept return reinterpret_cast<const size_t *>(getOutputStrides() + _ndim);
: output_size(other.output_size), }
ndim(other.ndim), inline const size_t *getInputShape(const size_t &index) const {
output_contiguous(other.output_contiguous), if (index < _input_size) {
input_contiguous(other.input_contiguous), return reinterpret_cast<const size_t *>(getAllInputShapes() + index * _ndim);
input_broadcasted(other.input_broadcasted), }
output_shape(other.output_shape), return nullptr;
input_shapes(other.input_shapes), }
output_strides(other.output_strides), inline const ptrdiff_t *getAllInputStrides() const {
input_strides(other.input_strides), return reinterpret_cast<const ptrdiff_t *>(getAllInputShapes() + _input_size * _ndim);
input_size(other.input_size) { }
other.input_contiguous = nullptr; inline const ptrdiff_t *getInputStrides(const size_t &index) const {
other.input_broadcasted = nullptr; if (index < _input_size) {
other.output_shape = nullptr; return reinterpret_cast<const ptrdiff_t *>(getAllInputStrides() + index * _ndim);
other.input_shapes = nullptr; }
other.output_strides = nullptr; return nullptr;
other.input_strides = nullptr; }
other.input_size = 0; inline const bool *getInputContiguous() const {
} return reinterpret_cast<const bool *>(getAllInputStrides() + _input_size * _ndim);
}
ElementwiseInfo(const ElementwiseInfo &other) = delete; inline const bool *getInputBroadcasted() const {
ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete; return reinterpret_cast<const bool *>(getInputContiguous() + _input_size);
ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete; }
using ResultType = utils::Result<ElementwiseInfo>; using ResultType = utils::Result<ElementwiseInfo>;
...@@ -136,40 +155,48 @@ public: ...@@ -136,40 +155,48 @@ public:
return INFINI_STATUS_BAD_TENSOR_STRIDES; return INFINI_STATUS_BAD_TENSOR_STRIDES;
} }
ElementwiseInfo info; auto input_size = input_descs.size();
info.input_size = input_descs.size(); auto ndim = output_desc->ndim();
info.ndim = output_desc->ndim(); auto output_size = output_desc->numel();
info.output_size = output_desc->numel(); auto output_contiguous = output_desc->isContiguous();
info.output_contiguous = output_desc->isContiguous();
// Allocate memory for meta
// Allocate memory for arrays auto shape_unit = output_desc->dim(0);
info.input_contiguous = new bool[info.input_size]; auto stride_unit = output_desc->stride(0);
info.input_broadcasted = new bool[info.input_size]; size_t meta_mem_size = ndim * (sizeof(shape_unit) + sizeof(stride_unit))
info.output_shape = new size_t[info.ndim]; + input_size * ndim * sizeof(shape_unit)
info.output_strides = new ptrdiff_t[info.ndim]; + input_size * ndim * sizeof(stride_unit)
info.input_shapes = new size_t *[info.input_size]; + 2 * input_size * sizeof(bool);
info.input_strides = new ptrdiff_t *[info.input_size]; std::vector<int8_t> meta(meta_mem_size);
int8_t *meta_ptr = meta.data();
// Fill arrays
const auto output_shape = output_desc->shape(); const auto output_shape = output_desc->shape();
const auto output_strides = output_desc->strides(); const auto output_strides = output_desc->strides();
std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape));
std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides));
for (size_t i = 0; i < info.input_size; ++i) { // Pointers to the sections within _meta
auto &desc = input_descs[i]; size_t *output_shape_p = reinterpret_cast<size_t *>(meta_ptr);
info.input_contiguous[i] = desc->isContiguous(); ptrdiff_t *output_strides_p = reinterpret_cast<ptrdiff_t *>(output_shape_p + ndim);
info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim()); size_t *input_shapes = reinterpret_cast<size_t *>(output_strides_p + ndim);
ptrdiff_t *input_strides = reinterpret_cast<ptrdiff_t *>(input_shapes + input_size * ndim);
bool *input_contiguous = reinterpret_cast<bool *>(input_strides + input_size * ndim);
bool *input_broadcasted = input_contiguous + input_size;
info.input_shapes[i] = new size_t[desc->ndim()]; // Copy output shape and strides
const auto &in_shape = desc->shape(); std::memcpy(output_shape_p, output_shape.data(), ndim * sizeof(*output_shape_p));
std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i])); std::memcpy(output_strides_p, output_strides.data(), ndim * sizeof(*output_strides_p));
info.input_strides[i] = new ptrdiff_t[desc->ndim()]; // Copy input shapes, strides, contiguous, and broadcasted flags
const auto &in_strides = desc->strides(); for (size_t i = 0; i < input_size; ++i) {
std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i])); auto &desc = input_descs[i];
const auto in_shape = desc->shape();
const auto in_strides = desc->strides();
std::memcpy(input_shapes + i * ndim, in_shape.data(), ndim * sizeof(*input_shapes));
std::memcpy(input_strides + i * ndim, in_strides.data(), ndim * sizeof(*input_strides));
input_contiguous[i] = desc->isContiguous();
input_broadcasted[i] = !input_contiguous[i] && (desc->ndim() != ndim || desc->hasBroadcastDim());
} }
ElementwiseInfo info(std::move(meta), output_size, input_size, ndim, output_contiguous);
return ResultType(std::move(info)); return ResultType(std::move(info));
} }
}; };
......
...@@ -30,6 +30,8 @@ infiniStatus_t Descriptor::create( ...@@ -30,6 +30,8 @@ infiniStatus_t Descriptor::create(
} }
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output, void *output,
std::vector<const void *> inputs, std::vector<const void *> inputs,
void *stream) const { void *stream) const {
......
...@@ -10,7 +10,7 @@ typedef struct SwiGLUOp { ...@@ -10,7 +10,7 @@ typedef struct SwiGLUOp {
private: private:
template <typename T> template <typename T>
T sigmoid(const T &x) const { T sigmoid(const T &x) const {
return 1 / (1 + std::exp(-x)); return T(1) / (T(1) + std::exp(-x));
} }
public: public:
......
...@@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create( ...@@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create(
const auto &gate_shape = gate_desc->shape(); const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
if (!SAME_VEC(out_shape, up_shape, gate_shape)) { CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// create CUDA elementwise descriptor // create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR CREATE_ELEMENTWISE_CUDA_DESCRIPTOR
...@@ -32,17 +30,23 @@ infiniStatus_t Descriptor::create( ...@@ -32,17 +30,23 @@ infiniStatus_t Descriptor::create(
} }
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output, void *output,
std::vector<const void *> inputs, std::vector<const void *> inputs,
void *stream) const { void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
return _device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream); return _device_info->calculate<256, SwiGLUOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
return _device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream); return _device_info->calculate<256, SwiGLUOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
return _device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream); return _device_info->calculate<256, SwiGLUOp, double>(_info, workspace, output, inputs, stream);
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
...@@ -66,8 +66,49 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( ...@@ -66,8 +66,49 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
#undef CREATE #undef CREATE
} }
__C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::swiglu::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu)
#endif
#ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaGetSwiGLUWorkspaceSize((SwiGLUMacaDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaGetSwiGLUWorkspaceSize((SwiGLUMusaDescriptor_t)desc, size);
}
#endif
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopSwiGLU( __C infiniStatus_t infiniopSwiGLU(
infiniopSwiGLUDescriptor_t desc, infiniopSwiGLUDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c, void *c,
const void *a, const void *a,
const void *b, const void *b,
...@@ -76,7 +117,7 @@ __C infiniStatus_t infiniopSwiGLU( ...@@ -76,7 +117,7 @@ __C infiniStatus_t infiniopSwiGLU(
#define CALCULATE(CASE, NAMESPACE) \ #define CALCULATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc) \ return reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc) \
->calculate(c, {a, b}, stream) ->calculate(workspace, workspace_size, c, {a, b}, stream)
switch (desc->device_type) { switch (desc->device_type) {
......
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_void_p from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
...@@ -14,6 +14,7 @@ from libinfiniop import ( ...@@ -14,6 +14,7 @@ from libinfiniop import (
debug, debug,
get_tolerance, get_tolerance,
profile_operation, profile_operation,
create_workspace
) )
from enum import Enum, auto from enum import Enum, auto
...@@ -160,10 +161,19 @@ def test( ...@@ -160,10 +161,19 @@ def test(
for tensor in [a_tensor, b_tensor, c_tensor]: for tensor in [a_tensor, b_tensor, c_tensor]:
tensor.destroyDesc(lib) tensor.destroyDesc(lib)
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetSwiGLUWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, c.device)
def lib_swiglu(): def lib_swiglu():
check_error( check_error(
lib.infiniopSwiGLU( lib.infiniopSwiGLU(
descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
c_tensor.data, a_tensor.data, b_tensor.data, None
) )
) )
...@@ -196,10 +206,18 @@ if __name__ == "__main__": ...@@ -196,10 +206,18 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
] ]
lib.infiniopGetSwiGLUWorkspaceSize.restype = c_int32
lib.infiniopGetSwiGLUWorkspaceSize.argtypes = [
infiniopSwiGLUDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopSwiGLU.restype = c_int32 lib.infiniopSwiGLU.restype = c_int32
lib.infiniopSwiGLU.argtypes = [ lib.infiniopSwiGLU.argtypes = [
infiniopSwiGLUDescriptor_t, infiniopSwiGLUDescriptor_t,
c_void_p, c_void_p,
c_uint64,
c_void_p,
c_void_p, c_void_p,
c_void_p, c_void_p,
c_void_p, c_void_p,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment