Unverified Commit e73b7dfd authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Bugfix] fix `an illegal memory access was encountered` of marlin kernel + act_order (#18245)

parent 7fdfa015
......@@ -1767,6 +1767,8 @@ __global__ void Marlin(
if constexpr (has_act_order) {
slice_k_start += tb_k * stages;
if (slice_k_start < prob_k) {
slice_k_start_shared_fetch += tb_k * stages;
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
......@@ -1780,6 +1782,7 @@ __global__ void Marlin(
__syncthreads();
}
}
}
if (slice_iters == 0) {
break;
}
......
......@@ -1588,6 +1588,8 @@ __global__ void Marlin(
if constexpr (has_act_order) {
slice_k_start += tb_k * stages;
if (slice_k_start < prob_k) {
slice_k_start_shared_fetch += tb_k * stages;
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
......@@ -1596,10 +1598,12 @@ __global__ void Marlin(
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);
fetch_act_order_scales_to_shared(false, first_group_id,
last_group_id);
__syncthreads();
}
}
}
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
......
......@@ -2,7 +2,7 @@ gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
#gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
gptq, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq, TheBloke/Llama-2-7B-GPTQ, main
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment