Commit 507be07e authored by YdrMaster's avatar YdrMaster
Browse files

issue/291/style: 所有 maca 改为 metax


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent e4605f7c
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "./ascend/infiniccl_ascend.h" #include "./ascend/infiniccl_ascend.h"
#include "./cuda/infiniccl_cuda.h" #include "./cuda/infiniccl_cuda.h"
#include "./maca/infiniccl_maca.h" #include "./metax/infiniccl_metax.h"
__C infiniStatus_t infinicclCommInitAll( __C infiniStatus_t infinicclCommInitAll(
infiniDevice_t device_type, infiniDevice_t device_type,
...@@ -17,7 +17,7 @@ __C infiniStatus_t infinicclCommInitAll( ...@@ -17,7 +17,7 @@ __C infiniStatus_t infinicclCommInitAll(
switch (device_type) { switch (device_type) {
COMM_INIT_ALL(INFINI_DEVICE_NVIDIA, cuda) COMM_INIT_ALL(INFINI_DEVICE_NVIDIA, cuda)
COMM_INIT_ALL(INFINI_DEVICE_ASCEND, ascend) COMM_INIT_ALL(INFINI_DEVICE_ASCEND, ascend)
COMM_INIT_ALL(INFINI_DEVICE_METAX, maca) COMM_INIT_ALL(INFINI_DEVICE_METAX, metax)
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -37,7 +37,7 @@ __C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) { ...@@ -37,7 +37,7 @@ __C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) {
switch (comm->device_type) { switch (comm->device_type) {
COMM_DESTROY(INFINI_DEVICE_NVIDIA, cuda) COMM_DESTROY(INFINI_DEVICE_NVIDIA, cuda)
COMM_DESTROY(INFINI_DEVICE_ASCEND, ascend) COMM_DESTROY(INFINI_DEVICE_ASCEND, ascend)
COMM_DESTROY(INFINI_DEVICE_METAX, maca) COMM_DESTROY(INFINI_DEVICE_METAX, metax)
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -65,7 +65,7 @@ __C infiniStatus_t infinicclAllReduce( ...@@ -65,7 +65,7 @@ __C infiniStatus_t infinicclAllReduce(
switch (comm->device_type) { switch (comm->device_type) {
ALL_REDUCE(INFINI_DEVICE_NVIDIA, cuda) ALL_REDUCE(INFINI_DEVICE_NVIDIA, cuda)
ALL_REDUCE(INFINI_DEVICE_ASCEND, ascend) ALL_REDUCE(INFINI_DEVICE_ASCEND, ascend)
ALL_REDUCE(INFINI_DEVICE_METAX, maca) ALL_REDUCE(INFINI_DEVICE_METAX, metax)
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#include "infiniccl_maca.h" #include "infiniccl_metax.h"
#include "../../utils.h" #include "../../utils.h"
...@@ -51,7 +51,7 @@ inline hcclComm_t getHcclComm(infinicclComm_t comm) { ...@@ -51,7 +51,7 @@ inline hcclComm_t getHcclComm(infinicclComm_t comm) {
return static_cast<hcclComm_t>(comm->comm); return static_cast<hcclComm_t>(comm->comm);
} }
namespace infiniccl::maca { namespace infiniccl::metax {
infiniStatus_t commInitAll( infiniStatus_t commInitAll(
infinicclComm_t *comms, infinicclComm_t *comms,
...@@ -92,4 +92,4 @@ infiniStatus_t allReduce( ...@@ -92,4 +92,4 @@ infiniStatus_t allReduce(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace infiniccl::maca } // namespace infiniccl::metax
#ifndef INFINICCL_MACA_H_ #ifndef INFINICCL_METAX_H_
#define INFINICCL_MACA_H_ #define INFINICCL_METAX_H_
#include "../infiniccl_impl.h" #include "../infiniccl_impl.h"
#if defined(ENABLE_METAX_API) && defined(ENABLE_CCL) #if defined(ENABLE_METAX_API) && defined(ENABLE_CCL)
INFINICCL_DEVICE_API_IMPL(maca) INFINICCL_DEVICE_API_IMPL(metax)
#else #else
INFINICCL_DEVICE_API_NOOP(maca) INFINICCL_DEVICE_API_NOOP(metax)
#endif #endif
#endif /* INFINICCL_MACA_H_ */ #endif /* INFINICCL_METAX_H_ */
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "kunlun/kunlun_handle.h" #include "kunlun/kunlun_handle.h"
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
#include "maca/maca_handle.h" #include "metax/metax_handle.h"
#endif #endif
__C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) { __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
...@@ -57,7 +57,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) { ...@@ -57,7 +57,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
CREATE(INFINI_DEVICE_KUNLUN, kunlun); CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, maca); CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
default: default:
...@@ -94,7 +94,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) { ...@@ -94,7 +94,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
DELETE(INFINI_DEVICE_KUNLUN, kunlun); DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, maca); DELETE(INFINI_DEVICE_METAX, metax);
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#include "../../../utils.h" #include "../../../utils.h"
#include "../pool.h" #include "../pool.h"
#include "maca_handle.h" #include "metax_handle.h"
#include <hcblas/hcblas.h> #include <hcblas/hcblas.h>
#include <hcdnn/hcdnn.h> #include <hcdnn/hcdnn.h>
#include <memory> #include <memory>
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#define CHECK_MCBLAS(API) CHECK_INTERNAL(API, HCBLAS_STATUS_SUCCESS) #define CHECK_MCBLAS(API) CHECK_INTERNAL(API, HCBLAS_STATUS_SUCCESS)
#define CHECK_MCDNN(API) CHECK_INTERNAL(API, HCDNN_STATUS_SUCCESS) #define CHECK_MCDNN(API) CHECK_INTERNAL(API, HCDNN_STATUS_SUCCESS)
namespace device::maca { namespace device::metax {
class Handle::Internal { class Handle::Internal {
Pool<hcblasHandle_t> mcblas_handles; Pool<hcblasHandle_t> mcblas_handles;
...@@ -39,4 +39,4 @@ public: ...@@ -39,4 +39,4 @@ public:
hcdnnDataType_t getHcdnnDtype(infiniDtype_t dt); hcdnnDataType_t getHcdnnDtype(infiniDtype_t dt);
} // namespace device::maca } // namespace device::metax
#include "common_maca.h" #include "metax_common.h"
namespace device::maca { namespace device::metax {
Handle::Handle(infiniDevice_t device, int device_id) Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id}, : InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>(device_id)) {} _internal(std::make_shared<Handle::Internal>(device_id)) {}
...@@ -83,4 +83,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) { ...@@ -83,4 +83,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace device::maca } // namespace device::metax
#ifndef __INFINIOP_MACA_HANDLE_H__ #ifndef __INFINIOP_METAX_HANDLE_H__
#define __INFINIOP_MACA_HANDLE_H__ #define __INFINIOP_METAX_HANDLE_H__
#include "../../handle.h" #include "../../handle.h"
#include <memory> #include <memory>
namespace device::maca { namespace device::metax {
struct Handle : public InfiniopHandle { struct Handle : public InfiniopHandle {
Handle(int device_id); Handle(int device_id);
class Internal; class Internal;
...@@ -20,6 +20,6 @@ private: ...@@ -20,6 +20,6 @@ private:
std::shared_ptr<Internal> _internal; std::shared_ptr<Internal> _internal;
}; };
} // namespace device::maca } // namespace device::metax
#endif // __INFINIOP_MACA_HANDLE_H__ #endif // __INFINIOP_METAX_HANDLE_H__
#define INFINIOP_MACA_KERNEL __global__ void #define INFINIOP_METAX_KERNEL __global__ void
// Posible maximum number of threads per block for MACA architectures // Posible maximum number of threads per block for METAX architectures
// Used for picking correct kernel launch configuration // Used for picking correct kernel launch configuration
#define MACA_BLOCK_SIZE_1024 1024 #define METAX_BLOCK_SIZE_1024 1024
#define MACA_BLOCK_SIZE_512 512 #define METAX_BLOCK_SIZE_512 512
#define CHECK_MACA(API) CHECK_INTERNAL(API, hcSuccess) #define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess)
using cuda_bfloat16 = hpcc_bfloat16; using cuda_bfloat16 = hpcc_bfloat16;
using cuda_bfloat162 = hpcc_bfloat162; using cuda_bfloat162 = hpcc_bfloat162;
namespace device::maca { namespace device::metax {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor // return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t __forceinline__ __device__ __host__ size_t
...@@ -41,7 +41,7 @@ indexToOffset( ...@@ -41,7 +41,7 @@ indexToOffset(
} }
return res; return res;
} }
} // namespace device::maca } // namespace device::metax
__forceinline__ __device__ float __forceinline__ __device__ float
exp_(const float val) { exp_(const float val) {
......
#ifndef __INFINIOP_ELEMENTWISE_MACA_H__ #ifndef __INFINIOP_ELEMENTWISE_METAX_H__
#define __INFINIOP_ELEMENTWISE_MACA_H__ #define __INFINIOP_ELEMENTWISE_METAX_H__
#include "../../../utils.h" #include "../../../utils.h"
#include "../../devices/maca/common_maca.h" #include "../../devices/metax/metax_common.h"
#include "../../devices/maca/maca_kernel_common.h" #include "../../devices/metax/metax_kernel_common.h"
#include "elementwise_maca_api.h" #include "elementwise_metax_api.h"
namespace op::elementwise::maca { namespace op::elementwise::metax {
template <typename T> template <typename T>
__device__ __forceinline__ const T *typedInputPtr(const void *ptr) { __device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr); return reinterpret_cast<const T *>(ptr);
...@@ -14,7 +14,7 @@ __device__ __forceinline__ const T *typedInputPtr(const void *ptr) { ...@@ -14,7 +14,7 @@ __device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim, __device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim,
const size_t *shape, const ptrdiff_t *strides) { const size_t *shape, const ptrdiff_t *strides) {
return is_contiguous ? idx : device::maca::indexToOffset(idx, ndim, shape, strides); return is_contiguous ? idx : device::metax::indexToOffset(idx, ndim, shape, strides);
} }
struct InputIndexer { struct InputIndexer {
...@@ -30,8 +30,8 @@ struct InputIndexer { ...@@ -30,8 +30,8 @@ struct InputIndexer {
return input_contiguous[input_id] return input_contiguous[input_id]
? idx ? idx
: (input_broadcasted[input_id] : (input_broadcasted[input_id]
? device::maca::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim) ? device::metax::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::maca::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim)); : device::metax::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
} }
}; };
...@@ -41,7 +41,7 @@ __device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence< ...@@ -41,7 +41,7 @@ __device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence<
} }
template <size_t N, typename Op, typename Tdata, typename... Args> template <size_t N, typename Op, typename Tdata, typename... Args>
INFINIOP_MACA_KERNEL elementwiseKernel( INFINIOP_METAX_KERNEL elementwiseKernel(
size_t output_size, size_t output_size,
size_t ndim, size_t ndim,
bool output_contiguous, bool output_contiguous,
...@@ -72,7 +72,7 @@ INFINIOP_MACA_KERNEL elementwiseKernel( ...@@ -72,7 +72,7 @@ INFINIOP_MACA_KERNEL elementwiseKernel(
} }
template <typename Op, typename Tout, typename... Tin> template <typename Op, typename Tout, typename... Tin>
INFINIOP_MACA_KERNEL elementwiseKernel( INFINIOP_METAX_KERNEL elementwiseKernel(
size_t output_size, size_t output_size,
size_t ndim, size_t ndim,
bool output_contiguous, bool output_contiguous,
...@@ -102,9 +102,9 @@ INFINIOP_MACA_KERNEL elementwiseKernel( ...@@ -102,9 +102,9 @@ INFINIOP_MACA_KERNEL elementwiseKernel(
} }
struct DeviceImpl::Opaque { struct DeviceImpl::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal; std::shared_ptr<device::metax::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::maca::Handle::Internal> &internal) Opaque(const std::shared_ptr<device::metax::Handle::Internal> &internal)
: internal(internal) {} : internal(internal) {}
template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args> template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args>
...@@ -159,8 +159,8 @@ private: ...@@ -159,8 +159,8 @@ private:
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size; const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device // copy the input pointer array and meta to device
CHECK_MACA(hcMemcpyAsync(workspace, h_inputs_arr, input_arr_size, hcMemcpyHostToDevice, stream)); CHECK_METAX(hcMemcpyAsync(workspace, h_inputs_arr, input_arr_size, hcMemcpyHostToDevice, stream));
CHECK_MACA(hcMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), hcMemcpyHostToDevice, stream)); CHECK_METAX(hcMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), hcMemcpyHostToDevice, stream));
// offset/assign the pointers // offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace); d_inputs_arr = reinterpret_cast<const void **>(workspace);
...@@ -259,6 +259,6 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf ...@@ -259,6 +259,6 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf
std::forward<Args>(args)...); std::forward<Args>(args)...);
} }
} // namespace op::elementwise::maca } // namespace op::elementwise::metax
#endif #endif
#ifndef __INFINIOP_ELEMENTWISE_MACA_API_H__ #ifndef __INFINIOP_ELEMENTWISE_METAX_API_H__
#define __INFINIOP_ELEMENTWISE_MACA_API_H__ #define __INFINIOP_ELEMENTWISE_METAX_API_H__
#include "../elementwise.h" #include "../elementwise.h"
namespace op::elementwise::maca { namespace op::elementwise::metax {
class DeviceImpl final { class DeviceImpl final {
struct Opaque; struct Opaque;
...@@ -37,23 +37,23 @@ public: ...@@ -37,23 +37,23 @@ public:
void *stream, void *stream,
Args &&...args); Args &&...args);
}; };
} // namespace op::elementwise::maca } // namespace op::elementwise::metax
#define CREATE_ELEMENTWISE_MACA_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \ #define CREATE_ELEMENTWISE_METAX_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\ \
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \ auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \ CHECK_RESULT(info_result); \
auto info = info_result.take(); \ auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \ auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\ \
auto device_impl_result = op::elementwise::maca::DeviceImpl::create(HANDLE->internal()); \ auto device_impl_result = op::elementwise::metax::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \ CHECK_RESULT(device_impl_result); \
\ \
*desc_ptr = new Descriptor( \ *desc_ptr = new Descriptor( \
DTYPE, \ DTYPE, \
std::move(info), \ std::move(info), \
std::move(device_impl_result.take()), \ std::move(device_impl_result.take()), \
workspace_size, \ workspace_size, \
HANDLE->device, \ HANDLE->device, \
HANDLE->device_id); HANDLE->device_id);
#endif // __INFINIOP_ELEMENTWISE_MACA_API_H__ #endif // __INFINIOP_ELEMENTWISE_METAX_API_H__
#include "../../../devices/maca/common_maca.h" #include "../../../devices/metax/metax_common.h"
#include "causal_softmax_metax.h" #include "causal_softmax_metax.h"
#include <hccub/block/block_reduce.cuh> #include <hccub/block/block_reduce.cuh>
#include "../../../devices/maca/maca_kernel_common.h" #include "../../../devices/metax/metax_kernel_common.h"
#include "../../../reduce/cuda/reduce.cuh" #include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh" #include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute> template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
INFINIOP_MACA_KERNEL causalSoftmax( INFINIOP_METAX_KERNEL causalSoftmax(
Tdata *y, const Tdata *x, Tdata *y, const Tdata *x,
size_t batch, size_t height, size_t width, size_t batch, size_t height, size_t width,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_h, ptrdiff_t y_stride_b, ptrdiff_t y_stride_h,
...@@ -20,7 +20,7 @@ INFINIOP_MACA_KERNEL causalSoftmax( ...@@ -20,7 +20,7 @@ INFINIOP_MACA_KERNEL causalSoftmax(
namespace op::causal_softmax::metax { namespace op::causal_softmax::metax {
struct Descriptor::Opaque { struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal; std::shared_ptr<device::metax::Handle::Internal> internal;
}; };
Descriptor::~Descriptor() { Descriptor::~Descriptor() {
...@@ -35,7 +35,7 @@ infiniStatus_t Descriptor::create( ...@@ -35,7 +35,7 @@ infiniStatus_t Descriptor::create(
auto info = CausalSoftmaxInfo::create(y_desc, x_desc); auto info = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(info); CHECK_RESULT(info);
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::maca::Handle *>(handle)->internal()}, new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id); info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -76,12 +76,12 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, ...@@ -76,12 +76,12 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
const void *x, const void *x,
void *stream_) const { void *stream_) const {
hcStream_t stream = (hcStream_t)stream_; hcStream_t stream = (hcStream_t)stream_;
if (_opaque->internal->maxThreadsPerBlock() == MACA_BLOCK_SIZE_1024) { if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MACA_BLOCK_SIZE_1024>( CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream)); _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MACA_BLOCK_SIZE_512) { } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MACA_BLOCK_SIZE_512>( CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_512>(
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream)); _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
} else { } else {
......
#include "gemm_metax.h" #include "gemm_metax.h"
#include "../../../devices/maca/common_maca.h" #include "../../../devices/metax/metax_common.h"
#include "../../../devices/maca/maca_handle.h" #include "../../../devices/metax/metax_handle.h"
namespace op::gemm::metax { namespace op::gemm::metax {
struct Descriptor::Opaque { struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal; std::shared_ptr<device::metax::Handle::Internal> internal;
}; };
Descriptor::~Descriptor() { Descriptor::~Descriptor() {
...@@ -18,7 +18,7 @@ infiniStatus_t Descriptor::create( ...@@ -18,7 +18,7 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) { infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::maca::Handle *>(handle_); auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto dtype = c_desc->dtype(); auto dtype = c_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
......
#ifndef __GEMM_MACA_H__ #ifndef __GEMM_METAX_H__
#define __GEMM_MACA_H__ #define __GEMM_METAX_H__
#include "../gemm.h" #include "../gemm.h"
DESCRIPTOR(metax) DESCRIPTOR(metax)
#endif // __GEMM_MACA_H__ #endif // __GEMM_METAX_H__
#include "../../../devices/maca/maca_kernel_common.h" #include "../../../devices/metax/metax_kernel_common.h"
#include "infinicore.h" #include "infinicore.h"
#include <hccub/device/device_radix_sort.cuh> #include <hccub/device/device_radix_sort.cuh>
#include <hccub/device/device_reduce.cuh> #include <hccub/device/device_reduce.cuh>
...@@ -62,7 +62,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) { ...@@ -62,7 +62,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
const auto n = static_cast<int>(n_); const auto n = static_cast<int>(n_);
size_t argmax; size_t argmax;
CHECK_MACA(argMax_<Tval>( CHECK_METAX(argMax_<Tval>(
nullptr, nullptr, n, nullptr, nullptr, n,
nullptr, argmax, nullptr, argmax,
nullptr)); nullptr));
...@@ -77,7 +77,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) { ...@@ -77,7 +77,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
size_random += align256(sizeof(Tidx) * n); size_random += align256(sizeof(Tidx) * n);
// cub device api // cub device api
size_t size_radix_sort; size_t size_radix_sort;
CHECK_MACA((radixSort<Tval, Tidx>( CHECK_METAX((radixSort<Tval, Tidx>(
nullptr, size_radix_sort, nullptr, size_radix_sort,
nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
...@@ -85,7 +85,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) { ...@@ -85,7 +85,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
nullptr))); nullptr)));
size_t size_inclusive_sum; size_t size_inclusive_sum;
CHECK_MACA(inclusiveSum<Tval>( CHECK_METAX(inclusiveSum<Tval>(
nullptr, size_inclusive_sum, nullptr, size_inclusive_sum,
nullptr, n, nullptr, n,
nullptr)); nullptr));
...@@ -233,7 +233,7 @@ struct Algo { ...@@ -233,7 +233,7 @@ struct Algo {
auto grid = (n + block - 1) / block; auto grid = (n + block - 1) / block;
// sort // sort
fillIndices<<<grid, block, 0, stream>>>(indices, n); fillIndices<<<grid, block, 0, stream>>>(indices, n);
CHECK_MACA(radixSort( CHECK_METAX(radixSort(
workspace_, workspace_size, workspace_, workspace_size,
logits, sorted, logits, sorted,
indices, indices_out, indices, indices_out,
...@@ -243,7 +243,7 @@ struct Algo { ...@@ -243,7 +243,7 @@ struct Algo {
partialSoftmaxKernel<<<grid, block, 0, stream>>>(sorted, n, temperature); partialSoftmaxKernel<<<grid, block, 0, stream>>>(sorted, n, temperature);
setSoftmaxMaxKernel<<<1, 1, 0, stream>>>(sorted); setSoftmaxMaxKernel<<<1, 1, 0, stream>>>(sorted);
// sum // sum
CHECK_MACA(inclusiveSum( CHECK_METAX(inclusiveSum(
workspace_, workspace, workspace_, workspace,
sorted, n, sorted, n,
stream)); stream));
......
#ifndef __RANDOM_SAMPLE_MACA_H__ #ifndef __RANDOM_SAMPLE_METAX_H__
#define __RANDOM_SAMPLE_MACA_H__ #define __RANDOM_SAMPLE_METAX_H__
#include "../random_sample.h" #include "../random_sample.h"
DESCRIPTOR(metax) DESCRIPTOR(metax)
#endif // __RANDOM_SAMPLE_MACA_H__ #endif // __RANDOM_SAMPLE_METAX_H__
#include "../../../devices/maca/common_maca.h" #include "../../../devices/metax/metax_common.h"
#include "../../../devices/maca/maca_handle.h" #include "../../../devices/metax/metax_handle.h"
#include "../info.h" #include "../info.h"
#include "random_sample_kernel.h" #include "random_sample_kernel.h"
#include "random_sample_metax.h" #include "random_sample_metax.h"
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace op::random_sample::metax { namespace op::random_sample::metax {
struct Descriptor::Opaque { struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal; std::shared_ptr<device::metax::Handle::Internal> internal;
}; };
Descriptor::~Descriptor() { Descriptor::~Descriptor() {
...@@ -19,7 +19,7 @@ infiniStatus_t Descriptor::create( ...@@ -19,7 +19,7 @@ infiniStatus_t Descriptor::create(
Descriptor **desc_ptr, Descriptor **desc_ptr,
infiniopTensorDescriptor_t result_desc, infiniopTensorDescriptor_t result_desc,
infiniopTensorDescriptor_t probs_desc) { infiniopTensorDescriptor_t probs_desc) {
auto handle = reinterpret_cast<device::maca::Handle *>(handle_); auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto result = RandomSampleInfo::create(result_desc, probs_desc); auto result = RandomSampleInfo::create(result_desc, probs_desc);
CHECK_RESULT(result); CHECK_RESULT(result);
...@@ -100,4 +100,4 @@ infiniStatus_t Descriptor::calculate( ...@@ -100,4 +100,4 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace op::random_sample::maca } // namespace op::random_sample::metax
#ifndef __REARRANGE_MACA_KERNEL_H__ #ifndef __REARRANGE_METAX_KERNEL_H__
#define __REARRANGE_MACA_KERNEL_H__ #define __REARRANGE_METAX_KERNEL_H__
#include "../../../devices/maca/common_maca.h" #include "../../../devices/metax/metax_common.h"
#include "../../../devices/maca/maca_kernel_common.h" #include "../../../devices/metax/metax_kernel_common.h"
#define ARRAY_TYPE_STRIDE ptrdiff_t #define ARRAY_TYPE_STRIDE ptrdiff_t
#define ARRAY_TYPE_SIZE size_t #define ARRAY_TYPE_SIZE size_t
...@@ -328,4 +328,4 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) { ...@@ -328,4 +328,4 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) {
return utils::Result<void *>(kernel_func); return utils::Result<void *>(kernel_func);
} }
#endif // __REARRANGE_MACA_KERNEL_H__ #endif // __REARRANGE_METAX_KERNEL_H__
#ifndef __REARRANGE_MACA_H__ #ifndef __REARRANGE_METAX_H__
#define __REARRANGE_MACA_H__ #define __REARRANGE_METAX_H__
#include "../rearrange.h" #include "../rearrange.h"
DESCRIPTOR(metax) DESCRIPTOR(metax)
#endif // __REARRANGE_MACA_H__ #endif // __REARRANGE_METAX_H__
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace op::rearrange::metax { namespace op::rearrange::metax {
struct Descriptor::Opaque { struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal; std::shared_ptr<device::metax::Handle::Internal> internal;
}; };
Descriptor::~Descriptor() { Descriptor::~Descriptor() {
...@@ -47,7 +47,7 @@ infiniStatus_t Descriptor::create( ...@@ -47,7 +47,7 @@ infiniStatus_t Descriptor::create(
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
std::move(*meta), std::move(*meta),
new Opaque{reinterpret_cast<device::maca::Handle *>(handle)->internal()}, new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
handle->device, handle->device_id); handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -429,18 +429,18 @@ infiniStatus_t launchKernel( ...@@ -429,18 +429,18 @@ infiniStatus_t launchKernel(
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *y, void *y,
const void *x, const void *x,
void *stream) const { void *stream_) const {
auto maca_stream = reinterpret_cast<hcStream_t>(stream); auto stream = reinterpret_cast<hcStream_t>(stream_);
// 如果没有维度,直接进行内存拷贝 // 如果没有维度,直接进行内存拷贝
if (_meta.ndim() == 0) { if (_meta.ndim() == 0) {
auto err = hcMemcpyAsync(y, x, _meta.unit(), hcMemcpyDeviceToDevice, maca_stream); auto err = hcMemcpyAsync(y, x, _meta.unit(), hcMemcpyDeviceToDevice, stream);
if (err != hcSuccess) { if (err != hcSuccess) {
return INFINI_STATUS_INTERNAL_ERROR; return INFINI_STATUS_INTERNAL_ERROR;
} }
CHECK_OR_RETURN(hcMemcpyAsync(y, x, _meta.unit(), hcMemcpyDeviceToDevice, maca_stream) == hcSuccess, CHECK_OR_RETURN(hcMemcpyAsync(y, x, _meta.unit(), hcMemcpyDeviceToDevice, stream) == hcSuccess,
INFINI_STATUS_INTERNAL_ERROR); INFINI_STATUS_INTERNAL_ERROR);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -449,7 +449,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -449,7 +449,7 @@ infiniStatus_t Descriptor::calculate(
int max_threads = _opaque->internal->maxThreadsPerBlock(); int max_threads = _opaque->internal->maxThreadsPerBlock();
// 准备参数 // 准备参数
auto params_result = prepareRearrangeParams(_meta, std::min(MACA_BLOCK_SIZE_1024, max_threads)); auto params_result = prepareRearrangeParams(_meta, std::min(METAX_BLOCK_SIZE_1024, max_threads));
CHECK_RESULT(params_result); CHECK_RESULT(params_result);
auto params = params_result.take(); auto params = params_result.take();
...@@ -469,10 +469,10 @@ infiniStatus_t Descriptor::calculate( ...@@ -469,10 +469,10 @@ infiniStatus_t Descriptor::calculate(
size_t block_size = params.block_len_total; size_t block_size = params.block_len_total;
if (block_size <= MACA_BLOCK_SIZE_512) { if (block_size <= METAX_BLOCK_SIZE_512) {
status = launchKernel<MACA_BLOCK_SIZE_512>(y, x, grid_size, params, _meta.unit(), maca_stream); status = launchKernel<METAX_BLOCK_SIZE_512>(y, x, grid_size, params, _meta.unit(), stream);
} else if (block_size <= MACA_BLOCK_SIZE_1024) { } else if (block_size <= METAX_BLOCK_SIZE_1024) {
status = launchKernel<MACA_BLOCK_SIZE_1024>(y, x, grid_size, params, _meta.unit(), maca_stream); status = launchKernel<METAX_BLOCK_SIZE_1024>(y, x, grid_size, params, _meta.unit(), stream);
} else { } else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
......
#ifndef __RMS_NORM_MACA_CUH__ #ifndef __RMS_NORM_METAX_CUH__
#define __RMS_NORM_MACA_CUH__ #define __RMS_NORM_METAX_CUH__
#include "../rms_norm.h" #include "../rms_norm.h"
DESCRIPTOR(maca) DESCRIPTOR(metax)
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment