Commit 2221f4ce authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Fixed potential memory leak.

parent 490153b2
...@@ -3561,7 +3561,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc ...@@ -3561,7 +3561,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
if(row_B < M) if(row_B < M)
{ {
if((inner_idx_halved + num_values_8bit) < K) if((inner_idx_halved + num_values_8bit) < (K/2))
{ {
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
} }
...@@ -3569,15 +3569,21 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc ...@@ -3569,15 +3569,21 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
{ {
#pragma unroll #pragma unroll
for(int j = 0; j < (num_values_8bit); j++) for(int j = 0; j < (num_values_8bit); j++)
if((inner_idx_halved) + j < K) if((inner_idx_halved) + j < (K/2))
local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; local_B_4bit[j] = B[offset_B+inner_idx_halved + j];
else else
local_B_4bit[j] = 0b01110111; local_B_4bit[j] = 0b01110111;
} }
} }
else
{
#pragma unroll
for(int j = 0; j < (num_values_8bit); j++)
local_B_4bit[j] = 0b01110111;
}
#pragma unroll #pragma unroll
for(int k = 0; k < num_values_4bit; k++) for(int k = 0; k < num_values_8bit; k++)
{ {
#if __CUDA_ARCH__ >= 800 #if __CUDA_ARCH__ >= 800
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
...@@ -3625,7 +3631,6 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc ...@@ -3625,7 +3631,6 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
local_C += ((float)local_A[k]*(float)local_B[k]); local_C += ((float)local_A[k]*(float)local_B[k]);
#endif #endif
} }
} }
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment