Commit 8848a764 authored by YdrMaster's avatar YdrMaster
Browse files

issue/121/fix: 修改 cuda 和昇腾的 rms norm


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent fd5d90c9
...@@ -32,9 +32,10 @@ infiniStatus_t Descriptor::create( ...@@ -32,9 +32,10 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc, infiniopTensorDescriptor_t w_desc,
float epsilon) { float epsilon) {
RMSNormInfo info;
auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle); auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon)); CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = 0; size_t workspace_size = 0;
aclOpExecutor *executor = nullptr; aclOpExecutor *executor = nullptr;
...@@ -65,14 +66,22 @@ infiniStatus_t Descriptor::create( ...@@ -65,14 +66,22 @@ infiniStatus_t Descriptor::create(
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(epsilon), ty, trstd, &workspace_size, &executor)); CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(epsilon), ty, trstd, &workspace_size, &executor));
aclSetAclOpExecutorRepeatable(executor); aclSetAclOpExecutorRepeatable(executor);
size_t allWorkspaceSize = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType); auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
*desc_ptr = new Descriptor(new Opaque{executor, y, x, w, rstd, workspace_size}, info, allWorkspaceSize, handle_ascend->device, handle_ascend->device_id); size_t all_workspace_size = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType);
*desc_ptr = new Descriptor(
new Opaque{executor, y, x, w, rstd, workspace_size},
std::move(info),
all_workspace_size,
handle_ascend->device, handle_ascend->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, void *y, infiniStatus_t Descriptor::calculate(
const void *x, const void *w, void *stream) { void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
if (workspace_size < workspaceSize()) { if (workspace_size < workspaceSize()) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
} }
......
#include "../../../devices/cuda/cuda_common.cuh" #include "../../../devices/cuda/cuda_common.cuh"
#include "rms_norm_cuda.cuh" #include "rms_norm_cuda.cuh"
#include "rms_norm_kernel.cuh" #include "rms_norm_kernel.cuh"
#include <memory>
#include <stdint.h>
namespace op::rms_norm::cuda { namespace op::rms_norm::cuda {
...@@ -21,8 +19,9 @@ infiniStatus_t Descriptor::create( ...@@ -21,8 +19,9 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc, infiniopTensorDescriptor_t w_desc,
float epsilon) { float epsilon) {
RMSNormInfo info; auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon)); CHECK_RESULT(result);
auto info = result.take();
// only support contiguous last dimension // only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) { if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
...@@ -31,7 +30,9 @@ infiniStatus_t Descriptor::create( ...@@ -31,7 +30,9 @@ infiniStatus_t Descriptor::create(
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::cuda::Handle *>(handle)->internal()}, new Opaque{reinterpret_cast<device::cuda::Handle *>(handle)->internal()},
info, 0, handle->device, handle->device_id); std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -70,8 +71,11 @@ infiniStatus_t launchKernel( ...@@ -70,8 +71,11 @@ infiniStatus_t launchKernel(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, infiniStatus_t Descriptor::calculate(
void *y, const void *x, const void *w, void *stream) { void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
if (workspace_size < _workspace_size) { if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment