Unverified Commit 3ecfdc37 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][GPTQ][Bugfix] Fix GPTQ GEMM kernel output zeroing race condition (#30719)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 45c1ca1c
...@@ -233,11 +233,6 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( ...@@ -233,11 +233,6 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
// Zero output // Zero output
if (n >= size_n) return; if (n >= size_n) return;
if (blockIdx.z == 0) {
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads(); __syncthreads();
// Find initial group // Find initial group
...@@ -372,11 +367,6 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( ...@@ -372,11 +367,6 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
// Zero output // Zero output
if (n >= size_n) return; if (n >= size_n) return;
if (blockIdx.z == 0) {
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads(); __syncthreads();
// Find initial group // Find initial group
...@@ -494,11 +484,6 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( ...@@ -494,11 +484,6 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
// Zero output // Zero output
if (n >= size_n) return; if (n >= size_n) return;
if (blockIdx.z == 0) {
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads(); __syncthreads();
// Find initial group // Find initial group
...@@ -623,11 +608,6 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( ...@@ -623,11 +608,6 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
// Zero output // Zero output
if (n >= size_n) return; if (n >= size_n) return;
if (blockIdx.z == 0) {
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads(); __syncthreads();
// Find initial group // Find initial group
...@@ -1224,9 +1204,6 @@ __global__ void gemm_half_q_half_alt_4bit_kernel( ...@@ -1224,9 +1204,6 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
__halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4)); __halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4));
} }
if (blockIdx.z == 0) {
for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0);
}
__syncthreads(); __syncthreads();
int i = width * h + w; int i = width * h + w;
...@@ -1319,9 +1296,6 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( ...@@ -1319,9 +1296,6 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
} }
} }
if (blockIdx.z == 0) {
for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0);
}
__syncthreads(); __syncthreads();
int i = width * h + w; int i = width * h + w;
...@@ -1857,7 +1831,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -1857,7 +1831,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
bool use_exllama, bool use_v2_format, int64_t bit) { bool use_exllama, bool use_v2_format, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); at::Tensor c = torch::zeros({a.size(0), b_q_weight.size(1)}, options);
at::Tensor temp_dq = torch::empty( at::Tensor temp_dq = torch::empty(
{b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); {b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment