Commit dfdc05ae authored by zhangshao's avatar zhangshao
Browse files

解决数据量过大,导致int32索引越界的问题

parent 066e63c2
...@@ -30,11 +30,11 @@ __global__ void act_and_mul_kernel_vectorize1( ...@@ -30,11 +30,11 @@ __global__ void act_and_mul_kernel_vectorize1(
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>; using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int token_idx = blockIdx.x; const int64_t token_idx= blockIdx.x;
int idx = threadIdx.x * VEC; int idx = threadIdx.x * VEC;
if (idx < d) { if (idx < d) {
const int x_index = token_idx * 2 * d + idx; const int64_t x_index = token_idx * 2 * d + idx;
const int y_index = token_idx * d + idx; const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index); VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d); VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index); VecType* y = (VecType*)(out + y_index);
...@@ -45,9 +45,7 @@ __global__ void act_and_mul_kernel_vectorize1( ...@@ -45,9 +45,7 @@ __global__ void act_and_mul_kernel_vectorize1(
*(VecType*)r_x2 = *x2; *(VecType*)r_x2 = *x2;
#pragma unroll #pragma unroll
for (int i = 0; i < VEC; i++) { for (int i = 0; i < VEC; i++) {
const scalar_t t_x1 = VLLM_LDG(&r_x1[i]); r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
const scalar_t t_x2 = VLLM_LDG(&r_x2[i]);
r_y[i] = ACT_FN(t_x1) * t_x2;
} }
*y = *(VecType*)r_y; *y = *(VecType*)r_y;
} }
...@@ -59,11 +57,11 @@ __global__ void act_and_mul_kernel_vectorize2( ...@@ -59,11 +57,11 @@ __global__ void act_and_mul_kernel_vectorize2(
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>; using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
int idx = threadIdx.x * VEC; int idx = threadIdx.x * VEC;
for (; idx < d; idx += blockDim.x * VEC) { for (; idx < d; idx += blockDim.x * VEC) {
const int x_index = token_idx * 2 * d + idx; const int64_t x_index = token_idx * 2 * d + idx;
const int y_index = token_idx * d + idx; const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index); VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d); VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index); VecType* y = (VecType*)(out + y_index);
...@@ -74,9 +72,7 @@ __global__ void act_and_mul_kernel_vectorize2( ...@@ -74,9 +72,7 @@ __global__ void act_and_mul_kernel_vectorize2(
*(VecType*)r_x2 = *x2; *(VecType*)r_x2 = *x2;
#pragma unroll #pragma unroll
for (int i = 0; i < VEC; i++) { for (int i = 0; i < VEC; i++) {
const scalar_t t_x1 = VLLM_LDG(&r_x1[i]); r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
const scalar_t t_x2 = VLLM_LDG(&r_x2[i]);
r_y[i] = ACT_FN(t_x1) * t_x2;
} }
*y = *(VecType*)r_y; *y = *(VecType*)r_y;
} }
...@@ -229,4 +225,4 @@ void gelu_fast(torch::Tensor& out, // [..., d] ...@@ -229,4 +225,4 @@ void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d] torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment