Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
61a97c32
Unverified
Commit
61a97c32
authored
Jul 29, 2024
by
Tyler Michael Smith
Committed by
GitHub
Jul 30, 2024
Browse files
[Kernel] Fix marlin divide-by-zero warnings (#6904)
parent
4fbf4aa1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
39 deletions
+58
-39
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+32
-29
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
+13
-5
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
+13
-5
No files found.
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
61a97c32
...
@@ -1128,44 +1128,47 @@ __global__ void Marlin(
...
@@ -1128,44 +1128,47 @@ __global__ void Marlin(
};
};
auto
fetch_zp_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
auto
fetch_zp_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
if
constexpr
(
!
has_zp
)
{
if
constexpr
(
has_zp
)
{
return
;
// This code does not handle group_blocks == 0,
}
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert
(
group_blocks
!=
0
);
int
pipe
=
full_pipe
%
stages
;
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
==
-
1
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp
))[
zp_sh_rd
+
i
];
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp
))[
zp_sh_rd
+
i
];
}
}
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
else
{
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
}
}
};
};
...
...
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
View file @
61a97c32
...
@@ -452,10 +452,15 @@ __global__ void Marlin(
...
@@ -452,10 +452,15 @@ __global__ void Marlin(
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
}
// Only fetch scales if this tile starts a new group
// Only fetch scales if this tile starts a new group
if
(
group_blocks
!=
-
1
&&
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
constexpr
(
group_blocks
!=
-
1
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
// This assumes group_blocks >= thread_k_blocks
if
(
s_sh_wr_pred
)
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
s
[
s_gl_rd
]);
// and would need to be modified to support smaller groups.
s_gl_rd
+=
s_gl_rd_delta
;
static_assert
(
group_blocks
>=
thread_k_blocks
);
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
(
s_sh_wr_pred
)
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
s
[
s_gl_rd
]);
s_gl_rd
+=
s_gl_rd_delta
;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// Insert a fence even when we are winding down the pipeline to ensure that
...
@@ -480,7 +485,10 @@ __global__ void Marlin(
...
@@ -480,7 +485,10 @@ __global__ void Marlin(
// however, this does not seem to be a significant bottleneck, while some
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
// the compiler and correspondingly a noticeable drop in performance.
if
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
!=
-
1
)
{
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert
(
group_blocks
>=
thread_k_blocks
);
int4
*
sh_s_stage
=
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
...
...
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
View file @
61a97c32
...
@@ -404,10 +404,15 @@ __global__ void Marlin_24(
...
@@ -404,10 +404,15 @@ __global__ void Marlin_24(
meta_ptr
[
i
]
+=
m_gl_rd_delta_o
;
meta_ptr
[
i
]
+=
m_gl_rd_delta_o
;
}
}
// Only fetch scales if this tile starts a new group
// Only fetch scales if this tile starts a new group
if
(
group_blocks
!=
-
1
&&
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
constexpr
(
group_blocks
!=
-
1
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
// This assumes group_blocks >= thread_k_blocks
if
(
s_sh_wr_pred
)
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
s
[
s_gl_rd
]);
// and would need to be modified to support smaller groups.
s_gl_rd
+=
s_gl_rd_delta
;
static_assert
(
group_blocks
>=
thread_k_blocks
);
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
(
s_sh_wr_pred
)
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
s
[
s_gl_rd
]);
s_gl_rd
+=
s_gl_rd_delta
;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// Insert a fence even when we are winding down the pipeline to ensure that
...
@@ -432,7 +437,10 @@ __global__ void Marlin_24(
...
@@ -432,7 +437,10 @@ __global__ void Marlin_24(
// however, this does not seem to be a significant bottleneck, while some
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
// the compiler and correspondingly a noticeable drop in performance.
if
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
!=
-
1
)
{
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert
(
group_blocks
>=
thread_k_blocks
);
int4
*
sh_s_stage
=
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment