Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e288df06
"docs/vscode:/vscode.git/clone" did not exist on "de1cb38769e9eb9812fa425c4fdfcf8faa3c420e"
Unverified
Commit
e288df06
authored
May 08, 2024
by
alexm-nm
Committed by
GitHub
May 08, 2024
Browse files
[Bugfix] Fine-tune gptq_marlin configs to be more similar to marlin (#4626)
parent
8b9241be
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
13 deletions
+35
-13
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+35
-13
No files found.
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
e288df06
...
@@ -115,7 +115,8 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) {
...
@@ -115,7 +115,8 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) {
return
res
;
return
res
;
}
}
// Constructs destination register by taking bytes from 2 sources (based on mask)
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template
<
int
start_byte
,
int
mask
>
template
<
int
start_byte
,
int
mask
>
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
uint32_t
res
;
uint32_t
res
;
...
@@ -933,9 +934,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -933,9 +934,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
};
};
// Since multiple threadblocks may process parts of the same column slice, we
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
partitioning
// finally have to globally reduce over the results. As the striped
// minimizes the number of such reductions and our outputs are
usually rather
//
partitioning
minimizes the number of such reductions and our outputs are
// small, we perform this reduction serially in L2 cache.
//
usually rather
small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// maximize L2 cache utilization in this step. To do this, we write out
...
@@ -1275,13 +1276,22 @@ typedef struct {
...
@@ -1275,13 +1276,22 @@ typedef struct {
thread_config_t
tb_cfg
;
thread_config_t
tb_cfg
;
}
exec_config_t
;
}
exec_config_t
;
thread_config_t
thread_configs
[]
=
{
thread_config_t
small_batch_
thread_configs
[]
=
{
// Ordered by priority
// Ordered by priority
// thread_k, thread_n, num_threads
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
// Default (max cache usage)
{
128
,
128
,
256
},
{
64
,
128
,
128
},
// Reduce N, reduce warps
{
64
,
128
,
128
},
{
128
,
64
,
128
},
// Reduce N more, but increase K
{
128
,
64
,
128
},
};
thread_config_t
large_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
{
64
,
128
,
128
},
{
128
,
64
,
128
},
};
};
...
@@ -1397,13 +1407,23 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
...
@@ -1397,13 +1407,23 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int
max_shared_mem
)
{
int
max_shared_mem
)
{
int
max_m_blocks
=
4
;
int
max_m_blocks
=
4
;
while
(
max_m_blocks
>
0
)
{
while
(
max_m_blocks
>
0
)
{
for
(
auto
th_config
:
thread_configs
)
{
if
(
prob_m
<=
16
)
{
for
(
auto
th_config
:
small_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
}
}
else
{
for
(
auto
th_config
:
large_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
printf
(
"WARNING: Marlin kernel is reducing max_m_blocks due to small SM "
printf
(
"WARNING: Marlin kernel is reducing max_m_blocks due to small SM "
"GPU cache. This may "
"GPU cache. This may "
...
@@ -1574,10 +1594,12 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
...
@@ -1574,10 +1594,12 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
}
}
CALL_IF
(
4
,
32
,
2
,
256
)
CALL_IF
(
4
,
32
,
2
,
256
)
CALL_IF
(
4
,
16
,
4
,
256
)
CALL_IF
(
4
,
16
,
4
,
256
)
CALL_IF
(
4
,
8
,
8
,
256
)
CALL_IF
(
4
,
8
,
4
,
128
)
CALL_IF
(
4
,
8
,
4
,
128
)
CALL_IF
(
4
,
4
,
8
,
128
)
CALL_IF
(
4
,
4
,
8
,
128
)
CALL_IF
(
8
,
32
,
2
,
256
)
CALL_IF
(
8
,
32
,
2
,
256
)
CALL_IF
(
8
,
16
,
4
,
256
)
CALL_IF
(
8
,
16
,
4
,
256
)
CALL_IF
(
8
,
8
,
8
,
256
)
CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
4
,
8
,
128
)
CALL_IF
(
8
,
4
,
8
,
128
)
else
{
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment