Unverified Commit 65bed07e authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #322 from YdrMaster/main

issue/291/style: 所有 maca 改为 metax
parents e4605f7c 507be07e
......@@ -2,7 +2,7 @@
#include "./ascend/infiniccl_ascend.h"
#include "./cuda/infiniccl_cuda.h"
#include "./maca/infiniccl_maca.h"
#include "./metax/infiniccl_metax.h"
__C infiniStatus_t infinicclCommInitAll(
infiniDevice_t device_type,
......@@ -17,7 +17,7 @@ __C infiniStatus_t infinicclCommInitAll(
switch (device_type) {
COMM_INIT_ALL(INFINI_DEVICE_NVIDIA, cuda)
COMM_INIT_ALL(INFINI_DEVICE_ASCEND, ascend)
COMM_INIT_ALL(INFINI_DEVICE_METAX, maca)
COMM_INIT_ALL(INFINI_DEVICE_METAX, metax)
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -37,7 +37,7 @@ __C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) {
switch (comm->device_type) {
COMM_DESTROY(INFINI_DEVICE_NVIDIA, cuda)
COMM_DESTROY(INFINI_DEVICE_ASCEND, ascend)
COMM_DESTROY(INFINI_DEVICE_METAX, maca)
COMM_DESTROY(INFINI_DEVICE_METAX, metax)
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -65,7 +65,7 @@ __C infiniStatus_t infinicclAllReduce(
switch (comm->device_type) {
ALL_REDUCE(INFINI_DEVICE_NVIDIA, cuda)
ALL_REDUCE(INFINI_DEVICE_ASCEND, ascend)
ALL_REDUCE(INFINI_DEVICE_METAX, maca)
ALL_REDUCE(INFINI_DEVICE_METAX, metax)
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#include "infiniccl_maca.h"
#include "infiniccl_metax.h"
#include "../../utils.h"
......@@ -51,7 +51,7 @@ inline hcclComm_t getHcclComm(infinicclComm_t comm) {
return static_cast<hcclComm_t>(comm->comm);
}
namespace infiniccl::maca {
namespace infiniccl::metax {
infiniStatus_t commInitAll(
infinicclComm_t *comms,
......@@ -92,4 +92,4 @@ infiniStatus_t allReduce(
return INFINI_STATUS_SUCCESS;
}
} // namespace infiniccl::maca
} // namespace infiniccl::metax
#ifndef INFINICCL_MACA_H_
#define INFINICCL_MACA_H_
#ifndef INFINICCL_METAX_H_
#define INFINICCL_METAX_H_
#include "../infiniccl_impl.h"
#if defined(ENABLE_METAX_API) && defined(ENABLE_CCL)
INFINICCL_DEVICE_API_IMPL(maca)
INFINICCL_DEVICE_API_IMPL(metax)
#else
INFINICCL_DEVICE_API_NOOP(maca)
INFINICCL_DEVICE_API_NOOP(metax)
#endif
#endif /* INFINICCL_MACA_H_ */
#endif /* INFINICCL_METAX_H_ */
......@@ -21,7 +21,7 @@
#include "kunlun/kunlun_handle.h"
#endif
#ifdef ENABLE_METAX_API
#include "maca/maca_handle.h"
#include "metax/metax_handle.h"
#endif
__C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
......@@ -57,7 +57,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, maca);
CREATE(INFINI_DEVICE_METAX, metax);
#endif
default:
......@@ -94,7 +94,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, maca);
DELETE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#include "../../../utils.h"
#include "../pool.h"
#include "maca_handle.h"
#include "metax_handle.h"
#include <hcblas/hcblas.h>
#include <hcdnn/hcdnn.h>
#include <memory>
......@@ -8,7 +8,7 @@
#define CHECK_MCBLAS(API) CHECK_INTERNAL(API, HCBLAS_STATUS_SUCCESS)
#define CHECK_MCDNN(API) CHECK_INTERNAL(API, HCDNN_STATUS_SUCCESS)
namespace device::maca {
namespace device::metax {
class Handle::Internal {
Pool<hcblasHandle_t> mcblas_handles;
......@@ -39,4 +39,4 @@ public:
hcdnnDataType_t getHcdnnDtype(infiniDtype_t dt);
} // namespace device::maca
} // namespace device::metax
#include "common_maca.h"
#include "metax_common.h"
namespace device::maca {
namespace device::metax {
Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>(device_id)) {}
......@@ -83,4 +83,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
return INFINI_STATUS_SUCCESS;
}
} // namespace device::maca
} // namespace device::metax
#ifndef __INFINIOP_MACA_HANDLE_H__
#define __INFINIOP_MACA_HANDLE_H__
#ifndef __INFINIOP_METAX_HANDLE_H__
#define __INFINIOP_METAX_HANDLE_H__
#include "../../handle.h"
#include <memory>
namespace device::maca {
namespace device::metax {
struct Handle : public InfiniopHandle {
Handle(int device_id);
class Internal;
......@@ -20,6 +20,6 @@ private:
std::shared_ptr<Internal> _internal;
};
} // namespace device::maca
} // namespace device::metax
#endif // __INFINIOP_MACA_HANDLE_H__
#endif // __INFINIOP_METAX_HANDLE_H__
#define INFINIOP_MACA_KERNEL __global__ void
#define INFINIOP_METAX_KERNEL __global__ void
// Posible maximum number of threads per block for MACA architectures
// Posible maximum number of threads per block for METAX architectures
// Used for picking correct kernel launch configuration
#define MACA_BLOCK_SIZE_1024 1024
#define MACA_BLOCK_SIZE_512 512
#define METAX_BLOCK_SIZE_1024 1024
#define METAX_BLOCK_SIZE_512 512
#define CHECK_MACA(API) CHECK_INTERNAL(API, hcSuccess)
#define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess)
using cuda_bfloat16 = hpcc_bfloat16;
using cuda_bfloat162 = hpcc_bfloat162;
namespace device::maca {
namespace device::metax {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
......@@ -41,7 +41,7 @@ indexToOffset(
}
return res;
}
} // namespace device::maca
} // namespace device::metax
__forceinline__ __device__ float
exp_(const float val) {
......
#ifndef __INFINIOP_ELEMENTWISE_MACA_H__
#define __INFINIOP_ELEMENTWISE_MACA_H__
#ifndef __INFINIOP_ELEMENTWISE_METAX_H__
#define __INFINIOP_ELEMENTWISE_METAX_H__
#include "../../../utils.h"
#include "../../devices/maca/common_maca.h"
#include "../../devices/maca/maca_kernel_common.h"
#include "elementwise_maca_api.h"
#include "../../devices/metax/metax_common.h"
#include "../../devices/metax/metax_kernel_common.h"
#include "elementwise_metax_api.h"
namespace op::elementwise::maca {
namespace op::elementwise::metax {
template <typename T>
__device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr);
......@@ -14,7 +14,7 @@ __device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim,
const size_t *shape, const ptrdiff_t *strides) {
return is_contiguous ? idx : device::maca::indexToOffset(idx, ndim, shape, strides);
return is_contiguous ? idx : device::metax::indexToOffset(idx, ndim, shape, strides);
}
struct InputIndexer {
......@@ -30,8 +30,8 @@ struct InputIndexer {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::maca::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::maca::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
? device::metax::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::metax::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
......@@ -41,7 +41,7 @@ __device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence<
}
template <size_t N, typename Op, typename Tdata, typename... Args>
INFINIOP_MACA_KERNEL elementwiseKernel(
INFINIOP_METAX_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
......@@ -72,7 +72,7 @@ INFINIOP_MACA_KERNEL elementwiseKernel(
}
template <typename Op, typename Tout, typename... Tin>
INFINIOP_MACA_KERNEL elementwiseKernel(
INFINIOP_METAX_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
......@@ -102,9 +102,9 @@ INFINIOP_MACA_KERNEL elementwiseKernel(
}
struct DeviceImpl::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
std::shared_ptr<device::metax::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::maca::Handle::Internal> &internal)
Opaque(const std::shared_ptr<device::metax::Handle::Internal> &internal)
: internal(internal) {}
template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args>
......@@ -159,8 +159,8 @@ private:
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device
CHECK_MACA(hcMemcpyAsync(workspace, h_inputs_arr, input_arr_size, hcMemcpyHostToDevice, stream));
CHECK_MACA(hcMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(workspace, h_inputs_arr, input_arr_size, hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), hcMemcpyHostToDevice, stream));
// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
......@@ -259,6 +259,6 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf
std::forward<Args>(args)...);
}
} // namespace op::elementwise::maca
} // namespace op::elementwise::metax
#endif
#ifndef __INFINIOP_ELEMENTWISE_MACA_API_H__
#define __INFINIOP_ELEMENTWISE_MACA_API_H__
#ifndef __INFINIOP_ELEMENTWISE_METAX_API_H__
#define __INFINIOP_ELEMENTWISE_METAX_API_H__
#include "../elementwise.h"
namespace op::elementwise::maca {
namespace op::elementwise::metax {
class DeviceImpl final {
struct Opaque;
......@@ -37,23 +37,23 @@ public:
void *stream,
Args &&...args);
};
} // namespace op::elementwise::maca
#define CREATE_ELEMENTWISE_MACA_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::maca::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
} // namespace op::elementwise::metax
#define CREATE_ELEMENTWISE_METAX_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::metax::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif // __INFINIOP_ELEMENTWISE_MACA_API_H__
#endif // __INFINIOP_ELEMENTWISE_METAX_API_H__
#include "../../../devices/maca/common_maca.h"
#include "../../../devices/metax/metax_common.h"
#include "causal_softmax_metax.h"
#include <hccub/block/block_reduce.cuh>
#include "../../../devices/maca/maca_kernel_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
INFINIOP_MACA_KERNEL causalSoftmax(
INFINIOP_METAX_KERNEL causalSoftmax(
Tdata *y, const Tdata *x,
size_t batch, size_t height, size_t width,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_h,
......@@ -20,7 +20,7 @@ INFINIOP_MACA_KERNEL causalSoftmax(
namespace op::causal_softmax::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
......@@ -35,7 +35,7 @@ infiniStatus_t Descriptor::create(
auto info = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::maca::Handle *>(handle)->internal()},
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
......@@ -76,12 +76,12 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
const void *x,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
if (_opaque->internal->maxThreadsPerBlock() == MACA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MACA_BLOCK_SIZE_1024>(
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MACA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MACA_BLOCK_SIZE_512>(
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_512>(
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
} else {
......
#include "gemm_metax.h"
#include "../../../devices/maca/common_maca.h"
#include "../../../devices/maca/maca_handle.h"
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_handle.h"
namespace op::gemm::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
......@@ -18,7 +18,7 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::maca::Handle *>(handle_);
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto dtype = c_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
......
#ifndef __GEMM_MACA_H__
#define __GEMM_MACA_H__
#ifndef __GEMM_METAX_H__
#define __GEMM_METAX_H__
#include "../gemm.h"
DESCRIPTOR(metax)
#endif // __GEMM_MACA_H__
#endif // __GEMM_METAX_H__
#include "../../../devices/maca/maca_kernel_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "infinicore.h"
#include <hccub/device/device_radix_sort.cuh>
#include <hccub/device/device_reduce.cuh>
......@@ -62,7 +62,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
const auto n = static_cast<int>(n_);
size_t argmax;
CHECK_MACA(argMax_<Tval>(
CHECK_METAX(argMax_<Tval>(
nullptr, nullptr, n,
nullptr, argmax,
nullptr));
......@@ -77,7 +77,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
size_random += align256(sizeof(Tidx) * n);
// cub device api
size_t size_radix_sort;
CHECK_MACA((radixSort<Tval, Tidx>(
CHECK_METAX((radixSort<Tval, Tidx>(
nullptr, size_radix_sort,
nullptr, nullptr,
nullptr, nullptr,
......@@ -85,7 +85,7 @@ utils::Result<size_t> calculateWorkspace(size_t n_) {
nullptr)));
size_t size_inclusive_sum;
CHECK_MACA(inclusiveSum<Tval>(
CHECK_METAX(inclusiveSum<Tval>(
nullptr, size_inclusive_sum,
nullptr, n,
nullptr));
......@@ -233,7 +233,7 @@ struct Algo {
auto grid = (n + block - 1) / block;
// sort
fillIndices<<<grid, block, 0, stream>>>(indices, n);
CHECK_MACA(radixSort(
CHECK_METAX(radixSort(
workspace_, workspace_size,
logits, sorted,
indices, indices_out,
......@@ -243,7 +243,7 @@ struct Algo {
partialSoftmaxKernel<<<grid, block, 0, stream>>>(sorted, n, temperature);
setSoftmaxMaxKernel<<<1, 1, 0, stream>>>(sorted);
// sum
CHECK_MACA(inclusiveSum(
CHECK_METAX(inclusiveSum(
workspace_, workspace,
sorted, n,
stream));
......
#ifndef __RANDOM_SAMPLE_MACA_H__
#define __RANDOM_SAMPLE_MACA_H__
#ifndef __RANDOM_SAMPLE_METAX_H__
#define __RANDOM_SAMPLE_METAX_H__
#include "../random_sample.h"
DESCRIPTOR(metax)
#endif // __RANDOM_SAMPLE_MACA_H__
#endif // __RANDOM_SAMPLE_METAX_H__
#include "../../../devices/maca/common_maca.h"
#include "../../../devices/maca/maca_handle.h"
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_handle.h"
#include "../info.h"
#include "random_sample_kernel.h"
#include "random_sample_metax.h"
......@@ -7,7 +7,7 @@
namespace op::random_sample::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
......@@ -19,7 +19,7 @@ infiniStatus_t Descriptor::create(
Descriptor **desc_ptr,
infiniopTensorDescriptor_t result_desc,
infiniopTensorDescriptor_t probs_desc) {
auto handle = reinterpret_cast<device::maca::Handle *>(handle_);
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto result = RandomSampleInfo::create(result_desc, probs_desc);
CHECK_RESULT(result);
......@@ -100,4 +100,4 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_SUCCESS;
}
} // namespace op::random_sample::maca
} // namespace op::random_sample::metax
#ifndef __REARRANGE_MACA_KERNEL_H__
#define __REARRANGE_MACA_KERNEL_H__
#ifndef __REARRANGE_METAX_KERNEL_H__
#define __REARRANGE_METAX_KERNEL_H__
#include "../../../devices/maca/common_maca.h"
#include "../../../devices/maca/maca_kernel_common.h"
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#define ARRAY_TYPE_STRIDE ptrdiff_t
#define ARRAY_TYPE_SIZE size_t
......@@ -328,4 +328,4 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) {
return utils::Result<void *>(kernel_func);
}
#endif // __REARRANGE_MACA_KERNEL_H__
#endif // __REARRANGE_METAX_KERNEL_H__
#ifndef __REARRANGE_MACA_H__
#define __REARRANGE_MACA_H__
#ifndef __REARRANGE_METAX_H__
#define __REARRANGE_METAX_H__
#include "../rearrange.h"
DESCRIPTOR(metax)
#endif // __REARRANGE_MACA_H__
#endif // __REARRANGE_METAX_H__
......@@ -10,7 +10,7 @@
namespace op::rearrange::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
......@@ -47,7 +47,7 @@ infiniStatus_t Descriptor::create(
*desc_ptr = new Descriptor(
std::move(*meta),
new Opaque{reinterpret_cast<device::maca::Handle *>(handle)->internal()},
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
......@@ -429,18 +429,18 @@ infiniStatus_t launchKernel(
infiniStatus_t Descriptor::calculate(
void *y,
const void *x,
void *stream) const {
void *stream_) const {
auto maca_stream = reinterpret_cast<hcStream_t>(stream);
auto stream = reinterpret_cast<hcStream_t>(stream_);
// 如果没有维度,直接进行内存拷贝
if (_meta.ndim() == 0) {
auto err = hcMemcpyAsync(y, x, _meta.unit(), hcMemcpyDeviceToDevice, maca_stream);
auto err = hcMemcpyAsync(y, x, _meta.unit(), hcMemcpyDeviceToDevice, stream);
if (err != hcSuccess) {
return INFINI_STATUS_INTERNAL_ERROR;
}
CHECK_OR_RETURN(hcMemcpyAsync(y, x, _meta.unit(), hcMemcpyDeviceToDevice, maca_stream) == hcSuccess,
CHECK_OR_RETURN(hcMemcpyAsync(y, x, _meta.unit(), hcMemcpyDeviceToDevice, stream) == hcSuccess,
INFINI_STATUS_INTERNAL_ERROR);
return INFINI_STATUS_SUCCESS;
}
......@@ -449,7 +449,7 @@ infiniStatus_t Descriptor::calculate(
int max_threads = _opaque->internal->maxThreadsPerBlock();
// 准备参数
auto params_result = prepareRearrangeParams(_meta, std::min(MACA_BLOCK_SIZE_1024, max_threads));
auto params_result = prepareRearrangeParams(_meta, std::min(METAX_BLOCK_SIZE_1024, max_threads));
CHECK_RESULT(params_result);
auto params = params_result.take();
......@@ -469,10 +469,10 @@ infiniStatus_t Descriptor::calculate(
size_t block_size = params.block_len_total;
if (block_size <= MACA_BLOCK_SIZE_512) {
status = launchKernel<MACA_BLOCK_SIZE_512>(y, x, grid_size, params, _meta.unit(), maca_stream);
} else if (block_size <= MACA_BLOCK_SIZE_1024) {
status = launchKernel<MACA_BLOCK_SIZE_1024>(y, x, grid_size, params, _meta.unit(), maca_stream);
if (block_size <= METAX_BLOCK_SIZE_512) {
status = launchKernel<METAX_BLOCK_SIZE_512>(y, x, grid_size, params, _meta.unit(), stream);
} else if (block_size <= METAX_BLOCK_SIZE_1024) {
status = launchKernel<METAX_BLOCK_SIZE_1024>(y, x, grid_size, params, _meta.unit(), stream);
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
......
#ifndef __RMS_NORM_MACA_CUH__
#define __RMS_NORM_MACA_CUH__
#ifndef __RMS_NORM_METAX_CUH__
#define __RMS_NORM_METAX_CUH__
#include "../rms_norm.h"
DESCRIPTOR(maca)
DESCRIPTOR(metax)
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment