Commit 8b760951 authored by zhangyue's avatar zhangyue
Browse files

Merge branch 'main' of https://github.com/InfiniTensor/InfiniCore into issue-385

parents eb3972eb d4b03cf7
...@@ -8,12 +8,17 @@ namespace device::bang { ...@@ -8,12 +8,17 @@ namespace device::bang {
Handle::Handle(infiniDevice_t device, int device_id) Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id}, : InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>()) {} _internal(std::make_shared<Handle::Internal>(device_id)) {}
auto Handle::internal() const -> const std::shared_ptr<Internal> & { auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal; return _internal;
} }
Handle::Internal::Internal(int device_id) {
cnrtDeviceGetAttribute(&_cluster_count, cnrtAttrClusterCount, device_id);
cnrtDeviceGetAttribute(&_core_per_cluster, cnrtAttrMcorePerCluster, device_id);
}
infiniStatus_t Handle::Internal::useCnnl(cnrtQueue_t queue, const Fn<cnnlHandle_t> &f) const { infiniStatus_t Handle::Internal::useCnnl(cnrtQueue_t queue, const Fn<cnnlHandle_t> &f) const {
auto handle = cnnl_handles.pop(); auto handle = cnnl_handles.pop();
if (!handle) { if (!handle) {
...@@ -25,6 +30,9 @@ infiniStatus_t Handle::Internal::useCnnl(cnrtQueue_t queue, const Fn<cnnlHandle_ ...@@ -25,6 +30,9 @@ infiniStatus_t Handle::Internal::useCnnl(cnrtQueue_t queue, const Fn<cnnlHandle_
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
int Handle::Internal::getCorePerCluster() const { return _core_per_cluster; }
int Handle::Internal::getClusterCount() const { return _cluster_count; }
cnnlDataType_t getCnnlDtype(infiniDtype_t dt) { cnnlDataType_t getCnnlDtype(infiniDtype_t dt) {
switch (dt) { switch (dt) {
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
......
#ifndef __INFINIOP_BANG_KERNEL_COMMON_H__
#define __INFINIOP_BANG_KERNEL_COMMON_H__
// Include Cambricon CNNL and CNRT headers for MLU (Machine Learning Unit) specific functions
#include "cnnl.h"
#include "cnrt.h"
namespace device::bang::kernel {
/**
* @brief Converts a flattened index to a reduced offset considering broadcasting.
*
* This function is used when dealing with broadcasted tensors where the input
* has been broadcast to match the output shape. It calculates the offset in
* the original (non-broadcasted) tensor.
*
* @param flat_index The flattened index in the output tensor
* @param ndim Number of dimensions
* @param broadcasted_strides Strides of the broadcasted tensor
* @param target_strides Strides of the original (non-broadcasted) tensor
* @return size_t Offset in the original tensor's memory
*/
inline __mlu_device__ size_t indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
// Calculate contribution from each dimension
res += flat_index / broadcasted_strides[i] * target_strides[i];
// Remove the contribution from this dimension
flat_index %= broadcasted_strides[i];
}
return res;
}
/**
* @brief Converts a flattened index to a memory offset considering tensor striding.
*
* This is the general case for non-contiguous tensors where elements are not
* stored sequentially in memory.
*
* @param flat_index The flattened index in the tensor
* @param ndim Number of dimensions
* @param shape Tensor shape
* @param strides Tensor strides (in elements)
* @return size_t Offset in the tensor's memory
*/
inline __mlu_device__ size_t indexToOffset(
size_t flat_index,
size_t ndim,
const size_t *shape,
const ptrdiff_t *strides) {
size_t res = 0;
// Process dimensions from highest to lowest
for (size_t i = ndim; i-- > 0;) {
// Add contribution from this dimension
res += (flat_index % shape[i]) * strides[i];
// Remove the contribution from this dimension
flat_index /= shape[i];
}
return res;
}
/**
* @brief Helper struct for computing input tensor indices considering broadcasting and striding.
*
* This is particularly useful for operations where inputs may be broadcasted
* to match the output shape, or may have non-contiguous memory layouts.
*/
struct InputIndexer {
size_t idx; // Base index for this task
size_t ndim; // Number of dimensions
const bool *input_contiguous; // Array indicating which inputs are contiguous
const bool *input_broadcasted; // Array indicating which inputs are broadcasted
const size_t *input_shapes; // Array of input shapes (concatenated)
const ptrdiff_t *input_strides; // Array of input strides (concatenated)
const ptrdiff_t *output_strides; // Output tensor strides
/**
* @brief Computes memory offset for input tensor element.
*
* @param input_id Input tensor ID.
* @param element_idx Element index in output tensor.
* @return size_t Memory offset in input tensor.
*/
__mlu_device__ size_t operator()(size_t input_id, size_t element_idx) const {
size_t global_idx = idx + element_idx;
return input_contiguous[input_id]
? global_idx // Simple case: contiguous memory
: (input_broadcasted[input_id]
// Handle broadcasted case
? indexToReducedOffset(global_idx, ndim, output_strides, input_strides + input_id * ndim)
// General non-contiguous case
: indexToOffset(global_idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
/**
* @brief Computes output tensor index considering striding.
*
* @param idx Linear index.
* @param is_contiguous Whether output is contiguous.
* @param ndim Number of dimensions.
* @param shape Output tensor shape.
* @param strides Output tensor strides.
* @return size_t Memory offset in output tensor.
*/
inline __mlu_device__ size_t
getOutputIndex(size_t idx,
bool is_contiguous,
size_t ndim,
const size_t *shape,
const ptrdiff_t *strides) {
return is_contiguous ? idx : indexToOffset(idx, ndim, shape, strides);
} /**
* @brief Calculates optimal chunk size for memory operations based on tensor contiguity.
*
* This function doesn't handle tensors with non-standard strides, which
* require more general optimizations not specific to Cambricon.
*
* @param global_idx_ Starting global index.
* @param ndim Number of dimensions.
* @param shape Tensor shape.
* @param strides Tensor strides.
* @param max_len Maximum allowed chunk size.
* @return size_t Optimal chunk size for memory operations.
*/
__mlu_device__ size_t calculateChunkSize(
size_t global_idx_,
size_t ndim,
const size_t *shape,
const ptrdiff_t *strides,
size_t max_len) {
// Find the last dimension that is contiguous
int last_contiguous_dim = -1;
ptrdiff_t expected_stride = 1;
for (int i = (int)ndim - 1; i >= 0; --i) {
if (strides[i] != expected_stride) {
break;
}
last_contiguous_dim = i;
if (i > 0) {
expected_stride *= shape[i];
}
}
if (last_contiguous_dim < 0) {
return 1;
}
// Calculate position in the contiguous block
size_t global_idx = global_idx_;
size_t pos_in_block = 0;
size_t block_size = 1;
for (int i = (int)ndim - 1; i >= last_contiguous_dim; --i) {
size_t dim_idx = global_idx % shape[i];
pos_in_block += dim_idx * block_size;
block_size *= shape[i];
global_idx /= shape[i];
}
size_t remaining_in_block = block_size - pos_in_block;
return std::min(max_len, remaining_in_block);
}
/**
* @brief Helper function for non-contiguous memory copy
*
* @param dst Destination buffer
* @param src Source buffer
* @param direction Memory copy direction (GDRAM2NRAM or NRAM2GDRAM)
* @param indexer Input indexer helper (for input copies)
* @param input_idx Input tensor index (for input copies)
* @param processed Number of elements already processed
* @param curr_batch Current batch size
* @param start_idx Starting index for this task
* @param ndim Number of dimensions
* @param shape Tensor shape
* @param strides Tensor strides
* @param is_input_copy Whether this is an input copy operation
*/
template <typename Tdata>
__mlu_device__ void nonContiguousMemcpy(
Tdata *dst,
Tdata *src,
mluMemcpyDirection_t direction,
InputIndexer &indexer,
size_t input_idx,
size_t processed,
size_t curr_batch,
size_t start_idx,
size_t ndim,
const size_t *shape,
const ptrdiff_t *strides,
bool is_input_copy) {
size_t remaining = curr_batch;
size_t current_pos = 0;
while (remaining > 0) {
size_t element_offset = is_input_copy ? indexer(input_idx, processed + current_pos) : getOutputIndex(start_idx + processed + current_pos,
false, // output_contiguous is false for non-contiguous
ndim, shape, strides);
size_t chunk_size = calculateChunkSize(start_idx + processed + current_pos,
ndim,
shape,
strides,
remaining);
__memcpy_async(dst + (is_input_copy ? current_pos : element_offset),
src + (is_input_copy ? element_offset : current_pos),
chunk_size * sizeof(Tdata),
direction);
current_pos += chunk_size;
remaining -= chunk_size;
}
}
} // namespace device::bang::kernel
#endif
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define __COMMON_BANG_H__ #define __COMMON_BANG_H__
#include "../../../utils.h" #include "../../../utils.h"
#include "../../tensor.h"
#include "../pool.h" #include "../pool.h"
#include "bang_handle.h" #include "bang_handle.h"
#include "cnnl.h" #include "cnnl.h"
...@@ -10,16 +11,27 @@ ...@@ -10,16 +11,27 @@
#define CHECK_BANG(API) CHECK_INTERNAL(API, CNNL_STATUS_SUCCESS) #define CHECK_BANG(API) CHECK_INTERNAL(API, CNNL_STATUS_SUCCESS)
#define NRAM_MAX_SIZE 1024 * 240
constexpr size_t ALIGN_SIZE = 128;
namespace device::bang { namespace device::bang {
class Handle::Internal { class Handle::Internal {
Pool<cnnlHandle_t> cnnl_handles; Pool<cnnlHandle_t> cnnl_handles;
int _core_per_cluster;
int _cluster_count;
template <typename T> template <typename T>
using Fn = std::function<infiniStatus_t(T)>; using Fn = std::function<infiniStatus_t(T)>;
public: public:
Internal(int);
infiniStatus_t useCnnl(cnrtQueue_t queue, const Fn<cnnlHandle_t> &f) const; infiniStatus_t useCnnl(cnrtQueue_t queue, const Fn<cnnlHandle_t> &f) const;
int getCorePerCluster() const;
int getClusterCount() const;
}; };
cnnlDataType_t getCnnlDtype(infiniDtype_t dt); cnnlDataType_t getCnnlDtype(infiniDtype_t dt);
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "ascend/ascend_handle.h" #include "ascend/ascend_handle.h"
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
#include "musa/musa_handle.h" #include "moore/moore_handle.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
#include "kunlun/kunlun_handle.h" #include "kunlun/kunlun_handle.h"
...@@ -54,7 +54,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) { ...@@ -54,7 +54,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
CREATE(INFINI_DEVICE_ASCEND, ascend); CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa); CREATE(INFINI_DEVICE_MOORE, moore);
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun); CREATE(INFINI_DEVICE_KUNLUN, kunlun);
...@@ -94,7 +94,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) { ...@@ -94,7 +94,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
DELETE(INFINI_DEVICE_ASCEND, ascend); DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, musa); DELETE(INFINI_DEVICE_MOORE, moore);
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun); DELETE(INFINI_DEVICE_KUNLUN, kunlun);
......
#include "../../../utils.h" #include "../../../utils.h"
#include "../pool.h" #include "../pool.h"
#include "musa_handle.h" #include "moore_handle.h"
#include <mublas.h> #include <mublas.h>
#include <mudnn.h> #include <mudnn.h>
#include <musa.h> #include <musa.h>
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS) #define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS)
#define CHECK_MUDNN(API) CHECK_INTERNAL((int)API, (int)::musa::dnn::Status::SUCCESS) #define CHECK_MUDNN(API) CHECK_INTERNAL((int)API, (int)::musa::dnn::Status::SUCCESS)
namespace device::musa { namespace device::moore {
class Handle::Internal { class Handle::Internal {
Pool<std::unique_ptr<mublasHandle_t>> mublas_handles; Pool<std::unique_ptr<mublasHandle_t>> mublas_handles;
...@@ -39,4 +39,4 @@ public: ...@@ -39,4 +39,4 @@ public:
int gridSizeZ() const; int gridSizeZ() const;
}; };
} // namespace device::musa } // namespace device::moore
#include "common_musa.h" #include "moore_common.h"
namespace device::musa { namespace device::moore {
Handle::Handle(infiniDevice_t device, int device_id) Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id}, : InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>(device_id)) {} _internal(std::make_shared<Handle::Internal>(device_id)) {}
...@@ -67,4 +67,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) { ...@@ -67,4 +67,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace device::musa } // namespace device::moore
#ifndef __INFINIOP_MUSA_HANDLE_H__ #ifndef __INFINIOP_MOORE_HANDLE_H__
#define __INFINIOP_MUSA_HANDLE_H__ #define __INFINIOP_MOORE_HANDLE_H__
#include "../../handle.h" #include "../../handle.h"
#include <memory> #include <memory>
namespace device::musa { namespace device::moore {
struct Handle : public InfiniopHandle { struct Handle : public InfiniopHandle {
Handle(int device_id); Handle(int device_id);
class Internal; class Internal;
...@@ -20,6 +20,6 @@ private: ...@@ -20,6 +20,6 @@ private:
std::shared_ptr<Internal> _internal; std::shared_ptr<Internal> _internal;
}; };
} // namespace device::musa } // namespace device::moore
#endif // __INFINIOP_MUSA_HANDLE_H__ #endif // __INFINIOP_MOORE_HANDLE_H__
#define INFINIOP_MOORE_KERNEL __global__ void
#include <musa_bf16.h>
#include <musa_fp16.h>
// Posible maximum number of threads per block for MUSA architectures
// Used for picking correct kernel launch configuration
#define MOORE_BLOCK_SIZE_2048 2048
#define MOORE_BLOCK_SIZE_1024 1024
#define MOORE_BLOCK_SIZE_512 512
#define CHECK_MOORE(API) CHECK_INTERNAL(API, musaSuccess)
using cuda_bfloat16 = mt_bfloat16;
using cuda_bfloat162 = mt_bfloat162;
namespace device::moore {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i];
flat_index %= broadcasted_strides[i];
}
return res;
}
// get the memory offset of the given element in a tensor given its flat index
__forceinline__ __device__ __host__ size_t
indexToOffset(
size_t flat_index,
size_t ndim,
const size_t *shape,
const ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i]) * strides[i];
flat_index /= shape[i];
}
return res;
}
} // namespace device::moore
__forceinline__ __device__ float
exp_(const float val) {
return expf(val);
}
__forceinline__ __device__ long double
exp_(const long double val) {
return exp(val);
}
__forceinline__ __device__ double
exp_(const double val) {
return exp(val);
}
// <musa_bf16.h> may not support hexp
__forceinline__ __device__ __half
exp_(const __half x) {
float f_val = __half2float(x);
float f_result = expf(f_val);
return __float2half(f_result);
}
// <musa_bf16.h> may not support hexp
__forceinline__ __device__ __mt_bfloat16
exp_(const __mt_bfloat16 x) {
float f_val = __bfloat162float(x);
float f_result = expf(f_val);
return __float2bfloat16(f_result);
}
#ifndef __INFINIOP_ELEMENTWISE_BANG_H__
#define __INFINIOP_ELEMENTWISE_BANG_H__
#include "../../../utils.h"
#include "../../devices/bang/common_bang.h"
#include "elementwise_bang_api.h"
namespace op::elementwise::bang {
/**
* @brief Opaque implementation structure for BANG device operations.
*
* Contains device-specific resources and implementation methods.
*/
struct DeviceImpl::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
/**
* @brief Constructs an Opaque instance with device handle internals.
*
* @param internal_ Shared pointer to BANG device handle internals.
*/
Opaque(const std::shared_ptr<device::bang::Handle::Internal> &internal_)
: internal(internal_) {}
/**
* @brief Implements elementwise calculation for BANG device.
*
* @tparam N Number of input tensors.
* @tparam Op Operator functor type.
* @tparam Tdata Data type for inputs and output.
* @tparam Args Additional arguments for the operator.
*
* @param info Elementwise operation metadata (shapes, strides, etc.).
* @param workspace Device workspace memory.
* @param output Output tensor buffer.
* @param inputs Vector of input tensor pointers.
* @param queue BANG queue for asynchronous execution.
* @param args Additional arguments for the operator.
* @return infiniStatus_t Status indicating success or failure.
*/
template <size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
cnrtQueue_t queue,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers for metadata
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
// Copy metadata to device and setup pointers
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides));
// Launch the elementwise kernel
Op::template launch<Tdata>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
reinterpret_cast<const void *>(d_input_contiguous),
reinterpret_cast<const void *>(d_input_broadcasted),
reinterpret_cast<const void *>(d_output_shape),
reinterpret_cast<const void *>(d_input_shapes),
reinterpret_cast<const void *>(d_output_strides),
reinterpret_cast<const void *>(d_input_strides),
output,
reinterpret_cast<const void *const *>(d_inputs_arr),
queue,
internal,
args...);
// Synchronize queue to ensure completion
CNRT_CHECK(cnrtQueueSync(queue));
return INFINI_STATUS_SUCCESS;
}
private:
/**
* @brief Transfers elementwise operation metadata to device memory.
*
* @tparam N Number of input tensors.
*
* @param info Elementwise operation metadata.
* @param workspace Device workspace memory.
* @param h_inputs_arr Host array of input pointers.
* @param d_inputs_arr Output reference to device input pointers.
* @param d_input_contiguous Output reference to contiguous flags.
* @param d_input_broadcasted Output reference to broadcast flags.
* @param d_output_shape Output reference to output shape.
* @param d_output_strides Output reference to output strides.
* @param d_input_shapes Output reference to input shapes.
* @param d_input_strides Output reference to input strides.
* @return infiniStatus_t Status indicating success or failure.
*/
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides) const {
constexpr auto input_size = N;
const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// Copy input pointer array and metadata to device
CNRT_CHECK(cnrtMemcpy(workspace, (void *)h_inputs_arr, input_arr_size, CNRT_MEM_TRANS_DIR_HOST2DEV));
CNRT_CHECK(cnrtMemcpy((void *)d_meta_start, (void *)info_meta_start, info.getMetaMemSize(), CNRT_MEM_TRANS_DIR_HOST2DEV));
// Setup pointers to device memory regions
d_inputs_arr = reinterpret_cast<const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
return INFINI_STATUS_SUCCESS;
}
};
/**
* @brief Creates a DeviceImpl instance for BANG device.
*
* @tparam Args Argument types for Opaque construction.
* @param args Arguments forwarded to Opaque constructor.
* @return utils::Result<DeviceImpl*> Result containing new DeviceImpl instance.
*/
template <typename... Args>
utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) {
auto opaque = std::make_shared<Opaque>(std::forward<Args>(args)...);
return utils::Result<DeviceImpl *>(new DeviceImpl(opaque));
}
/**
* @brief Calculates elementwise operation for BANG device.
*
* @tparam Op Operator functor type.
* @tparam Tdata Data type for inputs and output.
* @tparam Args Additional arguments for the operator.
*
* @param info Elementwise operation metadata.
* @param workspace Device workspace memory.
* @param output Output tensor buffer.
* @param inputs Vector of input tensor pointers.
* @param queue BANG queue (as void*).
* @param args Additional arguments for the operator.
* @return infiniStatus_t Status indicating success or failure.
*/
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *queue,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<N, Op, Tdata>(
info, workspace, output, inputs,
reinterpret_cast<cnrtQueue_t>(queue),
std::forward<Args>(args)...);
}
} // namespace op::elementwise::bang
/**
* @brief Macro for declaring BANG kernel interface.
*
* @param OpName Name of the elementwise operation.
*/
#define LAUNCH_ELEMENTWISE_KERNEL(OpName) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
cnrtQueue_t queue, \
const std::shared_ptr<device::bang::Handle::Internal> &internal, \
Args... args);
#endif // __INFINIOP_ELEMENTWISE_BANG_H__
#ifndef __INFINIOP_ELEMENTWISE_BANG_API_H__
#define __INFINIOP_ELEMENTWISE_BANG_API_H__
#include "../elementwise.h"
namespace op::elementwise::bang {
/**
* @brief BANG device implementation for elementwise operations.
*
* Provides interface for creating and executing elementwise operations on BANG devices.
*/
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
/**
* @brief Creates a DeviceImpl instance.
*
* @tparam Args Argument types for construction.
* @param args Arguments forwarded to implementation.
* @return utils::Result<DeviceImpl*> Result containing new instance.
*/
template <typename... Args>
static utils::Result<DeviceImpl *> create(Args &&...args);
/**
* @brief Executes elementwise operation on BANG device.
*
* @tparam Op Operator functor type.
* @tparam Tdata Data type for inputs and output.
* @tparam Args Additional arguments for the operator.
*
* @param info Elementwise operation metadata.
* @param workspace Device workspace memory.
* @param output Output tensor buffer.
* @param inputs Vector of input tensor pointers.
* @param queue BANG queue (as void*).
* @param args Additional arguments for the operator.
* @return infiniStatus_t Status indicating success or failure.
*/
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *queue,
Args &&...args);
};
} // namespace op::elementwise::bang
/**
* @brief Macro for creating BANG elementwise operation descriptor.
*
* @param HANDLE Device handle.
* @param DTYPE Output data type.
* @param OUT_DESC Output tensor descriptor.
* @param INPUT_DESC_VEC Vector of input tensor descriptors.
*/
#define CREATE_ELEMENTWISE_BANG_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::bang::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif // __INFINIOP_ELEMENTWISE_BANG_API_H__
#ifndef __INFINIOP_ELEMENTWISE_BANG_KERNEL_MLU__
#define __INFINIOP_ELEMENTWISE_BANG_KERNEL_MLU__
#include "../../devices/bang/bang_kernel_common.h"
#include "../../devices/bang/common_bang.h"
using namespace device::bang::kernel;
/**
* @brief Core elementwise operation implementation for BANG device.
*
* @tparam N Number of input tensors.
* @tparam Op Operator functor type.
* @tparam Tdata Data type for inputs and output.
* @tparam Args Additional arguments for operator.
*
* @param typed_inputs Array of typed input pointers.
* @param output Output tensor pointer.
* @param nram_buf NRAM buffer for temporary storage.
* @param input_indexes Precomputed input indexes.
* @param output_index Starting output index.
* @param num_elements Number of elements to process.
* @param output_contiguous Whether output is contiguous.
* @param input_contiguous Array indicating input contiguity.
* @param ndim Number of dimensions.
* @param input_shape Input shape in global memory.
* @param input_strides Input strides in global memory.
* @param output_shape Output shape in global memory.
* @param output_strides Output strides in global memory.
* @param indexer Input indexer helper.
* @param start_idx Starting index for this task.
* @param args Additional arguments for operator.
*/
template <size_t N, typename Op, typename Tdata, typename... Args>
__mlu_device__ void launchOp(
Tdata **typed_inputs,
Tdata *output,
Tdata *nram_buf,
size_t *input_indexes,
size_t output_index,
size_t num_elements,
bool output_contiguous,
const bool *input_contiguous,
const bool *input_broadcasted,
size_t ndim,
const size_t *input_shapes,
const ptrdiff_t *input_strides,
const size_t *output_shape,
const ptrdiff_t *output_strides,
InputIndexer indexer,
size_t start_idx,
Args... args) {
static_assert(N == Op::num_inputs, "template N is not equal to Op::num_inputs!");
// NRAM memory planning
const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * (N + 1));
const size_t max_batch = nram_usable / ((N + 1) * sizeof(Tdata));
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min(max_batch, num_elements - processed);
// Align memory address
Tdata *aligned_buf = reinterpret_cast<Tdata *>(
(reinterpret_cast<size_t>(nram_buf) + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
// 1. Copy input data to NRAM
Tdata *input_buffers[N];
for (size_t i = 0; i < N; ++i) {
input_buffers[i] = aligned_buf + i * max_batch;
if (input_contiguous[i]) {
// Contiguous case - bulk copy
__memcpy_async(input_buffers[i],
typed_inputs[i] + input_indexes[i] + processed,
curr_batch * sizeof(Tdata),
GDRAM2NRAM);
} else {
// Non-contiguous case - copy in contiguous chunks
nonContiguousMemcpy<Tdata>(
input_buffers[i],
typed_inputs[i],
GDRAM2NRAM,
indexer,
i,
processed,
curr_batch,
start_idx,
ndim,
input_shapes + i * ndim,
input_strides + i * ndim,
true);
}
}
__sync_io();
// 2. Execute operation
Tdata *output_buffer = aligned_buf + N * max_batch;
Op op;
op(output_buffer, input_buffers[0], input_buffers[1], curr_batch, args...);
__sync_compute();
// 3. Write back results
if (output_contiguous) {
// Contiguous output - bulk copy
__memcpy_async(output + output_index + processed,
output_buffer,
curr_batch * sizeof(Tdata),
NRAM2GDRAM);
} else {
// Non-contiguous output - copy in contiguous chunks
nonContiguousMemcpy<Tdata>(
output,
output_buffer,
NRAM2GDRAM,
indexer,
0, // unused for output
processed,
curr_batch,
start_idx,
ndim,
output_shape,
output_strides,
false);
}
processed += curr_batch;
}
}
/**
* @brief BANG kernel for elementwise operations.
*
* @tparam N Number of input tensors.
* @tparam Op Operator functor type.
* @tparam Tdata Data type for inputs and output.
* @tparam Args Additional arguments for operator.
*
* @param output_size Total output elements.
* @param ndim Number of dimensions.
* @param output_contiguous Whether output is contiguous.
* @param input_contiguous Input contiguity flags in global memory.
* @param input_broadcasted Input broadcast flags in global memory.
* @param output_shape Output shape in global memory.
* @param input_shapes Input shapes in global memory.
* @param output_strides Output strides in global memory.
* @param input_strides Input strides in global memory.
* @param output Output tensor pointer.
* @param inputs Array of input pointers.
* @param args Additional arguments for operator.
*/
template <size_t N, typename Op, typename Tdata, typename... Args>
__mlu_global__ void elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *input_contiguous,
const bool *input_broadcasted,
const size_t *output_shape,
const size_t *input_shapes,
const ptrdiff_t *output_strides,
const ptrdiff_t *input_strides,
Tdata *output,
const void *const *inputs,
Args... args) {
// Cast input pointers to the correct type
Tdata *typed_inputs[N];
for (size_t i = 0; i < N; ++i) {
typed_inputs[i] = reinterpret_cast<Tdata *>(const_cast<void *>(inputs[i]));
}
// Calculate workload per task
size_t elements_per_task = (output_size + taskDim - 1) / taskDim;
size_t start_idx = taskId * elements_per_task;
size_t end_idx = std::min(start_idx + elements_per_task, output_size);
size_t num_elements = end_idx > start_idx ? end_idx - start_idx : 0;
if (num_elements == 0) {
return;
}
// Allocate NRAM buffer (shared by all inputs and output)
__nram__ Tdata nram_buf[NRAM_MAX_SIZE / sizeof(Tdata)];
// Get output index
size_t output_index = getOutputIndex(start_idx, output_contiguous,
ndim, output_shape, output_strides);
// Create input indexer
InputIndexer indexer{
static_cast<size_t>(start_idx),
ndim,
input_contiguous,
input_broadcasted,
input_shapes,
input_strides,
output_strides};
// Get index offsets for each operand
size_t input_indexes[N];
for (size_t i = 0; i < N; ++i) {
input_indexes[i] = indexer(i, 0);
}
// Launch the operation with all required parameters
launchOp<N, Op, Tdata>(typed_inputs, output, nram_buf, input_indexes,
output_index, num_elements, output_contiguous,
input_contiguous, input_broadcasted, ndim,
input_shapes, input_strides, output_shape,
output_strides, indexer, start_idx, args...);
}
/**
* @brief Intermediate layer that determines optimal launch configuration before calling elementwiseKernel.
*
* @tparam N Number of input tensors.
* @tparam Op Operator functor type.
* @tparam Tdata Data type for inputs and output.
* @tparam Args Additional arguments for operator.
*/
template <size_t N, typename Op, typename Tdata, typename... Args>
void launchElementwiseKernelWrapper(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *input_contiguous,
const bool *input_broadcasted,
const size_t *output_shape,
const size_t *input_shapes,
const ptrdiff_t *output_strides,
const ptrdiff_t *input_strides,
Tdata *output,
const void *const *inputs,
cnrtQueue_t queue,
const std::shared_ptr<device::bang::Handle::Internal> &internal,
Args... args) {
// Get hardware information from internal handle
int core_per_cluster = internal->getCorePerCluster();
int cluster_count = internal->getClusterCount();
// Set kernel launch dimensions
cnrtDim3_t dim;
dim.x = core_per_cluster;
dim.y = cluster_count;
dim.z = 1;
// Choose kernel type based on problem characteristics
cnrtFunctionType_t func_type = CNRT_FUNC_TYPE_BLOCK;
if (output_size > 1024 * 1024 && output_contiguous) {
// For large contiguous operations, use UNION type
func_type = CNRT_FUNC_TYPE_UNION1;
}
// Launch the kernel with optimal configuration
elementwiseKernel<N, Op, Tdata><<<dim, func_type, queue>>>(
output_size, ndim, output_contiguous,
input_contiguous, input_broadcasted,
output_shape, input_shapes,
output_strides, input_strides,
output, inputs, args...);
}
/**
* @brief Macro for implementing elementwise kernel launch.
*
* @param OpName Name of the operation.
* @param Op Operator functor type.
*/
#define LAUNCH_ELEMENTWISE_KERNEL_IMPL(OpName, Op) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
cnrtQueue_t queue, \
const std::shared_ptr<device::bang::Handle::Internal> &internal, \
Args... args) { \
launchElementwiseKernelWrapper<Op::num_inputs, Op, Tdata>( \
output_size, ndim, output_contiguous, \
reinterpret_cast<const bool *>(input_contiguous), \
reinterpret_cast<const bool *>(input_broadcasted), \
reinterpret_cast<const size_t *>(output_shape), \
reinterpret_cast<const size_t *>(input_shapes), \
reinterpret_cast<const ptrdiff_t *>(output_strides), \
reinterpret_cast<const ptrdiff_t *>(input_strides), \
reinterpret_cast<Tdata *>(output), inputs, queue, internal, args...); \
}
/**
* @brief Macro for instantiating elementwise kernel for specific types.
*
* @param OpName Name of the operation.
* @param T Data type.
* @param ... Additional template arguments.
*/
/**
* @brief Macro for instantiating elementwise kernel for specific types.
*
* @param OpName Name of the operation.
* @param T Data type.
* @param ... Additional template arguments.
*/
#define LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(OpName, T, ...) \
template void launch##OpName##Kernel<T, ##__VA_ARGS__>( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
cnrtQueue_t queue, \
const std::shared_ptr<device::bang::Handle::Internal> &internal, \
##__VA_ARGS__);
#endif
...@@ -227,9 +227,6 @@ private: ...@@ -227,9 +227,6 @@ private:
CHECK_KUNLUN(xpu_memcpy_async(workspace, h_inputs_arr, input_arr_size, XPU_HOST_TO_DEVICE, stream)); CHECK_KUNLUN(xpu_memcpy_async(workspace, h_inputs_arr, input_arr_size, XPU_HOST_TO_DEVICE, stream));
CHECK_KUNLUN(xpu_memcpy_async((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), XPU_HOST_TO_DEVICE, stream)); CHECK_KUNLUN(xpu_memcpy_async((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), XPU_HOST_TO_DEVICE, stream));
xpu_wait(stream);
// xpu_wait(stream);
// offset/assign the pointers // offset/assign the pointers
d_inputs_arr = reinterpret_cast<__global_ptr__ const void **>(workspace); d_inputs_arr = reinterpret_cast<__global_ptr__ const void **>(workspace);
d_output_shape = reinterpret_cast<__global_ptr__ const size_t *>(d_meta_start); d_output_shape = reinterpret_cast<__global_ptr__ const size_t *>(d_meta_start);
......
#ifndef __INFINIOP_ELEMENTWISE_MOORE_H__
#define __INFINIOP_ELEMENTWISE_MOORE_H__
#include "../../../utils.h"
#include "../../devices/moore/moore_common.h"
#include "../../devices/moore/moore_kernel_common.h"
#include "elementwise_moore_api.h"
namespace op::elementwise::moore {
template <typename T>
__device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr);
}
__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim,
const size_t *shape, const ptrdiff_t *strides) {
return is_contiguous ? idx : device::moore::indexToOffset(idx, ndim, shape, strides);
}
struct InputIndexer {
size_t idx;
size_t ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const size_t *input_shapes;
const ptrdiff_t *input_strides;
const ptrdiff_t *output_strides;
__device__ __forceinline__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::moore::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::moore::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
template <typename F, size_t... Is>
__device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence<Is...>) {
f(std::integral_constant<size_t, Is>{}...);
}
template <size_t N, typename Op, typename Tdata, typename... Args>
INFINIOP_MOORE_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape,
const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ input_strides,
Tdata *output,
const void *const *inputs,
size_t offset,
Args... args) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx < output_size) {
const Tdata *const *typed_inputs = reinterpret_cast<const Tdata *const *>(inputs);
size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides);
InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides};
unpackInputsAndApply(
[&](auto... Is) {
output[out_idx] = Op{}(typed_inputs[Is.value][indexer(Is.value)]..., std::forward<Args>(args)...);
},
std::make_index_sequence<N>{});
}
}
template <typename Op, typename Tout, typename... Tin>
INFINIOP_MOORE_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape,
const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ input_strides,
Tout *output,
const void *const *__restrict__ inputs,
size_t offset) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx < output_size) {
size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides);
InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides};
unpackInputsAndApply(
[&](auto... Is) {
output[out_idx] = Op{}.template operator()<Tout, Tin...>(
(typedInputPtr<Tin>(inputs[Is.value])[indexer(Is.value)])...);
},
std::index_sequence_for<Tin...>{});
}
}
struct DeviceImpl::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::moore::Handle::Internal> &internal)
: internal(internal) {}
template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
musaStream_t stream,
Args &&...args) {
return launchElementwiseKernel<BLOCK_SIZE, N>(
info, workspace,
reinterpret_cast<Tdata *>(output), inputs,
elementwiseKernel<N, Op, Tdata, Args...>,
stream,
std::forward<Args>(args)...);
}
template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
musaStream_t stream,
Args &&...args) {
return launchElementwiseKernel<BLOCK_SIZE, N>(
info, workspace,
reinterpret_cast<Tout *>(output), inputs,
elementwiseKernel<Op, Tout, Tin...>,
stream);
}
private:
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides,
musaStream_t stream) const {
constexpr auto input_size = N;
const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device
CHECK_MOORE(musaMemcpyAsync(workspace, h_inputs_arr, input_arr_size, musaMemcpyHostToDevice, stream));
CHECK_MOORE(musaMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), musaMemcpyHostToDevice, stream));
// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
return INFINI_STATUS_SUCCESS;
}
template <uint32_t BLOCK_SIZE, size_t N, typename KernelFunc, typename Tout, typename... Args>
infiniStatus_t launchElementwiseKernel(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
Tout *output,
const std::vector<const void *> &inputs,
KernelFunc kernel_func,
musaStream_t stream,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<uint32_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(uint32_t(CEIL_DIV(output_size, blockDims.x)), static_cast<uint32_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < output_size; i += step) {
kernel_func<<<gridDims, blockDims, 0, stream>>>(
output_size, info.getNdim(), info.isOutputContiguous(),
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_input_shapes,
d_output_strides, d_input_strides,
output, reinterpret_cast<const void **>(d_inputs_arr),
i, std::forward<Args>(args)...);
}
return INFINI_STATUS_SUCCESS;
}
};
template <typename... Args>
utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) {
auto opaque = std::make_shared<Opaque>(std::forward<Args>(args)...);
return utils::Result<DeviceImpl *>(new DeviceImpl(opaque));
}
/* Invoke elementwise operation for different input types */
template <uint32_t BLOCK_SIZE, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
static_assert(sizeof...(Tin) == N, "Input type count mismatch");
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tout, Tin...>(
info, workspace, output, inputs,
reinterpret_cast<musaStream_t>(stream),
std::forward<Args>(args)...);
}
/* Invoke elementwise operation when all inputs have the same dtype */
template <uint32_t BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>(
info, workspace, output, inputs,
reinterpret_cast<musaStream_t>(stream),
std::forward<Args>(args)...);
}
} // namespace op::elementwise::moore
#endif
#ifndef __INFINIOP_ELEMENTWISE_MOORE_API_H__
#define __INFINIOP_ELEMENTWISE_MOORE_API_H__
#include "../elementwise.h"
namespace op::elementwise::moore {
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static utils::Result<DeviceImpl *> create(Args &&...args);
template <uint32_t BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
template <uint32_t BLOCK_SIZE, typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
} // namespace op::elementwise::moore
#define CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::moore::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif // __INFINIOP_ELEMENTWISE_MOORE_API_H__
#ifndef __ADD_BANG_API_H__
#define __ADD_BANG_API_H__
#include "../../../elementwise/bang/elementwise_bang.h"
ELEMENTWISE_DESCRIPTOR(add, bang)
#endif // __ADD_BANG_API_H__
#include "add_bang.h"
// Operator Interface Declaration
LAUNCH_ELEMENTWISE_KERNEL(Add)
namespace op::add::bang {
typedef struct AddOp {
static constexpr size_t num_inputs = 2;
template <typename Tdata, typename... Args>
static infiniStatus_t launch(Args... args) {
launchAddKernel<Tdata>(args...);
return INFINI_STATUS_SUCCESS;
}
} AddOp;
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create Bang elementwise descriptor
CREATE_ELEMENTWISE_BANG_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *queue) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<AddOp, half>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_BF16:
return _device_info->calculate<AddOp, bfloat16_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_F32:
return _device_info->calculate<AddOp, float>(_info, workspace, output, inputs, queue);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add::bang
#ifndef __ADD_BANG_INTERNAL_H__
#define __ADD_BANG_INTERNAL_H__
#include "../../../elementwise/bang/elementwise_bang_kernel.h"
typedef struct AddOp {
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__mlu_device__ void operator()(T *out, const T *a, const T *b, size_t num_elements) const {
if constexpr (std::is_same_v<T, half> || std::is_same_v<T, bfloat16_t> || std::is_same_v<T, float>) {
__bang_add(out, a, b, num_elements);
} else {
out = a + b;
}
}
} AddOp;
LAUNCH_ELEMENTWISE_KERNEL_IMPL(Add, AddOp)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, half)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, bfloat16_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, float)
#endif // __ADD_BANG_INTERNAL_H__
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
#include "kunlun/add_kunlun.h" #include "kunlun/add_kunlun.h"
#endif #endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/add_bang.h"
#endif
__C infiniStatus_t infiniopCreateAddDescriptor( __C infiniStatus_t infiniopCreateAddDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -48,6 +51,9 @@ __C infiniStatus_t infiniopCreateAddDescriptor( ...@@ -48,6 +51,9 @@ __C infiniStatus_t infiniopCreateAddDescriptor(
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun); CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -78,6 +84,9 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz ...@@ -78,6 +84,9 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun); GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -118,6 +127,9 @@ __C infiniStatus_t infiniopAdd( ...@@ -118,6 +127,9 @@ __C infiniStatus_t infiniopAdd(
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -151,6 +163,9 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) { ...@@ -151,6 +163,9 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) {
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun); DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#ifndef __CLIP_KUNLUN_API_H__
#define __CLIP_KUNLUN_API_H__
#include "../../../elementwise/kunlun/elementwise_kunlun_api.h"
ELEMENTWISE_DESCRIPTOR(clip, kunlun)
#endif // __CLIP_KUNLUN_API_H__
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
#include "clip_kunlun.h"
#include "kernel.h"
namespace op::elementwise::kunlun {
using ClipOp = op::clip::kunlun::ClipOp;
INSTANTIATE_ELEMENTWISE_KERNEL(ClipOp::num_inputs, ClipOp, float);
INSTANTIATE_ELEMENTWISE_KERNEL(ClipOp::num_inputs, ClipOp, half);
INSTANTIATE_ELEMENTWISE_KERNEL(ClipOp::num_inputs, ClipOp, bfloat16_t);
} // namespace op::elementwise::kunlun
namespace op::clip::kunlun {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &in_desc = input_desc_vec.at(0);
const auto &min_desc = input_desc_vec.at(1);
const auto &max_desc = input_desc_vec.at(2);
const auto &out_shape = out_desc->shape();
const auto &in_shape = in_desc->shape();
const auto &min_shape = min_desc->shape();
const auto &max_shape = max_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(out_shape, in_shape);
CHECK_SAME_SHAPE(out_shape, min_shape);
CHECK_SAME_SHAPE(out_shape, max_shape);
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<8, ClipOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<8, ClipOp, bfloat16_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<8, ClipOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::clip::kunlun
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment