Unverified Commit 8d09630a authored by gongchensu's avatar gongchensu Committed by GitHub
Browse files

Merge branch 'demo131' into Issue/862

parents ab52dead 012df56c
......@@ -6,6 +6,7 @@
// Posible maximum number of threads per block for MUSA architectures
// Used for picking correct kernel launch configuration
#define MOORE_BLOCK_SIZE_4096 4096
#define MOORE_BLOCK_SIZE_2048 2048
#define MOORE_BLOCK_SIZE_1024 1024
#define MOORE_BLOCK_SIZE_512 512
......@@ -16,6 +17,8 @@ using cuda_bfloat16 = mt_bfloat16;
using cuda_bfloat162 = mt_bfloat162;
using cuda_fp8_e4m3 = __mt_fp8_e4m3;
using __nv_bfloat16 = __mt_bfloat16;
namespace device::moore {
// get the memory offset of the given element in a tensor given its flat index
......
......@@ -23,6 +23,10 @@ Handle::Internal::Internal(int device_id) {
_grid_size[0] = prop.maxGridSize[0];
_grid_size[1] = prop.maxGridSize[1];
_grid_size[2] = prop.maxGridSize[2];
this->useCublas(nullptr, [](cublasHandle_t handle) { return INFINI_STATUS_SUCCESS; });
#ifdef ENABLE_CUDNN_API
this->useCudnn(nullptr, [](cudnnHandle_t handle) { return INFINI_STATUS_SUCCESS; });
#endif
}
infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const {
......@@ -106,6 +110,18 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
} // namespace iluvatar
namespace ali {
Handle::Handle(int device_id)
: nvidia::Handle(INFINI_DEVICE_ALI, device_id) {}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS;
}
} // namespace ali
namespace qy {
Handle::Handle(int device_id)
......
......@@ -35,6 +35,17 @@ public:
} // namespace iluvatar
namespace ali {
struct Handle : public nvidia::Handle {
Handle(int device_id);
public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
};
} // namespace ali
namespace qy {
struct Handle : public nvidia::Handle {
......
......@@ -16,6 +16,7 @@
// Posible maximum number of threads per block for CUDA architectures
// Used for picking correct kernel launch configuration
#define CUDA_BLOCK_SIZE_4096 4096
#define CUDA_BLOCK_SIZE_2048 2048
#define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512
......@@ -54,7 +55,7 @@ exp_(const float val) {
return expf(val);
}
#if !defined(ENABLE_ILUVATAR_API) && !defined(ENABLE_QY_API) && !defined(ENABLE_HYGON_API)
#if !defined(ENABLE_ILUVATAR_API) && !defined(ENABLE_QY_API) && !defined(ENABLE_HYGON_API) && !defined(ENABLE_ALI_API)
__forceinline__ __device__ long double
exp_(const long double val) {
return expl(val);
......
......@@ -127,8 +127,8 @@ private:
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// Copy input pointer array and metadata to device
CNRT_CHECK(cnrtMemcpy(workspace, (void *)h_inputs_arr, input_arr_size, CNRT_MEM_TRANS_DIR_HOST2DEV));
CNRT_CHECK(cnrtMemcpy((void *)d_meta_start, (void *)info_meta_start, info.getMetaMemSize(), CNRT_MEM_TRANS_DIR_HOST2DEV));
CNRT_CHECK(cnrtMemcpy(workspace, (void *)h_inputs_arr, input_arr_size, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtMemcpy((void *)d_meta_start, (void *)info_meta_start, info.getMetaMemSize(), cnrtMemcpyHostToDev));
// Setup pointers to device memory regions
d_inputs_arr = reinterpret_cast<const void **>(workspace);
......
......@@ -248,10 +248,10 @@ void launchElementwiseKernelWrapper(
dim.z = 1;
// Choose kernel type based on problem characteristics
cnrtFunctionType_t func_type = CNRT_FUNC_TYPE_BLOCK;
cnrtFunctionType_t func_type = cnrtFuncTypeBlock;
if (output_size > 1024 * 1024 && output_contiguous) {
// For large contiguous operations, use UNION type
func_type = CNRT_FUNC_TYPE_UNION1;
func_type = cnrtFuncTypeUnion1;
}
// Launch the kernel with optimal configuration
......
import concurrent.futures
import functools
import inspect
import itertools
......@@ -16,40 +17,28 @@ BUILD_DIRECTORY_PATH = (
def build(premake, constexpr_param_grid, caller, op_name, output_dir):
headers = []
all_param_names = []
combinations = []
launches = []
for combination in _generate_param_value_combinations(constexpr_param_grid):
arrangement, application, tensors = premake(**combination)
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = []
for param_name, param_value in combination.items():
if isinstance(param_value, str):
combination[param_name] = (
f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}"
)
for combination in tuple(
_generate_param_value_combinations(constexpr_param_grid)
):
future = executor.submit(
_make, premake, combination, caller, op_name, output_dir
)
combination = {f"{name}_": value for name, value in combination.items()}
futures.append(future)
kernel_name = f"{op_name}_{_generate_suffix(combination.values())}"
for future in concurrent.futures.as_completed(futures):
header, param_names, combination, launch = future.result()
ninetoothed.make(
arrangement,
application,
tensors,
caller=caller,
kernel_name=kernel_name,
output_dir=output_dir,
)
header = output_dir / f"{kernel_name}.h"
param_names = ("stream",) + tuple(
inspect.signature(application).parameters.keys()
)
launch = f""" if ({_generate_condition(combination)})
return launch_{kernel_name}({", ".join(param_names)});"""
headers.append(header)
all_param_names.append(param_names)
launches.append(launch)
headers.append(header)
all_param_names.append(param_names)
combinations.append(combination)
launches.append(launch)
includes = "\n".join(f'#include "{header}"' for header in headers)
......@@ -64,7 +53,7 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir):
"NineToothedStream",
] + ["NineToothedTensor" for _ in range(len(param_names) - 1)]
for param_name in combination:
for param_name in functools.reduce(lambda x, y: x | y, combinations, {}):
param_names.append(param_name)
param_types.append("int")
......@@ -97,6 +86,36 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir):
(BUILD_DIRECTORY_PATH / header_file_name).write_text(header_content)
def _make(premake, combination, caller, op_name, output_dir):
arrangement, application, tensors = premake(**combination)
for param_name, param_value in combination.items():
if isinstance(param_value, str):
combination[param_name] = (
f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}"
)
combination = {f"{name}_": value for name, value in combination.items()}
kernel_name = f"{op_name}_{_generate_suffix(combination.values())}"
ninetoothed.make(
arrangement,
application,
tensors,
caller=caller,
kernel_name=kernel_name,
output_dir=output_dir,
)
header = output_dir / f"{kernel_name}.h"
param_names = ("stream",) + tuple(inspect.signature(application).parameters.keys())
launch = f""" if ({_generate_condition(combination)})
return launch_{kernel_name}({", ".join(param_names)});"""
return header, param_names, combination, launch
def _generate_condition(combination):
return " && ".join(f"{param} == {value}" for param, value in combination.items())
......
#ifndef __NINETOOTHED_UTILS__
#define __NINETOOTHED_UTILS__
#include <initializer_list>
#include <limits>
#include <type_traits>
#include <vector>
namespace ninetoothed {
template <typename T = float>
class Tensor {
public:
using Data = decltype(NineToothedTensor::data);
using Size = std::remove_pointer_t<decltype(NineToothedTensor::shape)>;
using Stride = std::remove_pointer_t<decltype(NineToothedTensor::strides)>;
template <typename Shape, typename Strides>
Tensor(const void *data, Shape shape, Strides strides) : data_{data}, shape_{shape}, strides_{strides}, ndim_{shape_.size()} {}
Tensor(const void *data, std::initializer_list<Size> shape, std::initializer_list<Stride> strides) : Tensor{data, decltype(shape_){shape}, decltype(strides_){strides}} {}
Tensor(const void *data, const Size *shape, const Stride *strides, Size ndim) : data_{data}, shape_{shape, shape + ndim}, strides_{strides, strides + ndim}, ndim_{shape_.size()} {}
Tensor(const T value) : value_{value}, data_{&value_}, ndim_{0} {}
operator NineToothedTensor() { return {const_cast<Data>(data_), shape_.data(), strides_.data()}; }
template <typename Shape>
Tensor expand(const Shape &sizes) const {
auto new_ndim{sizes.size()};
decltype(shape_) shape(new_ndim, 1);
decltype(strides_) strides(new_ndim, 0);
auto num_new_dims{new_ndim - ndim_};
for (auto dim{decltype(ndim_){0}}; dim < ndim_; ++dim) {
shape[dim + num_new_dims] = shape_[dim];
strides[dim + num_new_dims] = strides_[dim];
}
for (auto dim{decltype(new_ndim){0}}; dim < new_ndim; ++dim) {
if (sizes[dim] == std::numeric_limits<std::remove_reference_t<decltype(sizes[dim])>>::max() || shape[dim] != 1) {
continue;
}
shape[dim] = sizes[dim];
strides[dim] = 0;
}
return {data_, shape, strides};
}
Tensor expand_as(const Tensor &other) const {
return expand(other.shape_);
}
private:
const void *data_{nullptr};
std::vector<Size> shape_;
std::vector<Stride> strides_;
Size ndim_{0};
T value_{0};
};
} // namespace ninetoothed
#endif
......@@ -31,7 +31,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_I32, INFINI_DTYPE_I64);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
......@@ -59,6 +59,10 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<AddOp, bfloat16_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_F32:
return _device_info->calculate<AddOp, float>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_I32:
return _device_info->calculate<AddOp, int32_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_I64:
return _device_info->calculate<AddOp, int64_t>(_info, workspace, output, inputs, queue);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -8,7 +8,7 @@ public:
static constexpr size_t num_inputs = 2;
template <typename T>
__mlu_device__ void operator()(T *out, const T *a, const T *b, size_t num_elements) const {
if constexpr (std::is_same_v<T, half> || std::is_same_v<T, bfloat16_t> || std::is_same_v<T, float>) {
if constexpr (std::is_same_v<T, half> || std::is_same_v<T, bfloat16_t> || std::is_same_v<T, float> || std::is_same_v<T, int32_t> || std::is_same_v<T, int64_t>) {
__bang_add(out, a, b, num_elements);
} else {
out = a + b;
......@@ -21,5 +21,7 @@ LAUNCH_ELEMENTWISE_KERNEL_IMPL(Add, AddOp)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, half)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, bfloat16_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, float)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, int32_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, int64_t)
#endif // __ADD_BANG_INTERNAL_H__
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/add_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#include "nvidia/add_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -48,6 +48,9 @@ __C infiniStatus_t infiniopCreateAddDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -91,6 +94,9 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -142,6 +148,9 @@ __C infiniStatus_t infiniopAdd(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -187,6 +196,9 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) {
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
......
#ifndef ADD_RMS_NORM_H
#define ADD_RMS_NORM_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::add_rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
AddRMSNormInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
AddRMSNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t residual_out_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
void *residual_out, \
const void *a, \
const void *b, \
const void *weight, \
void *stream) const; \
}; \
}
#endif // ADD_RMS_NORM_H
#include "add_rms_norm_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
namespace op::add_rms_norm::cpu {
Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename T>
infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, T *residual_out, const T *a, const T *b, const T *w) {
const size_t batch_size = info->shape[0];
const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1;
const size_t dim = info->dim();
const ptrdiff_t total_blocks = static_cast<ptrdiff_t>(batch_size * nhead);
#pragma omp parallel for
for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) {
const size_t i = block_idx / nhead; // batch index
const size_t j = block_idx % nhead; // head index
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
T *residual_out_ptr = residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1];
// Compute add(a, b) once and store it
T sum_squared = (T)0;
for (size_t k = 0; k < dim; k++) {
T sum_val = a_ptr[k] + b_ptr[k];
residual_out_ptr[k] = sum_val; // Store add result
sum_squared += sum_val * sum_val;
}
// Compute RMS: 1 / (sqrt(mean(sum^2) + eps))
// Note: mean = sum_squared / dim
T rms = (T)1 / std::sqrt(sum_squared / (T)(dim) + (T)(info->epsilon));
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for (size_t k = 0; k < dim; k++) {
y_ptr[k] = residual_out_ptr[k] * w[k] * rms;
}
}
return INFINI_STATUS_SUCCESS;
}
template <typename T, typename Tw>
infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, T *residual_out, const T *a, const T *b, const Tw *w) {
static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value,
"T must be fp16_t or bf16_t");
const size_t batch_size = info->shape[0];
const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1;
const size_t dim = info->dim();
const ptrdiff_t total_blocks = static_cast<ptrdiff_t>(batch_size * nhead);
#pragma omp parallel for
for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) {
const size_t i = block_idx / nhead; // batch index
const size_t j = block_idx % nhead; // head index
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
T *residual_out_ptr = residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1];
// Compute sum of squares for RMS normalization and store add result
float sum_squared = 0.0f;
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
residual_out_ptr[k] = utils::cast<T>(sum_val); // Store add result
sum_squared += sum_val * sum_val;
}
// Compute RMS: 1 / (sqrt(sum/dim + eps))
float rms = 1.f / std::sqrt(sum_squared / (float)(dim) + info->epsilon);
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(residual_out_ptr[k]);
float val;
if constexpr (std::is_same<Tw, float>::value) {
val = sum_val * w[k] * rms;
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
val = sum_val * utils::cast<float>(w[k]) * rms;
} else {
std::abort();
}
y_ptr[k] = utils::cast<T>(val);
}
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight));
} else if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_BF16) {
if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight));
} else if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (float *)residual_out, (const float *)a, (const float *)b, (const float *)weight));
} else if (_info.atype == INFINI_DTYPE_F64) {
CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (double *)residual_out, (const double *)a, (const double *)b, (const double *)weight));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::cpu
#ifndef __ADD_RMS_NORM_CPU_H__
#define __ADD_RMS_NORM_CPU_H__
#include "../add_rms_norm.h"
DESCRIPTOR(cpu)
#endif
#ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
#define __ADD_RMS_NORM_CUDA_KERNEL_H__
#include <cub/block/block_reduce.cuh>
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
__device__ void add_rmsnormBlock(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
// Each block takes care of one head in one batch
// Each thread deals with every block_size element in the row
size_t batch_idx = blockIdx.x / nhead;
size_t head_idx = blockIdx.x % nhead;
auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
auto a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
auto b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
auto w_ptr = w;
Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
// Compute add(a, b) and sum of squares in one pass
Tcompute sum_squared = 0;
for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
Tcompute sum_val = Tcompute(a_ptr[i]) + Tcompute(b_ptr[i]);
residual_out_ptr[i] = Tdata(sum_val); // Store add result
sum_squared += sum_val * sum_val;
}
// Block-reduce sum of squares
using BlockReduce = cub::BlockReduce<Tcompute, BLOCK_SIZE>;
__shared__ typename BlockReduce::TempStorage temp_storage;
sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
__shared__ Tcompute rms;
if (threadIdx.x == 0) {
rms = Tcompute(rsqrtf(sum_squared / Tcompute(dim) + epsilon));
}
__syncthreads();
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
Tcompute sum_val = Tcompute(residual_out_ptr[i]); // Reuse stored value
y_ptr[i] = Tdata(sum_val * Tcompute(w_ptr[i]) * rms);
}
}
#endif
#ifndef __ADD_RMS_NORM_INFO_H__
#define __ADD_RMS_NORM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::add_rms_norm {
class AddRMSNormInfo {
AddRMSNormInfo() = default;
public:
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> residual_out_strides;
std::vector<ptrdiff_t> a_strides;
std::vector<ptrdiff_t> b_strides;
bool has_residual_out;
size_t ndim() const { return shape.size(); }
size_t dim() const { return shape[ndim() - 1]; }
static utils::Result<AddRMSNormInfo> create(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = weight_desc->dtype();
// Check that all input tensors have the same dtype
if (a_desc->dtype() != atype || b_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16 || atype == INFINI_DTYPE_BF16) {
// For half-precision types (FP16/BF16), weights can be the same half-precision type or FP32
if (wtype != atype && wtype != INFINI_DTYPE_F32 && wtype != INFINI_DTYPE_BF16 && wtype != INFINI_DTYPE_F16) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
// For FP32/FP64, activations and weights must be of the same type
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
const size_t y_ndim = y_desc->ndim();
const size_t a_ndim = a_desc->ndim();
const size_t b_ndim = b_desc->ndim();
const size_t w_ndim = weight_desc->ndim();
if (y_ndim != a_ndim || y_ndim != b_ndim || w_ndim != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = 1;
size_t nhead = 1;
size_t dim = 0;
if (y_ndim == 2) {
batch = y_desc->dim(0);
dim = y_desc->dim(1);
if (a_desc->dim(0) != batch || a_desc->dim(1) != dim || b_desc->dim(0) != batch || b_desc->dim(1) != dim || weight_desc->dim(0) != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else if (y_ndim == 3) {
batch = y_desc->dim(0);
nhead = y_desc->dim(1);
dim = y_desc->dim(2);
if (a_desc->dim(0) != batch || a_desc->dim(1) != nhead || a_desc->dim(2) != dim || b_desc->dim(0) != batch || b_desc->dim(1) != nhead || b_desc->dim(2) != dim || weight_desc->dim(0) != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Check contiguity of the last dimension
if (y_desc->stride(y_ndim - 1) != 1 || a_desc->stride(a_ndim - 1) != 1 || b_desc->stride(b_ndim - 1) != 1 || weight_desc->stride(w_ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
// residual_out_desc is required (always needed for fused operator)
if (residual_out_desc == nullptr) {
return INFINI_STATUS_BAD_PARAM;
}
const size_t residual_out_ndim = residual_out_desc->ndim();
if (residual_out_ndim != y_ndim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (residual_out_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// Check shape matches
for (size_t i = 0; i < y_ndim; i++) {
if (residual_out_desc->dim(i) != y_desc->dim(i)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (residual_out_desc->stride(residual_out_ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
AddRMSNormInfo info;
info.wtype = wtype;
info.atype = atype;
info.epsilon = epsilon;
info.shape = y_desc->shape();
info.y_strides = y_desc->strides();
info.a_strides = a_desc->strides();
info.b_strides = b_desc->strides();
info.has_residual_out = true; // Always true now
info.residual_out_strides = residual_out_desc->strides();
return utils::Result<AddRMSNormInfo>(info);
}
};
} // namespace op::add_rms_norm
#endif // __ADD_RMS_NORM_INFO_H__
#ifndef __ADD_RMS_NORM_METAX_CUH__
#define __ADD_RMS_NORM_METAX_CUH__
#include "../add_rms_norm.h"
DESCRIPTOR(metax)
#endif
#include "../../../devices/metax/metax_common.h"
#include "add_rms_norm_metax.cuh"
#include "../../../devices/metax/metax_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
// Kernel function template for add_rms_norm on Metax platform
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_METAX_KERNEL add_rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
add_rmsnormBlock<BLOCK_SIZE, Tcompute>(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, dim, epsilon);
}
namespace op::add_rms_norm::metax {
// Internal opaque structure for Metax device handle
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
// Destructor
Descriptor::~Descriptor() {
delete _opaque;
}
// Create descriptor for add_rms_norm operator
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// Launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead,
const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead,
const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
hcStream_t stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
// Handle different data type combinations following Metax pattern
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
// Main calculation function
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *stream) const {
// Check workspace size
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
// Extract tensor strides and dimensions
auto stride_a_batch = _info.a_strides[0];
auto stride_a_nhead = _info.a_strides[1];
auto stride_b_batch = _info.b_strides[0];
auto stride_b_nhead = _info.b_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto stride_residual_out_batch = _info.residual_out_strides[0];
auto stride_residual_out_nhead = _info.residual_out_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto stream_ = reinterpret_cast<hcStream_t>(stream);
// Launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_2048>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_4096>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::metax
#ifndef __ADD_RMS_NORM_MOORE_H__
#define __ADD_RMS_NORM_MOORE_H__
#include "../add_rms_norm.h"
DESCRIPTOR(moore)
#endif
#include "../../../devices/moore/moore_common.h"
#include "add_rms_norm_moore.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
// Kernel function template for add_rms_norm on Moore platform
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MOORE_KERNEL add_rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
add_rmsnormBlock<BLOCK_SIZE, Tcompute>(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, dim, epsilon);
}
namespace op::add_rms_norm::moore {
// Internal opaque structure for Moore device handle
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
// Destructor
Descriptor::~Descriptor() {
delete _opaque;
}
// Create descriptor for add_rms_norm operator
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// Launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead,
const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead,
const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
// Handle different data type combinations
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __mt_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__mt_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__mt_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
// Main calculation function
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *stream) const {
// Check workspace size
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
// Extract tensor strides and dimensions
auto stride_a_batch = _info.a_strides[0];
auto stride_a_nhead = _info.a_strides[1];
auto stride_b_batch = _info.b_strides[0];
auto stride_b_nhead = _info.b_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto stride_residual_out_batch = _info.residual_out_strides[0];
auto stride_residual_out_nhead = _info.residual_out_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// Launch kernel with appropriate block size based on device capability
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, musa_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::moore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment