Unverified Commit 3eb4a800 authored by AniZpZ's avatar AniZpZ Committed by GitHub
Browse files

Fix AWQ Dequant and Weight Loading of deepseek v2 (#6842)

parent e7261315
...@@ -2137,8 +2137,14 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2137,8 +2137,14 @@ class DeepseekV2ForCausalLM(nn.Module):
): ):
q_a_proj_weight = cached_a_proj[q_a_proj_name] q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name] kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
cat_dim = 0
if (
self.quant_config.get_name() == "awq"
or self.quant_config.get_name() == "moe_wna16"
):
cat_dim = 1
fused_weight = torch.cat( fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0 [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
) )
param_name = ( param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa") name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
......
...@@ -130,10 +130,12 @@ __global__ void __launch_bounds__(256) dequantize_weights( ...@@ -130,10 +130,12 @@ __global__ void __launch_bounds__(256) dequantize_weights(
int* __restrict__ qzeros, int* __restrict__ qzeros,
OutputT* __restrict__ output, OutputT* __restrict__ output,
int group_size, int group_size,
int qweight_cols) { int qweight_cols,
int qweight_rows) {
#if CUDA_VERSION >= 12000 #if CUDA_VERSION >= 12000
int col = blockIdx.x * blockDim.x + threadIdx.x; int col = blockIdx.x * blockDim.x + threadIdx.x;
int row = blockIdx.y * blockDim.y + threadIdx.y; int row = blockIdx.y * blockDim.y + threadIdx.y;
if (col >= qweight_cols || row >= qweight_rows) return;
int group_idx = row / group_size; int group_idx = row / group_size;
int scale_offset = 8 * col + group_idx * qweight_cols * 8; int scale_offset = 8 * col + group_idx * qweight_cols * 8;
...@@ -188,8 +190,8 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch: ...@@ -188,8 +190,8 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
int x_num_threads = 16; int x_num_threads = 16;
int y_num_threads = 16; int y_num_threads = 16;
int x_blocks = qweight_cols / x_num_threads; int x_blocks = (qweight_cols + x_num_threads - 1) / x_num_threads;
int y_blocks = qweight_rows / y_num_threads; int y_blocks = (qweight_rows + y_num_threads - 1) / y_num_threads;
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight)); const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
...@@ -206,13 +208,13 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch: ...@@ -206,13 +208,13 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
if (scales.scalar_type() == at::ScalarType::Half) { if (scales.scalar_type() == at::ScalarType::Half) {
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>()); auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>()); auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
dequantize_weights<half> dequantize_weights<half><<<num_blocks, threads_per_block, 0, stream>>>(
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols); _qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
} else { } else {
auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>()); auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>());
auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>()); auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>());
dequantize_weights<__nv_bfloat16> dequantize_weights<__nv_bfloat16><<<num_blocks, threads_per_block, 0, stream>>>(
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols); _qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
} }
return output; return output;
......
...@@ -67,8 +67,8 @@ def sglang_awq_dequantize( ...@@ -67,8 +67,8 @@ def sglang_awq_dequantize(
"qweight_row,qweight_col,is_bf16_act", "qweight_row,qweight_col,is_bf16_act",
list( list(
itertools.product( itertools.product(
[3584, 18944, 128, 256, 512, 1024], [3584, 18944, 128, 256, 512, 1024, 1536],
[448, 576, 4736, 16, 32, 64, 128], [448, 576, 4736, 16, 32, 64, 128, 72],
[True, False], [True, False],
) )
), ),
...@@ -77,7 +77,6 @@ def test_awq_dequant_compare_implementations( ...@@ -77,7 +77,6 @@ def test_awq_dequant_compare_implementations(
qweight_row: int, qweight_col: int, is_bf16_act: bool qweight_row: int, qweight_col: int, is_bf16_act: bool
): ):
device = torch.device("cuda") device = torch.device("cuda")
qweight = torch.randint( qweight = torch.randint(
0, 0,
torch.iinfo(torch.int32).max, torch.iinfo(torch.int32).max,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment