Unverified Commit c4b3a157 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #143 from YdrMaster/main

issue/121/fix: 修改 cuda 和昇腾的 rms norm
parents fb83addc 8848a764
......@@ -32,9 +32,10 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
RMSNormInfo info;
auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon));
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = 0;
aclOpExecutor *executor = nullptr;
......@@ -65,14 +66,22 @@ infiniStatus_t Descriptor::create(
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(epsilon), ty, trstd, &workspace_size, &executor));
aclSetAclOpExecutorRepeatable(executor);
size_t allWorkspaceSize = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType);
*desc_ptr = new Descriptor(new Opaque{executor, y, x, w, rstd, workspace_size}, info, allWorkspaceSize, handle_ascend->device, handle_ascend->device_id);
auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
size_t all_workspace_size = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType);
*desc_ptr = new Descriptor(
new Opaque{executor, y, x, w, rstd, workspace_size},
std::move(info),
all_workspace_size,
handle_ascend->device, handle_ascend->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, void *y,
const void *x, const void *w, void *stream) {
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
if (workspace_size < workspaceSize()) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
......
#include "../../../devices/cuda/cuda_common.cuh"
#include "rms_norm_cuda.cuh"
#include "rms_norm_kernel.cuh"
#include <memory>
#include <stdint.h>
namespace op::rms_norm::cuda {
......@@ -21,8 +19,9 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
RMSNormInfo info;
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon));
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
......@@ -31,7 +30,9 @@ infiniStatus_t Descriptor::create(
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::cuda::Handle *>(handle)->internal()},
info, 0, handle->device, handle->device_id);
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
......@@ -70,8 +71,11 @@ infiniStatus_t launchKernel(
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y, const void *x, const void *w, void *stream) {
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment