Unverified Commit d200972e authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Bugfix] Marlin 2:4 temp fix for large M dim (>256) (#10464)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
parent d5b68aba
...@@ -910,13 +910,16 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, ...@@ -910,13 +910,16 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
// than better compute utilization // than better compute utilization
thread_k = 128; thread_k = 128;
thread_m = 128; thread_m = 128;
} else if (prob_n <= 256) { } else {
thread_k = 64; thread_k = 64;
thread_m = 256; thread_m = 256;
} else {
thread_k = 32;
thread_m = 512;
} }
// Also had
// if prob_n > 256
// thread_k = 32;
// thread_m = 512;
// but this is broken,
// TODO(Lucas, Alex M): figure out why
} }
int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
...@@ -1079,6 +1082,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1079,6 +1082,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
// Verify A device and strides // Verify A device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
TORCH_CHECK(a.dtype() == torch::kFloat16,
"A is not float16, currently only float16 is supported");
// Verify B device and strides // Verify B device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
...@@ -1091,6 +1096,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1091,6 +1096,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
// Verify scales device and strides // Verify scales device and strides
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
TORCH_CHECK(b_scales.dtype() == torch::kFloat16,
"A is not float16, currently only float16 is supported");
// Alloc C matrix // Alloc C matrix
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
......
...@@ -50,6 +50,8 @@ MNK_FACTORS = [ ...@@ -50,6 +50,8 @@ MNK_FACTORS = [
(13, 17, 67), (13, 17, 67),
(26, 37, 13), (26, 37, 13),
(67, 13, 11), (67, 13, 11),
(257, 13, 11),
(658, 13, 11),
] ]
DTYPES = [torch.float16, torch.bfloat16] DTYPES = [torch.float16, torch.bfloat16]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment