Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d200972e
Unverified
Commit
d200972e
authored
Nov 19, 2024
by
Lucas Wilkinson
Committed by
GitHub
Nov 19, 2024
Browse files
[Bugfix] Marlin 2:4 temp fix for large M dim (>256) (#10464)
Signed-off-by:
Lucas Wilkinson
<
lwilkinson@neuralmagic.com
>
parent
d5b68aba
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
4 deletions
+13
-4
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
+11
-4
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+2
-0
No files found.
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
View file @
d200972e
...
@@ -910,13 +910,16 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
...
@@ -910,13 +910,16 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
// than better compute utilization
// than better compute utilization
thread_k
=
128
;
thread_k
=
128
;
thread_m
=
128
;
thread_m
=
128
;
}
else
if
(
prob_n
<=
256
)
{
}
else
{
thread_k
=
64
;
thread_k
=
64
;
thread_m
=
256
;
thread_m
=
256
;
}
else
{
thread_k
=
32
;
thread_m
=
512
;
}
}
// Also had
// if prob_n > 256
// thread_k = 32;
// thread_m = 512;
// but this is broken,
// TODO(Lucas, Alex M): figure out why
}
}
int
thread_k_blocks
=
thread_k
/
32
;
// 2:4 version with m16n8k32 instruction
int
thread_k_blocks
=
thread_k
/
32
;
// 2:4 version with m16n8k32 instruction
...
@@ -1079,6 +1082,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1079,6 +1082,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
// Verify A device and strides
// Verify A device and strides
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
TORCH_CHECK
(
a
.
is_contiguous
(),
"A is not contiguous"
);
TORCH_CHECK
(
a
.
is_contiguous
(),
"A is not contiguous"
);
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat16
,
"A is not float16, currently only float16 is supported"
);
// Verify B device and strides
// Verify B device and strides
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
...
@@ -1091,6 +1096,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1091,6 +1096,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
// Verify scales device and strides
// Verify scales device and strides
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat16
,
"A is not float16, currently only float16 is supported"
);
// Alloc C matrix
// Alloc C matrix
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
...
...
tests/kernels/test_marlin_gemm.py
View file @
d200972e
...
@@ -50,6 +50,8 @@ MNK_FACTORS = [
...
@@ -50,6 +50,8 @@ MNK_FACTORS = [
(
13
,
17
,
67
),
(
13
,
17
,
67
),
(
26
,
37
,
13
),
(
26
,
37
,
13
),
(
67
,
13
,
11
),
(
67
,
13
,
11
),
(
257
,
13
,
11
),
(
658
,
13
,
11
),
]
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment