Commit 305b5a09 authored by zhuwenwen's avatar zhuwenwen
Browse files

增加input tensor size>2^31支持

parent 1a493a24
......@@ -342,7 +342,7 @@ __global__ void fused_add_rms_kernel_opt(scalar_t* input,scalar_t* residual,scal
scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
int64_t idx = i * tcol + j;
idx*=Vec;
if (j < tcol) {
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
......@@ -381,7 +381,7 @@ __global__ void fused_rms_kernel_opt(scalar_t* input,scalar_t* output,scalar_t*
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
int64_t idx = i * tcol + j;
idx*=Vec;
if (j < tcol) {
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment