Commit 305b5a09 authored by zhuwenwen's avatar zhuwenwen
Browse files

增加input tensor size>2^31支持

parent 1a493a24
...@@ -342,7 +342,7 @@ __global__ void fused_add_rms_kernel_opt(scalar_t* input,scalar_t* residual,scal ...@@ -342,7 +342,7 @@ __global__ void fused_add_rms_kernel_opt(scalar_t* input,scalar_t* residual,scal
scalar_t intput_vec[Vec]; scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec]; scalar_t residual_vec[Vec];
T_ACC trstd; T_ACC trstd;
int idx = i * tcol + j; int64_t idx = i * tcol + j;
idx*=Vec; idx*=Vec;
if (j < tcol) { if (j < tcol) {
*(LoadT*)intput_vec = *(LoadT*)(input+idx); *(LoadT*)intput_vec = *(LoadT*)(input+idx);
...@@ -381,7 +381,7 @@ __global__ void fused_rms_kernel_opt(scalar_t* input,scalar_t* output,scalar_t* ...@@ -381,7 +381,7 @@ __global__ void fused_rms_kernel_opt(scalar_t* input,scalar_t* output,scalar_t*
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>; using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec]; scalar_t intput_vec[Vec];
T_ACC trstd; T_ACC trstd;
int idx = i * tcol + j; int64_t idx = i * tcol + j;
idx*=Vec; idx*=Vec;
if (j < tcol) { if (j < tcol) {
*(LoadT*)intput_vec = *(LoadT*)(input+idx); *(LoadT*)intput_vec = *(LoadT*)(input+idx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment