Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3c068c63
Unverified
Commit
3c068c63
authored
Sep 17, 2025
by
czhu-cohere
Committed by
GitHub
Sep 17, 2025
Browse files
[Kernel] Faster pre-processing time for W4A8 (#23972)
Signed-off-by:
czhu-cohere
<
conway.zhu@cohere.com
>
parent
f20c3b09
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
71 additions
and
1 deletion
+71
-1
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
+71
-1
No files found.
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
View file @
3c068c63
...
@@ -25,6 +25,8 @@
...
@@ -25,6 +25,8 @@
#include "cutlass_extensions/common.hpp"
#include "cutlass_extensions/common.hpp"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include <cuda_runtime.h>
namespace
vllm
::
cutlass_w4a8
{
namespace
vllm
::
cutlass_w4a8
{
using
namespace
cute
;
using
namespace
cute
;
...
@@ -393,6 +395,71 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
...
@@ -393,6 +395,71 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
return
packed_scales
;
return
packed_scales
;
}
}
/*
GPU-accelerated implementation of cutlass::unified_encode_int4b.
Constructs a lookup table in constant memory to map 8 bits
(two 4-bit values) at a time. Assumes memory is contiguous
and pointers are 16-byte aligned.
*/
__constant__
uint8_t
kNibbleLUT
[
256
];
__global__
void
unified_encode_int4b_device
(
const
uint8_t
*
in
,
uint8_t
*
out
,
size_t
nbytes
)
{
constexpr
size_t
V
=
sizeof
(
uint4
);
// 16 bytes
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
nthreads
=
size_t
(
gridDim
.
x
)
*
blockDim
.
x
;
const
size_t
nvec
=
nbytes
/
V
;
// 1-D grid-stride loop over 16-byte chunks
for
(
size_t
vec
=
tid
;
vec
<
nvec
;
vec
+=
nthreads
)
{
uint4
v
=
reinterpret_cast
<
const
uint4
*>
(
in
)[
vec
];
uint8_t
*
b
=
reinterpret_cast
<
uint8_t
*>
(
&
v
);
#pragma unroll
for
(
int
i
=
0
;
i
<
int
(
V
);
++
i
)
b
[
i
]
=
kNibbleLUT
[
b
[
i
]];
reinterpret_cast
<
uint4
*>
(
out
)[
vec
]
=
v
;
}
}
static
bool
upload_lut
()
{
std
::
array
<
uint8_t
,
256
>
lut
{};
auto
map_nib
=
[](
uint8_t
v
)
->
uint8_t
{
// 1..7 -> (8 - v); keep 0 and 8..15
return
(
v
==
0
||
(
v
&
0x8
))
?
v
:
uint8_t
(
8
-
v
);
};
for
(
int
b
=
0
;
b
<
256
;
++
b
)
{
uint8_t
lo
=
b
&
0xF
;
uint8_t
hi
=
(
b
>>
4
)
&
0xF
;
lut
[
b
]
=
uint8_t
((
map_nib
(
hi
)
<<
4
)
|
map_nib
(
lo
));
}
cudaError_t
e
=
cudaMemcpyToSymbol
(
kNibbleLUT
,
lut
.
data
(),
lut
.
size
(),
/*offset=*/
0
,
cudaMemcpyHostToDevice
);
return
(
e
==
cudaSuccess
);
}
static
bool
unified_encode_int4b
(
cutlass
::
int4b_t
const
*
in
,
cutlass
::
int4b_t
*
out
,
size_t
num_int4_elems
)
{
// Build/upload LUT
if
(
!
upload_lut
())
return
false
;
static_assert
(
sizeof
(
typename
cutlass
::
int4b_t
::
Storage
)
==
1
,
"int4 storage must be 1 byte"
);
const
size_t
nbytes
=
num_int4_elems
>>
1
;
auto
*
in_bytes
=
reinterpret_cast
<
uint8_t
const
*>
(
in
);
auto
*
out_bytes
=
reinterpret_cast
<
uint8_t
*>
(
out
);
// kernel launch params
constexpr
int
block
=
256
;
const
size_t
nvec
=
nbytes
/
sizeof
(
uint4
);
// # of 16B vectors
int
grid
=
int
((
nvec
+
block
-
1
)
/
block
);
if
(
grid
==
0
)
grid
=
1
;
// ensure we still cover the tail in the kernel
unified_encode_int4b_device
<<<
grid
,
block
>>>
(
in_bytes
,
out_bytes
,
nbytes
);
cudaError_t
err
=
cudaGetLastError
();
return
(
err
==
cudaSuccess
);
}
torch
::
Tensor
encode_and_reorder_int4b
(
torch
::
Tensor
const
&
B
)
{
torch
::
Tensor
encode_and_reorder_int4b
(
torch
::
Tensor
const
&
B
)
{
TORCH_CHECK
(
B
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
B
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
B
.
dim
()
==
2
);
TORCH_CHECK
(
B
.
dim
()
==
2
);
...
@@ -401,6 +468,7 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
...
@@ -401,6 +468,7 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
int
k
=
B
.
size
(
0
)
*
PackFactor
;
// logical k
int
k
=
B
.
size
(
0
)
*
PackFactor
;
// logical k
int
n
=
B
.
size
(
1
);
int
n
=
B
.
size
(
1
);
TORCH_CHECK
((
n
*
k
)
%
32
==
0
,
"need multiples of 32 int4s for 16B chunks"
);
auto
B_ptr
=
static_cast
<
QuantType
const
*>
(
B
.
const_data_ptr
());
auto
B_ptr
=
static_cast
<
QuantType
const
*>
(
B
.
const_data_ptr
());
auto
B_packed_ptr
=
static_cast
<
QuantType
*>
(
B_packed
.
data_ptr
());
auto
B_packed_ptr
=
static_cast
<
QuantType
*>
(
B_packed
.
data_ptr
());
...
@@ -409,7 +477,9 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
...
@@ -409,7 +477,9 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
LayoutB_Reordered
layout_B_reordered
=
LayoutB_Reordered
layout_B_reordered
=
cute
::
tile_to_shape
(
LayoutAtomQuant
{},
shape_B
);
cute
::
tile_to_shape
(
LayoutAtomQuant
{},
shape_B
);
cutlass
::
unified_encode_int4b
(
B_ptr
,
B_packed_ptr
,
n
*
k
);
bool
ok
=
vllm
::
cutlass_w4a8
::
unified_encode_int4b
(
B_ptr
,
B_packed_ptr
,
n
*
k
);
TORCH_CHECK
(
ok
,
"unified_encode_int4b failed"
);
cutlass
::
reorder_tensor
(
B_packed_ptr
,
layout_B
,
layout_B_reordered
);
cutlass
::
reorder_tensor
(
B_packed_ptr
,
layout_B
,
layout_B_reordered
);
return
B_packed
;
return
B_packed
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment