Commit 9cc0c416 authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: Refactor ElementwiseInfo, refactor elementwise to use workspace for...

issue/127: Refactor ElementwiseInfo, refactor elementwise to use workspace for storing meta, fix misc. issues
parent 40fdded5
......@@ -11,7 +11,11 @@ __C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t hand
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc);
__C __export infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
void const *a,
void const *b,
......
......@@ -9,6 +9,10 @@
#define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512
#define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess)
namespace device::cuda {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
indexToReducedOffset(
......@@ -38,6 +42,7 @@ indexToOffset(
}
return res;
}
} // namespace device::cuda
#ifdef ENABLE_CUDA_API
#include <cuda_fp16.h>
......
......@@ -18,6 +18,7 @@
dtype, \
info_result.take(), \
nullptr, \
0, \
handle->device, \
handle->device_id);
......@@ -103,24 +104,34 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, Args &&...args) {
}
// Perform elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, std::index_sequence<Is...>, Args &&...args) {
template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
Args &&...args) {
Tout *out = reinterpret_cast<Tout *>(output);
std::tuple<const Tin *...> input_ptrs = {reinterpret_cast<const Tin *>(inputs[Is])...};
ptrdiff_t output_size = info.output_size;
ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides);
size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i
: (info.input_broadcasted[input_id]
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id])
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id]));
return info.getInputContiguous()[input_id]
? i
: (info.getInputBroadcasted()[input_id]
? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
};
out[out_idx] = utils::cast<Tout>(Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
out[out_idx] = utils::cast<Tout>(
Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
}
}
......@@ -147,17 +158,20 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
Tdata *out = reinterpret_cast<Tdata *>(output);
std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
const ptrdiff_t output_size = info.output_size;
const ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides);
size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i
: (info.input_broadcasted[input_id]
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id])
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id]));
return info.getInputContiguous()[input_id]
? i
: (info.getInputBroadcasted()[input_id]
? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
};
if constexpr (std::is_same_v<Tdata, fp16_t>) {
......@@ -170,7 +184,11 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
// Invoke elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, void *stream, Args &&...args) {
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
......
......@@ -31,6 +31,7 @@ public:
* @tparam Args... Additional arguments passed to the operation.
*
* @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
......@@ -40,6 +41,7 @@ public:
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
......@@ -56,6 +58,7 @@ public:
* @tparam Tin... Input data types (must match Op::num_inputs).
* @tparam Args... Additional arguments passed to the operation.
* @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
......@@ -67,6 +70,7 @@ public:
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
......@@ -82,14 +86,17 @@ public:
\
auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
op::elementwise::cuda::DeviceImpl *device_impl; \
CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \
\
*desc_ptr = new Descriptor( \
dtype, \
std::move(info_result.take()), \
std::move(info), \
device_impl, \
workspace_size, \
handle->device, \
handle->device_id);
......
......@@ -19,21 +19,26 @@
infiniDtype_t _dtype; \
op::elementwise::ElementwiseInfo _info; \
std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
size_t _workspace_size; \
\
Descriptor( \
infiniDtype_t dtype, \
op::elementwise::ElementwiseInfo info, \
op::elementwise::NAMESPACE::DeviceImpl *device_info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \
_info(std::move(info)), \
_device_info(device_info) {} \
_device_info(device_info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
......@@ -41,6 +46,7 @@
std::vector<infiniopTensorDescriptor_t> input_descs); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *output, \
std::vector<const void *> inputs, \
void *stream) const; \
......@@ -62,57 +68,70 @@ namespace op::elementwise {
*/
struct ElementwiseInfo {
private:
ElementwiseInfo() = default;
std::vector<int8_t> _meta;
size_t _output_size;
size_t _input_size;
size_t _ndim;
bool _output_contiguous;
ElementwiseInfo(std::vector<int8_t> meta,
size_t output_size,
size_t input_size,
size_t ndim,
bool output_contiguous)
: _meta(std::move(meta)), _output_size(output_size),
_input_size(input_size), _ndim(ndim),
_output_contiguous(output_contiguous) {}
public:
size_t output_size;
size_t ndim;
bool output_contiguous;
bool *input_contiguous;
bool *input_broadcasted;
size_t *output_shape;
size_t **input_shapes;
ptrdiff_t *output_strides;
ptrdiff_t **input_strides;
size_t input_size;
~ElementwiseInfo() {
delete[] input_contiguous;
delete[] input_broadcasted;
delete[] output_shape;
delete[] output_strides;
for (size_t i = 0; i < input_size; ++i) {
delete[] input_shapes[i];
delete[] input_strides[i];
inline size_t getMetaMemSize() const {
return _meta.size();
}
inline const int8_t *getMetaStart() const {
return _meta.data();
}
inline size_t getOutputSize() const {
return _output_size;
}
inline size_t getInputSize() const {
return _input_size;
}
inline size_t getNdim() const {
return _ndim;
}
inline bool isOutputContiguous() const {
return _output_contiguous;
}
inline const size_t *getOutputShape() const {
return reinterpret_cast<const size_t *>(_meta.data());
}
inline const ptrdiff_t *getOutputStrides() const {
return reinterpret_cast<const ptrdiff_t *>(getOutputShape() + _ndim);
}
inline const size_t *getAllInputShapes() const {
return reinterpret_cast<const size_t *>(getOutputStrides() + _ndim);
}
inline const size_t *getInputShape(const size_t &index) const {
if (index < _input_size) {
return reinterpret_cast<const size_t *>(getAllInputShapes() + index * _ndim);
}
return nullptr;
}
inline const ptrdiff_t *getAllInputStrides() const {
return reinterpret_cast<const ptrdiff_t *>(getAllInputShapes() + _input_size * _ndim);
}
inline const ptrdiff_t *getInputStrides(const size_t &index) const {
if (index < _input_size) {
return reinterpret_cast<const ptrdiff_t *>(getAllInputStrides() + index * _ndim);
}
delete[] input_shapes;
delete[] input_strides;
}
ElementwiseInfo(ElementwiseInfo &&other) noexcept
: output_size(other.output_size),
ndim(other.ndim),
output_contiguous(other.output_contiguous),
input_contiguous(other.input_contiguous),
input_broadcasted(other.input_broadcasted),
output_shape(other.output_shape),
input_shapes(other.input_shapes),
output_strides(other.output_strides),
input_strides(other.input_strides),
input_size(other.input_size) {
other.input_contiguous = nullptr;
other.input_broadcasted = nullptr;
other.output_shape = nullptr;
other.input_shapes = nullptr;
other.output_strides = nullptr;
other.input_strides = nullptr;
other.input_size = 0;
}
ElementwiseInfo(const ElementwiseInfo &other) = delete;
ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete;
ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete;
return nullptr;
}
inline const bool *getInputContiguous() const {
return reinterpret_cast<const bool *>(getAllInputStrides() + _input_size * _ndim);
}
inline const bool *getInputBroadcasted() const {
return reinterpret_cast<const bool *>(getInputContiguous() + _input_size);
}
using ResultType = utils::Result<ElementwiseInfo>;
......@@ -136,40 +155,48 @@ public:
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
ElementwiseInfo info;
info.input_size = input_descs.size();
info.ndim = output_desc->ndim();
info.output_size = output_desc->numel();
info.output_contiguous = output_desc->isContiguous();
// Allocate memory for arrays
info.input_contiguous = new bool[info.input_size];
info.input_broadcasted = new bool[info.input_size];
info.output_shape = new size_t[info.ndim];
info.output_strides = new ptrdiff_t[info.ndim];
info.input_shapes = new size_t *[info.input_size];
info.input_strides = new ptrdiff_t *[info.input_size];
// Fill arrays
auto input_size = input_descs.size();
auto ndim = output_desc->ndim();
auto output_size = output_desc->numel();
auto output_contiguous = output_desc->isContiguous();
// Allocate memory for meta
auto shape_unit = output_desc->dim(0);
auto stride_unit = output_desc->stride(0);
size_t meta_mem_size = ndim * (sizeof(shape_unit) + sizeof(stride_unit))
+ input_size * ndim * sizeof(shape_unit)
+ input_size * ndim * sizeof(stride_unit)
+ 2 * input_size * sizeof(bool);
std::vector<int8_t> meta(meta_mem_size);
int8_t *meta_ptr = meta.data();
const auto output_shape = output_desc->shape();
const auto output_strides = output_desc->strides();
std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape));
std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides));
for (size_t i = 0; i < info.input_size; ++i) {
auto &desc = input_descs[i];
info.input_contiguous[i] = desc->isContiguous();
info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim());
// Pointers to the sections within _meta
size_t *output_shape_p = reinterpret_cast<size_t *>(meta_ptr);
ptrdiff_t *output_strides_p = reinterpret_cast<ptrdiff_t *>(output_shape_p + ndim);
size_t *input_shapes = reinterpret_cast<size_t *>(output_strides_p + ndim);
ptrdiff_t *input_strides = reinterpret_cast<ptrdiff_t *>(input_shapes + input_size * ndim);
bool *input_contiguous = reinterpret_cast<bool *>(input_strides + input_size * ndim);
bool *input_broadcasted = input_contiguous + input_size;
info.input_shapes[i] = new size_t[desc->ndim()];
const auto &in_shape = desc->shape();
std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i]));
// Copy output shape and strides
std::memcpy(output_shape_p, output_shape.data(), ndim * sizeof(*output_shape_p));
std::memcpy(output_strides_p, output_strides.data(), ndim * sizeof(*output_strides_p));
info.input_strides[i] = new ptrdiff_t[desc->ndim()];
const auto &in_strides = desc->strides();
std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i]));
// Copy input shapes, strides, contiguous, and broadcasted flags
for (size_t i = 0; i < input_size; ++i) {
auto &desc = input_descs[i];
const auto in_shape = desc->shape();
const auto in_strides = desc->strides();
std::memcpy(input_shapes + i * ndim, in_shape.data(), ndim * sizeof(*input_shapes));
std::memcpy(input_strides + i * ndim, in_strides.data(), ndim * sizeof(*input_strides));
input_contiguous[i] = desc->isContiguous();
input_broadcasted[i] = !input_contiguous[i] && (desc->ndim() != ndim || desc->hasBroadcastDim());
}
ElementwiseInfo info(std::move(meta), output_size, input_size, ndim, output_contiguous);
return ResultType(std::move(info));
}
};
......
......@@ -30,6 +30,8 @@ infiniStatus_t Descriptor::create(
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
......
......@@ -10,7 +10,7 @@ typedef struct SwiGLUOp {
private:
template <typename T>
T sigmoid(const T &x) const {
return 1 / (1 + std::exp(-x));
return T(1) / (T(1) + std::exp(-x));
}
public:
......
......@@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create(
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
if (!SAME_VEC(out_shape, up_shape, gate_shape)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR
......@@ -32,17 +30,23 @@ infiniStatus_t Descriptor::create(
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream);
return _device_info->calculate<256, SwiGLUOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream);
return _device_info->calculate<256, SwiGLUOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream);
return _device_info->calculate<256, SwiGLUOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -66,8 +66,49 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
#undef CREATE
}
__C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::swiglu::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu)
#endif
#ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaGetSwiGLUWorkspaceSize((SwiGLUMacaDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaGetSwiGLUWorkspaceSize((SwiGLUMusaDescriptor_t)desc, size);
}
#endif
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopSwiGLU(
infiniopSwiGLUDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
const void *a,
const void *b,
......@@ -76,7 +117,7 @@ __C infiniStatus_t infiniopSwiGLU(
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc) \
->calculate(c, {a, b}, stream)
->calculate(workspace, workspace_size, c, {a, b}, stream)
switch (desc->device_type) {
......
import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_void_p
from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64
from libinfiniop import (
infiniopHandle_t,
infiniopTensorDescriptor_t,
......@@ -14,6 +14,7 @@ from libinfiniop import (
debug,
get_tolerance,
profile_operation,
create_workspace
)
from enum import Enum, auto
......@@ -160,10 +161,19 @@ def test(
for tensor in [a_tensor, b_tensor, c_tensor]:
tensor.destroyDesc(lib)
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetSwiGLUWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, c.device)
def lib_swiglu():
check_error(
lib.infiniopSwiGLU(
descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
c_tensor.data, a_tensor.data, b_tensor.data, None
)
)
......@@ -196,10 +206,18 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t,
]
lib.infiniopGetSwiGLUWorkspaceSize.restype = c_int32
lib.infiniopGetSwiGLUWorkspaceSize.argtypes = [
infiniopSwiGLUDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopSwiGLU.restype = c_int32
lib.infiniopSwiGLU.argtypes = [
infiniopSwiGLUDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment