Commit 37ee5700 authored by zhuwenwen's avatar zhuwenwen
Browse files

解决数据量过大,导致int32索引越界的问题

parent 8cd246bd
......@@ -30,11 +30,11 @@ __global__ void act_and_mul_kernel_vectorize1(
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int token_idx = blockIdx.x;
const int64_t token_idx= blockIdx.x;
int idx = threadIdx.x * VEC;
if (idx < d) {
const int x_index = token_idx * 2 * d + idx;
const int y_index = token_idx * d + idx;
const int64_t x_index = token_idx * 2 * d + idx;
const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
......@@ -45,9 +45,7 @@ __global__ void act_and_mul_kernel_vectorize1(
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
const scalar_t t_x1 = VLLM_LDG(&r_x1[i]);
const scalar_t t_x2 = VLLM_LDG(&r_x2[i]);
r_y[i] = ACT_FN(t_x1) * t_x2;
r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
}
*y = *(VecType*)r_y;
}
......@@ -59,11 +57,11 @@ __global__ void act_and_mul_kernel_vectorize2(
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int token_idx = blockIdx.x;
const int64_t token_idx = blockIdx.x;
int idx = threadIdx.x * VEC;
for (; idx < d; idx += blockDim.x * VEC) {
const int x_index = token_idx * 2 * d + idx;
const int y_index = token_idx * d + idx;
const int64_t x_index = token_idx * 2 * d + idx;
const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
......@@ -74,9 +72,7 @@ __global__ void act_and_mul_kernel_vectorize2(
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
const scalar_t t_x1 = VLLM_LDG(&r_x1[i]);
const scalar_t t_x2 = VLLM_LDG(&r_x2[i]);
r_y[i] = ACT_FN(t_x1) * t_x2;
r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
}
*y = *(VecType*)r_y;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment